python 当你的模型不能过拟合一小批数据时意味着什么?

sqxo8psd  于 2023-01-29  发布在  Python
关注(0)|答案(1)|浏览(75)

我尝试训练RNN模型将句子分为4类,但似乎不起作用。我尝试过拟合4个示例(蓝线),但即使只有8个示例(红线)也不起作用,更不用说整个数据集了。

我尝试了hidden_sizeembedding_size的不同学习速率和大小,但似乎没有帮助,我错过了什么?我知道,如果模型不能过拟合小批量,这意味着容量应该增加,但在这种情况下,增加容量没有效果。架构如下:

class RNN(nn.Module):
    def __init__(self, embedding_size=256, hidden_size=128, num_classes=4):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab), embedding_size, 0)
        self.rnn = nn.RNN(embedding_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        #x=[batch_size, sequence_length]
        x = self.embedding(x) #x=[batch_size, sequence_length, embedding_size]
        _, h_n = self.rnn(x)  #h_n=[1, batch_size, hidden_size]
        h_n = h_n.squeeze(0)
        out = self.fc(h_n)  #out=[batch_size, num_classes]
        return out

输入数据是标记化的句子,用0填充到批处理中的最长句子,因此作为示例,一个样本将是:[2784,9544,1321,120,0,0]。数据来自 Torch 文本数据集中的AG_NEWS数据集。
培训代码:

model = RNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
model.train()

for epoch in range(NUM_EPOCHS):
    epoch_losses = []
    correct_predictions = []
    for batch_idx, (labels, texts) in enumerate(train_loader):
        scores = model(texts)
        loss = criterion(scores, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        epoch_losses.append(loss.item())
        correct = (scores.max(1).indices==labels).sum()
        correct_predictions.append(correct)
        
    epoch_avg_loss = sum(epoch_losses)/len(epoch_losses)
    epoch_avg_accuracy = float(sum(correct_predictions))/float(len(labels))
fslejnso

fslejnso1#

该问题是由于梯度消失造成的。

相关问题