这是我第一次使用检查点,我遇到了一个我无法理解的问题。我使用检查点的原因是因为我在使用GPU时有时间限制,所以我需要我的代码工作一段时间,保存检查点,当我有机会再次启动它时从它离开的地方继续。总结一下我的代码加载模型,加载imagenet验证数据集,创建一个数据加载器,并启动一个循环,将批量图像传递给我编写的函数。请记住,我没有做任何训练,只是对图像进行操作,将它们传递给模型以测试。
正如你所看到的(理论上),当代码再次启动时,它应该进入保存检查点的文件夹,并检索最后一个已创建的检查点。它成功地做到了这一点,到目前为止没有问题。所以让我们看看代码,然后我会说什么不工作。
from torch.utils.data.dataloader import default_collate
def create_data_loader(dataset, batch_size, device):
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, \
collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
return val_loader
def main(args):
# Loading pre-trained model
model = ...
config = ...
T = ...
# Obtaining ImageNet validation dataset from folder
val = ImageFolder(
root='path_to_folder_here',
transform=T
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
val_loader = create_data_loader(val, args.batchsize, device)
print("Number of batches:", len(val_loader))
# Initializing progress variables
start_batch = 0
# Creating variables that I want to save in checkpoints here
# ...
# Checking if checkpoint directory exists
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
print("Created checkpoints directory", args.checkpoint_dir)
# If the checkpoint directory is not empty
if os.listdir(args.checkpoint_dir):
# Getting the name of the latest checkpoint file added to the directory
files_in_dir = os.listdir(args.checkpoint_dir)
sorted_files = sorted(files_in_dir, key=lambda x: os.path.getctime(os.path.join(args.checkpoint_dir, x)), reverse=True)
latest_checkpoint = sorted_files[0]
print("Found latest checkpoint created:", latest_checkpoint)
checkpoint = load_checkpoint(os.path.join(args.checkpoint_dir, latest_checkpoint))
start_batch = checkpoint['batch']
# Retrieving other variables from checkpoint here
total_batches = len(val_loader)
for j, (images, _) in enumerate(val_loader, start=start_batch):
print(f"Processing batch {j}/{total_batches} - Bacth size: {len(images)}, Number of images: {images.shape[0]}")
batches_left = total_batches - j
print(f"{batches_left} remaining")
result = myFunction(...)
#Doing things with result here to use after out of the for loop
#...
print("Batch:" + str(j) + "finished")
j+=1
# Checkpoint every 'args.freq' batches
if j % args.freq == 0:
save_checkpoint({
'batch': j,
# Saving other variables ...
}, f"{args.checkpoint_dir}/checkpoint_{j}.pth")
好吧,Imagenet验证数据集有50000张图片,我的batchsize是256,总batches是196。我的代码第一次设法做了168个batches,当它第二次被调用时,它似乎从那里开始,因为我被打印出来:
找到创建的最新检查点:checkpoint_168.pth检查点从检查点加载/checkpoint_168.pth start处理批次168/196 - Bacth大小:256,图像数量:256 28剩余批次:168完成…
所以我让它继续...问题是?它没有停止。它还在继续,目前最后打印的东西是:
处理批次195/196 - Bacth规格:256,图像数量:256 1剩余批次:195成品处理批次196/196 - Bacth规格:256,图像数量:256 0剩余批次:196成品处理批次197/196 - Bacth规格:256,图像数量:256剩余批次:197成品处理批次198/196 - Bacth大小:256,图像数量:256 - 2剩余
如果我让它继续,它会继续。你能帮助我理解发生了什么吗?我的代码不应该在数据加载器中处理第一个图像,然后在恢复时继续处理剩余的图像吗?很明显,它正在做更多的事情,或者可能正在重做应该已经完成的事情,我不明白我看到了什么。
非常感谢大家!
1条答案
按热度按时间bttbmeg01#
我认为你的问题是你如何处理数据加载器。
val_loader = create_data_loader(val, args.batchsize, device)
行创建了一个新的数据加载器,它还没有生成任何批处理。如果在for循环中使用它,它将生成196个批处理,每个迭代一个,从batch 0开始,到batch 195结束。稍后,当你执行
for j, (images, _) in enumerate(val_loader, start=start_batch):
时,这个循环无论如何都会运行196次。这里enumerate
的start
参数并没有跳过val_loader
的第一个start_batch
es,它只是让j
从start_batch
开始,而不是从0
开始,并在start_batch + 195
而不是195处结束。val_loader
仍在泵送196个批次,而不是196 - start_batch
。要跳过
start_batch
之前的所有内容,可以执行以下操作:循环仍然会运行196次,但它只会处理从批次
start_batch
开始的数据。