Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pit performance #1673

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -9,6 +9,7 @@ __pycache__/
_build
build/
dist/
tests/test_pit_data/

*.pkl
*.hd5
Expand Down
4 changes: 2 additions & 2 deletions qlib/data/base.py
Expand Up @@ -267,10 +267,10 @@ class PFeature(Feature):
def __str__(self):
return "$$" + self._name

def _load_internal(self, instrument, start_index, end_index, cur_time, period=None):
def _load_internal(self, instrument, start_index, end_index, cur_time, period=None, start_time=None):
from .data import PITD # pylint: disable=C0415

return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time, period)
return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time, period, start_time)


class ExpressionOps(Expression):
Expand Down
4 changes: 4 additions & 0 deletions qlib/data/cache.py
Expand Up @@ -160,6 +160,7 @@ def __init__(self, mem_cache_size_limit=None, limit_type="length"):
self.__calendar_mem_cache = klass(size_limit)
self.__instrument_mem_cache = klass(size_limit)
self.__feature_mem_cache = klass(size_limit)
self.__pit_mem_cache = klass(size_limit)

def __getitem__(self, key):
if key == "c":
Expand All @@ -168,13 +169,16 @@ def __getitem__(self, key):
return self.__instrument_mem_cache
elif key == "f":
return self.__feature_mem_cache
elif key == "p":
return self.__pit_mem_cache
else:
raise KeyError("Unknown memcache unit")

def clear(self):
self.__calendar_mem_cache.clear()
self.__instrument_mem_cache.clear()
self.__feature_mem_cache.clear()
self.__pit_mem_cache.clear()


class MemCacheExpire:
Expand Down
157 changes: 94 additions & 63 deletions qlib/data/data.py
Expand Up @@ -33,8 +33,7 @@
normalize_cache_fields,
code_to_fname,
time_to_slc_point,
read_period_data,
get_period_list,
get_period_list_by_offset,
)
from ..utils.paral import ParallelExt
from .ops import Operators # pylint: disable=W0611 # noqa: F401
Expand All @@ -48,7 +47,10 @@ class ProviderBackendMixin:

def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
if hasattr(self, "provider_name"):
provider_name = getattr(self, "provider_name")
else:
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
Expand Down Expand Up @@ -336,6 +338,10 @@ def feature(self, instrument, field, start_time, end_time, freq):


class PITProvider(abc.ABC):
@property
def provider_name(self):
return "PIT"

@abc.abstractmethod
def period_feature(
self,
Expand Down Expand Up @@ -742,29 +748,39 @@ def feature(self, instrument, field, start_index, end_index, freq):
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]


class LocalPITProvider(PITProvider):
class LocalPITProvider(PITProvider, ProviderBackendMixin):
# TODO: Add PIT backend file storage
# NOTE: This class is not multi-threading-safe!!!!

def period_feature(self, instrument, field, start_index, end_index, cur_time, period=None):
def __init__(self, remote=False, backend={}):
super().__init__()
self.remote = remote
self.backend = backend

def period_feature(self, instrument, field, start_offset, end_offset, cur_time, period=None, start_time=None):
"""get raw data from PIT
we have 3 modes to query data from PIT, all method need current datetime

1. given period, return value observed at current datetime
return series with index as datetime
2. given start_time, return value **observed by each day** from start_time to current datetime
return series with index as datetime
3. given start_offset and end_offset, return period data between [-start_offset, end_offset] observed at current datetime
return series with index as period

"""
if not isinstance(cur_time, pd.Timestamp):
raise ValueError(
f"Expected pd.Timestamp for `cur_time`, got '{cur_time}'. Advices: you can't query PIT data directly(e.g. '$$roewa_q'), you must use `P` operator to convert data to each day (e.g. 'P($$roewa_q)')"
)

assert end_index <= 0 # PIT don't support querying future data

DATA_RECORDS = [
("date", C.pit_record_type["date"]),
("period", C.pit_record_type["period"]),
("value", C.pit_record_type["value"]),
("_next", C.pit_record_type["index"]),
]
VALUE_DTYPE = C.pit_record_type["value"]
assert end_offset <= 0 # PIT don't support querying future data

