我想在jax中“转换”下面的tensorflow代码:
def mlp(L, n_list, activation, Cb, Cw):
model = tf.keras.Sequential()
kernel_initializers_list = []
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
for l in range(1, L):
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))
model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
bias_initializer = bias_initializer))
for l in range(1, L):
model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
bias_initializer = bias_initializer))
model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
bias_initializer = bias_initializer))
print(model.summary())
return model
在jax中,我可以用tensorflow的model.add()
的等价物在调用stax.serial()
时添加一个stax.Dense
吗?我该怎么做?
1条答案
按热度按时间vawmfj5a1#
是的,你可以。
这是我帮助你的推荐信。