如何在C++中循环遍历PytorchTensor中的每个值?

px9o7tmv  于 2022-12-15  发布在  其他
关注(0)|答案(2)|浏览(248)

我尝试在C中为PyTorchTensor做类似[0 if i〈1 else 1]的事情。我尝试使用tensor.accessor(),但它似乎需要你事先知道维度。而我想动态地传递它。
有没有办法用C
为Pytorch做这些?

eeq64g8w

eeq64g8w1#

看看这是否有帮助:
使用Tensor迭代器:
https://labs.quansight.org/blog/2020/04/pytorch-tensoriterator-internals/
或者利用t.is_contiguous()/ t.contiguous()来简化横向:
https://discuss.pytorch.org/t/iterating-over-tensor-in-c/60333/2

vxbzzdmp

vxbzzdmp2#

c++ torch中的For循环是这样的。你也可以试试torch.where。注意,与内置操作相比,for循环可能会很慢。

auto answer = torch::zeros_like(x);
        auto batchCount = x.sizes()[0];
        auto pointCount = x.sizes()[1];
        
        // auto nextPoint = torch::zeros({batchCount, _pointDim});
        for (int i = 0; i < batchCount; ++i)
        {
            for (int j = 1; j < _pointCount; ++j)
            {
                answer[i][j] = answer[i][j - 1] + x[i][j - 1];
            } 
        }

相关问题