我正尝试通过this GitHub Repository微调GPT-J。运行training命令时,我遇到此错误:
Traceback (most recent call last):
File "device_train.py", line 13, in <module>
from mesh_transformer import util
File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in <module>
class ClipByGlobalNormState(OptState):
File "/usr/lib/python3.8/typing.py", line 317, in __new__
raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>
这看起来像一个源代码错误,但我不确定。我也提出了一个关于GitHub的问题。任何帮助将不胜感激!
1条答案
按热度按时间q5lcpyga1#
我也有同样的错误,这是由错误的Optax版本引起的。通过将Optax降级到0. 0. 9来修复这个错误(至少它对我有效!)。
pip install optax==0.0.9