1
+ from abc import (
2
+ ABC ,
3
+ abstractmethod ,
4
+ abstractproperty ,
5
+ )
6
+ from gym import Env
7
+ import numpy as np
8
+ import torch
9
+ from utils import to_np , to_torch
10
+ import csv
11
+
12
+ class BaseEnv (Env ):
13
+ def __init__ (self , args ):
14
+ super (BaseEnv , self ).__init__ ()
15
+ self .args = args
16
+ self .pid = None
17
+ self .sample_idx = 0
18
+ # TODO obs space and action space
19
+ self .reward_list = []
20
+ self .stl_reward_list = []
21
+ self .acc_reward_list = []
22
+ self .history = []
23
+ if hasattr (args , "write_csv" ) and args .write_csv :
24
+ self .epi = 0
25
+ self .csvfile = open ('%s/monitor_full.csv' % (args .exp_dir_full ), 'w' , newline = '' )
26
+ self .csvwriter = csv .writer (self .csvfile , delimiter = ',' , quotechar = '|' , quoting = csv .QUOTE_MINIMAL )
27
+ self .reward_fn = self .generate_reward_batch_fn ()
28
+ self .reward_fn_torch = self .wrap_reward_fn_torch (self .reward_fn )
29
+
30
+ @abstractmethod
31
+ def next_state (self , x , u ):
32
+ pass
33
+
34
+ # @abstractmethod
35
+ def dynamics (self , x0 , u , include_first = False ):
36
+ args = self .args
37
+ t = u .shape [1 ]
38
+ x = x0 .clone ()
39
+ segs = []
40
+ if include_first :
41
+ segs .append (x )
42
+ for ti in range (t ):
43
+ new_x = self .next_state (x , u [:, ti ])
44
+ segs .append (new_x )
45
+ x = new_x
46
+ return torch .stack (segs , dim = 1 )
47
+
48
+ @abstractmethod
49
+ def init_x_cycle (self ):
50
+ pass
51
+
52
+ @abstractmethod
53
+ def init_x (self ):
54
+ pass
55
+
56
+ @abstractmethod
57
+ def generate_stl (self ):
58
+ pass
59
+
60
+ @abstractmethod
61
+ def generate_heur_loss (self ):
62
+ pass
63
+
64
+ @abstractmethod
65
+ def visualize (self ):
66
+ pass
67
+
68
+ def transform (self , seg ):
69
+ # this is used for some case when there is a need to first augment the state trajectory
70
+ # for example, for the panda env environment
71
+ return seg
72
+
73
+ #@abstractmethod
74
+ def step (self ):
75
+ pass
76
+
77
+ def write_to_csv (self , env_steps ):
78
+ r_rs = self .get_rewards ()
79
+ r_rs = np .array (r_rs , dtype = np .float32 )
80
+ r_avg = np .mean (r_rs [0 ])
81
+ rs_avg = np .mean (r_rs [1 ])
82
+ racc_avg = np .mean (r_rs [2 ])
83
+ self .csvwriter .writerow ([self .epi , env_steps , r_avg , rs_avg , racc_avg ])
84
+ self .csvfile .flush ()
85
+ print ("epi:%06d step:%06d r:%.3f %.3f %.3f" % (self .epi , env_steps , r_avg , rs_avg , racc_avg ))
86
+ self .epi += 1
87
+
88
+ #@abstractmethod
89
+ # def reset(self):
90
+ # pass
91
+ def reset (self ):
92
+ N = self .args .num_samples
93
+ if self .sample_idx % N == 0 :
94
+ self .x0 = self .init_x (N )
95
+ self .indices = torch .randperm (N )
96
+ self .state = to_np (self .x0 [self .indices [self .sample_idx % N ]])
97
+ self .sample_idx += 1
98
+ self .t = 0
99
+ if len (self .history )> self .args .nt :
100
+ segs_np = np .stack (self .history , axis = 0 )
101
+ segs = to_torch (segs_np [None , :])
102
+ seg_aug = self .transform (segs )
103
+ seg_aug_np = to_np (seg_aug )
104
+ # print(seg_aug_np.shape)
105
+ # exit()
106
+ self .reward_list .append (np .sum (self .generate_reward_batch (seg_aug_np .squeeze ())))
107
+ self .stl_reward_list .append (self .stl_reward (seg_aug )[0 , 0 ])
108
+ self .acc_reward_list .append (self .acc_reward (seg_aug )[0 , 0 ])
109
+ self .history = [np .array (self .state )]
110
+ return self .state
111
+
112
+ def get_rewards (self ):
113
+ if len (self .reward_list )== 0 :
114
+ return 0 , 0 , 0
115
+ else :
116
+ return self .reward_list [- 1 ], self .stl_reward_list [- 1 ], self .acc_reward_list [- 1 ]
117
+
118
+ def generate_reward_batch (self , state ): # (n, 7)
119
+ return self .reward_fn (None , state )
120
+
121
+ def wrap_reward_fn_torch (self , reward_fn ):
122
+ def reward_fn_torch (act , state ):
123
+ act_np = act .detach ().cpu ().numpy ()
124
+ state_np = state .detach ().cpu ().numpy ()
125
+ reward_np = reward_fn (act_np , state_np )
126
+ return torch .from_numpy (reward_np ).float ()[:, None ].to (state .device )
127
+ return reward_fn_torch
128
+
129
+ @abstractmethod
130
+ def generate_reward_batch_fn (self ):
131
+ pass
132
+
133
+ #@abstractmethod
134
+ def generate_reward (self , state ):
135
+ if self .args .stl_reward or self .args .acc_reward :
136
+ last_one = (self .t + 1 ) >= self .args .nt
137
+ if last_one :
138
+ segs = to_torch (np .stack (self .history , axis = 0 )[None , :])
139
+ segs_aug = self .transform (segs )
140
+ if self .args .stl_reward :
141
+ return self .stl_reward (segs_aug )[0 , 0 ]
142
+ elif self .args .acc_reward :
143
+ return self .acc_reward (segs_aug )[0 , 0 ]
144
+ else :
145
+ raise NotImplementError
146
+ else :
147
+ return np .zeros_like (0 )
148
+ else :
149
+ return self .generate_reward_batch (state [None , :])[0 ]
150
+
151
+ def stl_reward (self , segs ):
152
+ score = self .stl (segs , self .args .smoothing_factor )[:, :1 ]
153
+ reward = to_np (score )
154
+ return reward
155
+
156
+ def acc_reward (self , segs ):
157
+ score = (self .stl (segs , self .args .smoothing_factor , d = {"hard" :True })[:, :1 ]>= 0 ).float ()
158
+ reward = 100 * to_np (score )
159
+ return reward
160
+
161
+ def print_stl (self ):
162
+ print (self .stl )
163
+ self .stl .update_format ("word" )
164
+ print (self .stl )
165
+
166
+ def my_render (self ):
167
+ if self .pid == 0 :
168
+ self .render (None )
169
+
170
+ def test (self ):
171
+ for trial_i in range (self .num_trials ):
172
+ obs = self .test_reset ()
173
+ trajs = [self .test_state ()]
174
+ for ti in range (self .nt ):
175
+ u = solve (obs )
176
+ obs , reward , done , di = self .test_step (u )
177
+ trajs .append (self .test_state ())
178
+
179
+ # save metrics result
0 commit comments