Paddle张量slice报错

lsmepo6l  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(20)

bug描述 Describe the Bug

paddle 对 张量slice失败

Device:

NVIDIA A100 40GB

Cuda:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

code

import paddle
x = paddle.zeros([1, 64, 72, 72, 37])
slices_x = [slice(None, None, None), slice(None, None, None), slice(20, -20, None), slice(20, -20, None), slice(None, -20, None)]
print("\n slices_x", slices_x)
print("\n x", x.shape)
# print("\n x[slices_x]", x[slices_x])

log

grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
W0618 06:54:43.306851  2740 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 12.0, Runtime API Version: 11.6
W0618 06:54:43.314005  2740 gpu_resources.cc:164] device: 0, cuDNN Version: 8.4.
 slices_x [slice(None, None, None), slice(None, None, None), slice(20, -20, None), slice(20, -20, None), slice(None, -20, None)]
 x [1, 64, 72, 72, 37]
Traceback (most recent call last):
  File "/workspace/wangguan/gino_paddle_ahmed/gino/src/neuralop/models/test_tensorslice.py", line 6, in <module>
    print("\n x[slices_x]", x[slices_x])
  File "/root/miniconda3/envs/gino_conda/lib/python3.9/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 986, in __getitem__
    item = pre_deal_index(self, item)
  File "/root/miniconda3/envs/gino_conda/lib/python3.9/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 979, in pre_deal_index
    item[i] = paddle.to_tensor(slice_item)
  File "/root/miniconda3/envs/gino_conda/lib/python3.9/site-packages/paddle/tensor/creation.py", line 806, in to_tensor
    return _to_tensor_non_static(data, dtype, place, stop_gradient)
  File "/root/miniconda3/envs/gino_conda/lib/python3.9/site-packages/paddle/tensor/creation.py", line 588, in _to_tensor_non_static
    raise ValueError(
ValueError: 
        Failed to convert input data to a regular ndarray :
         - Usually this means the input data contains nested lists with different lengths.

相关问题