diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ff23c2..750b994 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.0.2版本 + +1. 使用zoneinfo替换pytz库 +2. 调整安装脚本setup.cfg,添加Python版本限制 + # 1.0.1版本 1. 将模块的图标文件信息,改为完整路径字符串 diff --git a/README.md b/README.md index b8daa6f..d521886 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

- + @@ -17,7 +17,7 @@ PaperAccount是用于本地仿真交易的功能模块,用户可以通过其UI ## 安装 -安装环境推荐基于3.0.0版本以上的【[**VeighNa Studio**](https://www.vnpy.com)】。 +安装环境推荐基于3.3.0版本以上的【[**VeighNa Studio**](https://www.vnpy.com)】。 直接使用pip命令: diff --git a/setup.cfg b/setup.cfg index c79ebed..db6388d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = vnpy_paperaccount -version = 1.0.1 +version = 1.0.2 url = https://www.vnpy.com license = MIT author = Xiaoyou Chen @@ -30,6 +30,7 @@ classifiers = [options] packages = find: zip_safe = False +python_requires = >=3.7 install_requires = importlib_metadata diff --git a/vnpy_paperaccount/__init__.py b/vnpy_paperaccount/__init__.py index f8827ed..3eb7354 100644 --- a/vnpy_paperaccount/__init__.py +++ b/vnpy_paperaccount/__init__.py @@ -22,8 +22,8 @@ from pathlib import Path - import importlib_metadata + from vnpy.trader.app import BaseApp from .engine import PaperEngine, APP_NAME @@ -37,10 +37,10 @@ class PaperAccountApp(BaseApp): """""" - app_name = APP_NAME - app_module = __module__ - app_path = Path(__file__).parent - display_name = "模拟交易" - engine_class = PaperEngine - widget_name = "PaperManager" - icon_name = str(app_path.joinpath("ui", "paper.ico")) + app_name: str = APP_NAME + app_module: str = __module__ + app_path: Path = Path(__file__).parent + display_name: str = "模拟交易" + engine_class: PaperEngine = PaperEngine + widget_name: str = "PaperManager" + icon_name: str = str(app_path.joinpath("ui", "paper.ico")) diff --git a/vnpy_paperaccount/engine.py b/vnpy_paperaccount/engine.py index caddc08..0617abb 100644 --- a/vnpy_paperaccount/engine.py +++ b/vnpy_paperaccount/engine.py @@ -1,10 +1,10 @@ from copy import copy from typing import Any, Dict, Tuple, Optional, List from datetime import datetime -from tzlocal import get_localzone +from tzlocal import get_localzone_name from vnpy.event import Event, EventEngine -from vnpy.trader.utility import extract_vt_symbol, save_json, load_json +from vnpy.trader.utility import extract_vt_symbol, save_json, load_json, ZoneInfo from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.object import ( OrderRequest, CancelRequest, QuoteData, QuoteRequest, SubscribeRequest, @@ -29,17 +29,17 @@ ) -LOCAL_TZ = get_localzone() +LOCAL_TZ = ZoneInfo(get_localzone_name()) APP_NAME = "PaperAccount" GATEWAY_NAME = "PAPER" class PaperEngine(BaseEngine): """""" - setting_filename = "paper_account_setting.json" - data_filename = "paper_account_data.json" + setting_filename: str = "paper_account_setting.json" + data_filename: str = "paper_account_data.json" - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None: """""" super().__init__(main_engine, event_engine, APP_NAME) @@ -73,7 +73,7 @@ def __init__(self, main_engine: MainEngine, event_engine: EventEngine): self.load_data() self.register_event() - def register_event(self): + def register_event(self) -> None: """""" self.event_engine.register(EVENT_CONTRACT, self.process_contract_event) self.event_engine.register(EVENT_TICK, self.process_tick_event) @@ -86,9 +86,9 @@ def process_contract_event(self, event: Event) -> None: contract.gateway_name = GATEWAY_NAME for direciton in Direction: - key = (contract.vt_symbol, direciton) + key: tuple = (contract.vt_symbol, direciton) if key in self.positions: - position = self.positions[key] + position: PositionData = self.positions[key] self.put_event(EVENT_POSITION, position) def process_tick_event(self, event: Event) -> None: @@ -98,7 +98,7 @@ def process_tick_event(self, event: Event) -> None: self.ticks[tick.vt_symbol] = tick - active_orders = self.active_orders.get(tick.vt_symbol, None) + active_orders: Optional[dict] = self.active_orders.get(tick.vt_symbol, None) if active_orders: for orderid, order in list(active_orders.items()): self.cross_order(order, tick) @@ -106,7 +106,7 @@ def process_tick_event(self, event: Event) -> None: if not order.is_active(): active_orders.pop(orderid) - quote = self.active_quotes.get(tick.vt_symbol, None) + quote: Optional[QuoteData] = self.active_quotes.get(tick.vt_symbol, None) if quote: self.cross_quote(quote, tick) @@ -121,29 +121,29 @@ def process_timer_event(self, event: Event) -> None: self.timer_count = 0 for position in self.positions.values(): - contract = self.main_engine.get_contract(position.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(position.vt_symbol) if contract: self.calculate_pnl(position) self.put_event(EVENT_POSITION, copy(position)) def calculate_pnl(self, position: PositionData) -> None: """""" - tick = self.ticks.get(position.vt_symbol, None) + tick: Optional[TickData] = self.ticks.get(position.vt_symbol, None) if tick: - contract = self.main_engine.get_contract(position.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(position.vt_symbol) if position.direction == Direction.SHORT: - multiplier = -position.volume * contract.size + multiplier: float = -position.volume * contract.size else: - multiplier = position.volume * contract.size + multiplier: float = position.volume * contract.size position.pnl = (tick.last_price - position.price) * multiplier position.pnl = round(position.pnl, 2) def subscribe(self, req: SubscribeRequest, gateway_name: str) -> None: """""" - original_gateway_name = self.gateway_map.get(req.vt_symbol, "") + original_gateway_name: str = self.gateway_map.get(req.vt_symbol, "") if original_gateway_name: self._subscribe(req, original_gateway_name) else: @@ -151,7 +151,7 @@ def subscribe(self, req: SubscribeRequest, gateway_name: str) -> None: def query_history(self, req: HistoryRequest, gateway_name: str) -> List[BarData]: """""" - original_gateway_name = self.gateway_map.get(req.vt_symbol, "") + original_gateway_name: str = self.gateway_map.get(req.vt_symbol, "") if original_gateway_name: return self._query_history(req, original_gateway_name) else: @@ -159,28 +159,28 @@ def query_history(self, req: HistoryRequest, gateway_name: str) -> List[BarData] def send_order(self, req: OrderRequest, gateway_name: str) -> str: """""" - contract: ContractData = self.main_engine.get_contract(req.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(req.vt_symbol) if not contract: self.write_log(f"委托失败,找不到该合约{req.vt_symbol}") return "" self.order_count += 1 - now = datetime.now().strftime("%y%m%d%H%M%S") - orderid = now + str(self.order_count) - vt_orderid = f"{GATEWAY_NAME}.{orderid}" + now: str = datetime.now().strftime("%y%m%d%H%M%S") + orderid: str = now + str(self.order_count) + vt_orderid: str = f"{GATEWAY_NAME}.{orderid}" # Put simulated order update event from gateway - order = req.create_order_data(orderid, GATEWAY_NAME) + order: OrderData = req.create_order_data(orderid, GATEWAY_NAME) self.put_event(EVENT_ORDER, copy(order)) # Check if order is valid - updated_position = self.check_order_valid(order, contract) + updated_position: PositionData = self.check_order_valid(order, contract) # Put simulated order update event from exchange if order.status != Status.REJECTED: order.datetime = datetime.now(LOCAL_TZ) order.status = Status.NOTTRADED - active_orders = self.active_orders.setdefault(order.vt_symbol, {}) + active_orders: dict = self.active_orders.setdefault(order.vt_symbol, {}) active_orders[orderid] = order self.put_event(EVENT_ORDER, copy(order)) @@ -191,12 +191,12 @@ def send_order(self, req: OrderRequest, gateway_name: str) -> str: # Cross order immediately with last tick data if self.instant_trade and order.status != Status.REJECTED: - tick = self.ticks.get(order.vt_symbol, None) + tick: Optional[TickData] = self.ticks.get(order.vt_symbol, None) if tick: self.cross_order(order, tick) if not order.is_active(): - active_orders = self.active_orders[order.vt_symbol] + active_orders: dict = self.active_orders[order.vt_symbol] active_orders.pop(orderid) return vt_orderid @@ -211,7 +211,7 @@ def cancel_order(self, req: CancelRequest, gateway_name: str) -> None: self.put_event(EVENT_ORDER, copy(order)) # Free frozen position volume - contract = self.main_engine.get_contract(order.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(order.vt_symbol) if contract.net_position: return @@ -219,32 +219,32 @@ def cancel_order(self, req: CancelRequest, gateway_name: str) -> None: return if order.direction == Direction.LONG: - position = self.get_position(order.vt_symbol, Direction.SHORT) + position: PositionData = self.get_position(order.vt_symbol, Direction.SHORT) else: - position = self.get_position(order.vt_symbol, Direction.LONG) + position: PositionData = self.get_position(order.vt_symbol, Direction.LONG) position.frozen -= order.volume self.put_event(EVENT_POSITION, copy(position)) def send_quote(self, req: QuoteRequest, gateway_name: str) -> str: """""" - contract: ContractData = self.main_engine.get_contract(req.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(req.vt_symbol) if not contract: self.write_log(f"报价失败,找不到该合约{req.vt_symbol}") return "" self.quote_count += 1 - now = datetime.now().strftime("%y%m%d%H%M%S") - quoteid = now + str(self.quote_count) - vt_quoteid = f"{GATEWAY_NAME}.{quoteid}" + now: str = datetime.now().strftime("%y%m%d%H%M%S") + quoteid: str = now + str(self.quote_count) + vt_quoteid: str = f"{GATEWAY_NAME}.{quoteid}" # Put simulated quote update event from gateway - quote = req.create_quote_data(quoteid, GATEWAY_NAME) + quote: QuoteData = req.create_quote_data(quoteid, GATEWAY_NAME) self.put_event(EVENT_QUOTE, copy(quote)) # Put old quote cancel event if quote.vt_symbol in self.active_quotes: - old_quote = self.active_quotes.pop(quote.vt_symbol) + old_quote: QuoteData = self.active_quotes.pop(quote.vt_symbol) old_quote.status = Status.CANCELLED self.put_event(EVENT_QUOTE, old_quote) @@ -259,7 +259,7 @@ def send_quote(self, req: QuoteRequest, gateway_name: str) -> str: def cancel_quote(self, req: CancelRequest, gateway_name: str) -> None: """""" - quote: QuoteData = self.active_quotes.get(req.vt_symbol, None) + quote: Optional[QuoteData] = self.active_quotes.get(req.vt_symbol, None) if not quote: return @@ -272,7 +272,7 @@ def cancel_quote(self, req: CancelRequest, gateway_name: str) -> None: def put_event(self, event_type: str, data: Any) -> None: """""" - event = Event(event_type, data) + event: Event = Event(event_type, data) self.event_engine.put(event) def check_order_valid(self, order: OrderData, contract: ContractData) -> Optional[PositionData]: @@ -291,8 +291,8 @@ def check_order_valid(self, order: OrderData, contract: ContractData) -> Optiona return if order.direction == Direction.LONG: - short_position = self.get_position(order.vt_symbol, Direction.SHORT) - available = short_position.volume - short_position.frozen + short_position: PositionData = self.get_position(order.vt_symbol, Direction.SHORT) + available: float = short_position.volume - short_position.frozen if order.volume > available: order.status = Status.REJECTED @@ -301,8 +301,8 @@ def check_order_valid(self, order: OrderData, contract: ContractData) -> Optiona short_position.frozen += order.volume return short_position else: - long_position = self.get_position(order.vt_symbol, Direction.LONG) - available = long_position.volume - long_position.frozen + long_position: PositionData = self.get_position(order.vt_symbol, Direction.LONG) + available: float = long_position.volume - long_position.frozen if order.volume > available: order.status = Status.REJECTED @@ -311,9 +311,9 @@ def check_order_valid(self, order: OrderData, contract: ContractData) -> Optiona long_position.frozen += order.volume return long_position - def cross_order(self, order: OrderData, tick: TickData): + def cross_order(self, order: OrderData, tick: TickData) -> None: """""" - contract = self.main_engine.get_contract(order.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(order.vt_symbol) trade_price = 0 @@ -345,7 +345,7 @@ def cross_order(self, order: OrderData, tick: TickData): order.traded = order.volume self.put_event(EVENT_ORDER, order) - trade = TradeData( + trade: TradeData = TradeData( symbol=order.symbol, exchange=order.exchange, orderid=order.orderid, @@ -361,25 +361,25 @@ def cross_order(self, order: OrderData, tick: TickData): self.update_position(trade, contract) - def cross_quote(self, quote: QuoteData, tick: TickData): + def cross_quote(self, quote: QuoteData, tick: TickData) -> None: """""" - contract = self.main_engine.get_contract(quote.vt_symbol) + contract: Optional[ContractData] = self.main_engine.get_contract(quote.vt_symbol) trade_price = 0 if tick.last_price >= quote.ask_price and quote.ask_volume: trade_price = quote.ask_price - direction = Direction.SHORT - offset = Offset.CLOSE + direction: Direction = Direction.SHORT + offset: Offset = Offset.CLOSE volume = quote.ask_volume quote.ask_volume = 0 elif tick.last_price <= quote.bid_price and quote.bid_volume: trade_price = quote.bid_price - direction = Direction.LONG - offset = Offset.OPEN + direction: Direction = Direction.LONG + offset: Offset = Offset.OPEN volume = quote.bid_volume quote.bid_volume = 0 @@ -392,7 +392,7 @@ def cross_quote(self, quote: QuoteData, tick: TickData): self.put_event(EVENT_QUOTE, quote) self.trade_count += 1 - trade = TradeData( + trade: TradeData = TradeData( symbol=quote.symbol, exchange=quote.exchange, orderid=str(self.trade_count), @@ -408,23 +408,23 @@ def cross_quote(self, quote: QuoteData, tick: TickData): self.update_position(trade, contract) - def update_position(self, trade: TradeData, contract: ContractData): + def update_position(self, trade: TradeData, contract: ContractData) -> None: """""" - vt_symbol = trade.vt_symbol + vt_symbol: str = trade.vt_symbol # Net position mode if contract.net_position: - position = self.get_position(vt_symbol, Direction.NET) + position: PositionData = self.get_position(vt_symbol, Direction.NET) - old_volume = position.volume - old_cost = position.volume * position.price + old_volume: float = position.volume + old_cost: float = position.volume * position.price if trade.direction == Direction.LONG: pos_change = trade.volume else: pos_change = -trade.volume - new_volume = position.volume + pos_change + new_volume: float = position.volume + pos_change # No position holding, clear price if not new_volume: @@ -448,8 +448,8 @@ def update_position(self, trade: TradeData, contract: ContractData): self.put_event(EVENT_POSITION, copy(position)) # Long/Short position mode else: - long_position = self.get_position(vt_symbol, Direction.LONG) - short_position = self.get_position(vt_symbol, Direction.SHORT) + long_position: PositionData = self.get_position(vt_symbol, Direction.LONG) + short_position: PositionData = self.get_position(vt_symbol, Direction.SHORT) if trade.direction == Direction.LONG: if trade.offset == Offset.OPEN: @@ -486,15 +486,15 @@ def update_position(self, trade: TradeData, contract: ContractData): self.save_data() - def get_position(self, vt_symbol: str, direction: Direction): + def get_position(self, vt_symbol: str, direction: Direction) -> PositionData: """""" - key = (vt_symbol, direction) + key: tuple = (vt_symbol, direction) if key in self.positions: return self.positions[key] else: symbol, exchange = extract_vt_symbol(vt_symbol) - position = PositionData( + position: PositionData = PositionData( symbol=symbol, exchange=exchange, direction=direction, @@ -506,18 +506,18 @@ def get_position(self, vt_symbol: str, direction: Direction): def write_log(self, msg: str) -> None: """""" - log = LogData(msg=msg, gateway_name=GATEWAY_NAME) + log: LogData = LogData(msg=msg, gateway_name=GATEWAY_NAME) self.put_event(EVENT_LOG, log) def save_data(self) -> None: """""" - position_data = [] + position_data: list = [] for position in self.positions.values(): if not position.volume: continue - d = { + d: dict = { "vt_symbol": position.vt_symbol, "volume": position.volume, "price": position.price, @@ -529,19 +529,19 @@ def save_data(self) -> None: def load_data(self) -> None: """""" - position_data = load_json(self.data_filename) + position_data: dict = load_json(self.data_filename) for d in position_data: - vt_symbol = d["vt_symbol"] - direction = Direction(d["direction"]) + vt_symbol: str = d["vt_symbol"] + direction: Direction = Direction(d["direction"]) - position = self.get_position(vt_symbol, direction) + position: PositionData = self.get_position(vt_symbol, direction) position.volume = d["volume"] position.price = d["price"] def load_setting(self) -> None: """""" - setting = load_json(self.setting_filename) + setting: dict = load_json(self.setting_filename) if setting: self.trade_slippage = setting["trade_slippage"] @@ -550,7 +550,7 @@ def load_setting(self) -> None: def save_setting(self) -> None: """""" - setting = { + setting: dict = { "trade_slippage": self.trade_slippage, "timer_interval": self.timer_interval, "instant_trade": self.instant_trade diff --git a/vnpy_paperaccount/ui/widget.py b/vnpy_paperaccount/ui/widget.py index 677f32d..7dbfe75 100644 --- a/vnpy_paperaccount/ui/widget.py +++ b/vnpy_paperaccount/ui/widget.py @@ -2,7 +2,6 @@ from vnpy.trader.engine import MainEngine from vnpy.trader.ui import QtWidgets - from ..engine import ( PaperEngine, APP_NAME, @@ -12,7 +11,7 @@ class PaperManager(QtWidgets.QWidget): """""" - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + def __init__(self, main_engine: MainEngine, event_engine: EventEngine) -> None: """""" super().__init__() @@ -23,39 +22,39 @@ def __init__(self, main_engine: MainEngine, event_engine: EventEngine): self.init_ui() - def init_ui(self): + def init_ui(self) -> None: """""" self.setWindowTitle("模拟交易") self.setFixedHeight(200) self.setFixedWidth(500) - interval_spin = QtWidgets.QSpinBox() + interval_spin: QtWidgets.QSpinBox = QtWidgets.QSpinBox() interval_spin.setMinimum(1) interval_spin.setValue(self.paper_engine.timer_interval) interval_spin.setSuffix(" 秒") interval_spin.valueChanged.connect(self.paper_engine.set_timer_interval) - slippage_spin = QtWidgets.QSpinBox() + slippage_spin: QtWidgets.QSpinBox = QtWidgets.QSpinBox() slippage_spin.setMinimum(0) slippage_spin.setValue(self.paper_engine.trade_slippage) slippage_spin.setSuffix(" 跳") slippage_spin.valueChanged.connect(self.paper_engine.set_trade_slippage) - instant_check = QtWidgets.QCheckBox() + instant_check: QtWidgets.QCheckBox = QtWidgets.QCheckBox() instant_check.setChecked(self.paper_engine.instant_trade) instant_check.stateChanged.connect(self.paper_engine.set_instant_trade) - clear_button = QtWidgets.QPushButton("清空所有持仓") + clear_button: QtWidgets.QPushButton = QtWidgets.QPushButton("清空所有持仓") clear_button.clicked.connect(self.paper_engine.clear_position) clear_button.setFixedHeight(clear_button.sizeHint().height() * 2) - form = QtWidgets.QFormLayout() + form: QtWidgets.QFormLayout = QtWidgets.QFormLayout() form.addRow("市价委托和停止委托的成交滑点", slippage_spin) form.addRow("模拟交易持仓盈亏的计算频率", interval_spin) form.addRow("下单后立即使用当前盘口撮合", instant_check) form.addRow(clear_button) - vbox = QtWidgets.QVBoxLayout() + vbox: QtWidgets.QVBoxLayout = QtWidgets.QVBoxLayout() vbox.addStretch() vbox.addLayout(form) vbox.addStretch()