pytorch 在PytTorch中多次挤压的快速方法?

vwkv1x7d  于 2023-06-29  发布在  其他
关注(0)|答案(1)|浏览(134)

在pytorchTensor中,有没有什么有效的方法可以多次挤压(取消挤压)?
例如,Tensora具有形状:[4,1,1,2],

import torch 
import torch.nn as nn
a = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.float32)
a = a.reshape(4,1,1,2)
print(a.shape)

torch.Size([4, 1, 1, 2])

我想将a压缩到[4,2],但不手动压缩两次,例如,

a1 = a.squeeze(1).squeeze(1)
print(a1.shape)
torch.Size([4, 2])

有没有什么方法可以让我不用写(un)squeeze(1)两次/多次。谢谢

jv4diomz

jv4diomz1#

不要使用a.squeeze(1)。只要a.squeeze()就能给予你想要的结果。

相关问题