Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IsaacGym Preview4 with Stable Baselines3? #1768

Open
5 tasks done
LyuJZ opened this issue Nov 27, 2023 · 10 comments
Open
5 tasks done

IsaacGym Preview4 with Stable Baselines3? #1768

LyuJZ opened this issue Nov 27, 2023 · 10 comments
Labels
custom gym env Issue related to Custom Gym Env documentation Improvements or additions to documentation help wanted Help from contributors is welcomed

Comments

@LyuJZ
Copy link

LyuJZ commented Nov 27, 2023

Hi, thanks a lot for the well-documented stable baselines3. Now I am using Isaac Gym Preview4. May I ask if it is possible to give some examples to wrap IsaacGymEnvs into VecEnv? I noticed this issue was mentioned before. And some tips have been given in the issue #772. However, it seems it is for Isaac Gym Preview3. Could you give one example for Isaac Gym Preview 4?

Checklist

@LyuJZ LyuJZ added the custom gym env Issue related to Custom Gym Env label Nov 27, 2023
@araffin
Copy link
Member

araffin commented Nov 27, 2023

Hello,
what have you tried so far (see also links in our doc, please provide minimal and working code) and what is the current issue?

@araffin araffin added the check the checklist You have checked the required items in the checklist but you didn't do what is written... label Nov 27, 2023
@LyuJZ
Copy link
Author

LyuJZ commented Nov 27, 2023

I tried to create a wrapper based on this template. The following is my code

import isaacgym
from isaacgymenvs.tasks.base.vec_task import VecTask
from isaacgymenvs.tasks.kuka_reaching import KukaReaching

import hydra
from omegaconf import DictConfig, OmegaConf

import gymnasium as gym
import numpy as np
import torch
from typing import Any, Dict, List

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn


class Sb3VecEnvWrapper(gym.Wrapper, VecEnv):
    
    def __init__(self, env: VecTask):
        gym.Wrapper.__init__(self, env)
        VecEnv.__init__(self, self.env.num_envs, self.env.observation_space, self.env.action_space)
        self._ep_rew_buf = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.env.device)
        self._ep_len_buf = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.env.device)
        
    def get_episode_rewards(self) -> List[float]:
        """Returns the rewards of all the episodes."""
        return self._ep_rew_buf.cpu().tolist()

    def get_episode_lengths(self) -> List[int]:
        """Returns the number of time-steps of all the episodes."""
        return self._ep_len_buf.cpu().tolist()
    
    def reset(self) -> VecEnvObs:  # noqa: D102
        obs_dict = self.env.reset()
        # convert data types to numpy depending on backend
        return self._process_obs(obs_dict)
    
    def step(self, actions: np.ndarray) -> VecEnvStepReturn:  # noqa: D102
            # convert input to numpy array
            actions = np.asarray(actions)
            # convert to tensor
            actions = torch.from_numpy(actions).to(device=self.env.device)
            # record step information
            obs_dict, rew, dones, extras = self.env.step(actions)

            # update episode un-discounted return and length
            self._ep_rew_buf += rew
            self._ep_len_buf += 1
            reset_ids = (dones > 0).nonzero(as_tuple=False)

            # convert data types to numpy depending on backend
            # Note: IsaacEnv uses torch backend (by default).
            obs = self._process_obs(obs_dict)
            rew = rew.cpu().numpy()
            dones = dones.cpu().numpy()
            # convert extra information to list of dicts
            infos = self._process_extras(obs, dones, extras, reset_ids)

            # reset info for terminated environments
            self._ep_rew_buf[reset_ids] = 0
            self._ep_len_buf[reset_ids] = 0

            return obs, rew, dones, infos
        
    """
    Unused methods.
    """

    def step_async(self, actions):  # noqa: D102
        self._async_actions = actions

    def step_wait(self):  # noqa: D102
        return self.step(self._async_actions)

    def get_attr(self, attr_name, indices):  # noqa: D102
        raise NotImplementedError

    def set_attr(self, attr_name, value, indices=None):  # noqa: D102
        raise NotImplementedError

    def env_method(self, method_name: str, *method_args, indices=None, **method_kwargs):  # noqa: D102
        raise NotImplementedError

    def env_is_wrapped(self, wrapper_class, indices=None):  # noqa: D102
        raise NotImplementedError

    def get_images(self):  # noqa: D102
        raise NotImplementedError
    
    """
    Helper functions.
    """

    def _process_obs(self, obs_dict) -> np.ndarray:
        """Convert observations into NumPy data type."""
        # Sb3 doesn't support asymmetric observation spaces, so we only use "policy"
        obs = obs_dict["policy"]
        # Note: IsaacEnv uses torch backend (by default).
        if self.env.sim.backend == "torch":
            if isinstance(obs, dict):
                for key, value in obs.items():
                    obs[key] = value.detach().cpu().numpy()
            else:
                obs = obs.detach().cpu().numpy()
        elif self.env.sim.backend == "numpy":
            pass
        else:
            raise NotImplementedError(f"Unsupported backend for simulation: {self.env.sim.backend}")
        return obs
    
    def _process_extras(self, obs, dones, extras, reset_ids) -> List[Dict[str, Any]]:
            """Convert miscellaneous information into dictionary for each sub-environment."""
            # create empty list of dictionaries to fill
            infos: List[Dict[str, Any]] = [dict.fromkeys(extras.keys()) for _ in range(self.env.num_envs)]
            # fill-in information for each sub-environment
            # Note: This loop becomes slow when number of environments is large.
            for idx in range(self.env.num_envs):
                # fill-in episode monitoring info
                if idx in reset_ids:
                    infos[idx]["episode"] = dict()
                    infos[idx]["episode"]["r"] = float(self._ep_rew_buf[idx])
                    infos[idx]["episode"]["l"] = float(self._ep_len_buf[idx])
                else:
                    infos[idx]["episode"] = None
                # fill-in information from extras
                for key, value in extras.items():
                    # 1. remap the key for time-outs for what SB3 expects
                    # 2. remap extra episodes information safely
                    # 3. for others just store their values
                    if key == "time_outs":
                        infos[idx]["TimeLimit.truncated"] = bool(value[idx])
                    elif key == "episode":
                        # only log this data for episodes that are terminated
                        if infos[idx]["episode"] is not None:
                            for sub_key, sub_value in value.items():
                                infos[idx]["episode"][sub_key] = sub_value
                    else:
                        infos[idx][key] = value[idx]
                # add information about terminal observation separately
                if dones[idx] == 1:
                    # extract terminal observations
                    if isinstance(obs, dict):
                        terminal_obs = dict.fromkeys(obs.keys())
                        for key, value in obs.items():
                            terminal_obs[key] = value[idx]
                    else:
                        terminal_obs = obs[idx]
                    # add info to dict
                    infos[idx]["terminal_observation"] = terminal_obs
                else:
                    infos[idx]["terminal_observation"] = None
            # return list of dictionaries
            return infos
        
        
