Pytorch运行时错误:参数#1 'indices'的Tensor应具有以下标量类型之一:长,中间;但得到 Torch ,浮动张力代替

4xrmg8kj  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(266)

Pytorch摘要导致运行时错误。以下是模型。打印模型摘要时发生运行时错误。我使用torchsummary.summary来获取模型摘要。
完整错误消息:运行时错误:参数#1 'indices'的Tensor应具有以下标量类型之一:Long,Int;但得到torch.FloatTensor(在检查参数是否嵌入时)

class SentimentalModelV3(nn.Module):
  def __init__(self, output_size, vocab_size, embedding_dim = 128, hidden_dim = 64, batch_size = 32, padded_seq_len = 10, n_layers = 1, drop_prob = 0.3, bidirectional = False):
    super().__init__()
    self.batch_size = batch_size
    self.output_size = output_size
    self.n_layers = n_layers
    self.hidden_dim = hidden_dim
    self.padded_seq_len = padded_seq_len

    self.embedding = nn.Embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim)
    self.lstm = nn.LSTM(input_size = embedding_dim, hidden_size = hidden_dim, num_layers = n_layers, dropout = drop_prob, batch_first = True, bidirectional = bidirectional)

    self.dropout = nn.Dropout(0.3)

    #Linear and activation layer
    self.fc1=nn.Linear(self.hidden_dim * self.padded_seq_len, 64)
    self.fc2=nn.Linear(64, 16)
    self.fc3=nn.Linear(16,output_size)
    self.Relu = nn.ReLU()

  def forward(self, one_hot, hn, cn):
    embed = self.embedding(one_hot)
    lstm_out, hidden = self.lstm(embed)

    #stack up the lstm output
    lstm_out = lstm_out.reshape(shape = (lstm_out.shape[0], lstm_out.shape[1] * lstm_out.shape[2]))
    # dropout and fully connected layers
    out = self.dropout(lstm_out)
    out = self.Relu(out)
    out = self.Relu(self.fc1(out))
    out = self.Relu(self.fc2(out))
    out = self.fc3(out)

    return out

  def initCellState(self):
    h =  torch.zeros(self.n_layers , self.batch_size , self.hidden_dim).to(device)
    c =  torch.zeros(self.n_layers , self.batch_size , self.hidden_dim).to(device)
    return h, c

虽然,打印摘要,我得到了错误,如下图所示

from torchsummary import summary

n_layers = 5
batch_size = 32
hidden_dim = 64

model_v3 = SentimentalModelV3(output_size = 3, 
                              vocab_size = UNIQUE_WORD_COUNT, 
                              embedding_dim = 128, 
                              hidden_dim = hidden_dim, 
                              n_layers = n_layers, 
                              drop_prob = 0.3, 
                              padded_seq_len = 10, 
                              batch_size = batch_size,
                              bidirectional = False).to(device)

hn, cn = model_v3.initCellState()
summary(model_v3, [(1, 10), (n_layers, batch_size, hidden_dim), (n_layers, batch_size, hidden_dim)])

ndh0cuux

ndh0cuux1#

默认情况下,summary以给定形状的FloatTensor作为输入来运行model_v3。但是您的代码期望one_hotLongTensor。修复方法是将正确的Tensor类型指定为summary。值得庆幸的是,summary为此接受Tensor序列:

summary(model_v3, [
    torch.zeros(1, 10).long(),  # ensure one_hot is a LongTensor
    torch.zeros(n_layers, batch_size, hidden_dim),
    torch.zeros(n_layers, batch_size, hidden_dim),
])

相关问题