field = str(field).lower()[2:]
instrument = code_to_fname(instrument)

backend_obj = self.backend_obj(instrument=instrument, field=field)

# {For acceleration
# start_index, end_index, cur_index = kwargs["info"]
# if cur_index == start_index:
Expand All @@ -777,58 +793,73 @@ def period_feature(self, instrument, field, start_index, end_index, cur_time, pe
# self.period_index[field] = {}
# For acceleration}

if not field.endswith("_q") and not field.endswith("_a"):
raise ValueError("period field must ends with '_q' or '_a'")
key = (instrument, field)
quarterly = field.endswith("_q")
index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
if not (index_path.exists() and data_path.exists()):
raise FileNotFoundError("No file is found.")
# NOTE: The most significant performance loss is here.
# Does the acceleration that makes the program complicated really matters?
# - It makes parameters of the interface complicate
# - It does not performance in the optimal way (places all the pieces together, we may achieve higher performance)
# - If we design it carefully, we can go through for only once to get the historical evolution of the data.
# So I decide to deprecated previous implementation and keep the logic of the program simple
# Instead, I'll add a cache for the index file.
data = np.fromfile(data_path, dtype=DATA_RECORDS)

# find all revision periods before `cur_time`
cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day)
loc = np.searchsorted(data["date"], cur_time_int, side="right")
if loc <= 0:
return pd.Series(dtype=C.pit_record_type["value"])
last_period = data["period"][:loc].max() # return the latest quarter
first_period = data["period"][:loc].min()
period_list = get_period_list(first_period, last_period, quarterly)
if key in H["p"]:
df = H["p"][key]
else:
if not field.endswith("_q") and not field.endswith("_a"):
raise ValueError("period field must ends with '_q' or '_a'")
# index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index"
data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data"
if not data_path.exists():
raise FileNotFoundError("No file is found.")
## get first period offset
## NOTE: current index file return offset from a given period not date
## so we cannot findout the offset by given date
## stop using index in this version
# start_point = get_pitdata_offset(index_path, period, )
data = backend_obj.np_data()
df = pd.DataFrame(data)
df.sort_values(by=["date", "period"], inplace=True)
df["date"] = pd.to_datetime(df["date"].astype(str))
H["f"][key] = df

# return df
if period is not None:
# NOTE: `period` has higher priority than `start_index` & `end_index`
if period not in period_list:
return pd.Series(dtype=C.pit_record_type["value"])
else:
period_list = [period]
retur = df[df["period"] == period].set_index("date")["value"]
elif start_time is not None:
# df is sorted by date, and the term whose period is monotonically non-decreasing is selected.
s_sign = pd.Series(False, index=df.index)
max_p = df["period"].iloc[0]
for i in range(0, len(s_sign)):
if df["period"].iloc[i] >= max_p:
s_sign.iloc[i] = True
max_p = df["period"].iloc[i]
df_sim = df[s_sign].drop_duplicates(subset=["date"], keep="last")
s_part = df_sim.set_index("date")[start_time:]["value"]
if s_part.empty:
return pd.Series(dtype="float64")
if start_time != s_part.index[0] and start_time >= df["date"].iloc[0]:
# add previous value to result to avoid nan in the first period
pre_value = pd.Series(df[df["date"] < start_time]["value"].iloc[-1], index=[start_time])
s_part = pd.concat([pre_value, s_part])
return s_part
else:
period_list = period_list[max(0, len(period_list) + start_index - 1) : len(period_list) + end_index]
value = np.full((len(period_list),), np.nan, dtype=VALUE_DTYPE)
for i, p in enumerate(period_list):
# last_period_index = self.period_index[field].get(period) # For acceleration
value[i], now_period_index = read_period_data(
index_path, data_path, p, cur_time_int, quarterly # , last_period_index # For acceleration
)
# self.period_index[field].update({period: now_period_index}) # For acceleration
# NOTE: the index is period_list; So it may result in unexpected values(e.g. nan)
# when calculation between different features and only part of its financial indicator is published
series = pd.Series(value, index=period_list, dtype=VALUE_DTYPE)

