为什么numba jit在numpy索引中报告错误?

dy1byipe  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(143)

下面是我的代码。它运行WELLw/o numba并抛出converting to object警告w/o (nopython=True)

@numba.jit(nopython=True)
def computation_numba(array_cmp: np.ndarray, y_hat: np.ndarray, y_df: np.ndarray):
    cold_start_ratio = 0.05
    opportunities = 0
    if array_cmp.shape[1] <= 100:
        idx = np.flip(np.argsort(y_hat))
        y_df = y_df[idx]
        opportunities = max(int(cold_start_ratio*len(y_df)), 1)
        pos = y_df[:opportunities]
        neg = y_df[-opportunities:]
    else:
        ys = np.append(array_cmp,np.stack((y_hat,y_df),axis=0),axis=1)
        # to remove outlier 5-sigma
        means = ys[0,:].mean()
        stds = ys[0,:].std()
        ys = ys[:,np.array(ys[0,:]<=means+5*stds)]
        ys = ys[:,np.array(ys[0,:]>=means-5*stds)]
        idx = np.flip(np.argsort(y_hat))
        ys = ys[:,idx]
        opportunities = int(0.1*ys.shape[1])
        pos_th = ys[0,opportunities]
        neg_th = ys[0,-opportunities]
        pos = ys[1,np.array(ys[0,:]>=pos_th)]
        neg = ys[1,np.array(ys[0,:]<=neg_th)]
    return ( pos.sum()-neg.sum() )

array_cmp = np.random.random([2,200])
y_hat = np.random.random(50)
y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

字符串
它返回以下错误:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[21], line 30
     28 y_hat = np.random.random(50)
     29 y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
---> 30 print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:
 
 >>> array(array(bool, 1d, C))
 
There are 4 candidate implementations:
      - Of which 4 did not match due to:
      Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 129.
        With argument(s): '(array(bool, 1d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: array(bool, 1d, C) not allowed in a homogeneous sequence
  raised from /usr/local/lib64/python3.9/site-packages/numba/core/typing/npydecl.py:488

During: resolving callee type: Function(<built-in function array>)
During: typing of call at /tmp/ipykernel_1655694/2732175936.py (16)

File "../../../../tmp/ipykernel_1655694/2732175936.py", line 16:
<source missing, REPL/exec in use?>


我试过this post,但它没有解决问题。我也读了this post,但没有想出如何适应我的情况。
--更新1--
如果我删除np.array类型强制转换(与this post相反),代码将返回类似的错误。主要区别是 *4候选人 * 变成 *22候选人 *。

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[26], line 30
     28 y_hat = np.random.random(50)
     29 y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
---> 30 print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 2d, C), Tuple(Literal[int](1), array(bool, 1d, C)))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 2d, C), Tuple(int64, array(bool, 1d, C)))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
        With argument(s): '(array(float64, 2d, C), Tuple(int64, array(bool, 1d, C)))':
       Rejected as the implementation raised a specific error:
         NumbaNotImplementedError: only one advanced index supported
  raised from /usr/local/lib64/python3.9/site-packages/numba/core/typing/arraydecl.py:69

During: typing of intrinsic-call at /tmp/ipykernel_1655694/1304761145.py (23)

File "../../../../tmp/ipykernel_1655694/1304761145.py", line 23:
<source missing, REPL/exec in use?>

  • 更新2-
    7月14日的代码帖子确实是我使用的,从这个截图可以看出:x1c 0d1x的数据
    在我删除np.array()之后,它仍然返回“22 candidate”错误。(截图中的错误与 * 更新1* 中的代码块几乎相同。Python 3.9 w/ numba==0.56.4numpy==1.23.5pandas==2.0.1
zbwhf8kr

zbwhf8kr1#

感谢@Rutger Kassies在评论中。问题来自版本冲突(bug?)。代码在numba==0.57.1上运行正常。

相关问题