python 为什么我不能在Huggingface中设置TrainingArguments.device?

kyks70gy  于 2023-04-10  发布在  Python
关注(0)|答案(2)|浏览(267)

提问

当我尝试将.device属性设置为torch.device('cpu')时,我得到一个错误。那么我应该如何设置device呢?

Python代码

from transformers import TrainingArguments
from transformers import Trainer
import torch

training_args = TrainingArguments(
    output_dir="./some_local_dir",
    overwrite_output_dir=True,

    per_device_train_batch_size=4,
    dataloader_num_workers=2,

    max_steps=500,
    logging_steps=1,

    evaluation_strategy="steps",
    eval_steps=5
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

training_args.device = torch.device('cpu')

Python错误

AttributeError                            Traceback (most recent call last)
<ipython-input-11-30a92c0570b8> in <cell line: 28>()
     26 )
     27 
---> 28 training_args.device = torch.device('cpu')

AttributeError: can't set attribute
mec1mxoz

mec1mxoz1#

从TrainingArguments对象的文档中获取的设备属性没有可设置的设备属性。
但有趣的是,设备被初始化但不可变:

import torch
from transformers import TrainingArguments

args = TrainingArguments('./')

args.device  # [out]: device(type='cpu')

args.device = torch.device(type='cpu')

[out]:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-12-dcb5ef23be68> in <cell line: 8>()
      6 
      7 
----> 8 args.device = torch.device(type='cpu')

AttributeError: can't set attribute

从代码中可以看出,设备似乎是在初始化后设置的:https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py#L1113
并且仅在调用TrainingArguments.device时才设置设备

  • https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py#L1678
  • https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py#L1524

也许你指的是huggingface模型,你想尝试在CPU上训练它。

默认TrainingArguments已经设置为CPU,如果_n_gpu = -1

但如果你想显式地设置模型以在CPU上使用,请尝试:

model = model.to('cpu')
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

或者:

trainer = Trainer(
    model=model.to('cpu'),
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    place_model_on_device=True
)
gwbalxhn

gwbalxhn2#

你不需要在训练参数中设置设备。训练将在模型的设备上进行。下面的代码应该可以帮助你在cpu上训练模型

device = torch.device('cpu')
model = model.to(device)

training_args.device是一个只能读取而不能设置的属性,因此出现错误。

相关问题