tensorflow jit-compiled tfnp.take_along_axis shape bug

vwkv1x7d  于 22天前  发布在  其他
关注(0)|答案(6)|浏览(16)

问题类型

Bug

你是否在TensorFlow Nightly版本中复现了这个bug?

是的

问题来源

二进制文件

TensorFlow版本

tf_nightly-2.16.0.dev20231113-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64

自定义代码

是的

OS平台和发行版

colab

移动设备

无响应

Python版本

3.10

Bazel版本

无响应

GCC/编译器版本

无响应

CUDA/cuDNN版本

无响应

GPU型号和内存

无响应

当前行为?

colab demonstrating the issue,当与tf.function(jit_compile=True)一起使用时,返回具有不正确形状的Tensor。

独立代码以重现问题

import tensorflow as tf

x = tf.random.normal((5, 3, 2))
indices = tf.constant([[[-1]]], dtype="int32")

def f(x, i):
    return tf.squeeze(tf.experimental.numpy.take_along_axis(x, i, axis=-2), axis=-2)

z = f(x, indices)
print(z.shape)  # (5, 2)

z1 = tf.function(f, jit_compile=True)(x, indices) # errors
print(z1.shape)

相关日志输出

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-4-97d5eab12d23> in <cell line: 1>()
----> 1 z1 = tf.function(f, jit_compile=True)(x, indices)
      2 print(z1.shape)

1 frames
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51   try:
     52     ctx.ensure_initialized()
---> 53     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                         inputs, attrs, num_outputs)
     55   except core._NotOkStatusException as e:

InvalidArgumentError: Tried to explicitly squeeze dimension 1 but dimension was not 1: 2

Stack trace for op definition: 
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "<ipython-input-4-97d5eab12d23>", line 1, in <cell line: 1>
File "<ipython-input-2-4f365be98111>", line 8, in f

	 [[{{node Squeeze}}]]
	tf2xla conversion failed while converting __inference_f_283[_XlaMustCompile=true,config_proto=3175580994766145631,executor_type=11160318154034397263]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions. [Op:__inference_f_283]
axkjgtzd

axkjgtzd1#

你好,@jackd!

我能够在jit-compiled=True的情况下复现这个问题,但在2.15和tf-nightly的jit-compiled=False下无法复现。

请查看这个gist以供参考。

作为解决方法,你可以使用tf.function(jit_compile=False)或者不使用tfnp.take_along_axis来避免这类错误。你能确认一下吗?

谢谢!

mwkjh3gx

mwkjh3gx2#

这个问题已经过期,因为它已经开放了7天,没有活动。如果没有进一步的活动发生,它将被关闭。谢谢。

q43xntqr

q43xntqr4#

感谢jackd的确认。如果这个解决方法对你有效,我们是否可以将此问题转为已关闭状态?
谢谢!

yhqotfr8

yhqotfr85#

存在一个解决方法并不等同于修复了一个错误...

lb3vh1jj

lb3vh1jj6#

@sachinprasadhs 你能看一下这个问题吗?
谢谢!

相关问题