我有一个Tensor,在tensorflow 中的形状是(16, 512, 4096)
,我想从Tensor中计算出k
的最小元素。
请注意,我可以使用以下代码片段在pytorch中获取它-
#inputs.shape (16L, 512L, 4096L)
dists, inputs_idx = torch.topk(inputs, 64, 2, largest=False, sorted=False)
#dists.shape (16L, 512L, 64L), inputs_idx.shape (16L, 512L, 64L)
有什么解决办法吗?
1条答案
按热度按时间epggiuax1#
因为
torch.topk
可以用来得到k
最大的元素,所以你可以对这些值取反,执行tne操作,然后再次对它们取反以得到值: