在tensorflow计算机视觉建模中,在哪里定义批量大小?

w8ntj3qf  于 2021-08-20  发布在  Java
关注(0)|答案(1)|浏览(491)

我在其中指定了 batch_size as 32:


# Preparing and preprocessing the data

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_dir = '/content/pizza_steak/train'
test_dir = '/content/pizza_steak/test'

train_data_gen_aug = ImageDataGenerator(rotation_range=0.2, 
                                        width_shift_range=0.2, 
                                        height_shift_range=0.2, 
                                        shear_range=0.2, 
                                        zoom_range=0.2,
                                        horizontal_flip=True, 
                                        vertical_flip=True, 
                                        rescale=1./255)
test_data_gen = ImageDataGenerator(rescale=1./255)

train_data_aug = train_data_gen_aug.flow_from_directory(train_dir, 
                                                        target_size=(224, 224), 
                                                        class_mode='binary', 
                                                        batch_size=32, 
                                                        seed=42)
test_data = test_data_gen.flow_from_directory(test_dir, 
                                              target_size=(224, 224), 
                                              class_mode='binary', 
                                              batch_size=32, 
                                              seed=42)

返回:

Found 1500 images belonging to 2 classes.
Found 500 images belonging to 2 classes

当我探索它时,如下所示:


# Explore the data

train_images, train_labels = train_data.next()
train_images_aug, train_labels_aug = train_data_aug.next()
test_images_aug, test_labels_aug = test_data.next()
print('train_data:     ', len(train_data), train_images.shape, train_labels.shape)
print('train_data_aug: ', len(train_data_aug), train_images_aug.shape, train_labels_aug.shape)
print('test_data:      ', len(test_data), test_images_aug.shape, test_labels_aug.shape)

它返回:

train_data:      47 (32, 224, 224, 3) (32,)
train_data_aug:  47 (32, 224, 224, 3) (32,)
test_data:       16 (32, 224, 224, 3) (32,)

然后我构建并编译模型,并指定 bacth_size 作为 NoneInputLayer :

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.activations import relu, sigmoid
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import binary_accuracy

# create the model

model = Sequential()

# Add the input layer

INPUT_SHAPE = (224, 224, 3)
model.add(InputLayer(input_shape=INPUT_SHAPE, 
                       batch_size=None,  # I enetered the batch size here as None
                       ))

# Add the hidden layers

model.add(Conv2D(filters=10, 
                   kernel_size=3, 
                   strides=1, 
                   padding='valid',
                   activation=relu))
model.add(MaxPool2D(pool_size=(2, 2), strides=None, padding='valid'))

# Add the flatten layer

model.add(Flatten())

# Add the output layer

model.add(Dense(units=1, activation=sigmoid))

# Compile the model

model.compile(optimizer=Adam(), 
                loss=BinaryCrossentropy(),
                metrics=[binary_accuracy])

然后拟合模型,我将batch_size指定为none

history = model.fit(train_data_aug,
                    batch_size=None,  # bacth_size defined as None
                    epochs=5, 
                    verbose=1, 
                    validation_data=test_data, 
                    steps_per_epoch=len(train_data_aug), 
                    validation_steps=len(test_data))

该模型运行良好,仅在5个时期内进行训练时,具有良好的性能 val_binary_accuracy 百分之八十一。
在什么情况下,应使用另外两种情况下的批次大小,并且可以定义 batch_size 在他们所有人身上,还是会引起问题?

chy5wohz

chy5wohz1#

batch size是每次梯度更新的样本数。如果未指定,如model.fit()中所述,则默认值为32。但是,您的数据是以生成器的形式显示的,该生成器已经有批处理。因此,您不必指定批量大小。
从tensorflow文档:https://www.tensorflow.org/api_docs/python/tf/keras/model

Do not specify the batch_size if your data is in the form of datasets, generators, or 
keras.utils.Sequence instances (since they generate batches).

相关问题