# {For acceleration
# if cur_index == end_index:
# self.all_fields.remove(field)
# if not len(self.all_fields):
# del self.all_fields
# del self.period_index
# For acceleration}

return series
df_remain = df[(df["date"] <= cur_time)]
if df_remain.empty:
return pd.Series(dtype="float64")
last_observe_date = df_remain["date"].iloc[-1]
# keep only the latest period value
df_remain = df_remain.sort_values(by=["period"]).drop_duplicates(subset=["period"], keep="last")
df_remain = df_remain.set_index("period")

cache_key = (
instrument,
field,
last_observe_date,
start_offset,
end_offset,
quarterly,
) # f"{instrument}.{field}.{last_observe_date}.{start_offset}.{end_offset}.{quarterly}"
if cache_key in H["p"]:
retur = H["p"][cache_key]
else:
last_period = df_remain.index[-1]
period_list = get_period_list_by_offset(last_period, start_offset, end_offset, quarterly)
retur = df_remain["value"].reindex(period_list, fill_value=np.nan)
H["p"][cache_key] = retur
return retur


class LocalExpressionProvider(ExpressionProvider):
Expand Down
46 changes: 30 additions & 16 deletions qlib/data/pit.py
Expand Up @@ -24,31 +24,45 @@ class P(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = Cal.calendar(freq=freq)
resample_data = np.empty(end_index - start_index + 1, dtype="float32")

for cur_index in range(start_index, end_index + 1):
cur_time = _calendar[cur_index]
# To load expression accurately, more historical data are required
start_ws, end_ws = self.feature.get_extended_window_size()
if end_ws > 0:
raise ValueError(
"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported"
)

# The calculated value will always the last element, so the end_offset is zero.
# To load expression accurately, more historical data are required
start_ws, end_ws = self.feature.get_extended_window_size()
# if start_ws = 0, means expression use only current data, so pit history data is not required
if start_ws == 0 and end_ws == 0:
try:
s = self._load_feature(instrument, -start_ws, 0, cur_time)
resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan
# get start and end date
s = self._load_feature(instrument, 0, 0, _calendar[end_index], None, _calendar[start_index])
if len(s) == 0:
return pd.Series(dtype="float32", name=str(self))
# index in s may not in calendar, so we need to reindex it to continue date first
s = s.reindex(pd.date_range(start=s.index[0], end=_calendar[end_index])).fillna(method="ffill")
resample_data = s.reindex(_calendar[start_index : end_index + 1]).fillna(method="ffill").values
except FileNotFoundError:
get_module_logger("base").warning(f"WARN: period data not found for {str(self)}")
return pd.Series(dtype="float32", name=str(self))
else:
for cur_index in range(start_index, end_index + 1):
cur_time = _calendar[cur_index]

if end_ws > 0:
raise ValueError(
"PIT database does not support referring to future period (e.g. expressions like `Ref('$$roewa_q', -1)` are not supported"
)

# The calculated value will always the last element, so the end_offset is zero.
try:
s = self._load_feature(instrument, -start_ws, 0, cur_time)
resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan
except FileNotFoundError:
get_module_logger("base").warning(f"WARN: period data not found for {str(self)}")
return pd.Series(dtype="float32", name=str(self))

resample_series = pd.Series(
resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype="float32", name=str(self)
)
return resample_series

def _load_feature(self, instrument, start_index, end_index, cur_time):
return self.feature.load(instrument, start_index, end_index, cur_time)
def _load_feature(self, instrument, start_index, end_index, cur_time, period=None, start_time=None):
return self.feature.load(instrument, start_index, end_index, cur_time, period, start_time)

def get_longest_back_rolling(self):
# The period data will collapse as a normal feature. So no extending and looking back
Expand All @@ -67,5 +81,5 @@ def __init__(self, feature, period):
def __str__(self):
return f"{super().__str__()}[{self.period}]"

def _load_feature(self, instrument, start_index, end_index, cur_time):
def _load_feature(self, instrument, start_index, end_index, cur_time, period=None, start_time=None):
PaleNeutron marked this conversation as resolved.
Show resolved Hide resolved
return self.feature.load(instrument, start_index, end_index, cur_time, self.period)