Pytorch 1.0:net.to(设备)在nn.DataParallel中做什么?

42fyovps  于 2023-02-19  发布在  其他
关注(0)|答案(2)|浏览(210)

pytorch data paraleelism教程中的以下代码对我来说很奇怪:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

据我所知,mode.to(device)将数据复制到GPU。
DataParallel自动分割数据,并将作业订单发送到多个GPU上的多个模型。每个模型完成作业后,DataParallel会收集并合并结果,然后将结果返回给您。
如果DataParallel执行复制工作,那么to(device)在这里做什么呢?

wydwbb8l

wydwbb8l1#

他们在教程中添加了几行来解释nn.DataParallel
DataParallel自动分割数据,并使用数据将作业订单发送到不同GPU上的多个型号。每个型号完成作业后,DataParallel会为您收集并合并结果。
上面的引用可以理解为nn.DataParallel只是一个 Package 类,用于通知model.cuda()应该向GPU进行多个复制。
在我的情况下,我的笔记本电脑上没有任何GPU,我仍然调用nn.DataParallel()没有任何问题。

import torch
import torchvision

model = torchvision.models.alexnet()
model = torch.nn.DataParallel(model)
# No error appears if I don't move the model to `cuda`
ohfgkhjo

ohfgkhjo2#

以下内容是否正确?

model = MyModel(args).cuda()
model = torch.nn.DataParallel(model).cuda()

相关问题