在pytorchTensor中,有没有什么有效的方法可以多次挤压(取消挤压)?
例如,Tensora具有形状:[4,1,1,2],
import torch
import torch.nn as nn
a = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.float32)
a = a.reshape(4,1,1,2)
print(a.shape)
torch.Size([4, 1, 1, 2])
我想将a压缩到[4,2],但不手动压缩两次,例如,
a1 = a.squeeze(1).squeeze(1)
print(a1.shape)
torch.Size([4, 2])
有没有什么方法可以让我不用写(un)squeeze(1)
两次/多次。谢谢
1条答案
按热度按时间jv4diomz1#
不要使用
a.squeeze(1)
。只要a.squeeze()
就能给予你想要的结果。