Skip to content

Commit

Permalink
fix: add retries for 'requests.ConnectionError' (#186)
Browse files Browse the repository at this point in the history
* Retry ConnectionError

* Retry ConnectionError with asyncio too

* address feedback

* lint

* fix import error in py2

* Restructure guarded import for test coverage

* respond to feedback; add ChunkedEncodingError
  • Loading branch information
andrewsg committed Dec 3, 2020
1 parent c0490dd commit 0d76eac
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 46 deletions.
41 changes: 20 additions & 21 deletions google/_async_resumable_media/_helpers.py
Expand Up @@ -19,21 +19,11 @@
import time


from six.moves import http_client


from google.resumable_media import common


RANGE_HEADER = u"range"
CONTENT_RANGE_HEADER = u"content-range"
RETRYABLE = (
common.TOO_MANY_REQUESTS,
http_client.INTERNAL_SERVER_ERROR,
http_client.BAD_GATEWAY,
http_client.SERVICE_UNAVAILABLE,
http_client.GATEWAY_TIMEOUT,
)

_SLOW_CRC32C_WARNING = (
"Currently using crcmod in pure python form. This is a slow "
Expand Down Expand Up @@ -162,24 +152,33 @@ async def wait_and_retry(func, get_status_code, retry_strategy):
object: The return value of ``func``.
"""

response = await func()

if get_status_code(response) not in RETRYABLE:
return response

total_sleep = 0.0
num_retries = 0
base_wait = 0.5 # When doubled will give 1.0
while retry_strategy.retry_allowed(total_sleep, num_retries):

while True: # return on success or when retries exhausted.
error = None
try:
response = await func()
except ConnectionError as e:
error = e
else:
if get_status_code(response) not in common.RETRYABLE:
return response

if not retry_strategy.retry_allowed(total_sleep, num_retries):
# Retries are exhausted and no acceptable response was received. Raise the
# retriable_error or return the unacceptable response.
if error:
raise error

return response

base_wait, wait_time = calculate_retry_wait(base_wait, retry_strategy.max_sleep)

num_retries += 1
total_sleep += wait_time
time.sleep(wait_time)
response = await func()
if get_status_code(response) not in RETRYABLE:
return response

return response


class _DoNothingHash(object):
Expand Down
67 changes: 48 additions & 19 deletions google/resumable_media/_helpers.py
Expand Up @@ -14,27 +14,20 @@

"""Shared utilities used by both downloads and uploads."""

from __future__ import absolute_import

import base64
import hashlib
import logging
import random
import time
import warnings

from six.moves import http_client

from google.resumable_media import common


RANGE_HEADER = u"range"
CONTENT_RANGE_HEADER = u"content-range"
RETRYABLE = (
common.TOO_MANY_REQUESTS,
http_client.INTERNAL_SERVER_ERROR,
http_client.BAD_GATEWAY,
http_client.SERVICE_UNAVAILABLE,
http_client.GATEWAY_TIMEOUT,
)

_SLOW_CRC32C_WARNING = (
"Currently using crcmod in pure python form. This is a slow "
Expand Down Expand Up @@ -162,23 +155,43 @@ def wait_and_retry(func, get_status_code, retry_strategy):
Returns:
object: The return value of ``func``.
"""
response = func()
if get_status_code(response) not in RETRYABLE:
return response

total_sleep = 0.0
num_retries = 0
base_wait = 0.5 # When doubled will give 1.0
while retry_strategy.retry_allowed(total_sleep, num_retries):

# Set the retriable_exception_type if possible. We expect requests to be
# present here and the transport to be using requests.exceptions errors,
# but due to loose coupling with the transport layer we can't guarantee it.
try:
connection_error_exceptions = _get_connection_error_classes()
except ImportError:
# We don't know the correct classes to use to catch connection errors,
# so an empty tuple here communicates "catch no exceptions".
connection_error_exceptions = ()

while True: # return on success or when retries exhausted.
error = None
try:
response = func()
except connection_error_exceptions as e:
error = e
else:
if get_status_code(response) not in common.RETRYABLE:
return response

if not retry_strategy.retry_allowed(total_sleep, num_retries):
# Retries are exhausted and no acceptable response was received. Raise the
# retriable_error or return the unacceptable response.
if error:
raise error

return response

base_wait, wait_time = calculate_retry_wait(base_wait, retry_strategy.max_sleep)

num_retries += 1
total_sleep += wait_time
time.sleep(wait_time)
response = func()
if get_status_code(response) not in RETRYABLE:
return response

return response


def _get_crc32c_object():
Expand Down Expand Up @@ -349,6 +362,22 @@ def _get_checksum_object(checksum_type):
raise ValueError("checksum must be ``'md5'``, ``'crc32c'`` or ``None``")


def _get_connection_error_classes():
"""Get the exception error classes.
Requests is a soft dependency here so that multiple transport layers can be
added in the future. This code is in a separate function here so that the
test framework can override its behavior to simulate requests being
unavailable."""

import requests.exceptions

return (
requests.exceptions.ConnectionError,
requests.exceptions.ChunkedEncodingError,
)


class _DoNothingHash(object):
"""Do-nothing hash object.
Expand Down
14 changes: 14 additions & 0 deletions google/resumable_media/common.py
Expand Up @@ -17,6 +17,7 @@
Includes custom exception types, useful constants and shared helpers.
"""

from six.moves import http_client

_SLEEP_RETRY_ERROR_MSG = (
u"At most one of `max_cumulative_retry` and `max_retries` " u"can be specified."
Expand Down Expand Up @@ -60,6 +61,19 @@
exceeds this limit, no more retries will occur.
"""

RETRYABLE = (
TOO_MANY_REQUESTS, # 429
http_client.INTERNAL_SERVER_ERROR, # 500
http_client.BAD_GATEWAY, # 502
http_client.SERVICE_UNAVAILABLE, # 503
http_client.GATEWAY_TIMEOUT, # 504
)
"""iterable: HTTP status codes that indicate a retryable error.
Connection errors are also retried, but are not listed as they are
exceptions, not status codes.
"""


class InvalidResponse(Exception):
"""Error class for responses which are not in the correct state.
Expand Down
89 changes: 86 additions & 3 deletions tests/unit/test__helpers.py
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import hashlib
import mock
import pytest
import requests.exceptions
from six.moves import http_client

from google.resumable_media import _helpers
Expand Down Expand Up @@ -151,7 +154,7 @@ def test_under_limit(self, randint_mock):
class Test_wait_and_retry(object):
def test_success_no_retry(self):
truthy = http_client.OK
assert truthy not in _helpers.RETRYABLE
assert truthy not in common.RETRYABLE
response = _make_response(truthy)

func = mock.Mock(return_value=response, spec=[])
Expand Down Expand Up @@ -179,7 +182,7 @@ def test_success_with_retry(self, randint_mock, sleep_mock):
ret_val = _helpers.wait_and_retry(func, _get_status_code, retry_strategy)

assert ret_val == responses[-1]
assert status_codes[-1] not in _helpers.RETRYABLE
assert status_codes[-1] not in common.RETRYABLE

assert func.call_count == 4
assert func.mock_calls == [mock.call()] * 4
Expand All @@ -192,6 +195,59 @@ def test_success_with_retry(self, randint_mock, sleep_mock):
sleep_mock.assert_any_call(2.625)
sleep_mock.assert_any_call(4.375)

@mock.patch(u"time.sleep")
@mock.patch(u"random.randint")
def test_success_with_retry_connection_error(self, randint_mock, sleep_mock):
randint_mock.side_effect = [125, 625, 375]

response = _make_response(http_client.NOT_FOUND)
responses = [
requests.exceptions.ConnectionError,
requests.exceptions.ConnectionError,
requests.exceptions.ConnectionError,
response,
]
func = mock.Mock(side_effect=responses, spec=[])

retry_strategy = common.RetryStrategy()
ret_val = _helpers.wait_and_retry(func, _get_status_code, retry_strategy)

assert ret_val == responses[-1]

assert func.call_count == 4
assert func.mock_calls == [mock.call()] * 4

assert randint_mock.call_count == 3
assert randint_mock.mock_calls == [mock.call(0, 1000)] * 3

assert sleep_mock.call_count == 3
sleep_mock.assert_any_call(1.125)
sleep_mock.assert_any_call(2.625)
sleep_mock.assert_any_call(4.375)

@mock.patch(u"time.sleep")
@mock.patch(u"random.randint")
def test_connection_import_error_failure(self, randint_mock, sleep_mock):
randint_mock.side_effect = [125, 625, 375]

response = _make_response(http_client.NOT_FOUND)
responses = [
requests.exceptions.ConnectionError,
requests.exceptions.ConnectionError,
requests.exceptions.ConnectionError,
response,
]

with mock.patch(
"google.resumable_media._helpers._get_connection_error_classes",
side_effect=ImportError,
):
with pytest.raises(requests.exceptions.ConnectionError):
func = mock.Mock(side_effect=responses, spec=[])

retry_strategy = common.RetryStrategy()
_helpers.wait_and_retry(func, _get_status_code, retry_strategy)

@mock.patch(u"time.sleep")
@mock.patch(u"random.randint")
def test_retry_exceeds_max_cumulative(self, randint_mock, sleep_mock):
Expand All @@ -214,7 +270,34 @@ def test_retry_exceeds_max_cumulative(self, randint_mock, sleep_mock):
ret_val = _helpers.wait_and_retry(func, _get_status_code, retry_strategy)

assert ret_val == responses[-1]
assert status_codes[-1] in _helpers.RETRYABLE
assert status_codes[-1] in common.RETRYABLE

assert func.call_count == 8
assert func.mock_calls == [mock.call()] * 8

assert randint_mock.call_count == 7
assert randint_mock.mock_calls == [mock.call(0, 1000)] * 7

assert sleep_mock.call_count == 7
sleep_mock.assert_any_call(1.875)
sleep_mock.assert_any_call(2.0)
sleep_mock.assert_any_call(4.375)
sleep_mock.assert_any_call(8.5)
sleep_mock.assert_any_call(16.5)
sleep_mock.assert_any_call(32.25)
sleep_mock.assert_any_call(64.125)

@mock.patch(u"time.sleep")
@mock.patch(u"random.randint")
def test_retry_exceeded_reraises_connection_error(self, randint_mock, sleep_mock):
randint_mock.side_effect = [875, 0, 375, 500, 500, 250, 125]

responses = [requests.exceptions.ConnectionError] * 8
func = mock.Mock(side_effect=responses, spec=[])

retry_strategy = common.RetryStrategy(max_cumulative_retry=100.0)
with pytest.raises(requests.exceptions.ConnectionError):
_helpers.wait_and_retry(func, _get_status_code, retry_strategy)

assert func.call_count == 8
assert func.mock_calls == [mock.call()] * 8
Expand Down

0 comments on commit 0d76eac

Please sign in to comment.