目前,我将train.py
与model.py
分开用于我的深度学习项目。
因此,对于数据集,它们被发送到**epoch for loop
**内的cuda设备,如下所示。train.py
个
...
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = MyNet(~).to(device)
...
for batch_data in train_loader:
s0 = batch_data[0].to(device)
s1 = batch_data[1].to(device)
pred = model(s0, s1)
字符串
然而,在我的模型中(在model.py中),它还需要访问device变量以实现skip connection like方法。model.py
个
class MyNet(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super(MyNet, self).__init__()
self.conv1 = GCNConv(in_feats, hid_feats)
...
def forward(self, data):
x, edge_index = data.x, data.edge_index
x1 = copy.copy(x.float())
x = self.conv1(x, edge_index)
skip_conn = torch.zeros(len(data.batch), x1.size(1)).to(device) # <--
(some opps for x1 -> skip_conn)
x = torch.cat((x, skip_conn), 1)
型
在本例中,我当前将device
作为参数传递,但我认为这不是最佳实践。
1.将数据集发送到CUDA的最佳做法应该是什么?
1.如果有多个脚本需要访问device
,我应该如何处理?(参数,全局变量?)
2条答案
按热度按时间zlwx9yxi1#
您可以向
MyModel
添加一个新属性来存储device
信息,并在skip_conn
初始化中使用它。字符串
注意,在这个例子中,
MyNet
负责所有的设备逻辑,包括.to(device)
调用。这样,我们将所有与模型相关的设备管理封装在模型类本身中。vnjpjtjt2#
我不能100%确定这是否适用于你的情况,但你也可以在模型初始化后使用
.to(device)
:字符串
将
device
变量作为参数包含在模型类中也是可以的。下面是另一个实现选项:型