pytorch 使用Iterable Dataset将本地parquet文件流式传输到huggingface trainer

kwvwclae  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(222)

我想流一个大的.parquet文件,我在本地训练分类模型。我的脚本似乎只加载了第一个小批:即使文件非常大,时期的数量也会非常快地增加,1个时期应该持续大约10个小时。下面是我使用的代码:

import pyarrow.parquet as pq
import torch
import pandas as pd
import evaluate
from transformers import (AutoTokenizer, CamembertForSequenceClassification,
                          EarlyStoppingCallback, Trainer, TrainingArguments,
                          pipeline)
import numpy as np

class MyIterableDataset(torch.utils.data.IterableDataset):

    def __init__(self, parquet_file_path: str, tokenizer, label_encoder, batch_size: int = 8):
        self.parquet_file = pq.ParquetFile(parquet_file_path)
        self.generator = self.parquet_file.iter_batches(batch_size=batch_size)
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder

    def __iter__(self):
        """ """
        data = next(self.generator)
        encodings = self.tokenizer(data['text'].tolist(), truncation=True, padding=True, max_length=512)
        items = []
        for idx in range(len(data)):
            item = {key: torch.tensor(val[idx]) for key, val in encodings.items()}
            item["labels"] = torch.tensor(self.label_encoder.transform([str(data['target'][idx])]))
            items.append(item)
        return iter(items)

个字符

fcipmucu

fcipmucu1#

__iter__方法不会迭代整个数据集,因为它缺少一个循环来重复获取下一批数据并处理它。相反,它使用next(self.generator)加载第一个批处理,处理它,然后 * 返回一个迭代器 *,其中包含该批处理中的项。由于它只执行一次,因此您只能获得数据集中的第一个批处理。
你可以尝试这样的东西:

def __iter__(self):
    while True:
        try:
            data = next(self.generator) #try and get the next bit of data
        except StopIteration:
            # End of the dataset, break
            break

        encodings = self.tokenizer(data['text'].tolist(), truncation=True, padding=True, max_length=512)
        items = []
        for idx in range(len(data)): #for index encode and yield
            item = {key: torch.tensor(val[idx]) for key, val in encodings.items()}
            item["labels"] = torch.tensor(self.label_encoder.transform([str(data['target'][idx])]))
            items.append(item)
        yield from items

字符串
这个版本应该让__iter__方法继续从self.generator中获取批处理,处理它们,并从每个批处理中产生单独的项,直到.parquet文件中没有更多的批处理。
https://www.datacamp.com/tutorial/python-iterators-generators-tutorial
https://anandology.com/python-practice-book/iterators.html
https://www.geeksforgeeks.org/difference-between-iterator-vs-generator/

相关问题