我正在使用数据集和数据加载器,整理函数和采样器将别人的代码转换成一个更整洁的 Torch 管道。虽然我以前做过这样的工作,但我不知道如何解决以下问题。
数据集包含句子作为样本。因此,每个样本都有一定数量的单词(或tokens
),我们可以通过在白色空间(sample.split()
)上简单地分割样本来获得。这样一个虚拟数据集可能看起来像这样:
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
现在我希望能够加载数据,以便最大。批量中的 * token * 数量不超过250。这意味着批量大小在迭代之间可能不同。一个批次可以包含两个样本,总共不超过250个令牌(例如127 + 77),另一个批次可以包含三个(66+66+66)。现在,它的核心功能相当简单。下面是完整的例子;没有通过长度排序之类的方法来优化,但这对这个例子来说是可以的。
问题是,如何将其集成到PyTorch生态系统中?批量大小经常被用来表示samples
的数量(如在数据加载器中)。那么,我应该在哪里插入它,或者我应该子类化什么,让它像一个普通的数据加载器一样工作?
from random import randint
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self):
data = []
for _ in range(128):
data.append("hello " * randint(64, 176))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx]
if __name__ == '__main__':
dataset = DummyDataset()
def get_batch(max_tokens: int = 250):
data_idxs = list(range(len(dataset)))
batch = []
total_batch_len = 0
while data_idxs:
sample = dataset[data_idxs[0]]
sample_len = len(sample.split())
if total_batch_len + sample_len <= max_tokens:
batch.append(sample)
total_batch_len += sample_len
data_idxs.pop(0)
elif batch:
yield batch
batch = []
total_batch_len = 0
yield batch
# Sanity check that we indeed get all items from the dataset
num_samples = 0
num_batches = 0
for b in get_batch():
num_samples += len(b)
num_batches += 1
print(f"Created {num_batches} batches")
assert num_samples == len(dataset)
也许torchtext的迭代器和它的batch_size_fn
可以帮助,但我没有使用它的经验(我应该在哪里添加它;它本身是一个数据加载器,还是我应该在它周围 Package 一个数据加载器,等等)。
1条答案
按热度按时间nmpmafwu1#
在阅读了一些源代码之后,您似乎可以在Dataloader的
batch_sampler
中使用任何迭代器。下面的工作如预期。