我如何才能获得F1score,召回,混淆矩阵和precison在这个代码.我已经使用编译和获得的准确性,但我不知道如何写代码来获得这些指标从我的模型.我会很感激你帮助我.对于范围内的comm_round(comms_round):
global_weights = global_model.get_weights()
scaled_local_weight_list = list()
client_names= list(clients_batched.keys())
random.shuffle(client_names)
for client in client_names:
local_model = Transformer
local_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
metrics='acc')
global_model.set_weights(global_weights)
local_model.set_weights(global_weights)
history = local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])
scaling_factor = weight_scalling_factor(clients_batched, client)
scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
scaled_local_weight_list.append(scaled_weights)
K.clear_session()
average_weights = sum_scaled_weights(scaled_local_weight_list)
global_model.set_weights(average_weights)
for(X_test, Y_test) in test_batched:
global_acc, global_loss = test_model(test_x, test_y, global_model, comm_round + 1)
字符串
此外,我还想使用线图绘制模型在训练期间记录的训练集和测试集上的性能,每个损失和分类准确度各一个。
3条答案
按热度按时间ajsxfq5m1#
Keras的精度、AUC和其他指标列于:https://keras.io/api/metrics/classification_metrics/
请尝试按以下方式使用它们:
字符串
threshold=0
和from_logits=True
是如果您的模型返回logits,如上页所述。要绘制指标,它类似于:
型
kmb7vmvb2#
这是不正确的。我把它放在K.clear_session()之前。首先我得到错误IndexError:列表索引超出范围,因为“历史中的信息是:['loss','accuracy']”。然后当我删除metricC = keys[2]时,我没有得到任何可视化
yvgpqqbh3#
此代码为每个循环添加度量和绘图。注意,对于每个度量,每个时期将有一个测量。所以如果你有一个历元,你绘制的准确度,它只是1点。
字符串