PaddleOCR KIE中的SER模型训练报错

xvw2m8pv  于 2022-11-02  发布在  其他
关注(0)|答案(1)|浏览(759)

win11
paddlenlp 2.3.1
paddleocr 2.5.0.3
paddlepaddle-gpu 0.0.0.post110

ser_vi_layoutxlm_xfund_zh.yml
Architecture:
model_type: kie
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: true
checkpoints:

one of base or vi

mode: vi
num_classes: &num_classes 7

因为网络下载模型有问题,所以在paddlenlp.transformers.model_utils中直接写了本地模型路径(PaadleOCR提供的预训练模型)
'./inference/ser_vi_layoutxlm_xfund_pretrained/best_accuracy/model_state.pdparams'

Traceback (most recent call last):
File "tools/train.py", line 201, in
main(config, device, logger, vdl_writer)
File "tools/train.py", line 174, in main
program.train(config, train_dataloader, valid_dataloader, device, model,
File "E:\自研OCR接口\PaddleOCR-release-2.6\tools\program.py", line 299, in train
preds = model(batch)
File "C:\software\Anaconda3\lib\site-packages\paddle\fluid\dygraph\layers.py", line 911, incall
outputs = self.forward(inputs,kwargs)
File "E:\自研OCR接口\PaddleOCR-release-2.6\ppocr\modeling\architectures\base_model.py", line 86, in forward
x = self.backbone(x)
File "C:\software\Anaconda3\lib\site-packages\paddle\fluid\dygraph\layers.py", line 911, in
call
*
outputs = self.forward(*inputs,**kwargs)
File "E:\自研OCR接口\PaddleOCR-release-2.6\ppocr\modeling\backbones\vqa_layoutlm.py", line 177, in forward
res.update(x[1])
IndexError: tuple index out of range

ukxgm1gy

ukxgm1gy1#

不确定这个修改是否正确,但是修改后训练正常
修改:

模型网络输出层修改

ppocr.modeling.backbones.vqa_layoutlm
class LayoutLMv2ForSer(NLPBaseModel)
if self.training:
res = {"backbone_out": x[0]}

res.update(x[1]) 注解掉

return res
else:
return x

模型加载

paddlenlp.transformers.model_utils
模型加载的path直接path = './inference/model_state.pdparams'
模型网上提前下载好。
pretrained_resource_files_map = {
"model_state": {
"layoutxlm-base-uncased":
" https://bj.bcebos.com/paddlenlp/models/transformers/layoutxlm_base/model_state.pdparams ",
}
}
上面模型的下载,在win上面最后一个'/'会变成''导致无法下载。指定本地模型也可以解决无法联网的问题。

疑问:ser_vi_layoutxlm_xfund_pretrained中提供的预训练模型加载会报错。
试过checkpoints的再训练也不行,而且**_udml还不支持,不知道后续有没有相关优化。

相关问题