pytorch RuntimeError:使用Trainer API进行微调时,发现dtype为Long,但预期为Float

6xfqseft  于 11个月前  发布在  其他
关注(0)|答案(4)|浏览(142)

我正在尝试使用Huggingface Trainer API微调BERT模型用于情感分析(将文本分类为积极/消极)。我的数据集有两列,TextSentiment,它看起来像这样。

Text                     Sentiment
This was good place          1
This was bad place           0

字符串
下面是我的代码:

from datasets import load_dataset
from datasets import load_dataset_builder
from datasets import Dataset
import datasets
import transformers
from transformers import TrainingArguments
from transformers import Trainer

dataset = load_dataset('csv', data_files='./train/test.csv', sep=';')
tokenizer = transformers.BertTokenizer.from_pretrained("TurkuNLP/bert-base-finnish-cased-v1")
model = transformers.BertForSequenceClassification.from_pretrained("TurkuNLP/bert-base-finnish-cased-v1", num_labels=1) 
def tokenize_function(examples):
    return tokenizer(examples["Text"], truncation=True, padding='max_length')

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column('Sentiment', 'label')
tokenized_datasets = tokenized_datasets.remove_columns('Text')
training_args = TrainingArguments("test_trainer")
trainer = Trainer(
    model=model, args=training_args, train_dataset=tokenized_datasets['train']
)
trainer.train()


运行此命令会抛出错误:

Variable._execution_engine.run_backward(
RuntimeError: Found dtype Long but expected Float


这个错误可能来自数据集本身,但我可以用我的代码来修复它吗?我在互联网上搜索了一下,这个错误似乎已经通过“将Tensor转换为浮点数”解决了,但我如何使用Trainer API来解决这个问题?任何建议都非常感谢。
一些参考:
https://discuss.pytorch.org/t/run-backward-expected-dtype-float-but-got-dtype-long/61650/10

7gs2gvoe

7gs2gvoe1#

最有可能的是,这个问题与损失函数有关。如果你正确地设置了模型,这是可以修复的,主要是通过指定要使用的正确损失。请参阅此代码以查看确定正确损失的逻辑。
你的问题有二进制标签,因此应该被框定为单标签分类问题。因此,你共享的代码将被推断为回归问题,这解释了它期望float但发现目标标签的长类型的错误。
您需要传递正确的问题类型。

model = transformers.BertForSequenceClassification.from_pretrained(
    "TurkuNLP/bert-base-finnish-cased-v1", 
    num_labels=1, 
    problem_type = "single_label_classification"
)

字符串
这将利用BCE损失。对于BCE损失,你需要目标浮动,所以你也必须将标签转换为浮动。我认为你可以用数据集API来做到这一点。参见this
另一种方法是使用多类分类器或CE loss。为此,只需修复num_labels即可。

model = transformers.BertForSequenceClassification.from_pretrained(
    "TurkuNLP/bert-base-finnish-cased-v1", 
    num_labels=2,
)

9jyewag0

9jyewag02#

在这里,我假设你正在尝试做一个标签分类,也就是说,预测一个结果,而不是预测多个结果。
但是你使用的损失函数(我不知道你用的是什么,但它可能是BCE)需要一个来自你的向量作为标签。
因此,要么你需要像人们在评论中建议的那样将标签转换为向量,要么你可以将损失函数替换为交叉熵损失,并将标签参数的数量更改为2(或其他)。
如果你想将你的模型训练成多标签分类器,你可以使用sklearn将标签转换为向量。预处理:

from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import numpy as np

dataset = pd.read_csv("filename.csv", encoding="utf-8")
enc_labels = preprocessing.LabelEncoder()
int_encoded = enc_labels.fit_transform(np.array(dataset["Sentiment"].to_list()))

onehot_encoder = OneHotEncoder(sparse = False)
int_encoded = int_encoded.reshape(len(int_encoded),1)
onehot_encoded = onehot_encoder.fit_transform(int_encoded)
for index, cat in dataset.iterrows():
    dataset.at[index , 'Sentiment'] = onehot_encoded[index]

字符串

h43kikqp

h43kikqp3#

你可以把你的数据。
如果你有Pandas的文件,你可以这样做:

df['column_name'] = df['column_name'].astype(float)

字符串
如果你有HuggingFace的文件,你应该这样做:

from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')
from datasets import Value, ClassLabel

new_features = dataset.features.copy()
new_features["idx"] = Value('int64')
new_features["label"] = ClassLabel(names=['negative', 'positive'])
new_features["idx"] = Value('int64')
dataset = dataset.cast(new_features)


使用前:

dataset.features
{'idx': Value(dtype='int32', id=None),
 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], id=None),
 'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None)}

的字符串
之后:

dataset.features
{'idx': Value(dtype='int64', id=None),
 'label': ClassLabel(num_classes=2, names=['negative', 'positive'], id=None),
 'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None)}

的字符串

cwtwac6a

cwtwac6a4#

默认情况下,分类模型进行二进制分类,num_labels类设置为None。将num_labels设置为1会使其成为回归问题,因此会出现错误。您可以在此处阅读更多信息:https://simpletransformers.ai/docs/classification-models/#classificationmodel

相关问题