tensorflow :自定义层/渐变导致操作符不允许在图形中出现错误:不允许对'tf.Tensor'进行迭代

qq24tv8q  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(171)

我正在尝试实现一个自定义图层,该图层具有遵循规范引用herehere的自定义渐变
由于某种原因,我的代码引发了以下错误:
图形中不允许运算符错误:不允许对tf.Tensor进行迭代:AutoGraph已转换此函数。这可能表示您正在尝试使用不受支持的功能
我的MWE如下:

import tensorflow as tf
from tensorflow import keras
import sys

print("Python version")
print (sys.version)
print("Version info.")
print (sys.version_info)
print("Tensorflow version")
print(tf.__version__)

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )

    @tf.custom_gradient
    def call(self, inputs):
        def grad(dy, variables=None):
            return tf.matmul(inputs, dy)
        return tf.matmul(inputs, self.w), grad

model = tf.keras.models.Sequential([
    Linear(1),
])
model.compile(optimizer='sgd',loss='mean_squared_error')

xs = tf.constant([[-1.0],  [0.0], [1.0], [2.0], [3.0], [4.0]], dtype=float) 
print(model(xs))

ys = tf.constant([[-3.0], [-1.0], [1.0], [3.0], [5.0], [7.0]], dtype=float) 

model.fit(xs, ys, epochs=10)

输出为:

Python version
3.9.10 (v3.9.10:f2f3f53782, Jan 13 2022, 17:02:14) 
[Clang 6.0 (clang-600.0.57)]
Version info.
sys.version_info(major=3, minor=9, micro=10, releaselevel='final', serial=0)
Tensorflow version
2.7.0
2022-11-10 17:21:03.514995: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
tf.Tensor(
[[ 0.02415443]
 [-0.        ]
 [-0.02415443]
 [-0.04830886]
 [-0.07246329]
 [-0.09661772]], shape=(6, 1), dtype=float32)
Epoch 1/10
Traceback (most recent call last):
  File "question.py", line 41, in <module>
    model.fit(xs, ys, epochs=10)  
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1129, in autograph_handler
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:

File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/engine/training.py", line 878, in train_function  *
    return step_function(self, iterator)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/engine/training.py", line 867, in step_function  **
    outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/engine/training.py", line 860, in run_step  **
    outputs = model.train_step(data)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/engine/training.py", line 816, in train_step
    self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 530, in minimize
    grads_and_vars = self._compute_gradients(
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 583, in _compute_gradients
    grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 464, in _get_gradients
    grads = tape.gradient(loss, var_list, grad_loss)

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
a8jjtwal

a8jjtwal1#

这里的问题是@tf.custom_gradients需要返回两个变量,dx的梯度和变量的梯度,您只返回了dx_ part而不是变量的梯度,我已经修复了这个问题,试试这个...
第一个

相关问题