Skip to content

Commit

Permalink
Chore: Make release 1.0.46
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson authored and DominicCYK committed Nov 7, 2023
1 parent 5565677 commit 62a5f8a
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 65 deletions.
84 changes: 84 additions & 0 deletions gs_quant/api/api_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Copyright 2023 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Tuple

import cachetools
import pandas as pd

from gs_quant.base import Base
from gs_quant.session import GsSession


class CacheEvent(Enum):
PUT = 'Put'
GET = 'Get'


class ApiRequestCache(ABC):

def get(self, session: GsSession, key: Any, **kwargs):
cache_lookup = self._get(session, key, **kwargs)
if cache_lookup is not None:
self.record(session, key, CacheEvent.GET, **kwargs)
return cache_lookup

@abstractmethod
def _get(self, session: GsSession, key: Any, **kwargs):
pass

def record(self, session: GsSession, key: Any, method: CacheEvent, **kwargs):
pass

def put(self, session: GsSession, key: Any, value, **kwargs):
self._put(session, key, value, **kwargs)
self.record(session, key, CacheEvent.PUT, **kwargs)

@abstractmethod
def _put(self, session: GsSession, key: Any, value, **kwargs):
pass


class InMemoryApiRequestCache(ApiRequestCache):

def __init__(self, max_size=1000, ttl_in_seconds=3600):
self._cache = cachetools.TTLCache(max_size, ttl_in_seconds)
self._records = []

def get_events(self) -> Tuple[Tuple[CacheEvent, Any], ...]:
return tuple(self._records)

def clear_events(self):
self._records.clear()

def _make_str_key(self, key: Any):
if isinstance(key, (list, tuple)):
return "_".join(self._make_str_key(k) for k in key)
elif isinstance(key, (Base, pd.DataFrame)):
return key.to_json()
elif isinstance(key, dict):
return self._make_str_key(list(key.items()))
return str(key)

def _get(self, session: GsSession, key: Any, **kwargs):
return self._cache.get(self._make_str_key(key))

def record(self, session: GsSession, key: Any, method: CacheEvent, **kwargs):
self._records.append((method, key))

def _put(self, session: GsSession, key, value, **kwargs):
self._cache[self._make_str_key(key)] = value
3 changes: 2 additions & 1 deletion gs_quant/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import inflection
import pandas as pd

from gs_quant.api.api_session import ApiWithCustomSession
from gs_quant.api.fred.fred_query import FredQuery
from gs_quant.base import Base
from gs_quant.target.coordinates import MDAPIDataQuery
Expand All @@ -29,7 +30,7 @@
_logger = logging.getLogger(__name__)


class DataApi(metaclass=ABCMeta):
class DataApi(ApiWithCustomSession, metaclass=ABCMeta):
@classmethod
def query_data(cls, query: Union[DataQuery, FredQuery], dataset_id: str = None) -> Union[list, tuple]:
raise NotImplementedError('Must implement get_data')
Expand Down
87 changes: 54 additions & 33 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from gs_quant.data.log import log_debug, log_warning
from gs_quant.errors import MqValueError
from gs_quant.markets import MarketDataCoordinate
from gs_quant.session import GsSession
from gs_quant.target.common import MarketDataVendor, PricingLocation, Format
from gs_quant.target.coordinates import MDAPIDataBatchResponse, MDAPIDataQuery, MDAPIDataQueryResponse, MDAPIQueryField
from gs_quant.target.data import DataQuery, DataQueryResponse, DataSetCatalogEntry
from gs_quant.target.data import DataSetEntity, DataSetFieldEntity
from .assets import GsIdType
from ..api_cache import ApiRequestCache
from ...target.assets import EntityQuery, FieldFilterMap

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,10 +164,28 @@ class QueryType(Enum):
class GsDataApi(DataApi):
__definitions = {}
__asset_coordinates_cache = TTLCache(10000, 86400)
_api_request_cache: ApiRequestCache = None
DEFAULT_SCROLL = '30s'

# DataApi interface

