pytorch 嵌入索引超出范围:

lb3vh1jj  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(142)

我理解为什么这个错误经常发生,这是因为输入〉= embedding_dim.
但是在我的例子中torch.max(输入)= embedding_dim - 1。

print('inputs: ', src_seq)
print('input_shape: ', src_seq.shape)
print(self.src_word_emb)
inputs:  tensor([[10,  6,  2,  4,  9, 14,  6,  2,  5,  0],
        [12,  6,  3,  8, 13,  2,  0,  1,  1,  1],
        [13,  8, 12,  7,  2,  4,  0,  1,  1,  1]])
input_shape: [3, 10]
Embedding(15, 512, padding_idx=1)
emb = self.src_word_emb(src_seq)

我试图让一个转换器模型工作,但由于某种原因,编码器嵌入只接受〈embedding_dim_decoder的输入,这没有意义,对吗?

oipij1gg

oipij1gg1#

找到错误源了!在转换器模型中,编码器和解码器可以设置为共享相同的嵌入权重。然而,我有一个翻译任务,一个嵌入用于解码器,一个嵌入用于编码器。在代码中,它通过以下方式初始化权重:

if emb_src_trg_weight_sharing:
            self.encoder.src_word_emb.weight = self.decoder.trg_word_emb.weight

emb_src_trg_weight_sharing设置为false解决了问题!

相关问题