我已经问过这个问题here,但我认为StackOverflow会有更多的流量/人可能知道答案。
我正在构建一个自定义keras层,类似于这里的一个例子。我希望类中的call
方法能够知道流经该方法的inputs
数据的batch_size
是什么,但在模型预测期间,inputs.shape
显示为(None, 3)
。下面是一个具体的例子:
我初始化了一个简单的数据集,如下所示:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)
X = pd.DataFrame(np.concatenate([
np.reshape(x1, (-1, 1)),
np.reshape(x2, (-1, 1)),
np.reshape(x3, (-1, 1)),
], axis=1))
然后我定义一个自定义类来测试/显示我所谈论的内容:
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
print(inputs)
record_count, n = inputs.shape
print(f'inputs.shape = {inputs.shape}')
return inputs
然后,当我创建一个简单的模型并强制它向前传递时...
input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])
...我将此输出打印到屏幕上
model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]:
array([[ 0.5335418 , 0.7788839 , 0.64132416],
[ 0.2924202 , -0.08321562, 0.412311 ],
[ 0.5118007 , -0.6822934 , 1.1782378 ],
[ 0.03780456, -0.19350041, 0.7637337 ],
[ 0.86494124, -3.196387 , 4.8535166 ],
[ 0.26708454, -0.49397194, 0.91296834],
[ 0.49734482, -1.6618049 , 0.50054324],
[ 0.8563762 , 0.7956695 , 0.29466265],
[ 0.7682351 , 0.86538637, 0.6633331 ],
[ 0.85322225, 0.868021 , 0.1776046 ]], dtype=float32)
可以看到,在model.predict
调用过程中,inputs.shape
输出了一个值(None, 3)
,但显然这不是真的,因为call
方法返回的输出具有(10, 3)
的形状。在call
方法中,如何在本例中捕获10
值?
更新1
当我按照当前答案中的建议使用tf.shape
时,我可以将该值打印到屏幕上,但当我试图在变量中捕获该值时,我会得到一个错误。
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
record_count, n = tf.shape(inputs)
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
此代码会导致record_count, ...
行出错。
Traceback (most recent call last):
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-22-104d812c32e6>", line 1, in <module>
test = TestClass()(input_layer)
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
File "<ipython-input-21-2dec1d5b9547>", line 12, in call *
record_count, n = tf.shape(inputs)
OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
• inputs=tf.Tensor(shape=(None, 3), dtype=float32)
我试着用@tf.function
修饰call
方法,但是我得到了同样的错误。
更新2
我尝试了一些其他的方法,发现奇怪的是,tensorflow似乎不喜欢元组赋值,如果像这样编码的话,它似乎可以正常工作。
class TestClass(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestClass, self).__init__(**kwargs)
def get_config(self):
config = super(TestClass, self).get_config()
return config
def call(self, inputs: tf.Tensor):
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
shape = tf.shape(inputs)
record_count = shape[0]
n = shape[1]
tf.print("Dynamic batch size", tf.shape(inputs)[0])
return inputs
1条答案
按热度按时间h79rfbju1#
TL;DR--〉如果您想在
call
方法中捕获动态批处理大小,请使用tf.shape(inputs)[0]
,或者您可以只使用静态批处理大小(可在模型创建中指定)。TensorFlow使用
tf.function
来修饰call
和__call__
(这是call
方法 * 调用的 *)方法。使用print
和.shape
将无法按预期工作。使用
tf.function
,可以跟踪python代码并将其转换为原生TensorFlow操作。然后,创建一个静态图,这只是tf.Graph的一个示例。最后,在该图中执行操作。Python的
print
函数只在第一步中考虑,所以这不是在图形模式中打印内容的正确方法(用tf.function
修饰)。Tensor形状在运行时是动态的,因此您需要使用
tf.shape(inputs)[0]
,它将为您提供该批的批大小。如果你真的想在
call
中看到那个10
:正在运行:
将返回: