我尝试使用VIT对模型进行训练,以对一些图像进行分类。在训练过程中,脚本卡住,我不知道错误在哪里。一些图像的分类特征只有两个目标0和1(假和真)。批量为32,历元仅为3。
下面我把脚本的培训模式:
import torch.utils.data as data
from torch.autograd import Variable
import numpy as np
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
# Train the model
for epoch in range(EPOCHS):
for step, (x, y) in enumerate(train_loader):
# Change input array into list with each batch being one element
x = np.split(np.squeeze(np.array(x)), BATCH_SIZE)
# Remove unecessary dimension
for index, array in enumerate(x):
x[index] = np.squeeze(array)
# Apply feature extractor, stack back into 1 tensor and then convert to tensor
x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
# Send to GPU if available
x = x.to(device)
y = y.to(device)
b_x = Variable(x) # batch x (image)
b_y = Variable(y) # batch y (target)
# Feed through model
output = model(b_x, None)
loss = output[0]
# Calculate loss
if loss is None:
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
# Get the next batch for testing purposes
test = next(iter(test_loader))
test_x = test[0]
# Reshape and get feature matrices as needed
test_x = np.split(np.squeeze(np.array(test_x)), BATCH_SIZE)
for index, array in enumerate(test_x):
test_x[index] = np.squeeze(array)
test_x = torch.tensor(np.stack(feature_extractor(test_x)['pixel_values'], axis=0))
# Send to appropirate computing device
test_x = test_x.to(device)
test_y = test[1].to(device)
# Get output (+ respective class) and compare to target
test_output, loss = model(test_x, test_y)
test_output = test_output.argmax(1)
# Calculate Accuracy
accuracy = (test_output == test_y).sum().item() / BATCH_SIZE
print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
错误消息如下所示:值错误:数组拆分不会导致相等的除法。并且它突出显示了命令x = np.split(np.squeeze(np.array(x)), BATCH_SIZE)
。
1条答案
按热度按时间iklwldmw1#
我最近遇到了同样的错误。检查
train_ds
的长度。并更改BATCH_SIZE,以便当您将len(train_ds)
除以BATCH_SIZE
时,提醒为零。