keras 类型错误: www.example.com ()缺少1个必需的位置参数:'neg'.带有自定义损失的自定义训练循环

wh6knrhe  于 2023-04-30  发布在  其他
关注(0)|答案(1)|浏览(86)

下面是三重损失。它的调用方法有3个参数

class TripletLoss(keras.losses.Loss):
    def __init__(self, alpha=0.2, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        
    @staticmethod
    def dist_sqr(x1, x2):
        return tf.reduce_sum(tf.square(tf.subtract(x1, x2)), axis=-1)
    
    def call(self, anchor, pos, neg):
        dist_pos = TripletLoss.dist_sqr(anchor, pos)
        dist_neg = TripletLoss.dist_sqr(anchor, neg)
        loss = tf.maximum(dist_pos - dist_neg + self.alpha, 0)
        return tf.reduce_sum(loss)
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "alpha": self.alpha}

下面的代码中的loss作为3个参数

n_epochs = 30
n_steps = 267 // BATCH_SIZE
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = TripletLoss(alpha=0.2)
mean_loss = keras.metrics.Mean()

for epoch in range(1, n_epochs + 1):
    for step, (X_batch, y_batch) in enumerate(train_ds):
        pos, neg = select_all_triplets(images=X_batch, labels=y_batch)
        with tf.GradientTape() as tape:
            anchor_embed, pos_embed, neg_embed = model(X_batch), model(pos), model(neg)
            loss = loss_fn(anchor_embed, pos_embed, neg_embed)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        mean_loss(loss)
        print_status_bar(step, n_steps, mean_loss)

但也会出现错误
----〉6 loss = loss_fn(锚_embed,pos_embed,neg_embed).TypeErrorTripletLoss.call()缺少1个必需的位置参数:'neg'
即使我已经提供了所有的3个参数

iqjalb3h

iqjalb3h1#

您正在从损失中调用__call__方法。py.如果你想使用自己的调用函数,你应该这样调用它:

loss_fn.call(anchor_embed, pos_embed, neg_embed)

相关问题