diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py new file mode 100644 index 00000000..69f5b41b --- /dev/null +++ b/google/api_core/rest_streaming.py @@ -0,0 +1,114 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +from collections import deque +import string +from typing import Deque + +import requests + + +class ResponseIterator: + """Iterator over REST API responses. + + Args: + response (requests.Response): An API response object. + response_message_cls (Callable[proto.Message]): A proto + class expected to be returned from an API. + """ + + def __init__(self, response: requests.Response, response_message_cls): + self._response = response + self._response_message_cls = response_message_cls + # Inner iterator over HTTP response's content. + self._response_itr = self._response.iter_content(decode_unicode=True) + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs: Deque[str] = deque() + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + # Whether an escape symbol "\" was encountered. + self._escape_next = False + + def cancel(self): + """Cancel existing streaming operation. + """ + self._response.close() + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._escape_next: + self._in_string = not self._in_string + self._obj += char + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + if self._level == 0: + self._level += 1 + else: + self._obj += char + elif char == "]": + if self._level == 1: + self._level -= 1 + else: + self._obj += char + else: + self._obj += char + self._escape_next = not self._escape_next if char == "\\" else False + + def __next__(self): + while not self._ready_objs: + try: + chunk = next(self._response_itr) + self._process_chunk(chunk) + except StopIteration as e: + if self._level > 0: + raise ValueError("Unfinished stream: %s" % self._obj) + raise e + return self._grab() + + def _grab(self): + # Add extra quotes to make json.loads happy. + return self._response_message_cls.from_json(self._ready_objs.popleft()) + + def __iter__(self): + return self diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py new file mode 100644 index 00000000..4be59580 --- /dev/null +++ b/tests/unit/test_rest_streaming.py @@ -0,0 +1,211 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import random +import time +from typing import List +from unittest.mock import patch + +import proto +import pytest +import requests + +from google.api_core import rest_streaming +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 + + +SEED = int(time.time()) +logging.info(f"Starting rest streaming tests with random seed: {SEED}") +random.seed(SEED) + + +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +class ResponseMock(requests.Response): + class _ResponseItr: + def __init__(self, _response_bytes: bytes, random_split=False): + self._responses_bytes = _response_bytes + self._i = 0 + self._random_split = random_split + + def __next__(self): + if self._i == len(self._responses_bytes): + raise StopIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._i :])) + else: + n = 1 + x = self._responses_bytes[self._i : self._i + n] + self._i += n + return x.decode("utf-8") + + def __init__( + self, responses: List[proto.Message], response_cls, random_split=False, + ): + super().__init__() + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self, responses: List[proto.Message]) -> bytes: + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + self._response_message_cls.to_json(r).strip('"') for r in responses + ] + logging.info(f"Sending JSON stream: {json_responses}") + ret_val = "[{}]".format(",".join(json_responses)) + return bytes(ret_val, "utf-8") + + def close(self): + raise NotImplementedError() + + def iter_content(self, *args, **kwargs): + return self._ResponseItr( + self._parse_responses(self._responses), random_split=self._random_split, + ) + + +@pytest.mark.parametrize("random_split", [False]) +def test_next_simple(random_split): + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=EchoResponse + ) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + assert list(itr) == responses + + +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_nested(random_split): + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_stress(random_split): + n = 50 + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_escaped_characters_in_string(random_split): + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + + responses = [ + Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")), + Song( + title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\") + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), + ] + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +def test_next_not_array(): + with patch.object( + ResponseMock, "iter_content", return_value=iter('{"hello": 0}') + ) as mock_method: + + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() + + +def test_cancel(): + with patch.object(ResponseMock, "close", return_value=None) as mock_method: + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + itr.cancel() + mock_method.assert_called_once() + + +def test_check_buffer(): + with patch.object( + ResponseMock, + "_parse_responses", + return_value=bytes('[{"content": "hello"}, {', "utf-8"), + ): + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + next(itr) + + +def test_next_html(): + with patch.object( + ResponseMock, "iter_content", return_value=iter("") + ) as mock_method: + + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once()