Skip to content

Commit

Permalink
Chore: Make release 1.0.54
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson authored and Pang, Stephen S C. [GBM Public] committed Jan 18, 2024
1 parent 094828f commit 2177d12
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 19 deletions.
2 changes: 1 addition & 1 deletion gs_quant/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def time_field(cls, dataset_id: str) -> str:

@classmethod
def construct_dataframe_with_types(cls, dataset_id: str, data: Union[Base, list, tuple, pd.Series],
schema_varies=False) -> pd.DataFrame:
schema_varies=False, standard_fields=False) -> pd.DataFrame:
raise NotImplementedError('Must implement time_field')

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion gs_quant/api/fred/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def __handle_response(response: str) -> pd.Series:
data = data.sort_index()
return data

def construct_dataframe_with_types(self, dataset_id: str, data: pd.Series, schema_varies=False) -> pd.DataFrame:
def construct_dataframe_with_types(self, dataset_id: str, data: pd.Series, schema_varies=False,
standard_fields=False) -> pd.DataFrame:
"""
Constructs a dataframe with correct date types.
Expand Down
27 changes: 25 additions & 2 deletions gs_quant/api/gs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,21 +816,44 @@ def get_types(cls, dataset_id: str):
return field_types
raise RuntimeError(f"Unable to get Dataset schema for {dataset_id}")

@classmethod
def get_field_types(cls, field_names: Union[str, List[str]]):
try:
fields = cls.get_dataset_fields(names=field_names, limit=len(field_names))
except Exception:
return {}
if fields:
field_types = {}
field: DataSetFieldEntity
for field in fields:
field_name = field.name
field_type = field.type_
field_format = field.parameters.get('format')
field_types[field_name] = field_format or field_type
return field_types
return {}

@classmethod
def construct_dataframe_with_types(cls, dataset_id: str, data: Union[Base, List, Tuple],
schema_varies=False) -> pd.DataFrame:
schema_varies=False, standard_fields=False) -> pd.DataFrame:
"""
Constructs a dataframe with correct date types.
:param dataset_id: id of the dataset
:param data: data to convert with correct types
:param schema_varies: if set, method will not assume that all rows have the same columns
:param standard_fields: if set, will use fields api instead of catalog api to get fieldTypes
:return: dataframe with correct types
"""
if len(data):
dataset_types = cls.get_types(dataset_id)
# Use first row to infer fields from data
sample = data if schema_varies else [data[0]]
incoming_data_data_types = pd.DataFrame(sample).dtypes.to_dict()
dataset_types = cls.get_types(dataset_id) if not standard_fields \
else cls.get_field_types(field_names=list(incoming_data_data_types.keys()))

# fallback approach in case fields api doesn't return results
if dataset_types is {} and standard_fields:
dataset_types = cls.get_types(dataset_id)

df = pd.DataFrame(data, columns={**dataset_types, **incoming_data_data_types})

Expand Down
22 changes: 15 additions & 7 deletions gs_quant/api/gs/risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ async def handle_websocket():
complete, pending = await asyncio.wait(listeners, return_when=asyncio.FIRST_COMPLETED)

# Check results before sending more requests. Results can be lost otherwise

if result_listener in complete:
# New results have been received
request_id = None
try:
request_id, status_result_str = result_listener.result().split(';', 1)
status, result_str = status_result_str[0], status_result_str[1:]
Expand All @@ -178,15 +178,23 @@ async def handle_websocket():
result = RuntimeError(result_str)
else:
# Unpack the result

try:
result = msgpack.unpackb(base64.b64decode(result_str), raw=False) \
if cls.USE_MSGPACK else json.loads(result_str)
except Exception as ee:
result = ee

# Enqueue the request and result for the listener to handle
results.put_nowait((pending_requests.pop(request_id), result))
if request_id is None:
# Certain fatal websocket errors (e.g. ConnectionClosed) that are caught above will mean
# we have no request_id - In this case we abort and set the error on all results
result_listener.cancel()
for req in pending_requests.values():
results.put_nowait((req, result))
# Give up
pending_requests.clear()
all_requests_dispatched = True
else:
# Enqueue the request and result for the listener to handle
results.put_nowait((pending_requests.pop(request_id), result))
else:
result_listener.cancel()

Expand All @@ -197,8 +205,8 @@ async def handle_websocket():
all_requests_dispatched, items = request_listener.result()
if items:
if not all([isinstance(i[1], dict) for i in items]):
error = next(i[1] for i in items if not isinstance(i[1], dict))
raise RuntimeError(error[0][0][0]['errorString'])
error_item = next(i[1] for i in items if not isinstance(i[1], dict))
raise RuntimeError(error_item[0][0][0]['errorString'])

# ... extract the request IDs ...
request_ids = [i[1]['reportId'] for i in items]
Expand Down
16 changes: 12 additions & 4 deletions gs_quant/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_data(
fields: Optional[Iterable[Union[str, Fields]]] = None,
asset_id_type: Optional[str] = None,
empty_intervals: Optional[bool] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.DataFrame:
"""
Expand All @@ -127,6 +128,7 @@ def get_data(
:param since: Request data since
:param fields: DataSet fields to include
:param empty_intervals: whether to request empty intervals
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Extra query arguments, e.g. ticker='EDZ19'
:return: A Dataframe of the requested data
Expand Down Expand Up @@ -162,10 +164,12 @@ def get_data(
)
data = self.provider.query_data(query, self.id, asset_id_type=asset_id_type)
if type(data) is tuple:
df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies)
df = self.provider.construct_dataframe_with_types(self.id, data[0], schema_varies,
standard_fields=standard_fields)
return df.groupby(data[1], group_keys=True).apply(lambda x: x)
else:
return self.provider.construct_dataframe_with_types(self.id, data, schema_varies)
return self.provider.construct_dataframe_with_types(self.id, data, schema_varies,
standard_fields=standard_fields)

