keras 如何从Rllib的PPO算法中获得一系列观测值的值函数/临界值?

anauzrmj  于 2023-01-30  发布在  其他
关注(0)|答案(1)|浏览(120)

**目标:**我想训练PPO代理处理某个问题,并确定其在某个观察范围内的最佳值函数。以后我计划使用此值函数(经济不平等研究)。该问题非常复杂,以至于动态编程技术不再起作用。
**方法:**为了检查我是否得到了正确的值函数输出,我用一个简单的问题训练了PPO,这个问题的解析解是已知的。然而,值函数的结果是垃圾,这就是为什么我怀疑我做错了什么。
密码:

from keras import backend as k_util
...

parser = argparse.ArgumentParser()

# Define framework to use
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf",
    help="The DL framework specifier.",
)
...

def get_rllib_config(seeds, debug=False, framework="tf") -> Dict:
...

def get_value_function(agent, min_state, max_state):
    policy = agent.get_policy()
    value_function = []
    for i in np.arange(min_state, max_state, 1):
        model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)})
        value = k_util.eval(policy.model.value_function())[0]
        value_function.append(value)
        print(i, value)
    return value_function

def train_schedule(config, reporter):
    rllib_config = config["config"]
    iterations = rllib_config.pop("training_iteration", 10)

    agent = PPOTrainer(env=rllib_config["env"], config=rllib_config)
    for _ in range(iterations):
        result = agent.train()
        reporter(**result)
    values = get_value_function(agent, 0, 100)
    print(values)
    agent.stop()

...

resources = PPO.default_resource_request(exp_config)
tune_analysis = tune.Tuner(tune.with_resources(train_schedule, resources=resources), param_space=exp_config).fit()
ray.shutdown()

首先,我得到策略(policy = agent.get_policy()),并对100个值(model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)}))中的每一个值运行一次前向传递,然后,在每次前向传递之后,我使用value_function()方法获得评论网络的输出,并通过keras后端计算Tensor。

结果:True VF (analytical solution)VF output of Rllib

不幸的是,你可以看到结果并不乐观。也许我错过了一个预处理或后处理步骤?value_function()方法甚至返回评论家网络的最后一层吗?
我非常感谢任何帮助!

kiz8lqtg

kiz8lqtg1#

它不是脚本的一部分,但是我假设您在尝试从策略中获取有用的值之前已经对策略进行了训练。
您假设value_function是正确的()返回RLlib实现中评论者网络最后一层的输出,请查看值函数度量,看看它是否真的学到了什么(RLlib记录.../learner_stats/vf_loss.../learner_stats/vf_explained_var)!训练完模型后,我还尝试直接查询模型。你贴在这里的代码可能有问题。

相关问题