pytorch 如何从torch.utils.data.DataLoader返回numpy数组而不是Tensor

b1payxdu  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(276)

torch.utils.data。DataLoader返回torch.tensors。有没有办法返回numpy数组?下面的子类返回静态Tensor。我想将__iter()__改为返回numpy数组。(...my_tensor.numpy(),...)就是想不明白。

class CustomDataLoader(DataLoader):
    def __init__(self, dataset):
        super().__init__(dataset)
    
    def __iter__(self):
        it_ = super().__iter__()
        print( next(it_))
        print(super().__iter__().__dict__)
        return it_ 
        
c = CustomDataLoader(dataset)
next(iter(c))
wbrvyc0a

wbrvyc0a1#

是的,您可以定义自己的自定义排序函数并将其作为Dataloader(dataset,collate_fn=my_function)传递。collate函数负责将批处理中的各个元素聚合或“整理”成可索引或可迭代的批处理(例如,将大小为[100,100]nTensor列表转换为大小为[n,100,100]的单个Tensor。)如果你想以非平凡的方式整理数据,或者如果你的数据中有不寻常的类型,这通常是一种方法,因为pytorch只为最常见的用例提供默认的整理函数。在collate函数中,在最简单的情况下,您可以使用<tensor>.data.numpy().将任何Tensor转换为numpy数组
您可以查看文档或this StackOverflow关于定义自定义排序函数的问题。
希望这有帮助!

相关问题