我正在尝试运行以下代码。下面的代码在GoogleColab上运行良好,但是在我的系统上它抛出了一个错误。我的系统上安装的tensorflow版本是1.12.0,keras版本是2.2.4。非常感谢您的帮助。
def profiler(layer, test_input):
data_input = test_input
start = time.time()
data_input = layer.predict(data_input)
end = time.time() - start
milliseconds = end * 1000
return milliseconds
def dense_layer(input_dim, dense_size):
x = tf.keras.layers.Input((input_dim))
dense = tf.keras.layers.Dense(dense_size)(x)
model = tf.keras.models.Model(inputs=x, outputs=dense)
return model
def process_config(config):
tokens = config.split(",")
values = []
for token in tokens:
token = token.strip()
if token.find("-") == -1:
token = int(token)
values.append(token)
else:
start,end = token.split("-")
start = int(start.strip())
end = int(end.strip())
values = values + list(range(start,end+1))
return values
def evaluate_dense(input_shapes_range, dense_size_range):
for input_shape in input_shapes_range:
for dense_size in dense_size_range:
to_write = open("dense_data.csv", "a+")
model = dense_layer(input_shape, dense_size)
random_input = np.random.randn(1, input_shape)
running_time = profiler(model, random_input)
del model
input_size = "2000"
dense_size = "1000, 4096"
input_size_range = process_config(input_size)
dense_size_range = process_config(dense_size)
evaluate_dense(input_size_range, dense_size_range)
错误跟踪
File "C:/Users/Dense-layer.py", line 59, in <module>
evaluate_dense(input_size_range, dense_size_range)
File "C:/Users/Dense-layer.py", line 44, in evaluate_dense
model = dense_layer(input_shape, dense_size)
File "C:/Users/Dense-layer.py", line 16, in dense_layer
x = tf.keras.layers.Input((input_dim))
File "C:\Users\learn\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\engine\input_layer.py", line 229, in Input
input_tensor=tensor)
File "C:\Users\learn\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\engine\input_layer.py", line 91, in __init__
batch_input_shape = (batch_size,) + tuple(input_shape)
TypeError: 'int' object is not iterable
1条答案
按热度按时间zzzyeukh1#
input_shape
应该是一个元组,但是input_dim
是一个整数。你通过了input_dim
,并且由于您没有按名称指定它,因此它将其视为input_shape
. 因此,只需按名称指定它:或者,如果要指定形状,请按如下方式使用: