在PyTorch中使用分布式数据并行(DDP)时,在训练期间检查点的正确方式是什么?

fjnneemd  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(142)

我想(正确和官方的-错误免费的方式)做:
1.从检查点恢复以继续在多个GPU上训练
1.在使用多个gpu进行培训期间正确保存检查点
对此,我的猜测如下:
1.第一步,我们让所有进程从文件中加载检查点,然后为每个进程调用DDP(mdl),我假设检查点保存了一个ddp_mdl.module.state_dict()
1.要执行2,只需检查谁的排名= 0,并让那个人执行torch.save({'model':ddp_mdl.module.state_dict()})
近似代码:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

是这样吗?
我提出这个问题的原因之一是分布式代码可能会出现微妙的错误。我希望确保这种情况不会发生在我身上。当然,我希望避免死锁,但如果发生在我身上,那将是显而易见的(例如,如果所有进程都试图同时打开同一个ckpt文件,则可能会发生这种情况。在这种情况下,d以某种方式确保它们中只有一个一次加载一个,或者使秩0仅加载它,然后将它发送到其余进程)。
我问这个问题也是因为官方文档对我来说没有意义。我会粘贴他们的代码和解释,因为链接有时会死:
保存和加载检查点torch.save在训练和从检查点恢复期间,通常使用www.example.com和torch.load to checkpoint模块。有关详细信息,请参见保存和加载模型。使用DDP时,一种优化是仅在一个进程中保存模型,然后将其加载到所有进程。这是正确的,因为所有的处理从相同的参数开始,并且梯度在反向传递中被同步,因此优化器应该保持参数值的一致性,如果使用这种优化,请确保在保存完成之前所有进程都不会开始加载,另外,在加载模块时,需要提供一个合适的map_location参数,以防止进程进入其他设备,如果map_location缺失,torch.load会先将模块加载到CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用同一组设备。有关更高级的故障恢复和弹性支持,请参阅TorchElastic。

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

相关:

brgchamk

brgchamk1#

我正在看官方的ImageNet example,下面是他们是如何做到这一点的。首先,他们在DDP模式下创建模型:

model = ResNet50(...)
model = DDP(model,...)

在保存检查点,他们检查它是否是主进程,然后保存state_dict

import torch.distributed as dist

if dist.get_rank() == 0:  # check if main process, a simpler way compared to the link
    torch.save({'state_dict': model.state_dict(), ...},
                '/path/to/checkpoint.pth.tar')

在加载过程中,他们加载模型并像往常一样将其置于DDP模式,而无需检查秩:

checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)


如果你想在DDP模式下加载它,但不是,这有点棘手,因为出于某些原因,他们用额外的后缀module保存它。

state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

相关问题