我正在尝试创建一个新的激活层,我们称之为topk,它的工作原理如下:它将一个大小为n的向量x作为输入(前一层输出乘以权重矩阵并加上偏差的结果)和一个正整数k,并将输出一个大小为n的向量topk(x),其元素为:
x_i (if x_i is one of the top k elements of x)
topk(x)_i =
0 (otherwise)
在计算topk(x)的梯度时,x的前k个元素的梯度应为1,其他元素的梯度均为0。
我该如何实施呢?任何帮助都将不胜感激。
1条答案
按热度按时间x7rlezfr1#
您可以使用
torch.topk
来执行以下操作:请注意,虽然
topk()
的'values'
是可微的,但'indices'
are not(类似于argmax不是可微函数)。