pytorch 除顶点k外的所有向量元素为零?

a64a0gku  于 2023-01-09  发布在  其他
关注(0)|答案(1)|浏览(175)

我正在尝试创建一个新的激活层,我们称之为topk,它的工作原理如下:它将一个大小为n的向量x作为输入(前一层输出乘以权重矩阵并加上偏差的结果)和一个正整数k,并将输出一个大小为n的向量topk(x),其元素为:

x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
              0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应为1,其他元素的梯度均为0。
我该如何实施呢?任何帮助都将不胜感激。

x7rlezfr

x7rlezfr1#

您可以使用torch.topk来执行以下操作:

k = 2
output = torch.randn(5)
vals, idx = output.topk(k)

topk = torch.zeros_like(output)
topk[idx] = vals
>>> topk
tensor([1.0557, 0.0000, 0.0000, 1.4562, 0.0000])

请注意,虽然topk()'values'是可微的,但'indices'are not(类似于argmax不是可微函数)。

相关问题