我按照tutorial here来尝试使用CIFAR-100训练我的模型。但是我得到了这个错误。我该怎么办?
ValueError: Data Params Error: The dataset label shape (100,) does not match the
number of classes (10) in the dataset. Please ensure the dataset
labels have 10 classes, or change the number of classes
to match the dataset.
这是我的设置,改编自教程,但适用于CIFAR-100。
import tensorflow as tf
import tensorflow_datasets as tfds
import masterful
masterful = masterful.register()
TRAINING_PERCENTAGE = 5
(training_dataset,
test_dataset) = tfds.load('cifar100',
as_supervised=True,
split=[f'train[:{TRAINING_PERCENTAGE}%]', 'test'],
with_info=False)
def sparse_to_dense(image, label):
label = tf.cast(label, tf.int32)
one_hot_label = tf.one_hot(label, depth=100)
return image, one_hot_label
training_dataset = training_dataset.map(sparse_to_dense,
num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.map(sparse_to_dense,
num_parallel_calls=tf.data.AUTOTUNE)
def get_model():
model = tf.keras.models.Sequential()
model.add(
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255,
input_shape=(32, 32,
3)))
model.add(tf.keras.layers.Conv2D(
16,
(3, 3),
activation='relu',
))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(10))
return model
model = get_model()
model_params = masterful.architecture.learn_architecture_params(
model=model,
task=masterful.enums.Task.CLASSIFICATION,
input_range=masterful.enums.ImageRange.CIFAR10_TORCH,
num_classes=10,
prediction_logits=True,
)
training_dataset_params = masterful.data.learn_data_params(
dataset=training_dataset,
task=masterful.enums.Task.CLASSIFICATION,
image_range=masterful.enums.ImageRange.CIFAR10_TORCH,
num_classes=10,
sparse_labels=False,
)
optimization_params = masterful.optimization.learn_optimization_params(
model,
model_params,
training_dataset,
training_dataset_params,
)
# This is a set of parameters learned on CIFAR10 for
# small sized models.
regularization_params = masterful.regularization.parameters.CIFAR10_SMALL
training_report = masterful.training.train(
model,
model_params,
optimization_params,
regularization_params,
None,
training_dataset,
training_dataset_params,
)
2条答案
按热度按时间vx6bjr1n1#
你们两个都是对的,正则化和分类是支持10类的,你们可以考虑多步的方法。
[样品]:
[枚举数]:
[统计数据]:
[参数]:
t2a7ltrp2#
此值错误解释了消息中的问题。
作为CIFAR-100的数据集标签是长度为
100
的一个热向量。但是对“masterful.data.learn_data_params”的调用传递了值
10
。使用以下内容更新模型体系结构:
以及您对learn_data_params的调用: