tensorflow 向stax.serial对象添加新图层

t40tm48m  于 2023-03-19  发布在  其他
关注(0)|答案(1)|浏览(127)

我想在jax中“转换”下面的tensorflow代码:

def mlp(L, n_list, activation, Cb, Cw):
    model = tf.keras.Sequential()

    kernel_initializers_list = []
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
    for l in range(1, L): 
        kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
    bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))

    model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
          bias_initializer = bias_initializer))
    for l in range(1, L): 
        model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
              bias_initializer = bias_initializer))
    model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
              bias_initializer = bias_initializer))
    print(model.summary())
    return model

在jax中,我可以用tensorflow的model.add()的等价物在调用stax.serial()时添加一个stax.Dense吗?我该怎么做?

vawmfj5a

vawmfj5a1#

是的,你可以。

#Create new model by jax
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'),
    Relu,
    Conv(64, (3, 3), padding='SAME'),
    Relu,
    Conv(128, (3, 3), padding='SAME'),
    Relu,
    Conv(256, (3, 3), padding='SAME'),
    Relu,
    MaxPool((2, 2)),
    Flatten,
    Dense(128),
    Relu,
    Dense(10),
    LogSoftmax,
)

net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))    

#Feedfoward
inputs, targets = batch_data
net_apply(params, inputs)

这是我帮助你的推荐信。

相关问题