Paddle pytorch代码转paddle后batch_size维度对不上

kr98yfug  于 2022-10-20  发布在  其他
关注(0)|答案(6)|浏览(290)

pytorch代码:

class SpatialGroupEnhance(nn.Module):
    def __init__(self, groups = 64):
        super(SpatialGroupEnhance, self).__init__()
        self.groups   = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight   = Parameter(torch.zeros(1, groups, 1, 1))
        self.bias     = Parameter(torch.ones(1, groups, 1, 1))
        self.sig      = nn.Sigmoid()

    def forward(self, x): # (b, c, h, w)
        b, c, h, w = x.size()
        x = x.view(b * self.groups, -1, h, w) 
        xn = x * self.avg_pool(x)
        xn = xn.sum(dim=1, keepdim=True)
        t = xn.view(b * self.groups, -1)
        t = t - t.mean(dim=1, keepdim=True)
        std = t.std(dim=1, keepdim=True) + 1e-5
        t = t / std
        t = t.view(b, self.groups, h, w)
        t = t * self.weight + self.bias
        t = t.view(b * self.groups, 1, h, w)
        x = x * self.sig(t)
        x = x.view(b, c, h, w)
        return x

paddle代码:

def sge_block(self, x, groups=64, name=None):
        weight = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32', 
                                               default_initializer=fluid.initializer.Constant(value=1.0))
        bias = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32', 
                                             default_initializer=fluid.initializer.Constant(value=0.0))

        batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
        channels_per_group = num_channels // groups
        x = fluid.layers.reshape(x=x, shape=[-1, channels_per_group, height, width])
        xn = x * fluid.layers.pool2d(input=x, pool_type='avg', global_pooling=True, use_cudnn=False)
        xn = fluid.layers.reduce_sum(input=xn, dim=1, keep_dim=True)
        t = fluid.layers.reshape(x=xn, shape=[-1, height * width])
        t = fluid.layers.reshape(x=t, shape=[-1, groups, height, width])
        t = t*weight + bias
        t = fluid.layers.reshape(x=t, shape=[-1, 1, height, width])
        x = x * fluid.layers.sigmoid(t)

        x = fluid.layers.reshape(x=x, shape=[-1, num_channels, height, width])
        return x

单卡batch_size = 32,
报错信息:
Enforce failed. Expected x_dims[i + axis] == y_dims[i], but received x_dims[i + axis]:32 != y_dims[i]:1. Broadcast dimension mismatch. at [/paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:63]

eivgtgni

eivgtgni1#

补充:numpy模拟此过程不存在维度对不上的问题

ou6hu8tu

ou6hu8tu2#

@cuicheng01 观察到问题出现在 elementwise_op_function.h:63 ,且是维度问题,你能打印出 x_dimsy_dims 的维度嘛

wj8zmpe1

wj8zmpe13#

@cuicheng01
你的问题是出在 t = t*weight + bias 这里吗?能定位出嘛?
我想的是单独计算 t*weight 并输出维度,另外在输出 bias 的维度

zujrkrfu

zujrkrfu4#

@cuicheng01
你的问题是出在 t = t*weight + bias 这里吗?能定位出嘛?
我想的是单独计算 t*weight 并输出维度,另外在输出 bias 的维度

维度无法打印,目前不能定位错误在哪里

sigwle7e

sigwle7e5#

@cuicheng01
你的问题是出在 t = t*weight + bias 这里吗?能定位出嘛?
我想的是单独计算 t*weight 并输出维度,另外在输出 bias 的维度

维度无法打印,目前不能定位错误在哪里

我的建议是可以通过动态图的方式,打印中间值的维度信息来确定,动态图的用法参考:
http://paddlepaddle.org/documentation/docs/zh/1.4/user_guides/howto/dygraph/DyGraph.html

zpgglvta

zpgglvta6#

好的,不过现在pytorch的动态图也是没有问题的

相关问题