pytorch 数据加载器/采样器/整理器,用于根据样品内容(序列长度)创建批次

jm81lzqq  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(101)

我正在使用数据集和数据加载器,整理函数和采样器将别人的代码转换成一个更整洁的 Torch 管道。虽然我以前做过这样的工作,但我不知道如何解决以下问题。
数据集包含句子作为样本。因此,每个样本都有一定数量的单词(或tokens),我们可以通过在白色空间(sample.split())上简单地分割样本来获得。这样一个虚拟数据集可能看起来像这样:

from random import randint

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

现在我希望能够加载数据,以便最大。批量中的 * token * 数量不超过250。这意味着批量大小在迭代之间可能不同。一个批次可以包含两个样本,总共不超过250个令牌(例如127 + 77),另一个批次可以包含三个(66+66+66)。现在,它的核心功能相当简单。下面是完整的例子;没有通过长度排序之类的方法来优化,但这对这个例子来说是可以的。
问题是,如何将其集成到PyTorch生态系统中?批量大小经常被用来表示samples的数量(如在数据加载器中)。那么,我应该在哪里插入它,或者我应该子类化什么,让它像一个普通的数据加载器一样工作?

from random import randint

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

if __name__ == '__main__':
    dataset = DummyDataset()

    def get_batch(max_tokens: int = 250):
        data_idxs = list(range(len(dataset)))

        batch = []
        total_batch_len = 0
        while data_idxs:
            sample = dataset[data_idxs[0]]
            sample_len = len(sample.split())

            if total_batch_len + sample_len <= max_tokens:
                batch.append(sample)
                total_batch_len += sample_len
                data_idxs.pop(0)
            elif batch:
                yield batch
                batch = []
                total_batch_len = 0

        yield batch

    # Sanity check that we indeed get all items from the dataset
    num_samples = 0
    num_batches = 0
    for b in get_batch():
        num_samples += len(b)
        num_batches += 1

    print(f"Created {num_batches} batches")
    assert num_samples == len(dataset)

也许torchtext的迭代器和它的batch_size_fn可以帮助,但我没有使用它的经验(我应该在哪里添加它;它本身是一个数据加载器,还是我应该在它周围 Package 一个数据加载器,等等)。

nmpmafwu

nmpmafwu1#

在阅读了一些源代码之后,您似乎可以在Dataloader的batch_sampler中使用任何迭代器。下面的工作如预期。

from random import randint

from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

class TokenBatchSampler:
    def __init__(self, max_tokens: int = 250):
        self.max_tokens = max_tokens
        self.batches = []
        self._prepare_dataset()

    def __len__(self) -> int:
        return len(self.batches)

    def __iter__(self):
        return iter(self.batches)

    def _prepare_dataset(self):
        data_idxs = list(range(len(dataset)))

        batches = []
        batch_idxs = []
        total_batch_len = 0
        while data_idxs:
            sample_idx = data_idxs[0]
            sample = dataset[sample_idx]
            sample_len = len(sample.split())

            if total_batch_len + sample_len <= self.max_tokens:
                batch_idxs.append(sample_idx)
                total_batch_len += sample_len
                data_idxs.pop(0)
            elif batch_idxs:
                batches.append(batch_idxs)
                batch_idxs = []
                total_batch_len = 0

        batches.append(batch_idxs)

        self.batches = batches

if __name__ == "__main__":
    dataset = DummyDataset()

    sampler = TokenBatchSampler()
    dataloader = DataLoader(dataset, batch_sampler=sampler)
    # Sanity check that we indeed get all items from the dataset
    for epoch in range(3):
        num_samples = 0
        num_batches = 0
        for b in dataloader:
            num_samples += len(b)
            num_batches += 1

        print(f"Created {num_batches} batches in epoch {epoch}")
        assert num_samples == len(dataset)

    print(f"DataLoader length {len(dataloader)}")

相关问题