我目前正在训练一个模型来将图像分类为,比如说,10类。我有一个10 k图像的数据集(每个类1 k)。在训练过程中,NN能够达到99.6%的准确率,这意味着在10000张图像中有40张被错误分类。
我能以某种方式确切地知道是什么图像(或至少是批处理)导致了错误?我想这样做是为了直观地检查什么类型的图像,如果导致错误。我想它们可能是异常值,所以当我弄清楚它们的特征时,我可以用类似的图像来增强我的数据集以提高准确性。
我可以简单地在训练后对我的初始数据集运行“预测”,但也许有更优雅、更耗时的方法来找到罪魁祸首。
重要的是,我已经从目录tf.keras.utils.image_dataset_from_directory中获取了数据集
谢谢!
1条答案
按热度按时间6kkfgxo01#
您可以使用自定义回调,以便在每个epoch之后打印出错误分类图像的索引,但这需要重新训练模型。
现在将其添加到www.example.com中的回调model.fit