@hydra.main(version_base="1.1", config_name="config", config_path="./cfg")
def launch_rlg_hydra(cfg: DictConfig):

    import logging
    import os
    from datetime import datetime

    # noinspection PyUnresolvedReferences
    import isaacgym
    from isaacgymenvs.pbt.pbt import PbtAlgoObserver, initial_pbt_check
    from isaacgymenvs.utils.rlgames_utils import multi_gpu_get_rank
    from hydra.utils import to_absolute_path
    from isaacgymenvs.tasks import isaacgym_task_map
    import gym
    from isaacgymenvs.utils.reformat import omegaconf_to_dict, print_dict
    from isaacgymenvs.utils.utils import set_np_formatting, set_seed

    if cfg.pbt.enabled:
        initial_pbt_check(cfg)

    from isaacgymenvs.utils.rlgames_utils import RLGPUEnv, RLGPUAlgoObserver, MultiObserver, ComplexObsRLGPUEnv
    from isaacgymenvs.utils.wandb_utils import WandbAlgoObserver
    from rl_games.common import env_configurations, vecenv
    from rl_games.torch_runner import Runner
    from rl_games.algos_torch import model_builder
    from isaacgymenvs.learning import amp_continuous
    from isaacgymenvs.learning import amp_players
    from isaacgymenvs.learning import amp_models
    from isaacgymenvs.learning import amp_network_builder
    import isaacgymenvs


    time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_name = f"{cfg.wandb_name}_{time_str}"

    # ensure checkpoints can be specified as relative paths
    if cfg.checkpoint:
        cfg.checkpoint = to_absolute_path(cfg.checkpoint)

    cfg_dict = omegaconf_to_dict(cfg)
    task_config = omegaconf_to_dict(cfg.task)
    task_name = "KukaReaching"
    
    cuda = True
    
    # create native task and pass custom config
    envs = isaacgymenvs.make(
        seed=1,
        task=task_name,
        num_envs=8192,
        sim_device="cuda:0",
        rl_device="cuda:0",
        graphics_device_id=0 if torch.cuda.is_available() and cuda else -1,
        headless=False if torch.cuda.is_available() and cuda else True,
        multi_gpu=False,
        virtual_screen_capture=False,
        force_render=False,  # if False, no viewer rendering will happen
    )
    
    sb3Env = Sb3VecEnvWrapper(envs)
        
