我正在尝试使用Huggingface Trainer API微调BERT模型用于情感分析(将文本分类为积极/消极)。我的数据集有两列,Text
和Sentiment
,它看起来像这样。
Text Sentiment
This was good place 1
This was bad place 0
字符串
下面是我的代码:
from datasets import load_dataset
from datasets import load_dataset_builder
from datasets import Dataset
import datasets
import transformers
from transformers import TrainingArguments
from transformers import Trainer
dataset = load_dataset('csv', data_files='./train/test.csv', sep=';')
tokenizer = transformers.BertTokenizer.from_pretrained("TurkuNLP/bert-base-finnish-cased-v1")
model = transformers.BertForSequenceClassification.from_pretrained("TurkuNLP/bert-base-finnish-cased-v1", num_labels=1)
def tokenize_function(examples):
return tokenizer(examples["Text"], truncation=True, padding='max_length')
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column('Sentiment', 'label')
tokenized_datasets = tokenized_datasets.remove_columns('Text')
training_args = TrainingArguments("test_trainer")
trainer = Trainer(
model=model, args=training_args, train_dataset=tokenized_datasets['train']
)
trainer.train()
型
运行此命令会抛出错误:
Variable._execution_engine.run_backward(
RuntimeError: Found dtype Long but expected Float
型
这个错误可能来自数据集本身,但我可以用我的代码来修复它吗?我在互联网上搜索了一下,这个错误似乎已经通过“将Tensor转换为浮点数”解决了,但我如何使用Trainer API来解决这个问题?任何建议都非常感谢。
一些参考:
https://discuss.pytorch.org/t/run-backward-expected-dtype-float-but-got-dtype-long/61650/10
4条答案
按热度按时间7gs2gvoe1#
最有可能的是,这个问题与损失函数有关。如果你正确地设置了模型,这是可以修复的,主要是通过指定要使用的正确损失。请参阅此代码以查看确定正确损失的逻辑。
你的问题有二进制标签,因此应该被框定为单标签分类问题。因此,你共享的代码将被推断为回归问题,这解释了它期望float但发现目标标签的长类型的错误。
您需要传递正确的问题类型。
字符串
这将利用BCE损失。对于BCE损失,你需要目标浮动,所以你也必须将标签转换为浮动。我认为你可以用数据集API来做到这一点。参见this。
另一种方法是使用多类分类器或CE loss。为此,只需修复
num_labels
即可。型
9jyewag02#
在这里,我假设你正在尝试做一个标签分类,也就是说,预测一个结果,而不是预测多个结果。
但是你使用的损失函数(我不知道你用的是什么,但它可能是BCE)需要一个来自你的向量作为标签。
因此,要么你需要像人们在评论中建议的那样将标签转换为向量,要么你可以将损失函数替换为交叉熵损失,并将标签参数的数量更改为2(或其他)。
如果你想将你的模型训练成多标签分类器,你可以使用sklearn将标签转换为向量。预处理:
字符串
h43kikqp3#
你可以把你的数据。
如果你有Pandas的文件,你可以这样做:
字符串
如果你有HuggingFace的文件,你应该这样做:
型
使用前:
的字符串
之后:
的字符串
cwtwac6a4#
默认情况下,分类模型进行二进制分类,num_labels类设置为None。将num_labels设置为1会使其成为回归问题,因此会出现错误。您可以在此处阅读更多信息:https://simpletransformers.ai/docs/classification-models/#classificationmodel