WGAN-GP在Tensorflow中的实现

pgx2nnw8  于 2023-03-24  发布在  其他
关注(0)|答案(1)|浏览(112)

使用tensorflow,我尝试重新实现以下架构(现在我专注于 Generator 部分):

我现在所做的就是用下面的方式定义生成器:

N_Z = 128

generator = [
    tf.keras.layers.Dense(units=6144, activation="relu"),
    tf.keras.layers.Reshape(target_shape=(6, 4, 256)),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(5,5), strides=(2, 2), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
     tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    )
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    )
    tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    )
]

Generator = tf.keras.models.Sequential(generator)

但是如果我取一些随机噪声,让模型处理它,这就是我得到的最终形状:

noise = tf.random.normal((64,128))

result = Generator(noise)

result.shape

TensorShape([64, 28, 28, 1])

我在这里做错了什么?我也检查了original implementation以查看其他细节,但我找不到任何让我理解的东西。

pftdvrlh

pftdvrlh1#

很容易你需要看到输入输出,它需要一些帮助,在最高级别。

[样品]:

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=( 6144 )),
    tf.keras.layers.Dense( 48 * 128, activation="linear" ),
    tf.keras.layers.BatchNormalization( momentum=0.99, epsilon=0.00001 ),
    tf.keras.layers.Reshape(target_shape=( 6, 4, 256 )),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(5,5), strides=(2, 2), padding="same", activation="relu"
    ),
    tf.keras.layers.Resizing( 11, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(11, 8, 128)),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(22, 8, 128)),
    tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(22, 8, 64)),
     tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(43, 8, 64)),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(43, 8, 32)),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 85, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(85, 8, 32)),

])

model.summary()

[输出]:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 dense (Dense)               (None, 6144)              37754880

 batch_normalization (BatchN  (None, 6144)             24576
 ormalization)

 reshape (Reshape)           (None, 6, 4, 256)         0

 conv2d_transpose (Conv2DTra  (None, 12, 8, 128)       819328
 nspose)

 resizing (Resizing)         (None, 11, 8, 128)        0

 reshape_1 (Reshape)         (None, 11, 8, 128)        0

 conv2d_transpose_1 (Conv2DT  (None, 22, 8, 128)       147584
 ranspose)

 resizing_1 (Resizing)       (None, 22, 8, 128)        0

 reshape_2 (Reshape)         (None, 22, 8, 128)        0

 conv2d_transpose_2 (Conv2DT  (None, 22, 8, 64)        73792
 ranspose)

 resizing_2 (Resizing)       (None, 22, 8, 64)         0

 reshape_3 (Reshape)         (None, 22, 8, 64)         0

 conv2d_transpose_3 (Conv2DT  (None, 44, 8, 64)        36928
 ranspose)

 resizing_3 (Resizing)       (None, 43, 8, 64)         0

 reshape_4 (Reshape)         (None, 43, 8, 64)         0

 conv2d_transpose_4 (Conv2DT  (None, 43, 8, 32)        18464
 ranspose)

 resizing_4 (Resizing)       (None, 43, 8, 32)         0

 reshape_5 (Reshape)         (None, 43, 8, 32)         0

 conv2d_transpose_5 (Conv2DT  (None, 86, 8, 32)        9248
 ranspose)

 resizing_5 (Resizing)       (None, 85, 8, 32)         0

 reshape_6 (Reshape)         (None, 85, 8, 32)         0

=================================================================
Total params: 38,884,800
Trainable params: 38,872,512
Non-trainable params: 12,288
_________________________________________________________________
2022-04-03 03:37:10.354570: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
(1, 85, 8, 32)
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000

相关问题