pytorch 无法从aitextgen库导入TextDataset类

vyswwuz2  于 2023-05-29  发布在  其他
关注(0)|答案(1)|浏览(173)

我正在尝试构建一个机器学习模型,通过用我的职位描述数据集对GPT-NEO模型进行微调来生成职位描述
我试图从aitextgen导入TextDataset类,我一直面临这个错误

ImportError: cannot import name 'DeepSpeedPlugin' from 'pytorch_lightning.plugins

我确实尝试了很多chat gpt的建议,其中包括单独安装DeepSpeedPluging,降级我的aitextgen库升级我的pytorch库卸载和重新安装pytorch和aitextgen,但这里没有一个建议的方法是我的代码片段

!pip install aitextgen
import aitextgen.TokenDataset

在这个层次上我得到了错误

ImportError                               Traceback (most recent call last)
<ipython-input-3-24eff681443c> in <cell line: 1>()
----> 1 import aitextgen.TokenDataset

1 frames
/usr/local/lib/python3.10/dist-packages/aitextgen/aitextgen.py in <module>
     12 import torch
     13 from pkg_resources import resource_filename
---> 14 from pytorch_lightning.plugins import DeepSpeedPlugin
     15 from tqdm.auto import trange
     16 from transformers import (

ImportError: cannot import name 'DeepSpeedPlugin' from 'pytorch_lightning.plugins' (/usr/local/lib/python3.10/dist-packages/pytorch_lightning/plugins/__init__.py)"
xt0899hw

xt0899hw1#

您需要设置pytorch-lightning、transformers和aitextgen的版本。
像这样:

!pip install -qq pytorch-lightning==1.7.0 transformers==4.21.3 aitextgen==0.6.0

有一个开放的问题来解决这个问题:https://github.com/minimaxir/aitextgen/issues/215

相关问题