pytorch 实现Gumbel sigmoid来重构数据Tensor

mzsu5hc0  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(201)

假设我们有一个logitTensor(形状:B,W,1),每个值代表一个需要采样的二进制预测,基于采样的输出,我想在网络中添加额外的维度(这也是一个离散操作)。然后,重构的数据Tensor被传递到网络的下一个组件中,依此类推。例如,如果采样是0,1,0,则到下一层(在网络中)的输入将是d1,d2,x.d3(其中,表示连接,d1,d2,d3是初始Tensor,x是基于采样引入(扩展)的)在此用例中是否有简单的方法应用Gumbel技巧等?在PyTorch中的解决方案将是伟大的!

jei2mxaa

jei2mxaa1#

我看不出这有什么用处,但这里有:

b, w, num_samples = 3, 5, 7
thresholds = torch.rand(b, w) 
noise = torch.randn(7,).view(1, 1, -1).expand(b, w, num_samples)  # (
samples = (thresholds[:, :, None] > noise).long()

相关问题