Paddle 三层CNN, cifar10数据集,用momentum不收敛,用adam准确率比tf低3%~4%

332nm8kg  于 2021-11-30  发布在  Java
关注(0)|答案(4)|浏览(325)

2.0-alpha 版本的飞桨。
对比结果如下面两张图:

使用adam

使用momentum

BTW: mnist数据集,mlp网络用momentum可以收敛。

用来复现的notebook(可以在colab上,gpu运行)
https://gist.github.com/jzhang533/df107c3a91d896874b6437caf2907be0

zbq4xfa0

zbq4xfa01#

感谢使用Paddle。
请首先确认是否所有的超参数(学习率,初始化方法等等)都和benchmark的方法对齐?

zi8p0yeb

zi8p0yeb2#

肯定是对齐了的啊,我贴复现的代码了,你点开看了吗?

eufgjt7s

eufgjt7s3#

paddle 的momentum实现和tf 有区别,相同参数不收敛。Adam低3%的问题需要再看

xam8gpfp

xam8gpfp4#

更新一下:1.8.2版本的paddle,同样的任务和参数设置,用Momentum是可以收敛的。
复现地址:https://gist.github.com/jzhang533/cc74fbb9fa1f1604791accdd520f6def

2.0-alpha版本的paddle,做文本分类也有不收敛的问题。
复现地址:https://gist.github.com/jzhang533/78d7b9674a272e58cd56763b884f5ff6

相关问题