Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO doesn't work with MultiDiscrete observation space #1836

Open
5 tasks done
elisavio opened this issue Feb 13, 2024 · 6 comments
Open
5 tasks done

PPO doesn't work with MultiDiscrete observation space #1836

elisavio opened this issue Feb 13, 2024 · 6 comments
Labels
custom gym env Issue related to Custom Gym Env documentation Improvements or additions to documentation help wanted Help from contributors is welcomed

Comments

@elisavio
Copy link

馃悰 Bug

I am implementing a simple custom environment for using PPO with MultiDiscrete observation space.
It works if I use MultiDiscrete([ 5, 2, 2 ]), but when it becomes a multidimensional array it fails. In the code I attach I am using the MultiDiscrete observation given as example in https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiDiscrete .

Code example

import numpy as np
import gymnasium as gym
from gymnasium.spaces import MultiDiscrete
from stable_baselines3 import PPO

class CustomEnv(gym.Env):
    def __init__(self):
        self.observation_space = MultiDiscrete(np.array([[1,2], [3,4]]), seed=42)  # Example multi-discrete observation space
        
        self.action_space = MultiDiscrete(np.array([3, 4, 3, 4]), seed=42)
        self.reset()

    def reset(self, seed=None, options=None):
        self.state = self.observation_space.sample()
        return self.state, {}

    def step(self, action):
        self.state = self.observation_space.sample()
        reward = 1    # Example reward function
        done = False  # Example termination condition
        info = {}     # Additional information (optional)
        return self.state, reward, done, False, info
    
env = CustomEnv()
model = PPO('MlpPolicy', env, verbose=1)

Relevant log output / Error message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 25
     22         return self.state, reward, done, False, info
     24 env = CustomEnv()
---> 25 model = PPO('MlpPolicy', env, verbose=1)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\ppo\ppo.py:171, in PPO.__init__(self, policy, env, learning_rate, n_steps, batch_size, n_epochs, gamma, gae_lambda, clip_range, clip_range_vf, normalize_advantage, ent_coef, vf_coef, max_grad_norm, use_sde, sde_sample_freq, rollout_buffer_class, rollout_buffer_kwargs, target_kl, stats_window_size, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
    168 self.target_kl = target_kl
    170 if _init_setup_model:
--> 171     self._setup_model()

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\ppo\ppo.py:174, in PPO._setup_model(self)
    173 def _setup_model(self) -> None:
--> 174     super()._setup_model()
    176     # Initialize schedules for policy/value clipping
    177     self.clip_range = get_schedule_fn(self.clip_range)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py:133, in OnPolicyAlgorithm._setup_model(self)
    121         self.rollout_buffer_class = RolloutBuffer
    123 self.rollout_buffer = self.rollout_buffer_class(
    124     self.n_steps,
    125     self.observation_space,  # type: ignore[arg-type]
   (...)
    131     **self.rollout_buffer_kwargs,
    132 )
--> 133 self.policy = self.policy_class(  # type: ignore[assignment]
    134     self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
    135 )
    136 self.policy = self.policy.to(self.device)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\policies.py:505, in ActorCriticPolicy.__init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, share_features_extractor, normalize_images, optimizer_class, optimizer_kwargs)
    502 self.ortho_init = ortho_init
    504 self.share_features_extractor = share_features_extractor
--> 505 self.features_extractor = self.make_features_extractor()
    506 self.features_dim = self.features_extractor.features_dim
    507 if self.share_features_extractor:

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\policies.py:120, in BaseModel.make_features_extractor(self)
    118 def make_features_extractor(self) -> BaseFeaturesExtractor:
    119     """Helper method to create a features extractor."""
--> 120     return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\torch_layers.py:41, in FlattenExtractor.__init__(self, observation_space)
     40 def __init__(self, observation_space: gym.Space) -> None:
---> 41     super().__init__(observation_space, get_flattened_obs_dim(observation_space))
     42     self.flatten = nn.Flatten()

File c:\Users\Elisa\anaconda3\envs\RL_env\lib\site-packages\stable_baselines3\common\torch_layers.py:23, in BaseFeaturesExtractor.__init__(self, observation_space, features_dim)
     21 def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
     22     super().__init__()
---> 23     assert features_dim > 0
     24     self._observation_space = observation_space
     25     self._features_dim = features_dim

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

System Info

  • OS: Windows-10-10.0.22631-SP0 10.0.22631
  • Python: 3.10.0
  • Stable-Baselines3: 2.2.1
  • PyTorch: 2.1.2+cpu
  • GPU Enabled: False
  • Numpy: 1.26.2
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1
  • OpenAI Gym: 0.26.2

Checklist

@elisavio elisavio added the custom gym env Issue related to Custom Gym Env label Feb 13, 2024
@qgallouedec
Copy link
Collaborator

The simplest way around is to flatten the observation space.

from gymnasium.wrappers import FlattenObservation

env = FlattenObservation(CustomEnv())

@elisavio
Copy link
Author

Thank you very much for your answer.
If I try your command with the above example and then sample a random observation I get something totally different from what I want

env1 = CustomEnv()
env1.observation_space.shape
env1.observation_space.sample()
env2 = FlattenObservation(CustomEnv())
env2.observation_space.shape
env2.observation_space.sample()

The two shapes and the results of the samples are different: in the case of env1 we have a shape of (2,2), in the case of env2 we have (10,).

A question naturally arises: are there differences in the performance of an algorithm depending on the way I represent the observation (in this case, a flattened or not flattened observation) ?

@qgallouedec
Copy link
Collaborator

Indeed, it's different from what I expected too. It seems that flatten in the multi-discrete case works in a very counter-intuitive way (at least for me).

As far as I can see, there's no wrapper that allows this, so you'll have to create your own wrapper:

from gymnasium import ObservationWrapper

class FlattenMultiDiscrete(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = MultiDiscrete(env.observation_space.nvec.flatten())

    def observation(self, observation):
        return observation.flatten()

env = FlattenMultiDiscrete(CustomEnv())

@araffin araffin added documentation Improvements or additions to documentation help wanted Help from contributors is welcomed labels Feb 14, 2024
@araffin
Copy link
Member

araffin commented Feb 14, 2024

Note: the env checker must be updated to warn users that we don't support multi-dim multi discrete and propose a fix (the one from @qgallouedec ).

@elisavio
Copy link
Author

Thank you very much for the answer.
Tell me If I should close the issue, or I can leave it open until the bug is fixed.

@qgallouedec
Copy link
Collaborator

Please let it open until the env checker is updated :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom gym env Issue related to Custom Gym Env documentation Improvements or additions to documentation help wanted Help from contributors is welcomed
Projects
None yet
Development

No branches or pull requests

3 participants