如何方便地获取pytorch模块的设备类型?

elcex8rz  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(167)

我必须用不同的设备在不同种类的pytorch模型上堆叠一些我自己的层。
例如,A是一个cuda模型,B是一个cpu模型(但在得到设备类型之前我并不知道它)。那么新的模型分别是CD,其中

class NewModule(torch.nn.Module):
    def __init__(self, base):
        super(NewModule, self).__init__()
        self.base = base
        self.extra = my_layer() # e.g. torch.nn.Linear()

    def forward(self,x):
        y = self.base(x)
        z = self.extra(y)
        return z

...

C = NewModule(A) # cuda
D = NewModule(B) # cpu

但是我必须将baseextra移到相同的设备上,也就是说C的baseextracuda模型,D的是cpu模型。所以我尝试了__inin__

def __init__(self, base):
    super(NewModule, self).__init__()
    self.base = base
    self.extra = my_layer().to(base.device)

不幸的是,torch.nn.Module中没有属性device(引发AttributeError)。
我应该怎么做才能获得base的设备类型?或者使用任何其他方法使baseextra自动位于同一个设备上,即使base的结构是不特定的?

vh0rcniy

vh0rcniy1#

这个问题已经被问过很多次了(12)。引用一个PyTorch开发人员的回答:
That’s not possible. Modules can hold parameters of different types on different devices, and so it’s not always possible to unambiguously determine the device.
推荐的工作流程(如PyTorch博客中所述)是单独创建device对象并在任何地方使用它。


# at beginning of the script

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module

# this won't copy if they are already on the desired device

input = data.to(device)
model = MyModule(...).to(device)

请注意,没有什么可以阻止您向模型添加.device属性。
正如Kani(在注解中)提到的,如果模型中的所有参数都在同一个设备上,则可以使用next(model.parameters()).device

h9vpoimq

h9vpoimq2#

我的解决方案,在99%的情况下有效。

class Net(nn.Module):
  def __init__()
    super().__init__()
    self.dummy_param = nn.Parameter(torch.empty(0))

  def forward(x):
    device = self.dummy_param.device
    ... etc

此后,dummy_param将始终与模块Net具有相同的设备,因此您可以随时获取它。例如:

net = Net()
net.dummy_param.device

'cpu'

net = net.to('cuda')
net.dummy_param.device

'cuda:0'
00jrzges

00jrzges3#

@Duane的答案在模型中创建了一个参数(尽管是一个小Tensor)。
我觉得这个答案稍微有点

class Model(nn.Module):
    def __init__(self, *args,**kwargs):
        super().__init__()
        self.device = torch.device('cpu') # device parameter not defined by default for modules

    def _apply(self, fn):
        # https://stackoverflow.com/questions/54706146/moving-member-tensors-with-module-to-in-pytorch
        # override apply by moving the attribute device of the class object as well.
        # This allows to directly know where the class is when creating new attribute for the class object.
        super()._apply(fn)
        self.device = fn(self.device)
        return self

net.cuda()net.float()等都将同样工作,因为它们都调用_apply而不是to(如在the source中所见)。
来自@Kani的评论的替代解决方案(接受的答案)也非常优雅:

class Model(nn.Module):
    def __init__(self, *args,**kwargs):
        """
        Constructor for Neural Network.
        """
        super().__init__()

    @property
    def device(self):
        return next(self.parameters()).device

对于参数,您通过model.device访问设备。当模型中没有参数时,此解决方案不起作用。

e0uiprwp

e0uiprwp4#

您可以使用next(model.parameters()).device

相关问题