python 使用稳定基线3和具有Dict观测空间的自定义环境创建PPO模型时出错

xjreopfe  于 2023-10-14  发布在  Python
关注(0)|答案(1)|浏览(143)

我一直在尝试使用稳定的基线3和通过稳定基线环境检查的自定义环境来训练PPO模型。我似乎找不到任何真正链接回我的代码,所以我不知道到底发生了什么。
错误类型:

TypeError                                 Traceback (most recent call last)
Cell In[7], line 4
      1 env = DominoTrainEnv(6)
      3 # Initialize the PPO agent
----> 4 model = PPO("MultiInputPolicy", env)
      6 # Train the agent
      7 model.learn(total_timesteps=10000)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py:164, in PPO.__init__(self, policy, env, learning_rate, n_steps, batch_size, n_epochs, gamma, gae_lambda, clip_range, clip_range_vf, normalize_advantage, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, target_kl, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
    161 self.target_kl = target_kl
    163 if _init_setup_model:
--> 164     self._setup_model()

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py:167, in PPO._setup_model(self)
    166 def _setup_model(self) -> None:
--> 167     super()._setup_model()
    169     # Initialize schedules for policy/value clipping
    170     self.clip_range = get_schedule_fn(self.clip_range)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py:123, in OnPolicyAlgorithm._setup_model(self)
    113 self.rollout_buffer = buffer_cls(
    114     self.n_steps,
    115     self.observation_space,
   (...)
    120     n_envs=self.n_envs,
    121 )
    122 # pytype:disable=not-instantiable
--> 123 self.policy = self.policy_class(  # type: ignore[assignment]
    124     self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
    125 )
    126 # pytype:enable=not-instantiable
    127 self.policy = self.policy.to(self.device)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:853, in MultiInputActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    833 def __init__(
    834     self,
    835     observation_space: spaces.Dict,
   (...)
    851     optimizer_kwargs: Optional[Dict[str, Any]] = None,
    852 ):
--> 853     super().__init__(
    854         observation_space,
    855         action_space,
    856         lr_schedule,
    857         net_arch,
    858         activation_fn,
    859         ortho_init,
    860         use_sde,
    861         log_std_init,
    862         full_std,
    863         use_expln,
    864         squash_output,
    865         features_extractor_class,
    866         features_extractor_kwargs,
    867         share_features_extractor,
    868         normalize_images,
    869         optimizer_class,
    870         optimizer_kwargs,
    871     )

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:507, in ActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    504 # Action distribution
    505 self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
--> 507 self._build(lr_schedule)

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\policies.py:577, in ActorCriticPolicy._build(self, lr_schedule)
    573     self.action_net, self.log_std = self.action_dist.proba_distribution_net(
    574         latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
    575     )
    576 elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
--> 577     self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
    578 else:
    579     raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")

File ~\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\distributions.py:336, in MultiCategoricalDistribution.proba_distribution_net(self, latent_dim)
    325 def proba_distribution_net(self, latent_dim: int) -> nn.Module:
    326     """
    327     Create the layer that represents the distribution:
    328     it will be the logits (flattened) of the MultiCategorical distribution.
   (...)
    333     :return:
    334     """
--> 336     action_logits = nn.Linear(latent_dim, sum(self.action_dims))
    337     return action_logits

File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\linear.py:96, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
     94 self.in_features = in_features
     95 self.out_features = out_features
---> 96 self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
     97 if bias:
     98     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

我尝试使用与PPO不同的模型,如A2C,但仍然会导致相同的错误,我使用MultiInputPolicy,因为我的环境有一个Dict用于其观察空间。
从进一步的测试来看,它似乎与同时拥有Dict观察空间和MultiDiscrete动作空间有关。
最小、可重现示例:

import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import Discrete, MultiDiscrete, Box, Dict
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

class TestEnv(Env):
    def __init__(self,numPlayers:int):
        # Actions we can take, 13,13 for possible domino sides, [9,13] for possible domino placements
        self.action_space = MultiDiscrete(np.array([[13, 13], [9, 13]]))
        # observation space
        obsv =  {
        "hand": Box(high=np.array([[13, 13]*79]), dtype=np.int8,low=np.array([[-1, -1]*79]))
        }
        self.observation_space = Dict(obsv)
        self.state = self.getState()
    @staticmethod
    def __padArray(array,len:int):
        return np.pad(array,((0,len),(0,0)),mode='constant',constant_values=-1)
    def getState(self):
        #array values are placeholders
        handarray = np.array([(11,11),(12,11)], dtype=np.int8)
        
        hand_padding = TestEnv.__padArray(handarray, 79-len(handarray))
        state =  {
            "hand": hand_padding.ravel().reshape((1,158)),
        }
        return state
    def step(self, action):
        self.state = self.getState()
        # Return step information
        reward = 1
        done = False
        info = {}
        return self.state, reward, done,False, info

    def render(self):
        # Implement viz
        pass
    
    def reset(self, seed=None):
        self.state = self.getState()
        return self.state, {}

env = TestEnv(6)
print(check_env(env,skip_render_check=True))
model = PPO("MultiInputPolicy", env)
rjzwgtxy

rjzwgtxy1#

问题是MultiDiscrete的形状。Sb3不支持2D数组,环境检查无法捕获,因此我不得不将

MultiDiscrete([[13, 13], [9, 13]])

MultiDiscrete([13, 13, 9, 13])

相关问题