在tensorflow过程中确定批量大小,keras自定义类调用方法

2w3rbyxf  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(187)

我已经问过这个问题here,但我认为StackOverflow会有更多的流量/人可能知道答案。
我正在构建一个自定义keras层,类似于这里的一个例子。我希望类中的call方法能够知道流经该方法的inputs数据的batch_size是什么,但在模型预测期间,inputs.shape显示为(None, 3)。下面是一个具体的例子:
我初始化了一个简单的数据集,如下所示:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model

# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)

X = pd.DataFrame(np.concatenate([
    np.reshape(x1, (-1, 1)),
    np.reshape(x2, (-1, 1)),
    np.reshape(x3, (-1, 1)),
], axis=1))

然后我定义一个自定义类来测试/显示我所谈论的内容:

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)

        print(inputs)
        record_count, n = inputs.shape
        print(f'inputs.shape = {inputs.shape}')

        return inputs

然后,当我创建一个简单的模型并强制它向前传递时...

input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])

...我将此输出打印到屏幕上

model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]: 
array([[ 0.5335418 ,  0.7788839 ,  0.64132416],
       [ 0.2924202 , -0.08321562,  0.412311  ],
       [ 0.5118007 , -0.6822934 ,  1.1782378 ],
       [ 0.03780456, -0.19350041,  0.7637337 ],
       [ 0.86494124, -3.196387  ,  4.8535166 ],
       [ 0.26708454, -0.49397194,  0.91296834],
       [ 0.49734482, -1.6618049 ,  0.50054324],
       [ 0.8563762 ,  0.7956695 ,  0.29466265],
       [ 0.7682351 ,  0.86538637,  0.6633331 ],
       [ 0.85322225,  0.868021  ,  0.1776046 ]], dtype=float32)

可以看到,在model.predict调用过程中,inputs.shape输出了一个值(None, 3),但显然这不是真的,因为call方法返回的输出具有(10, 3)的形状。在call方法中,如何在本例中捕获10值?

更新1

当我按照当前答案中的建议使用tf.shape时,我可以将该值打印到屏幕上,但当我试图在变量中捕获该值时,我会得到一个错误。

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        record_count, n = tf.shape(inputs)
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

此代码会导致record_count, ...行出错。

Traceback (most recent call last):
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-22-104d812c32e6>", line 1, in <module>
    test = TestClass()(input_layer)
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
    File "<ipython-input-21-2dec1d5b9547>", line 12, in call  *
        record_count, n = tf.shape(inputs)
    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
  • inputs=tf.Tensor(shape=(None, 3), dtype=float32)

我试着用@tf.function修饰call方法,但是我得到了同样的错误。

更新2

我尝试了一些其他的方法,发现奇怪的是,tensorflow似乎不喜欢元组赋值,如果像这样编码的话,它似乎可以正常工作。

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        shape = tf.shape(inputs)
        record_count = shape[0]
        n = shape[1]
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs
h79rfbju

h79rfbju1#

TL;DR--〉如果您想在call方法中捕获动态批处理大小,请使用tf.shape(inputs)[0],或者您可以只使用静态批处理大小(可在模型创建中指定)。

TensorFlow使用tf.function来修饰call__call__(这是call方法 * 调用的 *)方法。使用print.shape将无法按预期工作。
使用tf.function,可以跟踪python代码并将其转换为原生TensorFlow操作。然后,创建一个静态图,这只是tf.Graph的一个示例。最后,在该图中执行操作。
Python的print函数只在第一步中考虑,所以这不是在图形模式中打印内容的正确方法(用tf.function修饰)。
Tensor形状在运行时是动态的,因此您需要使用tf.shape(inputs)[0],它将为您提供该批的批大小。
如果你真的想在call中看到那个10

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

正在运行:

input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])

将返回:

Dynamic batch size 10
1/1 [==============================] - 0s 65ms/step
array([[ 6.9646919e-01, -1.0032653e-02,  3.7556963e+00],
       [ 2.8613934e-01, -8.4564441e-01,  9.9685013e-01],
       [ 2.2685145e-01,  9.1146064e-01,  6.5008003e-01],
       [ 5.5131477e-01, -1.3744969e+00,  8.6379850e-01],
       [ 7.1946895e-01, -5.4706562e-01,  3.1904945e+00],
       [ 4.2310646e-01, -7.5526608e-05,  5.2649558e-01],
       [ 9.8076421e-01, -1.2116680e-01,  7.4064606e-01],
       [ 6.8482971e-01, -2.0085855e+00,  5.3138912e-01],
       [ 4.8093191e-01, -9.2064655e-01,  8.1520426e-01],
       [ 3.9211753e-01,  1.6823435e-01,  1.2382457e+00]], dtype=float32)

相关问题