字符串上的Numpy“where”

cgfeq70w  于 2023-01-09  发布在  其他
关注(0)|答案(4)|浏览(180)

我想在字符串数组上使用numpy.where函数。但是,我没有成功。有人能帮我解决这个问题吗?
例如,当我在下面的例子中使用numpy.where时,我得到一个错误:

import numpy as np

A = ['apple', 'orange', 'apple', 'banana']

arr_index = np.where(A == 'apple',1,0)

我得到了以下结果:

>>> arr_index
array(0)
>>> print A[arr_index]
>>> apple

但是,我想知道字符串数组A中与字符串'apple'匹配的索引,在上面的字符串中,这发生在0和2处,但是,np.where只返回0而不返回2。
那么,如何让numpy.where处理字符串呢?先谢谢了。

1szpjjfi

1szpjjfi1#

print(a[arr_index])

不是array_index!!

a = np.array(['apple', 'orange', 'apple', 'banana'])

arr_index = np.where(a == 'apple')

print(arr_index)

print(a[arr_index])
ulydmbyx

ulydmbyx2#

问题是你需要使用数组而不是列表来正确地使用where(同样,使用True和False而不是1和0来获得一个掩码来查找索引):

A = ['apple', 'orange', 'apple', 'banana']
arr_mask = np.where(np.array(A) == 'apple',True,False)
arr_index = np.arange(0, len(A))[arr_mask]

这样,您将得到arr_index:n个数组([0,2])
注意,要使用掩码arr_mask或索引arr_index查找A中的值,A必须是一个数组:

In [55]: A = ['apple', 'orange', 'apple', 'banana'] 
    ...: arr_mask = np.where(np.array(A) == 'apple',True,False) 
    ...: arr_index = np.arange(0, len(A))[arr_mask]                                                             

In [56]: A[arr_mask]                                                                                            
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-56-f8b153319425> in <module>
----> 1 A[arr_mask]

TypeError: only integer scalar arrays can be converted to a scalar index

In [57]: A[arr_index]                                                                                           
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-57-91c260fe71ab> in <module>
----> 1 A[arr_index]

TypeError: only integer scalar arrays can be converted to a scalar index

In [58]: B = np.array(A)                                                                                        

In [59]: B[arr_mask]                                                                                            
Out[59]: array(['apple', 'apple'], dtype='<U6')

In [60]: B[arr_index]                                                                                           
Out[60]: array(['apple', 'apple'], dtype='<U6')

如果只使用列表,函数np.where()找不到满足条件的地方,如果尝试:

A = ['apple', 'orange', 'apple', 'banana']
arr_index = np.where(A == 'orange',1,0)

您将再次获得array(0)作为输出。

at0kjp5o

at0kjp5o3#

还有另一种方法:

def GetIndexOfStr(npArray,theStr): 
    #npArray is from type of numpy.ndarray where each item is of type np.str
    return np.where(npArray == theStr)[0][0]

A = np.array(['apple', 'orange', 'apple', 'banana'])
print(A[GetIndexOfStr(A,"apple")]) # ==> this will result in "apple"
print(A[GetIndexOfStr(A,"appleX")]) # ==> this will throw IndexError
li9yvcax

li9yvcax4#

我认为更简单的方法是:

A = np.array(['apple', 'orange', 'apple', 'banana'])
arr_index = np.where(A == 'apple')
print(arr_index)

你会得到:

(array([0, 2]),)

相关问题