tensorflow 2.x - tf.cond cond的分支输出长度必须匹配

14ifxucb  于 2023-06-30  发布在  其他
关注(0)|答案(1)|浏览(131)

请参见下面的tensorflow 2.x代码,在图形模式下。du的梯度将是None,因为u不依赖于x。在我的应用程序中,du有时是None,有时是Tensor。因此,我希望在需要时将du设置为0.0,而不是None。我的代码不起作用,请参阅下面的错误。如何在TensorFlow图模式中正确实现这一点?

x = tf.Variable(1.0)
y = tf.Variable(1.0)

@tf.function
def func():
    with tf.GradientTape() as tape:
        u = y + 3.0

    du = tape.gradient(u, x)
    du = tf.cond(tf.cast(du == None, tf.bool), lambda: 0.0, lambda: du)
    return du

print(func())

错误:

ValueError: in user code:

    File "/tmp/ipykernel_5106/913668970.py", line 10, in func  *
        du = tf.cond(tf.cast(du == None, tf.bool), lambda: 0.0, lambda: du)
...
    ValueError: Lengths of branch outputs of cond must match.
    len(graphs[0].outputs): 1
    len(graphs[1].outputs): 0

请注意,下面的代码运行良好(du不是None,因为u = y + 3.0 + x):

x = tf.Variable(1.0)
y = tf.Variable(1.0)

@tf.function
def func():
    with tf.GradientTape() as tape:
        u = y + 3.0 + x

    du = tape.gradient(u, x)
    du = tf.cond(tf.cast(du == None, tf.bool), lambda: 0.0, lambda: du)
    return du

print(func())

输出:tf.Tensor(1.0, shape=(), dtype=float32)

sirbozc5

sirbozc51#

使用zero_if_none函数可以做到这一点:

x = tf.Variable(1.0)
y = tf.Variable(1.0)

@tf.function
def zero_if_none(du):
    if du == None:
        return 0.0
    else:
        return du

@tf.function
def func():
    with tf.GradientTape() as tape:
        u = y + 3.0

    du = tape.gradient(u, x)
    # du = tf.cond(tf.cast(du == None, tf.bool), lambda: 0.0, lambda: 0.0 + du)
    du = zero_if_none(du)

    return du

print(func())

相关问题