forked from ARISE-Initiative/robosuite
-
Notifications
You must be signed in to change notification settings - Fork 2
/
gym_wrapper.py
62 lines (51 loc) · 1.98 KB
/
gym_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
This file implements a wrapper for facilitating compatibility with OpenAI gym.
This is useful when using these environments with code that assumes a gym-like
interface.
"""
import numpy as np
from gym import spaces
from robosuite.wrappers import Wrapper
class GymWrapper(Wrapper):
env = None
def __init__(self, env, keys=None):
"""
Initializes the Gym wrapper.
Args:
env (MujocoEnv instance): The environment to wrap.
keys (list of strings): If provided, each observation will
consist of concatenated keys from the wrapped environment's
observation dictionary. Defaults to robot-state and object-state.
"""
self.env = env
if keys is None:
assert self.env.use_object_obs, "Object observations need to be enabled."
keys = ["robot-state", "object-state"]
self.keys = keys
# set up observation and action spaces
flat_ob = self._flatten_obs(self.env.reset(), verbose=True)
self.obs_dim = flat_ob.size
high = np.inf * np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low=low, high=high)
low, high = self.env.action_spec
self.action_space = spaces.Box(low=low, high=high)
def _flatten_obs(self, obs_dict, verbose=False):
"""
Filters keys of interest out and concatenate the information.
Args:
obs_dict: ordered dictionary of observations
"""
ob_lst = []
for key in obs_dict:
if key in self.keys:
if verbose:
print("adding key: {}".format(key))
ob_lst.append(obs_dict[key])
return np.concatenate(ob_lst)
def reset(self):
ob_dict = self.env.reset()
return self._flatten_obs(ob_dict)
def step(self, action):
ob_dict, reward, done, info = self.env.step(action)
return self._flatten_obs(ob_dict), reward, done, info