tensorflow 无法在用@tf.function修饰的函数内创建tf.variables

mwngjboj  于 2022-10-29  发布在  其他
关注(0)|答案(3)|浏览(206)

图2.5

@tf.function
def weight_fn():
    w = tf.Variable(tf.truncated_normal())

我有一个类似上面的函数,它被调用了大约50次,每次都应该生成一个新的变量并返回。

ValueError: A tf.Variable created inside your tf.function has been garbage-collected. Your code needs to keep Python references to variables created inside `tf.function`s.

    A common way to raise this error is to create and return a variable only referenced inside your function:

    @tf.function
    def f():
      v = tf.Variable(1.0)
      return v

    v = f()  # Crashes with this error message!

    The reason this crashes is that @tf.function annotated function returns a**`tf.Tensor`**with the**value**of the variable when the function is called rather than the variable instance itself. As such there is no code holding a reference to the `v` created inside the function and Python garbage collects it.

    The simplest way to fix this issue is to create variables outside the function and capture them:

    v = tf.Variable(1.0)

    @tf.function
    def f():
      return v

    f()  # <tf.Tensor: numpy=1.>
    v.assign_add(1.)
    f()  # <tf.Tensor: numpy=2.>

我应该在tf.function之外定义weight变量,这意味着我应该手动定义50多个weight变量,每行都有一个weight变量。

w1 = tf.Variable(tf.truncated_normal())
w2 = tf.Variable(tf.truncated_normal())
w3 = tf.Variable(tf.truncated_normal())
......
w50 = tf.Variable(tf.truncated_normal())

无疑,这种行为实在愚蠢,对这种不合理的统治有什么解决办法吗?

fgw7neuy

fgw7neuy1#

不确定您是否看到了#49310(注解),但ALLOW_DYNAMIC_VARIABLE_CREATION可能会有所帮助:(未测试)

from tensorflow.python.eager import def_function  # def_function.function is the same as tf.function
from tensorflow.python.ops import variables

def_function.ALLOW_DYNAMIC_VARIABLE_CREATION = True

vars = {}

@def_function.function
def weight_fn(val, key):
    if key not in vars:
      vars[key] = variables.Variable(val)

weights = [weight_fn(tf.truncated_normal(), f"w{ind+1}") for ind in range(50)]
zf9nrax1

zf9nrax12#

@sumanthratna谢谢你的提醒。你真是太好了。看来这是我们目前能得到的最好的解决方案了。

相关问题