Keras -在忽略最后一层的同时增加中间层的损失

w9apscun  于 2023-04-06  发布在  其他
关注(0)|答案(3)|浏览(130)

我创建了以下Keras自定义模型:

import tensorflow as tf
from tensorflow.keras.layers import Layer

class MyModel(tf.keras.Model):
    def __init__(self, num_classes):
        super(MyModel, self).__init__()
        self.dense_layer = tf.keras.layers.Dense(num_classes,activation='softmax')
        self.lambda_layer = tf.keras.layers.Lambda(lambda x: tf.math.argmax(x, axis=-1))

    
    def call(self, inputs):
        x = self.dense_layer(inputs)
        x = self.lambda_layer(x)
        return x

    # A convenient way to get model summary 
    # and plot in subclassed api
    def build_graph(self, raw_shape):
        x = tf.keras.layers.Input(shape=(raw_shape))
        return tf.keras.Model(inputs=[x], 
                              outputs=self.call(x))

任务是多类分类。模型由一个带有softmax激活的密集层和一个lambda层组成,lambda层作为后处理单元,将密集输出向量转换为单个值(预测类)。
训练目标是一个独热编码矩阵,如下所示:

[
   [0,0,0,0,1]
   [0,0,1,0,0]
   [0,0,0,1,0]
   [0,0,0,0,1]
]

如果我可以在密集层定义一个categorical_crossentropy损失,忽略lambda层,同时仍然保持功能,并在调用model.predict(x)时输出一个值,那就太好了。

请注意

我的工作环境不允许我使用@alonetogether建议的自定义训练循环。

fruv7luv

fruv7luv1#

您可以尝试使用自定义训练循环,这是非常简单的IMO:

import tensorflow as tf
from tensorflow.keras.layers import Layer

class MyModel(tf.keras.Model):
    def __init__(self, num_classes):
        super(MyModel, self).__init__()
        self.dense_layer = tf.keras.layers.Dense(num_classes,activation='softmax')
        self.lambda_layer = tf.keras.layers.Lambda(lambda x: tf.math.argmax(x, axis=-1))

    
    def call(self, inputs):
        x = self.dense_layer(inputs)
        x = self.lambda_layer(x)
        return x

    # A convenient way to get model summary 
    # and plot in subclassed api
    def build_graph(self, raw_shape):
        x = tf.keras.layers.Input(shape=(raw_shape))
        return tf.keras.Model(inputs=[x], 
                              outputs=self.call(x))
        
n_classes = 5
model = MyModel(n_classes)
labels = tf.keras.utils.to_categorical(tf.random.uniform((50, 1), maxval=5, dtype=tf.int32))
train_dataset = tf.data.Dataset.from_tensor_slices((tf.random.normal((50, 1)), labels)).batch(2)
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.CategoricalCrossentropy()
epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model.layers[0](x_batch_train)
            loss_value = loss_fn(y_batch_train, logits)

        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

预测:
x一个一个一个一个x一个一个二个x

shyt4zoc

shyt4zoc2#

我认为有一个Model.predict_classes函数可以替代lambda层。但如果它不起作用:
似乎没有办法做到这一点,而不使用这些黑客之一:

  • 两个输入(一个是基真值Y)
  • 两个输出
  • 两种型号

我确信没有其他的解决方法,所以,我相信“两个模型”版本最适合你的情况,你似乎“需要”一个单输入,单输出和fit的模型。
然后我会这么做

inputs = tf.keras.layers.Input(input_shape_without_batch_size)    
loss_outputs = tf.keras.layers.Dense(num_classes,activation='softmax')(inputs)
final_outputs = tf.keras.layers.Lambda(lambda x: tf.math.argmax(x, axis=-1))(loss_outputs)

training_model = tf.keras.models.Model(inputs, loss_outputs)
final_model = tf.keras.models.Model(inputs, final_outputs)

training_model.compile(.....)
training_model.fit(....)

results = final_model.predict(...)
nhhxz33t

nhhxz33t3#

我遇到了一个类似的问题,我需要在标准化的地面实况数据上训练模型,但我希望模型输出“非标准化”的结果。
我所做的是在模型中添加“非规范化”层,并在损失函数周围添加一个 Package 器,该 Package 器将模型的输出重新规范化,以便仅计算损失。这可以在编译函数中完成。

MyModel(keras.Model):
    def __init__(output_mean, output_std, *args, **kwargs):
        self.output_mean = output_mean
        self.output_std = output_std
        super(MyModel, self).__init__(*args, **kwargs)

   def compile(optimizer, loss, *args, **kwargs):
        def loss_wrapper(y_true, y_pred):
            y_pred = (y_pred - self.output_mean) / self.output_std
            return loss(y_true, y_pred)
        super(MyModel, self).compile(optimizer, loss_wrapper, *args, **kwargs)

当我稍后加载模型时,我不编译它,以避免“Loss function 'loss_wrapper' not found”错误

model.load_model("/model_path", compile=False)

相关问题