Paddle 2.5.1 adamw 自定义param_group Beta无效

ubof19bj  于 5个月前  发布在  其他
关注(0)|答案(2)|浏览(44)

bug描述 Describe the Bug

AdamW,希望给不同的param_group设置不同的beta,如文档所示( https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/AdamW_cn.html )
opt = paddle.optimizer.AdamW(
... learning_rate=0.1,
... parameters=[{
... 'params': linear_1.parameters()
... }, {
... 'params': linear_2.parameters(),
... 'weight_decay': 0.001,
... 'learning_rate': 0.1,
... 'beta1': 0.8
... }],
... weight_decay=0.01,
... beta1=0.9
... )
有设置AMP,DDP
运行无报错,但结果并未呈现不同的beta。通过设置相同seed使得可重复
针对最后一层设置了特殊的beta,其他层的beta1beta2为默认的0.9和0.999
下表中:
glc为最后一层的梯度范数,由model.param.grad计算
glcm0为最后一层梯度一阶矩的范数,由adamw.state_dict()得到
glcz为最后一层一阶矩/二阶矩.sqrt(),并除以最后一层的形状,得到的是梯度一阶矩/二阶矩.sqrt()矩阵的平均模长
理论上,随着beta1变小,梯度一阶矩的动量占比降低,将更多随着当前梯度变化而变化
但是跑出来的结果并不是这样,而是梯度一阶矩(glcm0)持续快速的降低,虽然此时梯度(glc列)出现升高
感觉设置的beta1并未生效
beta1:0.8,beta2:0.9
update:50, Apr 17 06:47:11 lr:1.6e-05, glc:3.2, glcm0:3.3, glcz:3.9,
update:100, Apr 17 06:58:06 lr:2.3e-05, glc:0.53, glcm0:0.54, glcz:0.98,
update:150, Apr 17 07:09:01 lr:3.0e-05, glc:0.25, glcm0:0.28, glcz:0.58,
update:200, Apr 17 07:19:56 lr:3.6e-05, glc:0.9, glcm0:0.18, glcz:0.32,
update:250, Apr 17 07:30:51 lr:4.3e-05, glc:1.7, glcm0:0.14, glcz:0.22,
update:300, Apr 17 07:41:46 lr:4.9e-05, glc:0.92, glcm0:0.12, glcz:0.14,
update:350, Apr 17 07:52:41 lr:5.6e-05, glc:2.6, glcm0:0.2, glcz:0.17,
update:400, Apr 17 08:03:37 lr:6.3e-05, glc:3.1, glcm0:0.19, glcz:0.14,
update:450, Apr 17 08:14:33 lr:6.9e-05, glc:3.3, glcm0:0.2, glcz:0.13,
update:500, Apr 17 08:25:29 lr:7.6e-05, glc:4, glcm0:0.32, glcz:0.23,

beta1:0.5,beta2:0.6
update:50, Apr 17 08:44:06 lr:1.6e-05,glc:3.2, glcm0:3.3, glcz:3.9,
update:100, Apr 17 08:55:01 lr:2.3e-05, glc:0.53, glcm0:0.54, glcz:0.98,
update:150, Apr 17 09:05:56 lr:3.0e-05, glc:0.25, glcm0:0.28, glcz:0.58,
update:200, Apr 17 09:16:51 lr:3.6e-05, glc:0.77, glcm0:0.18, glcz:0.31,
update:250, Apr 17 09:27:45 lr:4.3e-05, glc:2, glcm0:0.15, glcz:0.23,
update:300, Apr 17 09:38:41 lr:4.9e-05, glc:1.4, glcm0:0.15, glcz:0.15,
update:350, Apr 17 09:49:37 lr:5.6e-05, glc:3.1, glcm0:0.18, glcz:0.18,
update:400, Apr 17 10:00:33 lr:6.3e-05, glc:2.3, glcm0:0.13, glcz:0.12,
update:450, Apr 17 10:11:29 lr:6.9e-05, glc:3.5, glcm0:0.18, glcz:0.12,
update:500, Apr 17 10:22:24 lr:7.6e-05, glc:3.6, glcm0:0.18, glcz:0.11,

beta1:0.3,beta2:0.4
update:50, Apr 17 10:45:34 lr:1.6e-05, glc:3.2, glcm0:3.3, glcz:3.9,
update:100, Apr 17 10:56:29 lr:2.3e-05, glc:0.53, glcm0:0.54, glcz:0.98,
update:150, Apr 17 11:07:23 lr:3.0e-05, glc:0.25, glcm0:0.28, glcz:0.58,
update:200, Apr 17 11:18:18 lr:3.6e-05, glc:0.77, glcm0:0.18, glcz:0.31,
update:250, Apr 17 11:29:13 lr:4.3e-05, glc:1.8, glcm0:0.14, glcz:0.22,
update:300, Apr 17 11:40:09 lr:4.9e-05, glc:2.3, glcm0:0.19, glcz:0.2,
update:350, Apr 17 11:51:05 lr:5.6e-05, glc:2.4, glcm0:0.15, glcz:0.13,
update:400, Apr 17 12:02:01 lr:6.3e-05, glc:2.8, glcm0:0.16, glcz:0.14,
update:450, Apr 17 12:12:56 lr:6.9e-05, glc:3, glcm0:0.17, glcz:0.12,
update:500, Apr 17 12:23:51 lr:7.6e-05, glc:2.9, glcm0:0.14, glcz:0.12,

beta1:0.1,beta2:0.2
update:50, Apr 17 13:13:29 lr:1.6e-05, af:0.027, gWa:0.013, glc:3.2, glcm0:3.3, glcz:3.9,
update:100, Apr 17 13:24:25 lr:2.3e-05, af:0.031, gWa:0.012, glc:0.53, glcm0:0.54, glcz:0.98,
update:150, Apr 17 13:35:21 lr:3.0e-05, af:0.035, gWa:0.012, glc:0.25, glcm0:0.28, glcz:0.58,
update:200, Apr 17 13:46:18 lr:3.6e-05, af:0.039, gWa:0.01, glc:0.76, glcm0:0.18, glcz:0.31,
update:250, Apr 17 13:57:14 lr:4.3e-05, af:0.044, gWa:0.0063, glc:2.1, glcm0:0.16, glcz:0.24,
update:300, Apr 17 14:08:10 lr:4.9e-05, af:0.048, gWa:0.0049, glc:2.9, glcm0:0.2, glcz:0.24,
update:350, Apr 17 14:19:06 lr:5.6e-05, af:0.052, gWa:0.0032, glc:2.1, glcm0:0.14, glcz:0.15,
update:400, Apr 17 14:30:03 lr:6.3e-05, af:0.056, gWa:0.0031, glc:3, glcm0:0.16, glcz:0.14,
update:450, Apr 17 14:40:59 lr:6.9e-05, af:0.06, gWa:0.003, glc:3.6, glcm0:0.19, glcz:0.14,
update:500, Apr 17 14:51:56 lr:7.6e-05, af:0.064, gWa:0.0023, glc:2.6, glcm0:0.16, glcz:0.1,

其他补充信息 Additional Supplementary Information

full_version = '2.5.1'
major = '2'
minor = '5'
patch = '1'
rc = '0'
cuda_version = '12.0'
cudnn_version = '8.9.1'
xpu_version = 'False'
xpu_xccl_version = 'False'
istaged = False
commit = '41ba14f30600373df53839dbf763405cfacb5c92'
with_mkl = 'ON'
cinn_version = 'False'

相关问题