def get_data_series(
self,
Expand All @@ -175,6 +179,7 @@ def get_data_series(
as_of: Optional[dt.datetime] = None,
since: Optional[dt.datetime] = None,
dates: Optional[List[dt.date]] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.Series:
"""
Expand All @@ -185,6 +190,7 @@ def get_data_series(
:param end: Requested end date/datetime for data
:param as_of: Request data as_of
:param since: Request data since
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Extra query arguments, e.g. ticker='EDZ19'
:return: A Series of the requested data, indexed by date or time, depending on the DataSet
Expand Down Expand Up @@ -216,7 +222,7 @@ def get_data_series(

symbol_dimension = symbol_dimensions[0]
data = self.provider.query_data(query, self.id)
df = self.provider.construct_dataframe_with_types(self.id, data)
df = self.provider.construct_dataframe_with_types(self.id, data, standard_fields=standard_fields)

from gs_quant.api.gs.data import GsDataApi

Expand All @@ -237,6 +243,7 @@ def get_data_last(
as_of: Optional[Union[dt.date, dt.datetime]],
start: Optional[Union[dt.date, dt.datetime]] = None,
fields: Optional[Iterable[str]] = None,
standard_fields: Optional[bool] = False,
**kwargs
) -> pd.DataFrame:
"""
Expand All @@ -245,6 +252,7 @@ def get_data_last(
:param as_of: The date or time as of which to query
:param start: The start of the range to query
:param fields: The fields for which to query
:param standard_fields: If set, will use fields api instead of catalog api to get fieldTypes
:param kwargs: Additional query parameters, e.g., city='Boston'
:return: A Dataframe of values
Expand All @@ -266,7 +274,7 @@ def get_data_last(
query.format = None # "last" endpoint does not support MessagePack

data = self.provider.last_data(query, self.id)
return self.provider.construct_dataframe_with_types(self.id, data)
return self.provider.construct_dataframe_with_types(self.id, data, standard_fields=standard_fields)

def get_coverage(
self,
Expand Down
4 changes: 2 additions & 2 deletions gs_quant/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def _connect_websocket(self, path: str, headers: Optional[dict] = None):
extra_headers = self._headers() + list((headers or {}).items())
return websockets.connect(url,
extra_headers=extra_headers,
max_size=2 ** 64,
read_limit=2 ** 64,
max_size=2 ** 32,
read_limit=2 ** 32,
ssl=self.__ssl_context() if url.startswith('wss') else None)

def _headers(self):
Expand Down
43 changes: 41 additions & 2 deletions gs_quant/test/api/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pandas.testing import assert_frame_equal, assert_series_equal

from gs_quant.api.gs.data import GsDataApi
from gs_quant.base import DictBase
from gs_quant.context_base import ContextMeta
from gs_quant.errors import MqValueError
from gs_quant.markets import MarketDataCoordinate
Expand Down Expand Up @@ -359,8 +360,8 @@ def test_auto_scroll_on_pages(mocker):
assert len(response) == 5


def test_get_dataset_fields(mocker):
mock_response = {
def mock_fields_response():
return {
"totalResults": 2,
"results": [
{
Expand Down Expand Up @@ -431,6 +432,9 @@ def test_get_dataset_fields(mocker):
]
}


def test_get_dataset_fields(mocker):
mock_response = mock_fields_response()
mocker.patch.object(GsSession.__class__, 'default_value',
return_value=GsSession.get(Environment.QA, 'client_id', 'secret'))
mocker.patch.object(GsSession.current, '_post', return_value=mock_response)
Expand All @@ -445,5 +449,40 @@ def test_get_dataset_fields(mocker):
cls=DataSetFieldEntity)


def test_get_field_types(mocker):
mock_response = {
"totalResults": 4,
"results": [
DataSetFieldEntity(name="price", type_="number", parameters=DictBase({})),
DataSetFieldEntity(name="strikeReference", type_="string", parameters=DictBase({})),
DataSetFieldEntity(name="adjDate", type_="string", field_java_type='DateField',
parameters=DictBase({'format': 'date'})),
DataSetFieldEntity(name="time", type_="string", field_java_type='DateTimeField',
parameters=DictBase({'format': 'date-time'}))
]
}

mock_field_types = {
"price": "number",
"strikeReference": "string",
"adjDate": "date",
"time": "date-time"
}

mocker.patch.object(GsSession.__class__, 'default_value',
return_value=GsSession.get(Environment.QA, 'client_id', 'secret'))
mocker.patch.object(GsSession.current, '_post', return_value=mock_response)

response = GsDataApi.get_field_types(field_names=['price', 'strikeReference', 'adjDate', 'time'])
assert len(response) == 4
assert response == mock_field_types

GsSession.current._post.assert_called_once_with('/data/fields/query',
payload={'where': {'name': ['price', 'strikeReference',
'adjDate', 'time']},
'limit': 4},
cls=DataSetFieldEntity)


if __name__ == "__main__":
pytest.main(args=["test_data.py"])

0 comments on commit 2177d12

Please sign in to comment.