Skip to content

Commit

Permalink
Merge pull request #7 from noranhe/main
Browse files Browse the repository at this point in the history
[Add] 类型声明
  • Loading branch information
vnpy committed Aug 3, 2022
2 parents d8d34d5 + eba3ee5 commit 61830d2
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 132 deletions.
14 changes: 7 additions & 7 deletions vnpy_datamanager/__init__.py
Expand Up @@ -37,10 +37,10 @@
class DataManagerApp(BaseApp):
""""""

app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "数据管理"
engine_class = ManagerEngine
widget_name = "ManagerWidget"
icon_name = str(app_path.joinpath("ui", "manager.ico"))
app_name: str = APP_NAME
app_module: str = __module__
app_path: Path = Path(__file__).parent
display_name: str = "数据管理"
engine_class: ManagerEngine = ManagerEngine
widget_name: str = "ManagerWidget"
icon_name: str = str(app_path.joinpath("ui", "manager.ico"))
52 changes: 26 additions & 26 deletions vnpy_datamanager/engine.py
@@ -1,10 +1,10 @@
import csv
from datetime import datetime
from typing import List, Tuple
from typing import List, Optional

from vnpy.trader.engine import BaseEngine, MainEngine, EventEngine
from vnpy.trader.constant import Interval, Exchange
from vnpy.trader.object import BarData, HistoryRequest
from vnpy.trader.object import BarData, TickData, ContractData, HistoryRequest
from vnpy.trader.database import BaseDatabase, get_database, BarOverview, DB_TZ
from vnpy.trader.datafeed import BaseDatafeed, get_datafeed
from vnpy.trader.utility import ZoneInfo
Expand All @@ -19,7 +19,7 @@ def __init__(
self,
main_engine: MainEngine,
event_engine: EventEngine,
):
) -> None:
""""""
super().__init__(main_engine, event_engine, APP_NAME)

Expand All @@ -42,29 +42,29 @@ def import_data_from_csv(
turnover_head: str,
open_interest_head: str,
datetime_format: str
) -> Tuple:
) -> tuple:
""""""
with open(file_path, "rt") as f:
buf = [line.replace("\0", "") for line in f]
buf: list = [line.replace("\0", "") for line in f]

reader = csv.DictReader(buf, delimiter=",")
reader: csv.DictReader = csv.DictReader(buf, delimiter=",")

bars = []
start = None
count = 0
bars: List[BarData] = []
start: datetime = None
count: int = 0
tz = ZoneInfo(tz_name)

for item in reader:
if datetime_format:
dt = datetime.strptime(item[datetime_head], datetime_format)
dt: datetime = datetime.strptime(item[datetime_head], datetime_format)
else:
dt = datetime.fromisoformat(item[datetime_head])
dt: datetime = datetime.fromisoformat(item[datetime_head])
dt = dt.replace(tzinfo=tz)

turnover = item.get(turnover_head, 0)
open_interest = item.get(open_interest_head, 0)

bar = BarData(
bar: BarData = BarData(
symbol=symbol,
exchange=exchange,
datetime=dt,
Expand All @@ -86,7 +86,7 @@ def import_data_from_csv(
if not start:
start = bar.datetime

end = bar.datetime
end: datetime = bar.datetime

# insert into database
self.database.save_bar_data(bars)
Expand All @@ -103,9 +103,9 @@ def output_data_to_csv(
end: datetime
) -> bool:
""""""
bars = self.load_bar_data(symbol, exchange, interval, start, end)
bars: List[BarData] = self.load_bar_data(symbol, exchange, interval, start, end)

fieldnames = [
fieldnames: list = [
"symbol",
"exchange",
"datetime",
Expand All @@ -120,11 +120,11 @@ def output_data_to_csv(

try:
with open(file_path, "w") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer: csv.DictWriter = csv.DictWriter(f, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()

for bar in bars:
d = {
d: dict = {
"symbol": bar.symbol,
"exchange": bar.exchange.value,
"datetime": bar.datetime.strftime("%Y-%m-%d %H:%M:%S"),
Expand Down Expand Up @@ -155,7 +155,7 @@ def load_bar_data(
end: datetime
) -> List[BarData]:
""""""
bars = self.database.load_bar_data(
bars: List[BarData] = self.database.load_bar_data(
symbol,
exchange,
interval,
Expand All @@ -172,7 +172,7 @@ def delete_bar_data(
interval: Interval
) -> int:
""""""
count = self.database.delete_bar_data(
count: int = self.database.delete_bar_data(
symbol,
exchange,
interval
Expand All @@ -190,25 +190,25 @@ def download_bar_data(
"""
Query bar data from datafeed.
"""
req = HistoryRequest(
req: HistoryRequest = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=Interval(interval),
start=start,
end=datetime.now(DB_TZ)
)

vt_symbol = f"{symbol}.{exchange.value}"
contract = self.main_engine.get_contract(vt_symbol)
vt_symbol: str = f"{symbol}.{exchange.value}"
contract: Optional[ContractData] = self.main_engine.get_contract(vt_symbol)

# If history data provided in gateway, then query
if contract and contract.history_data:
data = self.main_engine.query_history(
data: List[BarData] = self.main_engine.query_history(
req, contract.gateway_name
)
# Otherwise use datafeed to query data
else:
data = self.datafeed.query_bar_history(req)
data: List[BarData] = self.datafeed.query_bar_history(req)

if data:
self.database.save_bar_data(data)
Expand All @@ -225,14 +225,14 @@ def download_tick_data(
"""
Query tick data from datafeed.
"""
req = HistoryRequest(
req: HistoryRequest = HistoryRequest(
symbol=symbol,
exchange=exchange,
start=start,
end=datetime.now(DB_TZ)
)

data = self.datafeed.query_tick_history(req)
data: List[TickData] = self.datafeed.query_tick_history(req)

if data:
self.database.save_tick_data(data)
Expand Down

0 comments on commit 61830d2

Please sign in to comment.