if __name__ == "__main__":
    launch_rlg_hydra()

KukaReaching is the env to be wrapped from isaacgymenvs. Is it the correct way to create the wrapper like this? VecEnv.__init__(self, self.env.num_envs, self.env.observation_space, self.env.action_space) always gives me an error get_attr() missing 1 required positional argument: 'indices'

@araffin
Copy link
Member

araffin commented Nov 27, 2023

could you please give the full traceback ?

@LyuJZ
Copy link
Author

LyuJZ commented Nov 27, 2023

The error seems from self.get_attr("render_mode"). The full traceback is

Traceback (most recent call last):
  File "/src/handover-kuka-reaching/wrappers.py", line 519, in launch_rlg_hydra
    sb3Env = Sb3VecEnvWrapper(envs)
  File "/src/handover-rl/kuka-reaching/wrappers.py", line 325, in __init__
    VecEnv.__init__(self, self.env.num_envs, self.env.observation_space, self.env.action_space)
  File "/anaconda3/envs/acmpc/lib/python3.8/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 75, in __init__
    render_modes = self.get_attr("render_mode")
TypeError: get_attr() missing 1 required positional argument: 'indices'

@LyuJZ
Copy link
Author

LyuJZ commented Nov 27, 2023

I found the problems. It is because I redefine abstract methods. Now I changed these abstract methods as follow:

    def step_async(self, actions):  # noqa: D102
        self._async_actions = actions

    def step_wait(self):  # noqa: D102
        return self.step(self._async_actions)

    def get_attr(self, attr_name, indices=None):  # noqa: D102
        return self.env.get_attr(attr_name, indices)
        # raise NotImplementedError

    def set_attr(self, attr_name, value, indices=None):  # noqa: D102
        return self.env.set_attr(attr_name, value, indices)

    def env_method(self, method_name: str, *method_args, indices=None, **method_kwargs):  # noqa: D102
        return self.env.env_method(method_name, *method_args, indices=indices, **method_kwargs)

    def env_is_wrapped(self, wrapper_class, indices=None):  # noqa: D102
        return self.env.env_is_wrapped(wrapper_class, indices=indices)

    def get_images(self):  # noqa: D102
        return self.env.get_images()

Now, I got new errors like this:

Traceback (most recent call last):
  File "/src/handover-rl/kuka-reaching/wrappers.py", line 215, in launch_rlg_hydra
    sb3Env = Sb3VecEnvWrapper(envs)
  File "/src/handover-rl/kuka-reaching/wrappers.py", line 20, in __init__
    VecEnv.__init__(self, self.env.num_envs, self.env.observation_space, self.env.action_space)
  File "/anaconda3/envs/acmpc/lib/python3.8/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 83, in __init__
    self.render_mode = render_modes[0]
AttributeError: can't set attribute

@araffin
Copy link
Member

araffin commented Nov 28, 2023

Hmm, it's weird that the wrapper is deriving from both gym Wrapper and SB3 VecEnv and I guess that's the issue.

See https://github.com/Farama-Foundation/Gymnasium/blob/2f19f3ed69b9423444ec6f7cfa46c5fb079168c5/gymnasium/core.py#L440-L443

So if you only derive from VecEnv it should make things simpler.

@LyuJZ
Copy link
Author

LyuJZ commented Nov 28, 2023

Is it possible to give one example to wrap IsaacGym VecEnv into sb3 VecEnv?

@araffin
Copy link
Member

araffin commented Nov 28, 2023

Remove gym.Wrapper and gym.Wrapper.__init__(self, env)

@LyuJZ
Copy link
Author

LyuJZ commented Nov 29, 2023

I revised the code a little bit and now it works! Thanks a lot

@araffin araffin added documentation Improvements or additions to documentation help wanted Help from contributors is welcomed and removed check the checklist You have checked the required items in the checklist but you didn't do what is written... labels Nov 29, 2023
@yinzikang
Copy link

yinzikang commented Mar 3, 2024

hello, I want to combine the IsaacGym Preview4 with sb3 and get rid of the rlgames, just like you do I guess.

So I am trying to reproduce what you have done. Precisely, I

  1. created a python file next to the train.py which can be used for rlgames-based training,
  2. copied and edited the code you provided in this issue according to the advice from araffin
  3. choose the Cartpole task and run the code

However, there seems to be some other modification to make it work. Could you please share the whole code?

In addition, while I was able to get the code to run, it wasn't as fast as rlgames. I guess it is caused by the conversion between numpy and torch involved in sb3. Is there a way to accelerate it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom gym env Issue related to Custom Gym Env documentation Improvements or additions to documentation help wanted Help from contributors is welcomed
Projects
None yet
Development

No branches or pull requests

3 participants