-
Notifications
You must be signed in to change notification settings - Fork 1
/
replay.py
110 lines (78 loc) · 3.31 KB
/
replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import argparse
import time
import numpy as np
from pathlib import Path
from mushroom_rl.core import Core, Agent, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.utils.dataset import compute_J, parse_dataset
class PauseCallback:
def __init__(self, dt):
self._dt = dt
def __call__(self, *args, **kwargs):
time.sleep(self._dt)
def replay(path, env_id, n_episodes, seed, save, dt):
logger = Logger(log_name='MetricRL', results_dir='logs' if save else None)
logger.info(f'Replaying MetricRL agent in {path}')
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads(1)
mdp = Gym(env_id)
agent = Agent.load(path)
if 'BulletEnv-v0' in env_id:
mdp.render()
render = False
else:
render = True
# Set environment seed
mdp.env.seed(seed)
mdp.env.reset()
distance = 4
pitch = -5
if 'BulletEnv-v0' in env_id:
mdp.env.env._p.resetDebugVisualizerCamera(cameraTargetPosition=[4.5, 0, 1.],
cameraDistance=distance,
cameraYaw=0.,
cameraPitch=pitch)
# Set experiment
core = Core(agent, mdp, callback_step=PauseCallback(dt))
dataset = core.evaluate(n_episodes=n_episodes, render=render, quiet=False)
J = np.mean(compute_J(dataset, mdp.info.gamma))
R = np.mean(compute_J(dataset))
logger.epoch_info(0, J=J, R=R)
s, *_ = parse_dataset(dataset)
w = torch.mean(agent.policy._regressor.get_membership(torch.tensor(s)), axis=0)
_, top_w = torch.topk(w, 5)
c = agent.policy._regressor.get_c_weights()
_, top_c = torch.topk(c, 5)
logger.info(f'w: {w.detach().numpy()})')
logger.info(f'top w: {top_w.detach().numpy()}')
logger.info(f'c: {w.detach().numpy()}')
logger.info(f'top c: {top_c.detach().numpy()}')
if env_id == 'Pendulum-v0':
w = agent._regressor.get_membership(torch.tensor(s)).detach().numpy()
w_default = np.expand_dims(1 - np.sum(w, axis=1), -1)
w_tot = np.concatenate([w, w_default], axis=1)
for run in range(n_episodes):
logger.info(f'w_tot: {np.argmax(w_tot[100*run:100*(run+1), :], axis=1)}')
logger.strong_line()
if save:
logger.log_dataset(dataset)
def load_policy(log_name, iteration, seed):
policy_path = os.path.join(log_name, 'net/network-' + str(seed) + '-' + str(iteration) + '.pth')
policy_torch = torch.load(policy_path)
return policy_torch
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', '-p', type=str,
default='Results/final_medium/HopperBulletEnv-v0/metricrl_c10hcovr_expdTruet1.0snone')
parser.add_argument("--env-id", '-e', type=str,
default='HopperBulletEnv-v0')
parser.add_argument("--seed", '-s', type=int, default=0)
parser.add_argument("--n-episodes", '-n', type=int, default=1)
parser.add_argument("--save", action='store_true')
parser.add_argument("--dt", type=float, default=4./240.)
args = parser.parse_args()
path = Path(args.path) / f'agent-{args.seed}.msh'
replay(path, args.env_id, n_episodes=args.n_episodes, seed=args.seed, save=args.save, dt=args.dt)