python-3.x 如何选择图层中输入和输出调光的值?基于什么?

thtygnil  于 2022-12-05  发布在  Python
关注(0)|答案(1)|浏览(143)

我尝试在pytorch_geometric中构建一个图形编码器,它将图形特征作为输入,并在无监督学习中以低维度生成嵌入。那么这个编码器模型是否是完成这项工作的正确模型?我不知道如何正确设置输入维度,使其适用于所有类型的图形数据。最后,它给出了一个错误,我不知道它从哪里来的,因为我想做投影,以便将数据投影到一个更低的维度,有人能帮我吗?

class Encoder(torch.nn.Module):
    def __init__(self, node_features_dim, out_features):
        super(GNN, self).__init__()

        # Base model
        self.conv1 = GCNConv(node_features_dim, 2 * out_features)
        self.conv2 = GCNConv(2 * out_features, 2 * out_features)
        self.conv3 = GCNConv(2 * out_features, out_features)
        # projection model
        self.projection = Linear(node_features_dim, out_features, bias=False)

    def forward(self, x, edge_index):
        emb = self.conv1(x, edge_index)
        emb.relu()
        emb = self.conv2(emb, edge_index)
        emb.relu()
        emb = self.conv3(emb, edge_index)
        emb.relu()

        emb = self.projection(emb)  

        return emb

data = next(iter(dataloader))

x, edge_index = data.x, data.edge_index

num_features = data.x.shape[-1]
out_features = 4  # based on what I should set this? 

hidden_dim = num_features // 4
    
    
        Traceback (most recent call last):
          File "C:\Users\marl\Desktop\files\model.py", line 204, in <module>
            emb = model(x, edge_index)
          File "C:\Users\marl\miniconda3\envs\tensor\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
            return forward_call(*input, **kwargs)
          File "C:\Users\marl\Desktop\files\model.py", line 101, in forward
            embeddings = self.projection(embeddings)  
          File "C:\Users\marl\miniconda3\envs\tensor\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
            return forward_call(*input, **kwargs)
          File "C:\Users\marl\miniconda3\envs\tensor\lib\site-packages\torch_geometric\nn\dense\linear.py", line 136, in forward
            return F.linear(x, self.weight, self.bias)
        RuntimeError: mat1 and mat2 shapes cannot be multiplied (441x4 and 44x4)
        
        Process finished with exit code 1
c7rzv4ha

c7rzv4ha1#

看起来您正尝试在Encoder类的forward方法中将形状为41x4的Tensor与形状为44x4的Tensor相乘。这是不可能的,因为内部维度(第二维和第三维)必须匹配才能执行矩阵乘法。
通过确保要相乘的Tensor具有兼容的形状,可以修复此错误。例如,可以使用torch.mm方法对形状为**(m,n)(n,p)的两个Tensor执行矩阵乘法,以生成形状为(m,p)的Tensor。**
下面是如何修改代码以修复此错误的示例:

class Encoder(torch.nn.Module):
    def __init__(self, node_features_dim, out_features):
        super(GNN, self).__init__()

        # Base model
        self.conv1 = GCNConv(node_features_dim, 2 * out_features)
        self.conv2 = GCNConv(2 * out_features, 2 * out_features)
        self.conv3 = GCNConv(2 * out_features, out_features)
        # projection model
        self.projection = Linear(node_features_dim, out_features, bias=False)

    def forward(self, x, edge_index):
        emb = self.conv1(x, edge_index)
        emb.relu()
        emb = self.conv2(emb, edge_index)
        emb.relu()
        emb = self.conv3(emb, edge_index)
        emb.relu()

        # Ensure that the tensors have compatible shapes before multiplying
        emb = torch.mm(emb, self.projection)

        return emb

相关问题