tensorflow 无法为多输出模型添加标准度量

hc8w905p  于 2022-11-25  发布在  其他
关注(0)|答案(1)|浏览(118)

我有一个基于MobileNet v2的猫和狗的分类+检测模型。它训练得很好,但是现在我想为它添加指标,我不能这样做。下面是代码的主要部分:

def localization_loss(y_true, yhat):            
    delta_coord = tf.reduce_sum(tf.square(y_true[:,:2] - yhat[:,:2]))          
    h_true = y_true[:,3] - y_true[:,1] 
    w_true = y_true[:,2] - y_true[:,0]
    h_pred = yhat[:,3] - yhat[:,1] 
    w_pred = yhat[:,2] - yhat[:,0] 
    delta_size = tf.reduce_sum(tf.square(w_true - w_pred) + tf.square(h_true-h_pred))
    return delta_coord + delta_size

classloss = tf.keras.losses.BinaryCrossentropy()
regressloss = localization_loss

opt = tf.keras.optimizers.Adam(learning_rate=0.0001, decay=0.001)

model.compile(
    optimizer = opt,
    loss=[classloss, regressloss],
    # metrics=["accuracy", "meaniou"],
)
hist = model.fit(train, epochs=10, validation_data=valid)

它工作正常,但如果我取消对度量行的注解,就会出现以下错误:

ValueError: as_list() is not defined on an unknown TensorShape.

如果我使用对象而不是字符串(metrics=[Accuracy(), MeanIoU(2)]),则会出现以下错误:

TypeError: '>' not supported between instances of 'NoneType' and 'int'

我做错了什么?我该如何解决?
UPD:如果我对两个输出都使用accuracy(metrics=[[Accuracy()], [Accuracy()]]),我的训练没有任何错误,所以我得出结论,我的代码中的MeanIoU有问题。
顺便说一句,对批次(8)样本进行预测(两个输出:类+坐标作为4个数字):

(array([[0.7866989 ],
        [0.973974  ],
        [0.9148978 ],
        [0.28471756],
        [0.9899457 ],
        [0.99033797],
        [0.7237025 ],
        [0.81942046]], dtype=float32),
 array([[0.2515184 , 0.25495493, 0.3642715 , 0.09299589],
        [0.87964845, 0.3134839 , 0.54833114, 0.36701256],
        [0.0304133 , 0.45813853, 0.19692126, 0.244534  ],
        [0.22500503, 0.70299083, 0.00123629, 0.41123846],
        [0.37099576, 0.6092719 , 0.13407992, 0.40188596],
        [0.32103425, 0.6240243 , 0.02281341, 0.03058532],
        [0.28678325, 0.19885723, 0.50342166, 0.57963324],
        [0.41590106, 0.21439987, 0.94105315, 0.3379435 ]], dtype=float32))

我想MeanIoU的格式可能是错误的,但4个数字的数组似乎对MeanIoU有效,不是吗?

oxalkeyp

oxalkeyp1#

正如我在这里回答的,正确的衡量标准是:BinaryAccuracycustom MeanIoU(据我所知,default MeanIoU不适用于bboxes回归)。工作代码片段在第一个链接中。

相关问题