图2.5
@tf.function
def weight_fn():
w = tf.Variable(tf.truncated_normal())
我有一个类似上面的函数,它被调用了大约50次,每次都应该生成一个新的变量并返回。
ValueError: A tf.Variable created inside your tf.function has been garbage-collected. Your code needs to keep Python references to variables created inside `tf.function`s.
A common way to raise this error is to create and return a variable only referenced inside your function:
@tf.function
def f():
v = tf.Variable(1.0)
return v
v = f() # Crashes with this error message!
The reason this crashes is that @tf.function annotated function returns a**`tf.Tensor`**with the**value**of the variable when the function is called rather than the variable instance itself. As such there is no code holding a reference to the `v` created inside the function and Python garbage collects it.
The simplest way to fix this issue is to create variables outside the function and capture them:
v = tf.Variable(1.0)
@tf.function
def f():
return v
f() # <tf.Tensor: numpy=1.>
v.assign_add(1.)
f() # <tf.Tensor: numpy=2.>
我应该在tf.function之外定义weight变量,这意味着我应该手动定义50多个weight变量,每行都有一个weight变量。
w1 = tf.Variable(tf.truncated_normal())
w2 = tf.Variable(tf.truncated_normal())
w3 = tf.Variable(tf.truncated_normal())
......
w50 = tf.Variable(tf.truncated_normal())
无疑,这种行为实在愚蠢,对这种不合理的统治有什么解决办法吗?
3条答案
按热度按时间fgw7neuy1#
不确定您是否看到了#49310(注解),但
ALLOW_DYNAMIC_VARIABLE_CREATION
可能会有所帮助:(未测试)zf9nrax12#
@sumanthratna谢谢你的提醒。你真是太好了。看来这是我们目前能得到的最好的解决方案了。
carvr3hs3#
/cc @管理员