使用TensorFLow / Keras时,在自定义成本函数中获取元数据的最有效方法是什么?

qni6mghb  于 2022-12-19  发布在  其他
关注(0)|答案(1)|浏览(115)

在我的数据集中,我有一个二进制Target列、一些Features列和一个Date列。我想编写一个自定义成本函数,首先计算按日期计算的成本数量,然后将所有成本相加。但要做到这一点,我需要在成本函数中知道y_predy_true中每个数据点的对应日期。
要最大限度地提高性能,最好的方法是什么?我有几个想法:

  • 使目标变量成为元组(target, date),具有提取元组的第一条目的定制第一层,并且使成本函数提取元组y_true的第二条目
  • 使目标列变量成为索引,并让自定义第一层以及自定义成本函数根据索引从全局变量中提取相关值

在自定义成本函数中获取此信息的最有效方法是什么?

cdmah0mi

cdmah0mi1#

我刚刚找到了一种方法,可以做到这一点。我不太确定它的性能如何,但有一种方法是使用以下形式的CustomLoss

def myLossWithDate(date_col):
    def customBinaryCrossEntropy(y_true, y_pred):
        print(list(zip(date_col, y_true.numpy())))
        # do smth here
        # return custom_loss or
        return tf.keras.losses.binary_crossentropy(y_true,y_pred)
    return customBinaryCrossEntropy

然后可以在模型中使用此损失,如下所示:

mod = tf.keras.models.Sequential([
    tf.keras.layers.Dense(1, activation="sigmoid")
])
mod.compile(optimizer="sgd", loss=myLossWithDate(date_col=X[:,1]), run_eagerly=True)
mod.fit(X, Y, epochs=1, verbose=False)

这里最主要的是

run_eagerly=True

否则你会得到Iterator Tensors(https://www.tensorflow.org/guide/intro_to_graphs)。根据输出的数据,由于print(list(zip(...)))的缘故,看起来像这样

[(1, array([0])), (2, array([1])), (3, array([1]))]

我曾经

Y = np.random.binomial(1, 0.5, 3).reshape(-1,1)
X = np.column_stack((np.array([1,2,3]), np.array([1,2,3]))) # data, date as int

作为数据。
显然这只是一个假人,但也许它会帮助你。
编辑:使用小批次
函数的变化如下

def myLossWithDate():
    def customBinaryCrossEntropy(y_true, y_pred):
        y_true_ = y_true[:,0]
        batch_size = y_true.shape[0]
        y_true_ = tf.reshape(y_true_, shape=(batch_size, 1))
        date_col = y_true[:,1]
        # do smth here
        # return custom_loss or
        return tf.keras.losses.binary_crossentropy(y_true_,y_pred)
    return customBinaryCrossEntropy

并通过

Y = np.column_stack((Y, date_col))

因为在backprop你通常不使用Y除了计算损失,你会做手动。
模型变成

batches = 2
batch_size = int(X.shape[0] / batches)

mod = tf.keras.models.Sequential([
    tf.keras.layers.Dense(1, activation="sigmoid")
])
mod.compile(optimizer="sgd", loss=myLossWithDate(), run_eagerly=True)
mod.fit(X, Y, epochs=1, verbose=False, batch_size=batch_size)

相关问题