sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0
# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
model = get_compiled_model()
model.fit(train_dataset, epochs=1)
1条答案
按热度按时间euoag5mw1#
来自tf.keras
model.fit()
的文档:sample_weight
[...]当x是数据集、生成器或
keras.utils.Sequence
示例时,不支持此参数,而是提供sample_weights作为x的第三个元素。这是什么意思?这是在一个官方文档turorials中针对
Dataset
案例进行的演示:字符串
查看完整示例的链接。