问题
根据official documentation的说法,Trainer
类“为大多数标准用例提供了PyTorch中功能完整训练的API”。
然而,当我尝试在实践中实际使用Trainer
时,我得到了以下错误消息,似乎表明TensorFlow目前正在幕后使用。
tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
那么它是哪一个呢?HuggingFace transformers库使用PyTorch还是TensorFlow来实现Trainer
的内部实现?是否可以切换到只使用PyTorch?我似乎无法在TrainingArguments
中找到相关参数。
为什么我的脚本总是打印出TensorFlow相关的错误?Trainer
不应该只使用PyTorch吗?
源码
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
from transformers import TextDataset
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer
from transformers import TrainingArguments
import torch
# Load the GPT-2 tokenizer and LM head model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
lmhead_model = GPT2LMHeadModel.from_pretrained('gpt2')
# Load the training dataset and divide blocksize
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path='./datasets/tinyshakespeare.txt',
block_size=64
)
# Create a data collator for preprocessing batches
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Defining the training arguments
training_args = TrainingArguments(
output_dir='./models/tinyshakespeare', # output directory for checkpoints
overwrite_output_dir=True, # overwrite any existing content
per_device_train_batch_size=4, # sample batch size for training
dataloader_num_workers=1, # number of workers for dataloader
max_steps=100, # maximum number of training steps
save_steps=50, # after # steps checkpoints are saved
save_total_limit=5, # maximum number of checkpoints to save
prediction_loss_only=True, # only compute loss during prediction
learning_rate=3e-4, # learning rate
fp16=False, # use 16-bit (mixed) precision
optim='adamw_torch', # define the optimizer for training
lr_scheduler_type='linear', # define the learning rate scheduler
logging_steps=5, # after # steps logs are printed
report_to='none', # report to wandb, tensorboard, etc.
)
if __name__ == '__main__':
torch.multiprocessing.freeze_support()
trainer = Trainer(
model=lmhead_model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
2条答案
按热度按时间hzbexzde1#
我认为Hugging Face transformers库中的默认Trainer类构建在PyTorch之上。
当您创建Trainer类的示例时,它会在后台初始化PyTorch模型和优化器。然后在训练期间使用PyTorch执行向前和向后传递,并使用优化器更新模型的权重。
Hugging Face还为TensorFlow用户提供了一个TFTrainer类,这些用户希望使用与Trainer类相同的训练循环和实用程序,但使用TensorFlow而不是PyTorch作为后端。
kr98yfug2#
这取决于模型的训练方式和加载方式。
transformers
上最流行的模型同时支持PyTorch和Tensorflow(有时也支持JAX)。[out]:
可能是这样的:
Q:如果我使用
Trainer
,它就是PyTorch?答:是的,很可能模型有PyTorch后端,训练循环(优化器,损失等)使用PyTorch。但
Trainer()
不是模型,它是 Package 器对象。Q:如果我想在Tensorflow后端模型中使用
Trainer
,我应该使用TFTrainer
吗?在
transformers
的最新版本中,TFTrainer
对象已被弃用,请参阅https://github.com/huggingface/transformers/pull/12706如果您使用的是带有Tensorflow后端的模型,建议您使用Keras的sklearn风格的
.fit()
训练。问:为什么我的脚本总是打印出与TensorFlow相关的错误?Trainer不应该只使用PyTorch吗?
尝试检查您的
transformers
版本,很可能您使用的是过时的版本,该版本使用了一些已弃用的对象,例如TextDataset(请参阅如何解决“仅单个元素的整数Tensor可以转换为索引”错误,当创建Dataset以微调GPT2模型时?)在后来的版本中,很可能是
pip install transformers>=4.26.1
,Trainer不应该激活TF警告,使用TFTrainer会引发警告,建议用户使用Keras。