使用Keras图像数据生成器可视化图像增强-'NumpyArrayIterator'中的输入数据的秩应为4

mklgxw1f  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(159)

我已经建立了一个CNN模型,并使用德国交通标志图像来训练它。我已经尝试了对图像应用数据增强,但使用matplotlib和Keras图像数据生成器显示这些图像时遇到了问题。
我已经导入了流程所需的库,下面是我获取pickle道路类标志的位置:

# The pickle module implements binary protocols for serializing and de-serializing a Python object structure.
with open("./traffic-signs-data/train.p", mode='rb') as training_data:
    train = pickle.load(training_data)
with open("./traffic-signs-data/valid.p", mode='rb') as validation_data:
    valid = pickle.load(validation_data)
with open("./traffic-signs-data/test.p", mode='rb') as testing_data:
    test = pickle.load(testing_data)

X_train, y_train = train['features'], train['labels']
X_validation, y_validation = valid['features'], valid['labels']
X_test, y_test = test['features'], test['labels']

# Shuffling the dataset
from sklearn.utils import shuffle
X_train, y_train = shuffle(X_train, y_train)

创建灰度图像

X_train_gray = np.sum(X_train / 3, axis = 3, keepdims = True)
X_test_gray  = np.sum(X_test / 3, axis = 3, keepdims = True)
X_validation_gray  = np.sum(X_validation / 3, axis = 3, keepdims = True) 

X_train_gray_norm = (X_train_gray - 128) / 128 
X_test_gray_norm = (X_test_gray - 128) / 128
X_validation_gray_norm = (X_validation_gray - 128) / 128

下面我将对图像进行数据增强

from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
                            rotation_range = 90,
                            width_shift_range = 0.1,
                            vertical_flip = True,
                             )

用增强方法将灰度图像拟合到Keras数据发生器

datagen.fit(X_train_gray_norm)

使数据生成器适合我构建的CNN模型,但没有显示

cnn_model.fit_generator(datagen.flow(X_train_gray_norm, y_train, batch_size = 250), epochs = 100)

尝试展示应用了数据增强的图像

i = 100

pic = datagen.flow(X_train_gray[i], batch_size = 1)
plt.figure(figsize=(10,8))

    
plt.show()

遇到此错误:
ValueError:('NumpyArrayIterator中的输入数据的秩应为4。您传递了一个形状为的数组',(32,32,1))

3pvhb19x

3pvhb19x1#

在0轴上展开数组的维度

datagen.flow(np.expand_dims(X_train_gray[i], 0), batch_size = 1)

#<keras.preprocessing.image.NumpyArrayIterator at 0x23b5ff0f5b0>

相关问题