pytorch 基于BERT和高度不平衡训练数据的多标签文本分类

tquggr8v  于 2023-01-30  发布在  其他
关注(0)|答案(1)|浏览(236)

我正在尝试使用BERT训练一个多标签文本分类模型。每段文本可以属于485个类中的0个或更多。我的模型由一个丢弃层和一个线性层组成,添加到Hugging Face的bert-base-uncased模型的合并输出之上。我使用的损失函数是PyTorch中的BCEWithLogitsLoss。
我有数百万个带标签的观测值需要训练,但是训练数据是高度不平衡的,一些标 checkout 现在不到10个观测值中,而另一些标 checkout 现在超过10万个观测值中!我希望得到一个“好的”回忆。
我第一次尝试在不调整数据不平衡的情况下进行训练,得到的micro recall rate为70%(足够好),但macro recall rate为45%(不够好),这些数字表明该模型在代表性不足的类上表现不佳。
如何在训练过程中有效地调整数据不平衡以提高宏召回率?我知道我们可以为BCEWithLogitsLoss损失函数提供标签权重。但是,假设数据中存在非常高的不平衡,导致权重在1到1 M的范围内,我真的能让模型收敛吗?我的初始实验表明,加权损失函数在训练过程中会上下波动。
或者,对于此类任务,是否有比使用BERT +丢弃+线性图层更好的方法?

vnzz0bqm

vnzz0bqm1#

在您的情况下,平衡训练数据中的标签可能会有所帮助。您有大量数据,因此您可以通过平衡来承担丢失一部分数据的后果。但在您这样做之前,我建议您阅读this answer about balancing classes in traing data
如果你真的只关心回忆,你可以试着调整你的模型,使回忆最大化。

相关问题