pytorch 在Python中计算F1分数和其他指标时出错

fhg3lkii  于 2023-10-20  发布在  Python
关注(0)|答案(1)|浏览(171)

我做了一个Python项目,使用PyTorch进行深度学习。我在计算F1分数时收到以下错误消息:

'Classification metrics can't handle a mix of multiclass-multioutput
and multilabel-indicator targets'

我的代码:

model = nn.Sequential(
    nn.Linear(135, 50),
    nn.ReLU(),
    nn.Linear(50, 50),
    nn.ReLU(),
    nn.Linear(50, max_length),
    nn.Sigmoid()
)

epochs = 1000
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.train()
for epoch in range(epochs):
  for X_train, y_train in Dataloader:
    y_pred = model(X_train)
    # Convert the target tensor to torch.float32 data type
    y_train = y_train.float()
    loss = loss_fn(y_pred, y_train)
    optimizer.zero_grad()
    loss.backward()
    print(loss.item())
    optimizer.step()

model.eval()
y_pred = model(X_test)
y_pred = (y_pred > 0.5).float()  # Threshold the probabilities to get binary predictions
acc = (y_pred == y_test).float().mean()
print("Model accuracy: %.2f%%" % (acc*100))

任何帮助都是感激不尽的。谢谢.

7nbnzgx9

7nbnzgx91#

之所以会出现这个错误,是因为您将这些指标全部计算在一起。它们应按类计算。

相关问题