pytorch我不知道如何定义多个模型

kwvwclae  于 2023-03-23  发布在  其他
关注(0)|答案(1)|浏览(104)

我想在pytorch中使用两个不同的模型。因此,我执行了下面的代码,但我无法成功运行第二个模型。我该怎么做呢?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(2, 64)
        self.linear2 = nn.Linear(64, 3)

    def forward(self, x):   
        x = self.linear1(x)
        x = torch.sigmoid(x)  
        x = self.linear2(x)

        return x
                
class Model2(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(2, 64)
        self.linear2 = nn.Linear(64, 1)

    def forward(self, x):   
        x = self.linear1(x)
        x = torch.sigmoid(x)  
        x = self.linear2(x)

        return x
                
net = Model()
net2 = Model2()

错误

类型错误跟踪(最后调用最近的)

/tmp/ipykernel_477/2280223066.py in <module>
     26 
     27 net = Model()
---> 28 net2 = Model2()

/tmp/ipykernel_477/2280223066.py in __init__(self)
     14 class Model2(nn.Module):
     15     def __init__(self):
---> 16         super(Model, self).__init__()
     17         self.linear1 = nn.Linear(2, 64)
     18         self.linear2 = nn.Linear(64, 1)

TypeError:super(type,obj):obj必须是类型的示例或子类型

xwbd5t1u

xwbd5t1u1#

这是因为你在第16行调用了super(Model,self).init()而不是super(Model2,self).init()。你必须每次都更改模型的名称以匹配当前类的名称。第3行运行是因为在这种情况下Model实际上是类的名称。

相关问题