有没有办法让TensorFlow的SISTTER_ND_UPDATE使用tf.int64类型的索引,而不是tf.int32?

ddrv8njm  于 2022-10-02  发布在  Python
关注(0)|答案(1)|浏览(150)

我正在使用一个tf.Variable,我需要使用SISTTER_ND_UPDATE函数更新它的值。问题是我的变量有一些元素等于512x512x2048x7=3 758 096 384,这比通过INT32(2147483647)转换的范围大。显然,SISTTER_ND_UPDATE函数只适用于tf.int32数字,而不适用于tf.int64。

以下是生成溢出错误的代码示例:

import tensorflow as tf

image = tf.Variable(tf.ones((512,512,2048,7), dtype=tf.float32))

update = tf.constant([0.0])
index = tf.constant([[0,0,0,0]], dtype=tf.int64)
image.scatter_nd_update(index, update)

print(f"image shape {tf.shape(image)}")

出现的错误如下:F./tensorflow/core/util/gpu_launch_config.h:129]检查失败:WORK_ELEMENT_COUNT>0(-536870912 vs.0)中止(核心转储)

你知道如何解决这个问题吗?

flvlnr44

flvlnr441#

你解出来了吗?我也遇到了这个问题。

相关问题