将PyTorch设备名称传递给模型的最佳实践

9vw9lbht  于 11个月前  发布在  其他
关注(0)|答案(2)|浏览(101)

目前,我将train.pymodel.py分开用于我的深度学习项目。
因此,对于数据集,它们被发送到**epoch for loop**内的cuda设备,如下所示。
train.py

...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
    s0 = batch_data[0].to(device)
    s1 = batch_data[1].to(device)
    pred = model(s0, s1)

字符串
然而,在我的模型中(在model.py中),它还需要访问device变量以实现skip connection like方法。
model.py

class MyNet(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(MyNet, self).__init__()
        self.conv1 = GCNConv(in_feats, hid_feats)
        ...

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x1 = copy.copy(x.float())
        x = self.conv1(x, edge_index)
        skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device)  # <--
        (some opps for x1 -> skip_conn)
        x = torch.cat((x, skip_conn), 1)


在本例中,我当前将device作为参数传递,但我认为这不是最佳实践。
1.将数据集发送到CUDA的最佳做法应该是什么?
1.如果有多个脚本需要访问device,我应该如何处理?(参数,全局变量?)

zlwx9yxi

zlwx9yxi1#

您可以向MyModel添加一个新属性来存储device信息,并在skip_conn初始化中使用它。

class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, device): # <--
    super(MyNet, self).__init__()
    self.conv1 = GCNConv(in_feats, hid_feats)
    self.device = device # <--
    self.to(self.device) # <--
    ...

def forward(self, data):
    x, edge_index = data.x, data.edge_index
    x1 = copy.copy(x.float())
    x = self.conv1(x, edge_index)
    skip_conn = torch.zeros(len(data.batch), x1.size(1), device=self.device)  # <--
    (some opps for x1 -> skip_conn)
    x = torch.cat((x, skip_conn), 1)

字符串
注意,在这个例子中,MyNet负责所有的设备逻辑,包括.to(device)调用。这样,我们将所有与模型相关的设备管理封装在模型类本身中。

vnjpjtjt

vnjpjtjt2#

我不能100%确定这是否适用于你的情况,但你也可以在模型初始化后使用.to(device)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class myModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.foo = nn.Linear(100, 100)

model = myModel().to(device)

print(next(model.parameters()).device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"

字符串
device变量作为参数包含在模型类中也是可以的。下面是另一个实现选项:

class myModel(nn.Module):
    def __init__(self):
        super().__init__()

    @property
    def device(self):
        return next(self.parameters()).device

model = myModel()
print(model.device) # "device(type='cuda', index=0)" if on GPU else "device(type='cpu')"

相关问题