keras 获得F1评分、召回率、混淆矩阵和精确度

wqlqzqxt  于 2023-08-06  发布在  其他
关注(0)|答案(3)|浏览(153)

我如何才能获得F1score,召回,混淆矩阵和precison在这个代码.我已经使用编译和获得的准确性,但我不知道如何写代码来获得这些指标从我的模型.我会很感激你帮助我.对于范围内的comm_round(comms_round):

global_weights = global_model.get_weights()

scaled_local_weight_list = list()

client_names= list(clients_batched.keys())
random.shuffle(client_names)

for client in client_names:
    local_model = Transformer
    local_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
                        optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
                        metrics='acc')

    global_model.set_weights(global_weights)

    local_model.set_weights(global_weights)

    history = local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])

    scaling_factor = weight_scalling_factor(clients_batched, client)
    scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
    scaled_local_weight_list.append(scaled_weights)

    K.clear_session()

average_weights = sum_scaled_weights(scaled_local_weight_list)

global_model.set_weights(average_weights)

for(X_test, Y_test) in test_batched:
    global_acc, global_loss = test_model(test_x, test_y, global_model, comm_round + 1)

字符串
此外,我还想使用线图绘制模型在训练期间记录的训练集和测试集上的性能,每个损失和分类准确度各一个。

ajsxfq5m

ajsxfq5m1#

Keras的精度、AUC和其他指标列于:https://keras.io/api/metrics/classification_metrics/
请尝试按以下方式使用它们:

local_model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
    metrics=['acc',
             tf.keras.metrics.Precision(thresholds=0),
             tf.keras.metrics.Recall(thresholds=0),
             tf.keras.metrics.AUC(from_logits=True),
             #could add more metrics...
            ]
)

字符串
threshold=0from_logits=True是如果您的模型返回logits,如上页所述。
要绘制指标,它类似于:

# list all data in available in history
keys = list(history.history.keys())
print('Info in history is:\n\t', keys)

# select some metrics to plot from history
metricA = keys[0]
metricB = keys[1]
metricC = keys[2]

#create plot
plt.plot(history.history[metricA])
plt.plot(history.history[metricB])
plt.plot(history.history[metricC])

plt.title('History')
plt.ylabel('value')
plt.xlabel('epoch')
plt.legend([metricA, metricB, metricC], loc='upper left')
plt.show()

kmb7vmvb

kmb7vmvb2#

这是不正确的。我把它放在K.clear_session()之前。首先我得到错误IndexError:列表索引超出范围,因为“历史中的信息是:['loss','accuracy']”。然后当我删除metricC = keys[2]时,我没有得到任何可视化

yvgpqqbh

yvgpqqbh3#

此代码为每个循环添加度量和绘图。注意,对于每个度量,每个时期将有一个测量。所以如果你有一个历元,你绘制的准确度,它只是1点。

from tf.keras import metrics

global_weights = global_model.get_weights()

scaled_local_weight_list = list()

client_names= list(clients_batched.keys())
random.shuffle(client_names)

for client in client_names:
    local_model = Transformer
    local_model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
    metrics=['acc',
             metrics.Precision(),
             metrics.Recall(),
             metrics.AUC()]
    )

    global_model.set_weights(global_weights)

    local_model.set_weights(global_weights)

    history = local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])

    scaling_factor = weight_scalling_factor(clients_batched, client)
    scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
    scaled_local_weight_list.append(scaled_weights)

    # list all data in available in history
    keys = list(history.history.keys())
    print(f'Information available in history for {client} is:\n\t', keys)
    
    # select some metrics to plot from history
    metricA = keys[0]
    metricB = keys[1]
    metricC = keys[2]
    
    #create plot
    plt.plot(history.history[metricA])
    plt.plot(history.history[metricB])
    plt.plot(history.history[metricC])

    plt.title('History')
    plt.ylabel('value')
    plt.xlabel('epoch')
    plt.legend([metricA, metricB, metricC], loc='upper left')
    plt.show()

    K.clear_session()

average_weights = sum_scaled_weights(scaled_local_weight_list)

global_model.set_weights(average_weights)

for(X_test, Y_test) in test_batched:
    global_acc, global_loss = test_model(test_x, test_y, global_model, comm_round + 1)

字符串

相关问题