tensorflow 为什么所有10次折叠的测试准确度和平衡准确度值相同?

ldfqzlk8  于 2022-12-27  发布在  其他
关注(0)|答案(1)|浏览(143)

对于每个折叠测试精度和平衡测试精度是不同的,但数值是相同的。例如,折叠1,测试精度是86,平衡测试精度是86。对于折叠2,测试精度是90,平衡测试精度是90对于折叠3,测试精度是70. 555,测试精度是70. 555 ...这里是我的代码

fold_no = 1
reports = []
accuracies = []
sensitivities = []
specificities = []
test_accuracy = []
for train, test in kfold.split(X_train, y_train):

  model = Sequential()
  model.add(Conv3D(128, kernel_size=(3, 3, 3))
  model.add(Flatten())
  model.add(Dense(256, activation='relu', kernel_regularizer='l2'))
  model.add(Dense(4096, activation='relu', kernel_regularizer='l2')) 
  model.add(Dropout(0.3))
  model.add(Dense(1, activation='sigmoid', kernel_regularizer='l2'))

  # Compile the model
  model.compile(loss=tensorflow.keras.losses.mean_squared_error,
                optimizer=tensorflow.keras.optimizers.Adam(learning_rate=learning_rate),
                metrics=['accuracy'])```

  history = model.fit(X_train[train], y_train[train],
                      batch_size=batch_size,
                      epochs=no_epochs,
                      verbose=verbosity, validation_data=(X_train[test], y_train[test]))

  # Compute the classification report for the testing set
  y_pred = model.predict(X_test, verbose = 0)
  c = model.evaluate(X_test, y_test)
  test_accuracy.append(c[1])
  report = classification_report(y_test, (y_pred>0.5), output_dict=True)
  from sklearn.metrics import balanced_accuracy_score
  bal_acc=balanced_accuracy_score(y_test,(y_pred>0.5))
  print("balenced acc is " + str(bal_acc))

  # Extract the sensitivity and specificity values from the report
  sensitivity = report["1"]["recall"]
  specificity = report["0"]["recall"]
  sensitivities.append(sensitivity)
  specificities.append(specificity)

  print(specificity))  
  print(sensitivity))
8yparm6h

8yparm6h1#

当类从一开始就平衡时,平衡的准确度和准确度是相同的:

from sklearn.metrics import accuracy_score, balanced_accuracy_score

y_true = [0, 0, 0, 0, 1, 1, 1, 1]    # 4 negatives, 4 positives
y_pred = [0, 0, 1, 0, 1, 0, 1, 1]

print(accuracy_score(y_true, y_pred), balanced_accuracy_score(y_true, y_pred))
# 0.75 0.75

相关问题