bug描述 Describe the Bug
bug描述 Describe the Bug:
所有基于RNN 的模型都有大量的参数冗余,导致训练模型文件体积过大。以PaddleOCR为例,ch_PP-OCRv3_rec 训练模型
以下代码中的断言全部通过,证明了形如( rnn.{bw}_{ih}_l{cn}
和rnn.{cn}.cell.layer.{bw}{ih}')两个名称的参数完全一致。相比之下,去掉冗余参数之后保存后的模型文件如下图,体积减少了约11.7%。
问题应该是出在,RNNBase的 init 方法中,虽然推理模型中没有这个问题,但是训练模型里也应该被优化掉。
import paddle
# 加载保存的训练模型
static_model_path = r'model/ch_PP-OCRv3_rec_slim_train/best_accuracy.pdparams'
state_dict = paddle.load(static_model_path)
# 下面证明每个参数文件中都有大量重复的参数
for role in ('Teacher',"Student"):
for bw in ('weight','bias'):
for ih in ('ih','hh'):
for cn in (0,1):
print(a:=state_dict[f"{role}.head.sar_head.decoder.rnn_decoder.{bw}_{ih}_l{cn}"])
print(b:=state_dict[f'{role}.head.sar_head.decoder.rnn_decoder.{cn}.cell._layer.{bw}_{ih}'])
assert paddle.all(a == b)
print(paddle.all(a == b))
state_dict.pop(f"{role}.head.sar_head.decoder.rnn_decoder.{bw}_{ih}_l{cn}")
paddle.save(state_dict,'model/ch_PP-OCRv3_rec_slim_train/new.pdparams')
下面是最小可证明代码段:
import paddle
from paddle.nn import LSTM
m = LSTM(3, 5, 2, 'bidirectional')
state_dict = m.state_dict()
print(m.state_dict())
for k in m.state_dict():
print(k)
for bw in ('weight', 'bias'):
for ih in ('ih', 'hh'):
for cn in (0, 1):
for fw in (0, 1):
print(a := state_dict[f"{bw}_{ih}_l{cn}{'' if fw else '_reverse'}"])
print(b := state_dict[f"{cn}.cell_{'fw' if fw else 'bw'}.{bw}_{ih}"])
assert paddle.all(a == b)
print(paddle.all(a == b))
其他补充信息 Additional Supplementary Information
No response
1条答案
按热度按时间xmakbtuz1#
已经联系相关同学来看该问题,请您稍等