问题类型
Bug
你是否在TensorFlow Nightly版本中复现了这个bug?
否
问题来源
source
TensorFlow版本
2.14.0
自定义代码
是
OS平台和发行版
- 无响应*
移动设备
- 无响应*
Python版本
- 无响应*
Bazel版本
- 无响应*
GCC/编译器版本
- 无响应*
CUDA/cuDNN版本
11.8
GPU型号和内存
GPU 0: NVIDIA GeForce RTX 2070 GPU 1: NVIDIA GeForce RTX 2070 GPU 2: NVIDIA GeForce RTX 2070 GPU 3: NVIDIA GeForce RTX 2070
当前行为?
当tf.raw_ops.SqrtGrad操作在一个启用了JIT编译的tf.function中被调用(jit_compile=True)时,它产生的结果与没有启用JIT编译的相同操作产生的结果不同。这种不一致性在GPU设备上执行代码时被观察到。
重现问题的独立代码
import tensorflow as tf
import traceback
class Network(tf.Module):
def __init__(self):
super().__init__()
@tf.function(jit_compile=True)
def __call__(self, x):
real_part = tf.random.normal([], dtype=tf.float64)
imag_part = tf.random.normal([], dtype=tf.float64)
tensor = tf.complex(real_part, imag_part)
tensor = tf.cast(tensor,dtype=tf.complex128)
x = tf.raw_ops.SqrtGrad(y=x, dy=tensor)
return x
m = Network()
real_part = tf.random.normal([], dtype=tf.float64)
imag_part = tf.random.normal([], dtype=tf.float64)
tensor = tf.complex(real_part, imag_part)
tensor = tf.cast(tensor,dtype=tf.complex128)
inp = {
"x": tensor,
}
with tf.device('/GPU:0'):
tf.config.run_functions_eagerly(True)
no_op_res = m(**inp)
tf.config.run_functions_eagerly(False)
with tf.device('/GPU:0'):
op_res = m(**inp)
tf.debugging.assert_near(tf.cast(no_op_res, tf.float64), tf.cast(op_res, tf.float64), atol=0.001, rtol=0.001)
相关日志输出
File "/home/guihuan/LLM/results/tf-2/2023-10-22-20-21/test.py", line 33, in <module>
tf.debugging.assert_near(tf.cast(no_op_res, tf.float64), tf.cast(op_res, tf.float64), atol=0.001, rtol=0.001)
File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/guihuan/.conda/envs/night/lib/python3.9/site-packages/tensorflow/python/ops/control_flow_assert.py", line 102, in Assert
raise errors.InvalidArgumentError(
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b''
b'x and y not equal to tolerance rtol = tf.Tensor(0.001, shape=(), dtype=float64), atol = tf.Tensor(0.001, shape=(), dtype=float64)'
b'x (shape=() dtype=float64) = '
-0.006697387971180855
b'y (shape=() dtype=float64) = '
0.07167101474792367
2条答案
按热度按时间edqdpe6u1#
你好,@zoux1a!
我能够使用jit_compile=True和jit_compile=False来复现这个问题。在这里,我附上了一张gist的图片。
谢谢!
xt0899hw2#
似乎与tf.function有关,因为在有和没有
jit_compile
的情况下都无法复制。