python 为什么当我尝试用另一个数组索引一个numpy ndarray时会抛出IndexError?

o8x7eapl  于 2022-12-17  发布在  Python
关注(0)|答案(1)|浏览(111)

我有一个5维的ndarray,叫做self.q_table。我有一个正则数组,长度是4。当我试图找出该行的最大值时,就像这样...

max = np.max(self.q_table[regular_array])

...我得到一个IndexError,即使regular_array的元素小于q_table的维度。
我尝试改变两个数组的维数,但效果并没有变好。
编辑:
错误为IndexError: index 11 is out of bounds for axis 0 with size 10
Numpy作为np导入,本示例中的最后一行抛出错误。

class AgentBase:
def __init__(self, observation_space):
    OBSERVATION_SPACE_RESOLUTION = [10, 15, 15, 15]
    self.q_table = np.zeros([*OBSERVATION_SPACE_RESOLUTION, 4])
    max_val = np.max(self.q_table[self.quantize_state(observation_space, [-150, 100, 3, 3])])
    print(max_val)

@staticmethod
def quantize_state(observation_space, state):
    state_quantized = np.zeros(len(state))
    lin_spaces = []
    for i in range(len(observation_space)):
        lin_spaces.append(np.linspace(observation_space[i][0], observation_space[i][1],
                                     OBSERVATION_SPACE_RESOLUTION[i] - 1, dtype=int))
    for i in range(len(lin_spaces)):
        state_quantized[i] = np.digitize(state[i], lin_spaces[i])
    return state_quantized.astype(int)

self.observation_space是具有以下值的参数:

px9o7tmv

px9o7tmv1#

当使用列表(或者引用numpy docs的话,“一个非元组序列对象”)索引ndarray时,您调用了高级索引,因此当使用[3,7,11,11]索引时,它会尝试将所有这些值应用于第一个维度。
我同意@hpaulj的评论,你可能想用元组来索引。

相关问题