在Pytorch中使用检查点恢复计算的问题

rseugnpd  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(147)

这是我第一次使用检查点,我遇到了一个我无法理解的问题。我使用检查点的原因是因为我在使用GPU时有时间限制,所以我需要我的代码工作一段时间,保存检查点,当我有机会再次启动它时从它离开的地方继续。总结一下我的代码加载模型,加载imagenet验证数据集,创建一个数据加载器,并启动一个循环,将批量图像传递给我编写的函数。请记住,我没有做任何训练,只是对图像进行操作,将它们传递给模型以测试。
正如你所看到的(理论上),当代码再次启动时,它应该进入保存检查点的文件夹,并检索最后一个已创建的检查点。它成功地做到了这一点,到目前为止没有问题。所以让我们看看代码,然后我会说什么不工作。

from torch.utils.data.dataloader import default_collate

def create_data_loader(dataset, batch_size, device):
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, \
                            collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
    return val_loader

 def main(args):
    # Loading pre-trained model
    model = ...
    config = ...
    T = ...

    # Obtaining ImageNet validation dataset from folder
    val = ImageFolder(
        root='path_to_folder_here',
        transform=T
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    val_loader = create_data_loader(val, args.batchsize, device)
    print("Number of batches:", len(val_loader))
    # Initializing progress variables
    start_batch = 0
    
    # Creating variables that I want to save in checkpoints here
    # ...
    
    # Checking if checkpoint directory exists
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
        print("Created checkpoints directory", args.checkpoint_dir)

    # If the checkpoint directory is not empty
    if os.listdir(args.checkpoint_dir):
        # Getting the name of the latest checkpoint file added to the directory
        files_in_dir = os.listdir(args.checkpoint_dir)
        sorted_files = sorted(files_in_dir, key=lambda x: os.path.getctime(os.path.join(args.checkpoint_dir, x)), reverse=True)
        latest_checkpoint = sorted_files[0]
        print("Found latest checkpoint created:", latest_checkpoint)
        checkpoint = load_checkpoint(os.path.join(args.checkpoint_dir, latest_checkpoint))
        start_batch = checkpoint['batch']
        # Retrieving other variables from checkpoint here

    total_batches = len(val_loader)
    for j, (images, _) in enumerate(val_loader, start=start_batch):
        print(f"Processing batch {j}/{total_batches} - Bacth size: {len(images)}, Number of images: {images.shape[0]}")
        batches_left = total_batches - j
        print(f"{batches_left} remaining")
        result = myFunction(...)
        #Doing things with result here to use after out of the for loop
        #...
        
        print("Batch:" + str(j) + "finished")
        j+=1

        # Checkpoint every 'args.freq' batches
        if j % args.freq == 0:
            save_checkpoint({
                'batch': j,
                # Saving other variables ...
            }, f"{args.checkpoint_dir}/checkpoint_{j}.pth")

好吧,Imagenet验证数据集有50000张图片,我的batchsize是256,总batches是196。我的代码第一次设法做了168个batches,当它第二次被调用时,它似乎从那里开始,因为我被打印出来:
找到创建的最新检查点:checkpoint_168.pth检查点从检查点加载/checkpoint_168.pth start处理批次168/196 - Bacth大小:256,图像数量:256 28剩余批次:168完成…
所以我让它继续...问题是?它没有停止。它还在继续,目前最后打印的东西是:
处理批次195/196 - Bacth规格:256,图像数量:256 1剩余批次:195成品处理批次196/196 - Bacth规格:256,图像数量:256 0剩余批次:196成品处理批次197/196 - Bacth规格:256,图像数量:256剩余批次:197成品处理批次198/196 - Bacth大小:256,图像数量:256 - 2剩余
如果我让它继续,它会继续。你能帮助我理解发生了什么吗?我的代码不应该在数据加载器中处理第一个图像,然后在恢复时继续处理剩余的图像吗?很明显,它正在做更多的事情,或者可能正在重做应该已经完成的事情,我不明白我看到了什么。
非常感谢大家!

bttbmeg0

bttbmeg01#

我认为你的问题是你如何处理数据加载器。
val_loader = create_data_loader(val, args.batchsize, device)行创建了一个新的数据加载器,它还没有生成任何批处理。如果在for循环中使用它,它将生成196个批处理,每个迭代一个,从batch 0开始,到batch 195结束。
稍后,当你执行for j, (images, _) in enumerate(val_loader, start=start_batch):时,这个循环无论如何都会运行196次。这里enumeratestart参数并没有跳过val_loader的第一个start_batch es,它只是让jstart_batch开始,而不是从0开始,并在start_batch + 195而不是195处结束。val_loader仍在泵送196个批次,而不是196 - start_batch
要跳过start_batch之前的所有内容,可以执行以下操作:

for j, (images, _) in enumerate(val_loader):
    if j < start_batch:
      continue
    ...

循环仍然会运行196次,但它只会处理从批次start_batch开始的数据。

相关问题