Pytorch摘要导致运行时错误。以下是模型。打印模型摘要时发生运行时错误。我使用torchsummary.summary来获取模型摘要。
完整错误消息:运行时错误:参数#1 'indices'的Tensor应具有以下标量类型之一:Long,Int;但得到torch.FloatTensor(在检查参数是否嵌入时)
class SentimentalModelV3(nn.Module):
def __init__(self, output_size, vocab_size, embedding_dim = 128, hidden_dim = 64, batch_size = 32, padded_seq_len = 10, n_layers = 1, drop_prob = 0.3, bidirectional = False):
super().__init__()
self.batch_size = batch_size
self.output_size = output_size
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.padded_seq_len = padded_seq_len
self.embedding = nn.Embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim)
self.lstm = nn.LSTM(input_size = embedding_dim, hidden_size = hidden_dim, num_layers = n_layers, dropout = drop_prob, batch_first = True, bidirectional = bidirectional)
self.dropout = nn.Dropout(0.3)
#Linear and activation layer
self.fc1=nn.Linear(self.hidden_dim * self.padded_seq_len, 64)
self.fc2=nn.Linear(64, 16)
self.fc3=nn.Linear(16,output_size)
self.Relu = nn.ReLU()
def forward(self, one_hot, hn, cn):
embed = self.embedding(one_hot)
lstm_out, hidden = self.lstm(embed)
#stack up the lstm output
lstm_out = lstm_out.reshape(shape = (lstm_out.shape[0], lstm_out.shape[1] * lstm_out.shape[2]))
# dropout and fully connected layers
out = self.dropout(lstm_out)
out = self.Relu(out)
out = self.Relu(self.fc1(out))
out = self.Relu(self.fc2(out))
out = self.fc3(out)
return out
def initCellState(self):
h = torch.zeros(self.n_layers , self.batch_size , self.hidden_dim).to(device)
c = torch.zeros(self.n_layers , self.batch_size , self.hidden_dim).to(device)
return h, c
虽然,打印摘要,我得到了错误,如下图所示
from torchsummary import summary
n_layers = 5
batch_size = 32
hidden_dim = 64
model_v3 = SentimentalModelV3(output_size = 3,
vocab_size = UNIQUE_WORD_COUNT,
embedding_dim = 128,
hidden_dim = hidden_dim,
n_layers = n_layers,
drop_prob = 0.3,
padded_seq_len = 10,
batch_size = batch_size,
bidirectional = False).to(device)
hn, cn = model_v3.initCellState()
summary(model_v3, [(1, 10), (n_layers, batch_size, hidden_dim), (n_layers, batch_size, hidden_dim)])
1条答案
按热度按时间ndh0cuux1#
默认情况下,
summary
以给定形状的FloatTensor
作为输入来运行model_v3
。但是您的代码期望one_hot
是LongTensor
。修复方法是将正确的Tensor类型指定为summary
。值得庆幸的是,summary
为此接受Tensor序列: