pytorch代码:
class SpatialGroupEnhance(nn.Module):
def __init__(self, groups = 64):
super(SpatialGroupEnhance, self).__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight = Parameter(torch.zeros(1, groups, 1, 1))
self.bias = Parameter(torch.ones(1, groups, 1, 1))
self.sig = nn.Sigmoid()
def forward(self, x): # (b, c, h, w)
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
xn = xn.sum(dim=1, keepdim=True)
t = xn.view(b * self.groups, -1)
t = t - t.mean(dim=1, keepdim=True)
std = t.std(dim=1, keepdim=True) + 1e-5
t = t / std
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t)
x = x.view(b, c, h, w)
return x
paddle代码:
def sge_block(self, x, groups=64, name=None):
weight = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32',
default_initializer=fluid.initializer.Constant(value=1.0))
bias = fluid.layers.create_parameter(shape=[1,groups,1,1], dtype='float32',
default_initializer=fluid.initializer.Constant(value=0.0))
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
channels_per_group = num_channels // groups
x = fluid.layers.reshape(x=x, shape=[-1, channels_per_group, height, width])
xn = x * fluid.layers.pool2d(input=x, pool_type='avg', global_pooling=True, use_cudnn=False)
xn = fluid.layers.reduce_sum(input=xn, dim=1, keep_dim=True)
t = fluid.layers.reshape(x=xn, shape=[-1, height * width])
t = fluid.layers.reshape(x=t, shape=[-1, groups, height, width])
t = t*weight + bias
t = fluid.layers.reshape(x=t, shape=[-1, 1, height, width])
x = x * fluid.layers.sigmoid(t)
x = fluid.layers.reshape(x=x, shape=[-1, num_channels, height, width])
return x
单卡batch_size = 32,
报错信息:Enforce failed. Expected x_dims[i + axis] == y_dims[i], but received x_dims[i + axis]:32 != y_dims[i]:1. Broadcast dimension mismatch. at [/paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:63]
6条答案
按热度按时间eivgtgni1#
补充:numpy模拟此过程不存在维度对不上的问题
ou6hu8tu2#
@cuicheng01 观察到问题出现在
elementwise_op_function.h:63
,且是维度问题,你能打印出x_dims
与y_dims
的维度嘛wj8zmpe13#
@cuicheng01
你的问题是出在
t = t*weight + bias
这里吗?能定位出嘛?我想的是单独计算
t*weight
并输出维度,另外在输出bias
的维度zujrkrfu4#
@cuicheng01
你的问题是出在
t = t*weight + bias
这里吗?能定位出嘛?我想的是单独计算
t*weight
并输出维度,另外在输出bias
的维度维度无法打印,目前不能定位错误在哪里
sigwle7e5#
@cuicheng01
你的问题是出在
t = t*weight + bias
这里吗?能定位出嘛?我想的是单独计算
t*weight
并输出维度,另外在输出bias
的维度维度无法打印,目前不能定位错误在哪里
我的建议是可以通过动态图的方式,打印中间值的维度信息来确定,动态图的用法参考:
http://paddlepaddle.org/documentation/docs/zh/1.4/user_guides/howto/dygraph/DyGraph.html
zpgglvta6#
好的,不过现在pytorch的动态图也是没有问题的