|
| 1 | + |
| 2 | +import logging |
| 3 | +from typing import Callable, Dict, List, Optional, cast |
| 4 | + |
| 5 | +from pandas import Timestamp |
| 6 | + |
| 7 | +from market_simulation.states.trans_state import TransState |
| 8 | +from market_simulation.wd.wd_order import WdOrder |
| 9 | +from mlib.core.action import Action |
| 10 | +from mlib.core.base_agent import BaseAgent |
| 11 | +from mlib.core.base_order import BaseOrder |
| 12 | +from mlib.core.observation import Observation |
| 13 | +from mlib.core.limit_order import LimitOrder |
| 14 | +from mlib.core.orderbook import Orderbook |
| 15 | +from mlib.core.state import State |
| 16 | +from mlib.core.transaction import Transaction |
| 17 | + |
| 18 | + |
| 19 | +class ReplayAgent(BaseAgent): |
| 20 | + """A agent used to replay market with orders and verify with transactions.""" |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + symbol: str, |
| 25 | + orders: List[BaseOrder], |
| 26 | + transactions: List[Transaction], |
| 27 | + on_order_submit: Optional[Callable[["ReplayAgent", BaseOrder], None]] = None, |
| 28 | + ) -> None: |
| 29 | + super().__init__(init_cash=0, communication_delay=0, computation_delay=0) |
| 30 | + self.symbol: str = symbol |
| 31 | + self.orders: List[BaseOrder] = orders |
| 32 | + self.transactions = transactions |
| 33 | + self._next_wakeup_order_index = 0 |
| 34 | + self._num_check_transactions = 0 |
| 35 | + self.on_order_submit = on_order_submit |
| 36 | + assert self.orders |
| 37 | + |
| 38 | + def get_next_wakeup_time(self, time: Timestamp) -> Optional[Timestamp]: |
| 39 | + if self._next_wakeup_order_index >= len(self.orders): |
| 40 | + return None |
| 41 | + next_time = self.orders[self._next_wakeup_order_index].time |
| 42 | + self._next_wakeup_order_index += 1 |
| 43 | + assert next_time >= time |
| 44 | + return next_time |
| 45 | + |
| 46 | + def get_action(self, observation: Observation, orderbook: Orderbook) -> Action: |
| 47 | + """Get action given observation. |
| 48 | +
|
| 49 | + It delegates its main functions to: |
| 50 | + - `get_next_wakeup_time` to get the next wakeup time, and |
| 51 | + - `get_orders` to get orders based on observation. `get_orders` will not be called for the first-time wakeup, |
| 52 | + when it's the market open wakeup. |
| 53 | +
|
| 54 | + """ |
| 55 | + assert self.agent_id == observation.agent.agent_id |
| 56 | + time = observation.time |
| 57 | + # return empty order for the market open wakeup |
| 58 | + orders: List[BaseOrder] = [] if observation.is_market_open_wakup else self.get_orders(time, orderbook) |
| 59 | + action = Action( |
| 60 | + agent_id=self.agent_id, |
| 61 | + time=time, |
| 62 | + orders=orders, |
| 63 | + next_wakeup_time=self.get_next_wakeup_time(time), |
| 64 | + ) |
| 65 | + return action |
| 66 | + |
| 67 | + def get_orders(self, time: Timestamp, orderbook: Orderbook): |
| 68 | + cur_order_index = self._next_wakeup_order_index - 1 |
| 69 | + assert cur_order_index >= 0 |
| 70 | + order = self.orders[cur_order_index] |
| 71 | + assert time == order.time |
| 72 | + if self.on_order_submit is not None: |
| 73 | + self.on_order_submit(self, order) |
| 74 | + validated = [self.validate_order(order, orderbook)] |
| 75 | + return [order for order in validated if order is not None] |
| 76 | + |
| 77 | + def on_states_update(self, time: Timestamp, symbol_states: Dict[str, Dict[str, State]]): |
| 78 | + super().on_states_update(time, symbol_states) |
| 79 | + |
| 80 | + def check_new_transactions_match(self): |
| 81 | + state_name = TransState.__name__ |
| 82 | + assert state_name in self.symbol_states[self.symbol] |
| 83 | + state = cast(TransState, self.symbol_states[self.symbol][state_name]) |
| 84 | + new_trans = state.transactons[self._num_check_transactions :] |
| 85 | + _check_transactions_match(self.transactions, new_trans, False, self._num_check_transactions) |
| 86 | + self._num_check_transactions = len(state.transactons) |
| 87 | + |
| 88 | + def on_market_close(self, time: Timestamp): |
| 89 | + super().on_market_close(time) |
| 90 | + _check_same_symbol_orders(self.agent_id, self.lob_orders, self.lob_price_orders, self.symbol_states) |
| 91 | + |
| 92 | + def validate_order(self, order: WdOrder, orderbook: Orderbook): |
| 93 | + if order.type != 'C': |
| 94 | + order = order.get_limit_orders(orderbook)[0] |
| 95 | + else: |
| 96 | + valid_cancel_vol = 0 |
| 97 | + if order.cancel_id in self.lob_orders[self.symbol].keys(): |
| 98 | + to_cancel = self.lob_orders[self.symbol][order.cancel_id] |
| 99 | + valid_cancel_vol = to_cancel.volume |
| 100 | + |
| 101 | + if valid_cancel_vol != 0: |
| 102 | + order.volume = valid_cancel_vol |
| 103 | + else: |
| 104 | + logging.warning(f"Invalid order {order}.") |
| 105 | + order = None |
| 106 | + if order is not None and order.price <= 0 : |
| 107 | + logging.warning(f"Invalid order {order}.") |
| 108 | + order = None |
| 109 | + |
| 110 | + return order |
| 111 | + |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | +def _check_same_symbol_orders( |
| 116 | + agent_id: int, |
| 117 | + lob_orders: Dict[str, Dict[int, LimitOrder]], |
| 118 | + lob_price_orders: Dict[str, Dict[int, Dict[int, LimitOrder]]], |
| 119 | + symbol_states: Dict[str, Dict[str, State]], |
| 120 | +): |
| 121 | + symbols = lob_orders.keys() |
| 122 | + state_name: str = State.__name__ |
| 123 | + for symbol in symbols: |
| 124 | + close_orderbook = symbol_states[symbol][state_name].close_orderbook |
| 125 | + if close_orderbook is None: |
| 126 | + # skip checking as close orderbook is empty, this happens when no close-auction. |
| 127 | + continue |
| 128 | + |
| 129 | + _check_same_orders_on_symbol( |
| 130 | + agent_id=agent_id, |
| 131 | + lob_orders=lob_orders[symbol], |
| 132 | + lob_price_orders=lob_price_orders[symbol], |
| 133 | + orderbook=close_orderbook, |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +def _check_same_orders_on_symbol( |
| 138 | + agent_id: int, |
| 139 | + lob_orders: Dict[int, LimitOrder], |
| 140 | + lob_price_orders: Dict[int, Dict[int, LimitOrder]], |
| 141 | + orderbook: Orderbook, |
| 142 | +): |
| 143 | + remaining_orders: List[LimitOrder] = [] |
| 144 | + for level in orderbook.asks + orderbook.bids: |
| 145 | + remaining_orders.extend([x for x in level.orders if x.agent_id == agent_id]) |
| 146 | + _check_same_orders(lob_orders, remaining_orders) |
| 147 | + |
| 148 | + price_orders: List[LimitOrder] = [] |
| 149 | + for value in lob_price_orders.values(): |
| 150 | + price_orders.extend(value.values()) |
| 151 | + _check_same_orders(lob_orders, price_orders) |
| 152 | + |
| 153 | + |
| 154 | +def _check_same_orders(lob_orders: Dict[int, LimitOrder], orders: List[LimitOrder]): |
| 155 | + assert len(orders) == len(lob_orders) |
| 156 | + for order in orders: |
| 157 | + assert order.order_id in lob_orders |
| 158 | + my_order = lob_orders[order.order_id] |
| 159 | + assert str(order) == str(my_order) |
| 160 | + |
| 161 | + |
| 162 | +def _check_transactions_match(trans_label: List[Transaction], trans_replay: List[Transaction], output_details: bool = True, label_start: int = 0): |
| 163 | + end = label_start + len(trans_replay) |
| 164 | + assert label_start >= 0 |
| 165 | + if len(trans_label) < end: |
| 166 | + logging.error(f"not enough transactoin [{label_start}, {end}), only {len(trans_label)}.") |
| 167 | + return False |
| 168 | + |
| 169 | + if len(trans_replay) == 0: |
| 170 | + return True |
| 171 | + |
| 172 | + for index in range(len(trans_replay)): |
| 173 | + str_label = str(trans_label[label_start + index]) |
| 174 | + str_replay = str(trans_replay[index]) |
| 175 | + if str_label == str_replay: |
| 176 | + if output_details: |
| 177 | + logging.info(f"same for {label_start + index}|{index}th trans: {str_label}") |
| 178 | + continue |
| 179 | + logging.error(f"diff for {label_start + index}|{index}th trans") |
| 180 | + logging.info(f" label: {str_label}") |
| 181 | + logging.info(f" reply: {str_replay}") |
| 182 | + return False |
| 183 | + return True |
0 commit comments