Pytorch训练循环不会停止

eqoofvh9  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(198)

当我运行我的代码时,train循环永远不会结束。当它打印出来,告诉它在哪里时,它已经超过了300个数据点,我告诉程序有,而且还有42000个,实际上在csv文件中。为什么它在300个样本后没有自动停止?
谢谢你们
我的代码:(为了可读性,我省略了Net和测试循环)

import torch
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
from torch import nn
from torchvision.transforms import ToTensor
#import os
import pandas as pd
#import numpy as np
import random
import time

#Hyperparameters
batch_size = 3
learning_rate = 8e-3


#DataSet
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, batches):
        self.img_dir = img_dir
        self.batches =batches
        self.data=pd.read_csv(self.img_dir)
        #self.data=pd.read_csv("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv")

    def __len__(self):
        #return len(self.data)
        return 300

    def __getitem__(self, idx):
        
        
        images =[]
        labels = torch.zeros(self.batches,dtype=int)
        for x in range(self.batches):
            label =self.data.at[(idx+x),'label']
            label = label.item()
            #image = torch.zeros(1,1,784) #,dtype=torch.int32
            image = torch.zeros(1,28,28) #,dtype=torch.int32

            for i in range(784):

                z = int(i%28)
                y= int((i-x)/28)

                column = 'pixel' +str(i)

                #image[0,0,i]=self.data.at[(idx+x),column]
                image[0,z,y]=self.data.at[(idx+x),column]
            
            images.append(image)
            labels[x] = label
            
        return torch.stack(images), labels
        
#DataLoader
train_loader=DataLoader(CustomImageDataset,batch_size, shuffle = False, drop_last= True)

#Creating Instances
Data =CustomImageDataset("01.Actual/02.NeuralNetwork/01.OffTopicTests/NumberRecognizer/train.csv",batch_size)
model =NeuralNetwork()

#Hyperparameters

epochs = 1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

#Creating training loop
def train_loop(dataloader,model,loss_fn,optimizer,batch_size):
    size=Data.__len__()
    
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        

        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        if batch % 100 == 0:
            #print(len(X))
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")    
       

#Executing Part of the script
for t in range(epochs):
    print(f"Epoch {t+1}\n ----------------")
    train_loop(Data,model,loss_fn,optimizer,batch_size)
    test_loop(Data,model,loss_fn)
aiazj4mn

aiazj4mn1#

问题是,当你定义__getitem__时,在这个类上迭代根本不会使用__len__。如果你想限制样本的数量,那么你应该在__getitem__内部通过提升StopIteration来实现。
示例:

class Example(object):
    def __init__(self, a, b, c):
        self.data = list(range(a))
        self.len = b
        self.stop = c

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if idx >= self.stop:
            raise StopIteration
        return self.data[idx]

结果:

In [1]: for x in Example(8, 4, 4):
   ...:     print(x, end=', ')
0, 1, 2, 3, 

In [2]: for x in Example(8, 4, 8):
   ...:     print(x, end=', ')
0, 1, 2, 3, 4, 5, 6, 7, 

In [3]: for x in Example(4, 4, 8):
   ...:     print(x, end=', ')
0, 1, 2, 3,

相关问题