Paddle RNNBase 模型内参数冗余导致基于RNN的训练模型文件体积大

2fjabf4q  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(44)

bug描述 Describe the Bug

bug描述 Describe the Bug:

所有基于RNN 的模型都有大量的参数冗余,导致训练模型文件体积过大。以PaddleOCR为例,ch_PP-OCRv3_rec 训练模型

以下代码中的断言全部通过,证明了形如( rnn.{bw}_{ih}_l{cn} 和rnn.{cn}.cell.layer.{bw}{ih}')两个名称的参数完全一致。相比之下,去掉冗余参数之后保存后的模型文件如下图,体积减少了约11.7%。

问题应该是出在,RNNBase的 init 方法中,虽然推理模型中没有这个问题,但是训练模型里也应该被优化掉。

import paddle

# 加载保存的训练模型
static_model_path = r'model/ch_PP-OCRv3_rec_slim_train/best_accuracy.pdparams'
state_dict = paddle.load(static_model_path)

# 下面证明每个参数文件中都有大量重复的参数
for role in ('Teacher',"Student"):
    for bw in ('weight','bias'):
        for ih in ('ih','hh'):
            for cn in (0,1):
                print(a:=state_dict[f"{role}.head.sar_head.decoder.rnn_decoder.{bw}_{ih}_l{cn}"])
                print(b:=state_dict[f'{role}.head.sar_head.decoder.rnn_decoder.{cn}.cell._layer.{bw}_{ih}'])
                assert paddle.all(a == b)
                print(paddle.all(a == b))
                state_dict.pop(f"{role}.head.sar_head.decoder.rnn_decoder.{bw}_{ih}_l{cn}")

paddle.save(state_dict,'model/ch_PP-OCRv3_rec_slim_train/new.pdparams')

下面是最小可证明代码段:

import paddle
from paddle.nn import LSTM

m = LSTM(3, 5, 2, 'bidirectional')
state_dict = m.state_dict()
print(m.state_dict())
for k in m.state_dict():
    print(k)

for bw in ('weight', 'bias'):
    for ih in ('ih', 'hh'):
        for cn in (0, 1):
            for fw in (0, 1):
                print(a := state_dict[f"{bw}_{ih}_l{cn}{'' if fw else '_reverse'}"])
                print(b := state_dict[f"{cn}.cell_{'fw' if fw else 'bw'}.{bw}_{ih}"])
                assert paddle.all(a == b)
                print(paddle.all(a == b))

其他补充信息 Additional Supplementary Information

No response

xmakbtuz

xmakbtuz1#

已经联系相关同学来看该问题,请您稍等

相关问题