假设我们有一个logitTensor(形状:B,W,1),每个值代表一个需要采样的二进制预测,基于采样的输出,我想在网络中添加额外的维度(这也是一个离散操作)。然后,重构的数据Tensor被传递到网络的下一个组件中,依此类推。例如,如果采样是0,1,0,则到下一层(在网络中)的输入将是d1,d2,x.d3(其中,表示连接,d1,d2,d3是初始Tensor,x是基于采样引入(扩展)的)在此用例中是否有简单的方法应用Gumbel技巧等?在PyTorch中的解决方案将是伟大的!
jei2mxaa1#
我看不出这有什么用处,但这里有:
b, w, num_samples = 3, 5, 7 thresholds = torch.rand(b, w) noise = torch.randn(7,).view(1, 1, -1).expand(b, w, num_samples) # ( samples = (thresholds[:, :, None] > noise).long()
1条答案
按热度按时间jei2mxaa1#
我看不出这有什么用处,但这里有: