torch.utils.data。DataLoader返回torch.tensors。有没有办法返回numpy数组?下面的子类返回静态Tensor。我想将__iter()__
改为返回numpy数组。(...my_tensor.numpy(),...)就是想不明白。
class CustomDataLoader(DataLoader):
def __init__(self, dataset):
super().__init__(dataset)
def __iter__(self):
it_ = super().__iter__()
print( next(it_))
print(super().__iter__().__dict__)
return it_
c = CustomDataLoader(dataset)
next(iter(c))
1条答案
按热度按时间wbrvyc0a1#
是的,您可以定义自己的自定义排序函数并将其作为
Dataloader(dataset,collate_fn=my_function)
传递。collate函数负责将批处理中的各个元素聚合或“整理”成可索引或可迭代的批处理(例如,将大小为[100,100]
的n
Tensor列表转换为大小为[n,100,100]
的单个Tensor。)如果你想以非平凡的方式整理数据,或者如果你的数据中有不寻常的类型,这通常是一种方法,因为pytorch只为最常见的用例提供默认的整理函数。在collate函数中,在最简单的情况下,您可以使用<tensor>.data.numpy().
将任何Tensor转换为numpy数组您可以查看文档或this StackOverflow关于定义自定义排序函数的问题。
希望这有帮助!