Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
dataset.py 3.71 KiB
"""
The shape of data is (batch_size, seq_len) for transformer models.
This is different from the convention we used for RNNs where the shape was (seq_len, batch_size).
"""

import torch
from torch.utils.data import Dataset
from typing import List, Tuple
from pathlib import Path
import json
from .tokenizer import Tokenizer


def read_data(file_paths: List[Path]) -> tuple[List[str], List[str]]:
    src_data = []
    tgt_data = []
    for file_path in file_paths:
        with open(file_path, "r", encoding="utf-8") as file:
            for i, line in enumerate(file):
                json_obj = json.loads(line.strip())
                src_data.append(json_obj["de"])
                tgt_data.append(json_obj["en"])

    return src_data, tgt_data


class Multi30kDataset(Dataset):
    def __init__(
        self, file_paths: List[Path], src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer
    ):
        self.src_data, self.tgt_data = read_data(file_paths)
        assert len(self.src_data) == len(
            self.tgt_data
        ), "Number of source and target examples do not match"

        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

        # special tokens in source language
        self.special_tokens = [
            self.src_tokenizer.pad_id,
            self.src_tokenizer.start_id,
            self.src_tokenizer.end_id,
            self.src_tokenizer.unk_id,
        ]

    def __len__(self) -> int:
        return len(self.src_data)

    def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]:
        src_text = self.src_data[idx]
        tgt_text = self.tgt_data[idx]

        src_tokens = self.src_tokenizer.encode(src_text)
        tgt_tokens = self.tgt_tokenizer.encode(tgt_text)

        # Add start and end tokens to target
        tgt_tokens = (
            [self.tgt_tokenizer.start_id] + tgt_tokens + [self.tgt_tokenizer.end_id]
        )
        return src_tokens, tgt_tokens


def to_padded_tensor(
    sequences: List[List[int]], pad_id: int, batch_first=True
) -> torch.Tensor:
    """Convert a list of sequences to a padded tensor.

    Args:
        sequences: List of token sequences.
        pad_id: Token ID to use for padding.
        batch_first: If True, the output tensor shape will be (batch_size, seq_len).
                     If False, the shape will be (seq_len, batch_size).

    Returns:
        torch.Tensor: Padded tensor of the specified shape.
    """
    # Find maximum sequence length in the batch
    max_len = max(len(seq) for seq in sequences)

    # Pad each sequence to max_len
    padded_sequences = []
    for seq in sequences:
        # Calculate padding length
        pad_len = max_len - len(seq)
        # Add padding to the right
        padded_seq = seq + [pad_id] * pad_len
        padded_sequences.append(padded_seq)

    # Convert to tensor of shape (batch_size, seq_len)
    padded_tensor = torch.tensor(padded_sequences, dtype=torch.long)

    if not batch_first:
        # Transpose to shape (seq_len, batch_size)
        padded_tensor = padded_tensor.T

    return padded_tensor


class BatchCollator:
    """Handles the collation of batches with padding."""

    def __init__(self, src_pad_id: int, tgt_pad_id: int, batch_first: bool):
        self.src_pad_id = src_pad_id
        self.tgt_pad_id = tgt_pad_id
        self.batch_first = batch_first

    def __call__(
        self, batch: List[Tuple[List[int], List[int]]]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Convert a batch of sequences into padded tensors."""
        src_sequences, tgt_sequences = zip(*batch)
        return (
            to_padded_tensor(list(src_sequences), self.src_pad_id, self.batch_first),
            to_padded_tensor(list(tgt_sequences), self.tgt_pad_id, self.batch_first),
        )