diff --git a/google/api_core/operation.py b/google/api_core/operation.py index 55adbdd8..9af9c4e6 100644 --- a/google/api_core/operation.py +++ b/google/api_core/operation.py @@ -192,7 +192,7 @@ def cancelled(self): ) -def _refresh_http(api_request, operation_name): +def _refresh_http(api_request, operation_name, retry=None): """Refresh an operation using a JSON/HTTP client. Args: @@ -200,11 +200,16 @@ def _refresh_http(api_request, operation_name): should generally be :meth:`google.cloud._http.Connection.api_request`. operation_name (str): The name of the operation. + retry (google.api_core.retry.Retry): (Optional) retry policy Returns: google.longrunning.operations_pb2.Operation: The operation. """ path = "operations/{}".format(operation_name) + + if retry is not None: + api_request = retry(api_request) + api_response = api_request(method="GET", path=path) return json_format.ParseDict(api_response, operations_pb2.Operation()) @@ -249,19 +254,25 @@ def from_http_json(operation, api_request, result_type, **kwargs): return Operation(operation_proto, refresh, cancel, result_type, **kwargs) -def _refresh_grpc(operations_stub, operation_name): +def _refresh_grpc(operations_stub, operation_name, retry=None): """Refresh an operation using a gRPC client. Args: operations_stub (google.longrunning.operations_pb2.OperationsStub): The gRPC operations stub. operation_name (str): The name of the operation. + retry (google.api_core.retry.Retry): (Optional) retry policy Returns: google.longrunning.operations_pb2.Operation: The operation. """ request_pb = operations_pb2.GetOperationRequest(name=operation_name) - return operations_stub.GetOperation(request_pb) + + rpc = operations_stub.GetOperation + if retry is not None: + rpc = retry(rpc) + + return rpc(request_pb) def _cancel_grpc(operations_stub, operation_name): diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py index 829a3f3b..2229c2d4 100644 --- a/tests/unit/test_operation.py +++ b/tests/unit/test_operation.py @@ -177,17 +177,39 @@ def test_unexpected_result(): def test__refresh_http(): - api_request = mock.Mock(return_value={"name": TEST_OPERATION_NAME, "done": True}) + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock(return_value=json_response) result = operation._refresh_http(api_request, TEST_OPERATION_NAME) + assert isinstance(result, operations_pb2.Operation) assert result.name == TEST_OPERATION_NAME assert result.done is True + api_request.assert_called_once_with( method="GET", path="operations/{}".format(TEST_OPERATION_NAME) ) +def test__refresh_http_w_retry(): + json_response = {"name": TEST_OPERATION_NAME, "done": True} + api_request = mock.Mock() + retry = mock.Mock() + retry.return_value.return_value = json_response + + result = operation._refresh_http(api_request, TEST_OPERATION_NAME, retry=retry) + + assert isinstance(result, operations_pb2.Operation) + assert result.name == TEST_OPERATION_NAME + assert result.done is True + + api_request.assert_not_called() + retry.assert_called_once_with(api_request) + retry.return_value.assert_called_once_with( + method="GET", path="operations/{}".format(TEST_OPERATION_NAME) + ) + + def test__cancel_http(): api_request = mock.Mock() @@ -224,6 +246,21 @@ def test__refresh_grpc(): operations_stub.GetOperation.assert_called_once_with(expected_request) +def test__refresh_grpc_w_retry(): + operations_stub = mock.Mock(spec=["GetOperation"]) + expected_result = make_operation_proto(done=True) + retry = mock.Mock() + retry.return_value.return_value = expected_result + + result = operation._refresh_grpc(operations_stub, TEST_OPERATION_NAME, retry=retry) + + assert result == expected_result + expected_request = operations_pb2.GetOperationRequest(name=TEST_OPERATION_NAME) + operations_stub.GetOperation.assert_not_called() + retry.assert_called_once_with(operations_stub.GetOperation) + retry.return_value.assert_called_once_with(expected_request) + + def test__cancel_grpc(): operations_stub = mock.Mock(spec=["CancelOperation"])