我使用TensorFlow后端在Keras中实现了一个图像分类器。对于具有两个输出类的数据集,我检查了预测标签:
if result[0][0] == 1: prediction ='adathodai' else: prediction ='thamarathtai'
Full code。对于三个类,我得到[[0. 0. 1.]]。如何检查if else格式中两个以上类的预测标签?
[[0. 0. 1.]]
uz75evzq1#
对于具有k个标签的多类分类问题,可以使用model.predict_classes()检索预测类的索引。玩具示例:
model.predict_classes()
import keras import numpy as np # Simpel model, 3 output nodes model = keras.Sequential() model.add(keras.layers.Dense(3, input_shape=(10,), activation='softmax')) # 10 random input data points x = np.random.rand(10, 10) model.predict_classes(x) > array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
如果你有一个列表中的标签,你可以使用预测的类来获得预测的标签:
labels = ['label1', 'label2', 'label3'] [labels[i] for i in model.predict_classes(x)] > ['label2', 'label2', 'label3', 'label2', 'label3', 'label2', 'label3', 'label2', 'label2', 'label2']
在后台,model.predict_classes返回预测中每行的最大预测类概率的索引:
model.predict_classes
model.predict_classes(x) > array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1]) model.predict(x).argmax(axis=-1) # same thing > array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
1条答案
按热度按时间uz75evzq1#
对于具有k个标签的多类分类问题,可以使用
model.predict_classes()
检索预测类的索引。玩具示例:如果你有一个列表中的标签,你可以使用预测的类来获得预测的标签:
在后台,
model.predict_classes
返回预测中每行的最大预测类概率的索引: