在训练过程中,我可以只使用PytorchTensor中的某些子Tensor吗?

wi3ka0sx  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(163)

例如,我有一个大小为(B,4,H,W)的PytorchTensor,我想在训练过程中只使用轴1的某些子Tensor(具体来说,索引为0和3)作为模型。剩下的两个子Tensor将不会在训练中使用。
当然,我可以将Tensor更改为(B,2,H,W)大小,但我很好奇,如果我使用原始Tensor,训练和推理过程是否仍然稳定?

z9gpfhce

z9gpfhce1#

当然!这种方法通常用于以不同的方式处理数据的特定组件。例如:

import torch
from torch import nn
B, C, H, W = 100, 4, 224, 224
inputs = torch.randn(B, C, H, W)
    
class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 3, 1, 1)
        self.conv2 = nn.Conv2d(2, 2, 3, 1, 1)
    
    def forward(self, x):
        x1 = x[:, :2, ...]
        x2 = x[:, 2:, ...]
        res1 = self.conv1(x1)
        res2 = self.conv2(x2)
        res = torch.cat([res1, res2], dim=1)
        return res
    
my_model = model()
res = my_model(inputs)
res.shape

相关问题