@classmethod
def set_api_request_cache(cls, cache: ApiRequestCache):
cls._api_request_cache = cache

@classmethod
def _post_with_cache_check(cls, url, **kwargs):
session = cls.get_session()
if cls._api_request_cache:
cache_key = (url, 'POST', kwargs)
cached_val = cls._api_request_cache.get(session, cache_key)
if cached_val is not None:
return cached_val
result = session._post(url, **kwargs)
if cls._api_request_cache:
cls._api_request_cache.put(session, cache_key, result)
return result

@classmethod
def query_data(cls, query: Union[DataQuery, MDAPIDataQuery], dataset_id: str = None,
asset_id_type: Union[GsIdType, str] = None) \
Expand All @@ -182,12 +200,12 @@ def query_data(cls, query: Union[DataQuery, MDAPIDataQuery], dataset_id: str = N
response: Union[DataQueryResponse, dict] = cls.execute_query(dataset_id, query)
return cls.get_results(dataset_id, response, query)

@staticmethod
def execute_query(dataset_id: str, query: Union[DataQuery, MDAPIDataQuery]):
@classmethod
def execute_query(cls, dataset_id: str, query: Union[DataQuery, MDAPIDataQuery]):
kwargs = {'payload': query}
if getattr(query, 'format', None) in (Format.MessagePack, 'MessagePack'):
kwargs['request_headers'] = {'Accept': 'application/msgpack'}
return GsSession.current._post('/data/{}/query'.format(dataset_id), **kwargs)
return cls._post_with_cache_check('/data/{}/query'.format(dataset_id), **kwargs)

@staticmethod
def get_results(dataset_id: str, response: Union[DataQueryResponse, dict], query: DataQuery) -> \
Expand Down Expand Up @@ -226,10 +244,10 @@ def last_data(cls, query: Union[DataQuery, MDAPIDataQuery], dataset_id: str = No
if timeout is not None:
kwargs['timeout'] = timeout
if getattr(query, 'marketDataCoordinates', None):
result = GsSession.current._post('/data/coordinates/query/last', payload=query, **kwargs)
result = cls._post_with_cache_check('/data/coordinates/query/last', payload=query, **kwargs)
return result.get('responses', ())
else:
result = GsSession.current._post('/data/{}/last/query'.format(dataset_id), payload=query, **kwargs)
result = cls._post_with_cache_check('/data/{}/last/query'.format(dataset_id), payload=query, **kwargs)
return result.get('data', ())

@classmethod
Expand Down Expand Up @@ -270,13 +288,14 @@ def get_coverage(
include_history: bool = False,
**kwargs
) -> List[dict]:
session = cls.get_session()
params = cls._build_params(scroll, scroll_id, limit, offset, fields, include_history, **kwargs)
body = GsSession.current._get(f'/data/{dataset_id}/coverage', payload=params)
body = session._get(f'/data/{dataset_id}/coverage', payload=params)
results = scroll_results = body['results']
total_results = body['totalResults']
while len(scroll_results) and len(results) < total_results:
params['scrollId'] = body['scrollId']
body = GsSession.current._get(f'/data/{dataset_id}/coverage', payload=params)
body = session._get(f'/data/{dataset_id}/coverage', payload=params)
scroll_results = body['results']
results += scroll_results

Expand All @@ -294,36 +313,37 @@ async def get_coverage_async(
include_history: bool = False,
**kwargs
) -> List[dict]:
session = cls.get_session()
params = cls._build_params(scroll, scroll_id, limit, offset, fields, include_history, **kwargs)
body = await GsSession.current._get_async(f'/data/{dataset_id}/coverage', payload=params)
body = await session._get_async(f'/data/{dataset_id}/coverage', payload=params)
results = scroll_results = body['results']
total_results = body['totalResults']
while len(scroll_results) and len(results) < total_results:
params['scrollId'] = body['scrollId']
body = await GsSession.current._get_async(f'/data/{dataset_id}/coverage', payload=params)
body = await session._get_async(f'/data/{dataset_id}/coverage', payload=params)
scroll_results = body['results']
if scroll_results:
results += scroll_results
return results

@classmethod
def create(cls, definition: Union[DataSetEntity, dict]) -> DataSetEntity:
result = GsSession.current._post('/data/datasets', payload=definition)
result = cls.get_session()._post('/data/datasets', payload=definition)
return result

@classmethod
def delete_dataset(cls, dataset_id: str) -> dict:
result = GsSession.current._delete(f'/data/datasets/{dataset_id}')
result = cls.get_session()._delete(f'/data/datasets/{dataset_id}')
return result

@classmethod
def undelete_dataset(cls, dataset_id: str) -> dict:
result = GsSession.current._put(f'/data/datasets/{dataset_id}/undelete')
result = cls.get_session()._put(f'/data/datasets/{dataset_id}/undelete')
return result

@classmethod
def update_definition(cls, dataset_id: str, definition: Union[DataSetEntity, dict]) -> DataSetEntity:
result = GsSession.current._put('/data/datasets/{}'.format(dataset_id), payload=definition, cls=DataSetEntity)
result = cls.get_session()._put('/data/datasets/{}'.format(dataset_id), payload=definition, cls=DataSetEntity)
return result

@classmethod
Expand All @@ -333,8 +353,9 @@ def upload_data(cls, dataset_id: str, data: Union[pd.DataFrame, list, tuple]) ->
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html
data = data.to_json(orient='records')
# Don't use msgpack for MDS
headers = None if 'us-east' in GsSession.current.domain else {'Content-Type': 'application/x-msgpack'}
result = GsSession.current._post('/data/{}'.format(dataset_id), payload=data, request_headers=headers)
session = cls.get_session()
headers = None if 'us-east' in session.domain else {'Content-Type': 'application/x-msgpack'}
result = session._post('/data/{}'.format(dataset_id), payload=data, request_headers=headers)
return result

@classmethod
Expand All @@ -343,13 +364,13 @@ def delete_data(cls, dataset_id: str, delete_query: Dict) -> Dict:
Delete data from dataset. You must have admin access to the dataset to delete data.
All data deleted is not recoverable.
"""
return GsSession.current._delete(f'/data/{dataset_id}', payload=delete_query, use_body=True)
return cls.get_session()._delete(f'/data/{dataset_id}', payload=delete_query, use_body=True)

@classmethod
def get_definition(cls, dataset_id: str) -> DataSetEntity:
definition = cls.__definitions.get(dataset_id)
if not definition:
definition = GsSession.current._get('/data/datasets/{}'.format(dataset_id), cls=DataSetEntity)
definition = cls.get_session()._get('/data/datasets/{}'.format(dataset_id), cls=DataSetEntity)
if not definition:
raise MqValueError('Unknown dataset {}'.format(dataset_id))

Expand All @@ -369,13 +390,13 @@ def get_many_definitions(cls,
dict(limit=limit, offset=offset, scroll=scroll, scrollId=scroll_id,
enablePagination='true').items()))

body = GsSession.current._get('/data/datasets', payload=params, cls=DataSetEntity)
body = cls.get_session()._get('/data/datasets', payload=params, cls=DataSetEntity)
results = scroll_results = body['results']
total_results = body['totalResults']

while len(scroll_results) and len(results) < total_results:
params['scrollId'] = body['scrollId']
body = GsSession.current._get('/data/datasets', payload=params, cls=DataSetEntity)
body = cls.get_session()._get('/data/datasets', payload=params, cls=DataSetEntity)
scroll_results = body['results']
results = results + scroll_results

Expand All @@ -391,20 +412,21 @@ def get_catalog(cls,
) -> Tuple[DataSetCatalogEntry]:

query = f'dataSetId={"&dataSetId=".join(dataset_ids)}' if dataset_ids else ''
gs_session = cls.get_session()
if len(query):
return GsSession.current._get(f'/data/catalog?{query}', cls=DataSetCatalogEntry)['results']
return gs_session._get(f'/data/catalog?{query}', cls=DataSetCatalogEntry)['results']
else:
params = dict(filter(lambda item: item[1] is not None,
dict(limit=limit, offset=offset, scroll=scroll, scrollId=scroll_id,
enablePagination='true').items()))

body = GsSession.current._get('/data/catalog', payload=params, cls=DataSetEntity)
body = gs_session._get('/data/catalog', payload=params, cls=DataSetEntity)
results = scroll_results = body['results']
total_results = body['totalResults']

while len(scroll_results) and len(results) < total_results:
params['scrollId'] = body['scrollId']
body = GsSession.current._get('/data/catalog', payload=params, cls=DataSetEntity)
body = gs_session._get('/data/catalog', payload=params, cls=DataSetEntity)
scroll_results = body['results']
results = results + scroll_results

Expand Down Expand Up @@ -434,7 +456,7 @@ def get_many_coordinates(
where=where,
limit=limit
)
results = GsSession.current._post('/data/mdapi/query', query)['results']
results = cls._post_with_cache_check('/data/mdapi/query', payload=query)['results']

if return_type is str:
return tuple(coordinate['name'] for coordinate in results)
Expand Down Expand Up @@ -494,7 +516,7 @@ def get_data_providers(cls,
Return a dictionary containing a set of dataset providers for each available data field.
For each field will return a dict of daily and real-time dataset providers where available.
"""
response = availability if availability else GsSession.current._get(f'/data/measures/{entity_id}/availability')
response = availability if availability else cls.get_session()._get(f'/data/measures/{entity_id}/availability')
if 'errorMessages' in response:
raise MqValueError(f"Data availability request {response['requestId']} "
f"failed: {response.get('errorMessages', '')}")
Expand Down Expand Up @@ -522,10 +544,9 @@ def get_data_providers(cls,

@classmethod
def get_market_data(cls, query, request_id=None, ignore_errors: bool = False) -> pd.DataFrame:
GsSession.current: GsSession
start = time.perf_counter()
try:
body = GsSession.current._post('/data/measures', payload=query)
body = cls._post_with_cache_check('/data/measures', payload=query)
except Exception as e:
log_warning(request_id, _logger, f'Market data query {query} failed due to {e}')
raise e
Expand Down Expand Up @@ -781,10 +802,10 @@ def coordinates_data_series(
else:
return ret

@staticmethod
@classmethod
@cachetools.cached(TTLCache(ttl=3600, maxsize=128))
def get_types(dataset_id: str):
results = GsSession.current._get(f'/data/catalog/{dataset_id}')
def get_types(cls, dataset_id: str):
results = cls.get_session()._get(f'/data/catalog/{dataset_id}')
fields = results.get("fields")
if fields:
field_types = {}
Expand Down Expand Up @@ -853,7 +874,7 @@ def get_dataset_fields(
"""

where = dict(filter(lambda item: item[1] is not None, dict(id=ids, name=names).items()))
response = GsSession.current._post('/data/fields/query',
response = cls.get_session()._post('/data/fields/query',
payload={'where': where, 'limit': limit},
cls=DataSetFieldEntity)
return response['results']
Expand Down Expand Up @@ -881,7 +902,7 @@ def create_dataset_fields(
>>> GsDataApi.create_dataset_fields(fields)
"""
params = {'fields': fields}
response = GsSession.current._post('/data/fields/bulk', payload=params, cls=DataSetFieldEntity)
response = cls.get_session()._post('/data/fields/bulk', payload=params, cls=DataSetFieldEntity)
return response['results']

@classmethod
Expand Down Expand Up @@ -909,7 +930,7 @@ def update_dataset_fields(
>>> GsDataApi.update_dataset_fields(fields)
"""
params = {'fields': fields}
response = GsSession.current._put('/data/fields/bulk', payload=params, cls=DataSetFieldEntity)
response = cls.get_session()._put('/data/fields/bulk', payload=params, cls=DataSetFieldEntity)
return response['results']


Expand Down

0 comments on commit 62a5f8a

Please sign in to comment.