PyTorch将嵌套模块迁移到GPU?

rdlzhqv9  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(118)

我是PyTorch的新手,有一些自定义的nn.Module,我想在GPU上运行。我们称它们为M_outerM_innerM_sub。一般来说,结构看起来像:

class M_outer(nn.Module):
    def __init__(self, **kwargs):
        self.inner_1 = M_inner(**kwargs)
        self.inner_2 = M_inner(**kwargs)
        # ...
    
    def forward(self, input):
        result = self.inner_1(input)
        result = self.inner_2(result)
        # ...
        return result
    

class M_inner(nn.Module):
    def __init__(self, **kwargs):
        self.sub_1 = M_sub(**kwargs)
        self.sub_2 = M_sub(**kwargs)
        # ...
    
    def forward(self, input):
        result = self.sub_1(input)
        result = self.sub_2(result)
        # ...
        return result

class M_sub(nn.Module):
    def __init__(self, **kwargs):
        self.emb_1 = nn.Embedding(x, y)
        self.emb_2 = nn.Embedding(x, y)
        # ...
        self.norm  = nn.LayerNorm()
    
    def forward(self, input):
        emb = (self.emb_1(input) + self.emb_2(input))        
        # ...
        return self.norm(emb)

我试着把我的模块放到GPU上,通过:

model = M_outer(params).to(device)

然而,我仍然从嵌入层中得到错误,说有些操作是在cpu上进行的。
我已阅读文档。我读过一些有用的讨论文章,比如this和相关的StackOverflow文章,比如this
我无法通过nn.Parameter注册nn.EmbeddingLayer。我错过了什么?

6yt4nkrj

6yt4nkrj1#

PyTorch会将所有子模块移动到指定的设备中。你的例子应该很好。为了重现性,我做了一些修改:

import torch
from torch import nn

class M_outer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = M_inner()
    def forward(self, input):
        return self.fc(input)
    

class M_inner(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = M_sub()
    def forward(self, input):
        return self.fc(input)

class M_sub(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    def forward(self, input):
        return self.fc(input)

model = M_outer().to("cuda")
t = torch.randn(1).unsqueeze(0).to("cuda")
model(t)

需要注意的是,PyTorch不会移动不是nn.Module示例的类成员。因此,如果在推理过程中使用静态Tensor进行类内的计算,则需要手动将其移动到设备中。

相关问题