当我运行我的代码时,train循环永远不会结束。当它打印出来,告诉它在哪里时,它已经超过了300个数据点,我告诉程序有,而且还有42000个,实际上在csv文件中。为什么它在300个样本后没有自动停止?
谢谢你们
我的代码:(为了可读性,我省略了Net和测试循环)
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
from torchvision.transforms import ToTensor
#import os
import pandas as pd
#import numpy as np
import random
import time
#Hyperparameters
batch_size = 3
learning_rate = 8e-3
#DataSet
class CustomImageDataset(Dataset):
def __init__(self, img_dir, batches):
self.img_dir = img_dir
self.batches =batches
self.data=pd.read_csv(self.img_dir)
#self.data=pd.read_csv("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv")
def __len__(self):
#return len(self.data)
return 300
def __getitem__(self, idx):
images =[]
labels = torch.zeros(self.batches,dtype=int)
for x in range(self.batches):
label =self.data.at[(idx+x),'label']
label = label.item()
#image = torch.zeros(1,1,784) #,dtype=torch.int32
image = torch.zeros(1,28,28) #,dtype=torch.int32
for i in range(784):
z = int(i%28)
y= int((i-x)/28)
column = 'pixel' +str(i)
#image[0,0,i]=self.data.at[(idx+x),column]
image[0,z,y]=self.data.at[(idx+x),column]
images.append(image)
labels[x] = label
return torch.stack(images), labels
#DataLoader
train_loader=DataLoader(CustomImageDataset,batch_size, shuffle = False, drop_last= True)
#Creating Instances
Data =CustomImageDataset("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv",batch_size)
model =NeuralNetwork()
#Hyperparameters
epochs = 1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#Creating training loop
def train_loop(dataloader,model,loss_fn,optimizer,batch_size):
size=Data.__len__()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
#print(len(X))
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
#Executing Part of the script
for t in range(epochs):
print(f"Epoch {t+1}\n ----------------")
train_loop(Data,model,loss_fn,optimizer,batch_size)
test_loop(Data,model,loss_fn)
1条答案
按热度按时间aiazj4mn1#
问题是,当你定义
__getitem__
时,在这个类上迭代根本不会使用__len__
。如果你想限制样本的数量,那么你应该在__getitem__
内部通过提升StopIteration
来实现。示例:
结果: