Skip to content

Commit 967fd17

Browse files
authored
Merge pull request #5 from YukhoY/main
Upload DiGA codes
2 parents 5a9d42b + dad934f commit 967fd17

23 files changed

+3988
-1
lines changed

DiGA/README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
## Codes for Diffusion Guided Meta Agent~(DiGA)
2+
3+
### Introduction
4+
In this repository, we provide the core codes of [DiGA](https://arxiv.org/pdf/2408.12991), including training and generating scripts.
5+
`train.py` is the Python script used for training meta controller. `generate.py` is the Python script used for generating meta agent that is guided by the meta controller for obtaining synthetic market trading records. The code for training rl agent in the simulated environment is provided under `rltask/train_test_rl.py` script.
6+
7+
### Prerequisites
8+
We recommend to use conda environment. The required packages can be install with:
9+
```python
10+
conda env create --file environment.yaml
11+
```
12+
After installation, use `conda activate diga` for activating the conda environment.
13+
14+
After that, install required packages following the instructions from [MarS](https://github.com/microsoft/MarS).
15+
16+
Data should be processed into the shape of $(N, C, T)$ where $N$ is the number of samples, $C$ is the number of market state parameters, $T$ is the effective number of minutes of each sample. In our case, $C=2$ along where the first dimension should be the mid-price log return rate and the second dimension should be the number of orders within each minute; $T=236$ which is the effective number of minutes in one trading day.
17+
18+
> Original data can be purchased from lisenced data vendors (e.g. Wind, Thomson Reuters).
19+
20+
21+
## Code Examples
22+
For training a meta controller:
23+
24+
```python
25+
python train.py --data_name "SZAMain" --ctrl_type "continuous" --ctrl_target "return" --n_bins 5 --diffsteps 200 --epochs 10 --checkpoints 3 --data_path {your_data_path} --output_path {your_output_path} --seed 0
26+
```
27+
28+
Code above will train meta controller with 'SZAMain' dataset, control on return, using continuous control encoder. By default the control target(return) is divided into 5 bins (indicating 5 control classes: lower, low, medium, high, higher). Moreover, the diffusion model in meta controller is configured to perform 200 diffusion steps, trained with a maximum number of epochs as 10. Make sure the dataset is stored in `{your_data_path}`, named after `{data_name}_train.npy` and `{data_name}_vali.npy`. The trained model will be saved inside `{your_output_path}/DiGA_{data_name}_{ctrl_type}_{ctrl_target}_{seed}/` by default.
29+
30+
For generating with meta agent guided by meta controller:
31+
32+
```python
33+
python generate.py --data_name "SZAMain" --ctrl_type "continuous" --ctrl_target "return" --ctrl_class 0 --cond_scale 1 --samsteps 20 --data_path {your_data_path} --output_path {your_output_path} --exp_name {your_exp_name} --checkpoint_path "last.ckpt" --save_name "DiGA_generation" --seed 0
34+
```
35+
36+
Code above will first sample from the trained meta controller, conditioned on `ctrl_class` of `ctrl_target`. If `ctrl_type` is "discrete", then the `ctrl_class` refers to the selected bin. If `ctrl_type` is "continuous", the `ctrl_class` refers to the relative strength of `ctrl_target`. `cond_scale` controls the strength of classifier-free guidance during diffusion model sampling. After sampling, the meta agent will generate the trade records within one trading day and the records are saved into `{output_path}/{exp_name}/{save_name}.pkl`, where the `exp_name` may be `"DiGA_{data_name}_{ctrl_type}_{ctrl_target}_{seed}"` in the above sample.
37+
38+
For running RL training in the generated market:
39+
```python
40+
python rltask/train_test_rl.py --market "DiGA" --data_path {your_data_path} --test_replay_path {your_test_replay_path} --output_path {your_output_path} --save_name {your_save_name}
41+
```
42+
Code above will train RL agent in a market environment generated by DiGA. It takes pre-computed meta controller samples from `data_path` for better efficiency. The trained agent and testing results are stored in `{output_path}/{save_name}`. For training with DiGA environment, the file in `data_path` should contain a dict with each item storing one sample generated by meta controller. For training (or testing) with Replay environment, the data in `data_path` (or `test_replay_path`) should contain paths to orders and transactions records (both in csv, preprocessed using market_simulation libary).
43+
44+
### Code argument details
45+
`train.py` accepts the following arguments:
46+
- `--data_name`: The name of the dataset to use for training.
47+
- `--ctrl_type`: The type of control to use (continuous or discrete).
48+
- `--ctrl_target`: The target of the control. (i.e. return, volatility)
49+
- `--n_bins`: The number of bins to use for the discretization of `ctrl_target`.
50+
- `--diffsteps`: The number of diffusion steps.
51+
- `--samsteps`: The number of sampling steps.
52+
- `--epochs`: The number of training epochs. Either `epochs` or `maxsteps` should be set and the training will stop when either one is reached.
53+
- `--maxsteps`: The maximum number of steps for the trainer. Either `epochs` or `maxsteps` should be set and the training will stop when either one is reached.
54+
- `--batch_size`: The batch size for training.
55+
- `--learning_rate`: The learning rate for the optimizer.
56+
- `--checkpoints`: The number of checkpoints to save.
57+
- `--data_path`: The path to your training data.
58+
- `--output_path`: The path where you want to save your trained model and other output files.
59+
- `--seed`: The seed for random number generation.
60+
- `--num_workers`: The number of workers to use for data loading.
61+
62+
`generate.py` accepts the following arguments:
63+
64+
- `--data_name`: The name of the dataset to use for training.
65+
- `--ctrl_type`: The type of control to use (continuous or discrete).
66+
- `--ctrl_target`: The target of the control. (i.e. return, volatility)
67+
- `--n_bins`: The number of bins to use for the discretization of `ctrl_target`.
68+
- `--diffsteps`: The number of diffusion steps.
69+
- `--samsteps`: The number of sampling steps.
70+
- `--seed`: The seed for random number generation.
71+
- `--data_path`: The path to your training data.
72+
- `--output_path`: The path where you want to save your trained model and other output files.
73+
- `--exp_name`: The name of the experiment.
74+
- `--checkpoint_path`: The path to the checkpoint file from your `output_path`.
75+
- `--save_name`: The name of the file to save the generated data.
76+
- `--random_price`: Whether to generate a random initial price.
77+
- `--pseudo_price`: The initial price to use if random_price is not set.
78+
79+
`rltask/train_test_rl.py` accepts the following arguments:
80+
81+
- `--market`: The type of market environment to use for training. Options are 'DiGA' and 'Replay'.
82+
- `--max_steps`: The maximum number of steps for the trainer.
83+
- `--save_name`: The folder name for saving rl run.
84+
- `--eval_eps`: The number of evaluation episodes.
85+
- `--data_path`: The path to data for generating training environment.
86+
- `--test_replay_path`: The path to data for generating testing environment.
87+
- `--output_path`: The path where you want to save your trained model and other output files.
88+
- `--seed`: The seed for random number generation.
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
2+
import logging
3+
from typing import Callable, Dict, List, Optional, cast
4+
5+
from pandas import Timestamp
6+
7+
from market_simulation.states.trans_state import TransState
8+
from market_simulation.wd.wd_order import WdOrder
9+
from mlib.core.action import Action
10+
from mlib.core.base_agent import BaseAgent
11+
from mlib.core.base_order import BaseOrder
12+
from mlib.core.observation import Observation
13+
from mlib.core.limit_order import LimitOrder
14+
from mlib.core.orderbook import Orderbook
15+
from mlib.core.state import State
16+
from mlib.core.transaction import Transaction
17+
18+
19+
class ReplayAgent(BaseAgent):
20+
"""A agent used to replay market with orders and verify with transactions."""
21+
22+
def __init__(
23+
self,
24+
symbol: str,
25+
orders: List[BaseOrder],
26+
transactions: List[Transaction],
27+
on_order_submit: Optional[Callable[["ReplayAgent", BaseOrder], None]] = None,
28+
) -> None:
29+
super().__init__(init_cash=0, communication_delay=0, computation_delay=0)
30+
self.symbol: str = symbol
31+
self.orders: List[BaseOrder] = orders
32+
self.transactions = transactions
33+
self._next_wakeup_order_index = 0
34+
self._num_check_transactions = 0
35+
self.on_order_submit = on_order_submit
36+
assert self.orders
37+
38+
def get_next_wakeup_time(self, time: Timestamp) -> Optional[Timestamp]:
39+
if self._next_wakeup_order_index >= len(self.orders):
40+
return None
41+
next_time = self.orders[self._next_wakeup_order_index].time
42+
self._next_wakeup_order_index += 1
43+
assert next_time >= time
44+
return next_time
45+
46+
def get_action(self, observation: Observation, orderbook: Orderbook) -> Action:
47+
"""Get action given observation.
48+
49+
It delegates its main functions to:
50+
- `get_next_wakeup_time` to get the next wakeup time, and
51+
- `get_orders` to get orders based on observation. `get_orders` will not be called for the first-time wakeup,
52+
when it's the market open wakeup.
53+
54+
"""
55+
assert self.agent_id == observation.agent.agent_id
56+
time = observation.time
57+
# return empty order for the market open wakeup
58+
orders: List[BaseOrder] = [] if observation.is_market_open_wakup else self.get_orders(time, orderbook)
59+
action = Action(
60+
agent_id=self.agent_id,
61+
time=time,
62+
orders=orders,
63+
next_wakeup_time=self.get_next_wakeup_time(time),
64+
)
65+
return action
66+
67+
def get_orders(self, time: Timestamp, orderbook: Orderbook):
68+
cur_order_index = self._next_wakeup_order_index - 1
69+
assert cur_order_index >= 0
70+
order = self.orders[cur_order_index]
71+
assert time == order.time
72+
if self.on_order_submit is not None:
73+
self.on_order_submit(self, order)
74+
validated = [self.validate_order(order, orderbook)]
75+
return [order for order in validated if order is not None]
76+
77+
def on_states_update(self, time: Timestamp, symbol_states: Dict[str, Dict[str, State]]):
78+
super().on_states_update(time, symbol_states)
79+
80+
def check_new_transactions_match(self):
81+
state_name = TransState.__name__
82+
assert state_name in self.symbol_states[self.symbol]
83+
state = cast(TransState, self.symbol_states[self.symbol][state_name])
84+
new_trans = state.transactons[self._num_check_transactions :]
85+
_check_transactions_match(self.transactions, new_trans, False, self._num_check_transactions)
86+
self._num_check_transactions = len(state.transactons)
87+
88+
def on_market_close(self, time: Timestamp):
89+
super().on_market_close(time)
90+
_check_same_symbol_orders(self.agent_id, self.lob_orders, self.lob_price_orders, self.symbol_states)
91+
92+
def validate_order(self, order: WdOrder, orderbook: Orderbook):
93+
if order.type != 'C':
94+
order = order.get_limit_orders(orderbook)[0]
95+
else:
96+
valid_cancel_vol = 0
97+
if order.cancel_id in self.lob_orders[self.symbol].keys():
98+
to_cancel = self.lob_orders[self.symbol][order.cancel_id]
99+
valid_cancel_vol = to_cancel.volume
100+
101+
if valid_cancel_vol != 0:
102+
order.volume = valid_cancel_vol
103+
else:
104+
logging.warning(f"Invalid order {order}.")
105+
order = None
106+
if order is not None and order.price <= 0 :
107+
logging.warning(f"Invalid order {order}.")
108+
order = None
109+
110+
return order
111+
112+
113+
114+
115+
def _check_same_symbol_orders(
116+
agent_id: int,
117+
lob_orders: Dict[str, Dict[int, LimitOrder]],
118+
lob_price_orders: Dict[str, Dict[int, Dict[int, LimitOrder]]],
119+
symbol_states: Dict[str, Dict[str, State]],
120+
):
121+
symbols = lob_orders.keys()
122+
state_name: str = State.__name__
123+
for symbol in symbols:
124+
close_orderbook = symbol_states[symbol][state_name].close_orderbook
125+
if close_orderbook is None:
126+
# skip checking as close orderbook is empty, this happens when no close-auction.
127+
continue
128+
129+
_check_same_orders_on_symbol(
130+
agent_id=agent_id,
131+
lob_orders=lob_orders[symbol],
132+
lob_price_orders=lob_price_orders[symbol],
133+
orderbook=close_orderbook,
134+
)
135+
136+
137+
def _check_same_orders_on_symbol(
138+
agent_id: int,
139+
lob_orders: Dict[int, LimitOrder],
140+
lob_price_orders: Dict[int, Dict[int, LimitOrder]],
141+
orderbook: Orderbook,
142+
):
143+
remaining_orders: List[LimitOrder] = []
144+
for level in orderbook.asks + orderbook.bids:
145+
remaining_orders.extend([x for x in level.orders if x.agent_id == agent_id])
146+
_check_same_orders(lob_orders, remaining_orders)
147+
148+
price_orders: List[LimitOrder] = []
149+
for value in lob_price_orders.values():
150+
price_orders.extend(value.values())
151+
_check_same_orders(lob_orders, price_orders)
152+
153+
154+
def _check_same_orders(lob_orders: Dict[int, LimitOrder], orders: List[LimitOrder]):
155+
assert len(orders) == len(lob_orders)
156+
for order in orders:
157+
assert order.order_id in lob_orders
158+
my_order = lob_orders[order.order_id]
159+
assert str(order) == str(my_order)
160+
161+
162+
def _check_transactions_match(trans_label: List[Transaction], trans_replay: List[Transaction], output_details: bool = True, label_start: int = 0):
163+
end = label_start + len(trans_replay)
164+
assert label_start >= 0
165+
if len(trans_label) < end:
166+
logging.error(f"not enough transactoin [{label_start}, {end}), only {len(trans_label)}.")
167+
return False
168+
169+
if len(trans_replay) == 0:
170+
return True
171+
172+
for index in range(len(trans_replay)):
173+
str_label = str(trans_label[label_start + index])
174+
str_replay = str(trans_replay[index])
175+
if str_label == str_replay:
176+
if output_details:
177+
logging.info(f"same for {label_start + index}|{index}th trans: {str_label}")
178+
continue
179+
logging.error(f"diff for {label_start + index}|{index}th trans")
180+
logging.info(f" label: {str_label}")
181+
logging.info(f" reply: {str_replay}")
182+
return False
183+
return True

0 commit comments

Comments
 (0)