numpy Q学习代码错误,而运行我怎么能修复它?

i2byvkas  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(108)

我正在尝试编写一个简单的Python程序,在OpenAI Gym Environment Frozen Lake上实现Q-Learning。我在数据营网站上找到了程序代码,您可以在下面找到代码和链接:
链接:Q_Learning_Code

import numpy as np
import gym
import random
from tqdm import trange

env = gym.make("FrozenLake-v1", render_mode="rgb_array")
env.reset()
env.render()

print("Observation Space", env.observation_space)
print("Sample Observation", env.observation_space.sample())

print("Action Space Shape", env.action_space.n)
print("Action Space Sample", env.action_space.sample())

state_space = env.observation_space.n
print("There are ", state_space, " possible states")

action_space = env.action_space.n
print("There are ", action_space, " possible actions")

def initialize_q_table(state_space, action_space):
    Qtable = np.zeros((state_space, action_space))
    return Qtable

Qtable_frozenlake = initialize_q_table(state_space, action_space)

def epsilon_greedy_policy(Qtable, state, epsilon):
    random_init = random.uniform(0, 1)
    if(random_init > epsilon):
        action = np.argmax(Qtable[state])
    else:
        action = env.action_space.sample()

    return action

def greedy_policy(Qtable, state):
    action = np.argmax(Qtable[state])
    return action

n_training_episodes = 10000
learning_rate = 0.7

n_eval_episodes = 100

env_id = "FrozenLake-v1"
max_steps = 99
gamma = 0.95
eval_seed = []

max_epsilon = 1.0
min_epsilon = 0.05
decay_rate = 0.0005

def train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable):
    for episode in trange(n_training_episodes):
    
        epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*episode)
    
        state = env.reset()
        step = 0
        done = False
    
        for step in range(max_steps):
        
            action = epsilon_greedy_policy(Qtable, state, epsilon)
        
            new_state, reward, done, trunc, info = env.step(action)
        
            Qtable[state][action] = Qtable[state][action] + learning_rate * (reward + gamma * np.max(Qtable[new_state]) - Qtable[state][action])
        
            if(done):
                break
        
            state = new_state
        
    return Qtable

Qtable_frozenlake = train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable_frozenlake)

当我运行程序时,我得到以下错误:
追溯(最近一次调用):

File "/tmp/ipykernel_15859/3962363982.py", line 80, in <module>
Qtable_frozenlake = train(n_training_episodes, min_epsilon, max_epsilon, decay_rate, env, max_steps, Qtable_frozenlake)

 File "/tmp/ipykernel_15859/3962363982.py", line 71, in train
Qtable[state][action] = Qtable[state][action] + learning_rate * (reward + gamma * np.max(Qtable[new_state]) - Qtable[state][action])

 IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

这个错误是什么意思?我如何修复这个错误?

643ylb08

643ylb081#

env.reset()通常返回stateinfo的元组。这里也是这样:

>> env.reset()
(0, {'prob': 1})

然后numpy无法识别元组作为索引的方式,因此将引发Exception。您希望做的是:

state, info = env.reset()
# or
state, _ = env.reset()
# or
state = env.reset()[0]

我想补充的另一点是,您还需要检查环境的截断。为了重置环境,您将需要检查if done or trunc,而不是检查if done。这可能与此场景无关,但对于其他环境来说是一个很好的实践。

相关问题