Skip to content

Commit 5a26a1e

Browse files
committed
init
0 parents  commit 5a26a1e

File tree

225 files changed

+30949
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

225 files changed

+30949
-0
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.vscode/*
2+
.idea/*
3+
.DS_Store
4+
__pycache__/*
5+
exps_stl
6+
scripts

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Signal Temporal Logic Neural Predictive Control
2+
3+
[![Journal](https://img.shields.io/badge/RA--L2023-Accepted-success)](https://ieeexplore.ieee.org/iel7/7083369/7339444/10251585.pdf)
4+
[![Conference](https://img.shields.io/badge/ICRA2024-Present-success)](https://2024.ieee-icra.org/index.html)
5+
6+
<!-- [![Arxiv](http://img.shields.io/badge/arxiv-cs:2309.05131-B31B1B.svg)](https://arxiv.org/abs/2309.05131.pdf) -->
7+
8+
[<ins>Reliable Autonomous Systems Lab @ MIT (REALM)</ins>](https://aeroastro.mit.edu/realm/)
9+
10+
[<ins>Yue Meng</ins>](https://mengyuest.github.io/), [<ins>Chuchu Fan</ins>](https://chuchu.mit.edu/)
11+
12+
![Alt Text](ral2023_teaser_v1.png)
13+
14+
> A differentiable learning framework to define task requirements and to learn control policies for robots.
15+
16+
17+
This repository contains the original code and tutorial for our ICRA2024 paper, "Signal Temporal Logic Neural Predictive Control." [[link]](https://arxiv.org/abs/2309.05131.pdf)
18+
19+
20+
```
21+
@article{meng2023signal,
22+
title={Signal Temporal Logic Neural Predictive Control},
23+
author={Meng, Yue and Fan, Chuchu},
24+
journal={IEEE Robotics and Automation Letters},
25+
year={2023},
26+
publisher={IEEE}
27+
}
28+
```
29+
30+
![Alt Text](ral2023.gif)
31+
32+
## Prerequisite
33+
Ubuntu 20.04 (better to have a GPU like NVidia RTX 2080Ti)
34+
35+
Packages (steps 1 and 2 suffice for just using our STL Library (see [tutorial](tutorial.ipynb))):
36+
1. Numpy and Matplotlib: `conda install numpy matplotlib`
37+
2. PyTorch v1.13.1 [[link]](https://pytorch.org/get-started/previous-versions/): `conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia` (other version might also work )
38+
3. Casadi, Gurobi and RL libraries: `pip install casadi gurobipy stable-baselines3 && cd mbrl_bsl && pip install -e . && cd -`
39+
4. (Just for the manipulation task) `pip install pytorch_kinematics mujoco forwardkinematics pybullet && sudo apt-get install libffi7`
40+
41+
## Tutorial
42+
You can find basic usage in our tutorial jupyter notebook [here](tutorial.ipynb).
43+
44+
## Experimental results
45+
Please look at [`exp_scripts`](exp_scripts.sh) to reproduce the full experiments.

envs/base_env.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from abc import (
2+
ABC,
3+
abstractmethod,
4+
abstractproperty,
5+
)
6+
from gym import Env
7+
import numpy as np
8+
import torch
9+
from utils import to_np, to_torch
10+
import csv
11+
12+
class BaseEnv(Env):
13+
def __init__(self, args):
14+
super(BaseEnv, self).__init__()
15+
self.args = args
16+
self.pid = None
17+
self.sample_idx = 0
18+
# TODO obs space and action space
19+
self.reward_list = []
20+
self.stl_reward_list = []
21+
self.acc_reward_list = []
22+
self.history = []
23+
if hasattr(args, "write_csv") and args.write_csv:
24+
self.epi = 0
25+
self.csvfile = open('%s/monitor_full.csv'%(args.exp_dir_full), 'w', newline='')
26+
self.csvwriter = csv.writer(self.csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
27+
self.reward_fn = self.generate_reward_batch_fn()
28+
self.reward_fn_torch = self.wrap_reward_fn_torch(self.reward_fn)
29+
30+
@abstractmethod
31+
def next_state(self, x, u):
32+
pass
33+
34+
# @abstractmethod
35+
def dynamics(self, x0, u, include_first=False):
36+
args = self.args
37+
t = u.shape[1]
38+
x = x0.clone()
39+
segs = []
40+
if include_first:
41+
segs.append(x)
42+
for ti in range(t):
43+
new_x = self.next_state(x, u[:, ti])
44+
segs.append(new_x)
45+
x = new_x
46+
return torch.stack(segs, dim=1)
47+
48+
@abstractmethod
49+
def init_x_cycle(self):
50+
pass
51+
52+
@abstractmethod
53+
def init_x(self):
54+
pass
55+
56+
@abstractmethod
57+
def generate_stl(self):
58+
pass
59+
60+
@abstractmethod
61+
def generate_heur_loss(self):
62+
pass
63+
64+
@abstractmethod
65+
def visualize(self):
66+
pass
67+
68+
def transform(self, seg):
69+
# this is used for some case when there is a need to first augment the state trajectory
70+
# for example, for the panda env environment
71+
return seg
72+
73+
#@abstractmethod
74+
def step(self):
75+
pass
76+
77+
def write_to_csv(self, env_steps):
78+
r_rs = self.get_rewards()
79+
r_rs = np.array(r_rs, dtype=np.float32)
80+
r_avg = np.mean(r_rs[0])
81+
rs_avg = np.mean(r_rs[1])
82+
racc_avg = np.mean(r_rs[2])
83+
self.csvwriter.writerow([self.epi, env_steps, r_avg, rs_avg, racc_avg])
84+
self.csvfile.flush()
85+
print("epi:%06d step:%06d r:%.3f %.3f %.3f"%(self.epi, env_steps, r_avg, rs_avg, racc_avg))
86+
self.epi += 1
87+
88+
#@abstractmethod
89+
# def reset(self):
90+
# pass
91+
def reset(self):
92+
N = self.args.num_samples
93+
if self.sample_idx % N == 0:
94+
self.x0 = self.init_x(N)
95+
self.indices = torch.randperm(N)
96+
self.state = to_np(self.x0[self.indices[self.sample_idx % N]])
97+
self.sample_idx += 1
98+
self.t = 0
99+
if len(self.history)>self.args.nt:
100+
segs_np = np.stack(self.history, axis=0)
101+
segs = to_torch(segs_np[None, :])
102+
seg_aug = self.transform(segs)
103+
seg_aug_np = to_np(seg_aug)
104+
# print(seg_aug_np.shape)
105+
# exit()
106+
self.reward_list.append(np.sum(self.generate_reward_batch(seg_aug_np.squeeze())))
107+
self.stl_reward_list.append(self.stl_reward(seg_aug)[0, 0])
108+
self.acc_reward_list.append(self.acc_reward(seg_aug)[0, 0])
109+
self.history = [np.array(self.state)]
110+
return self.state
111+
112+
def get_rewards(self):
113+
if len(self.reward_list)==0:
114+
return 0, 0, 0
115+
else:
116+
return self.reward_list[-1], self.stl_reward_list[-1], self.acc_reward_list[-1]
117+
118+
def generate_reward_batch(self, state): # (n, 7)
119+
return self.reward_fn(None, state)
120+
121+
def wrap_reward_fn_torch(self, reward_fn):
122+
def reward_fn_torch(act, state):
123+
act_np = act.detach().cpu().numpy()
124+
state_np = state.detach().cpu().numpy()
125+
reward_np = reward_fn(act_np, state_np)
126+
return torch.from_numpy(reward_np).float()[:, None].to(state.device)
127+
return reward_fn_torch
128+
129+
@abstractmethod
130+
def generate_reward_batch_fn(self):
131+
pass
132+
133+
#@abstractmethod
134+
def generate_reward(self, state):
135+
if self.args.stl_reward or self.args.acc_reward:
136+
last_one = (self.t+1) >= self.args.nt
137+
if last_one:
138+
segs = to_torch(np.stack(self.history, axis=0)[None, :])
139+
segs_aug = self.transform(segs)
140+
if self.args.stl_reward:
141+
return self.stl_reward(segs_aug)[0, 0]
142+
elif self.args.acc_reward:
143+
return self.acc_reward(segs_aug)[0, 0]
144+
else:
145+
raise NotImplementError
146+
else:
147+
return np.zeros_like(0)
148+
else:
149+
return self.generate_reward_batch(state[None, :])[0]
150+
151+
def stl_reward(self, segs):
152+
score = self.stl(segs, self.args.smoothing_factor)[:, :1]
153+
reward = to_np(score)
154+
return reward
155+
156+
def acc_reward(self, segs):
157+
score = (self.stl(segs, self.args.smoothing_factor, d={"hard":True})[:, :1]>=0).float()
158+
reward = 100 * to_np(score)
159+
return reward
160+
161+
def print_stl(self):
162+
print(self.stl)
163+
self.stl.update_format("word")
164+
print(self.stl)
165+
166+
def my_render(self):
167+
if self.pid==0:
168+
self.render(None)
169+
170+
def test(self):
171+
for trial_i in range(self.num_trials):
172+
obs = self.test_reset()
173+
trajs = [self.test_state()]
174+
for ti in range(self.nt):
175+
u = solve(obs)
176+
obs, reward, done, di = self.test_step(u)
177+
trajs.append(self.test_state())
178+
179+
# save metrics result

0 commit comments

Comments
 (0)