Skip to content

Commit

Permalink
feat: allow gRPC metadata to be passed to operations client (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
busunkim96 committed Jan 14, 2021
1 parent c5fee89 commit 73854e8
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 31 deletions.
16 changes: 10 additions & 6 deletions google/api_core/operation.py
Expand Up @@ -287,7 +287,7 @@ def _cancel_grpc(operations_stub, operation_name):
operations_stub.CancelOperation(request_pb)


def from_grpc(operation, operations_stub, result_type, **kwargs):
def from_grpc(operation, operations_stub, result_type, grpc_metadata=None, **kwargs):
"""Create an operation future using a gRPC client.
This interacts with the long-running operations `service`_ (specific
Expand All @@ -302,18 +302,20 @@ def from_grpc(operation, operations_stub, result_type, **kwargs):
operations_stub (google.longrunning.operations_pb2.OperationsStub):
The operations stub.
result_type (:func:`type`): The protobuf result type.
grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass
to the rpc.
kwargs: Keyword args passed into the :class:`Operation` constructor.
Returns:
~.api_core.operation.Operation: The operation future to track the given
operation.
"""
refresh = functools.partial(_refresh_grpc, operations_stub, operation.name)
cancel = functools.partial(_cancel_grpc, operations_stub, operation.name)
refresh = functools.partial(_refresh_grpc, operations_stub, operation.name, metadata=grpc_metadata)
cancel = functools.partial(_cancel_grpc, operations_stub, operation.name, metadata=grpc_metadata)
return Operation(operation, refresh, cancel, result_type, **kwargs)


def from_gapic(operation, operations_client, result_type, **kwargs):
def from_gapic(operation, operations_client, result_type, grpc_metadata=None, **kwargs):
"""Create an operation future from a gapic client.
This interacts with the long-running operations `service`_ (specific
Expand All @@ -328,12 +330,14 @@ def from_gapic(operation, operations_client, result_type, **kwargs):
operations_client (google.api_core.operations_v1.OperationsClient):
The operations client.
result_type (:func:`type`): The protobuf result type.
grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass
to the rpc.
kwargs: Keyword args passed into the :class:`Operation` constructor.
Returns:
~.api_core.operation.Operation: The operation future to track the given
operation.
"""
refresh = functools.partial(operations_client.get_operation, operation.name)
cancel = functools.partial(operations_client.cancel_operation, operation.name)
refresh = functools.partial(operations_client.get_operation, operation.name, metadata=grpc_metadata)
cancel = functools.partial(operations_client.cancel_operation, operation.name, metadata=grpc_metadata)
return Operation(operation, refresh, cancel, result_type, **kwargs)
8 changes: 5 additions & 3 deletions google/api_core/operation_async.py
Expand Up @@ -189,7 +189,7 @@ async def cancelled(self):
)


def from_gapic(operation, operations_client, result_type, **kwargs):
def from_gapic(operation, operations_client, result_type, grpc_metadata=None, **kwargs):
"""Create an operation future from a gapic client.
This interacts with the long-running operations `service`_ (specific
Expand All @@ -204,12 +204,14 @@ def from_gapic(operation, operations_client, result_type, **kwargs):
operations_client (google.api_core.operations_v1.OperationsClient):
The operations client.
result_type (:func:`type`): The protobuf result type.
grpc_metadata (Optional[List[Tuple[str, str]]]): Additional metadata to pass
to the rpc.
kwargs: Keyword args passed into the :class:`Operation` constructor.
Returns:
~.api_core.operation.Operation: The operation future to track the given
operation.
"""
refresh = functools.partial(operations_client.get_operation, operation.name)
cancel = functools.partial(operations_client.cancel_operation, operation.name)
refresh = functools.partial(operations_client.get_operation, operation.name, metadata=grpc_metadata)
cancel = functools.partial(operations_client.cancel_operation, operation.name, metadata=grpc_metadata)
return AsyncOperation(operation, refresh, cancel, result_type, **kwargs)
35 changes: 28 additions & 7 deletions google/api_core/operations_v1/operations_async_client.py
Expand Up @@ -77,7 +77,11 @@ def __init__(self, channel, client_config=operations_client_config.config):
)

async def get_operation(
self, name, retry=gapic_v1.method_async.DEFAULT, timeout=gapic_v1.method_async.DEFAULT
self,
name,
retry=gapic_v1.method_async.DEFAULT,
timeout=gapic_v1.method_async.DEFAULT,
metadata=None,
):
"""Gets the latest state of a long-running operation.
Expand All @@ -103,6 +107,8 @@ async def get_operation(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]):
Additional gRPC metadata.
Returns:
google.longrunning.operations_pb2.Operation: The state of the
Expand All @@ -114,14 +120,15 @@ async def get_operation(
subclass will be raised.
"""
request = operations_pb2.GetOperationRequest(name=name)
return await self._get_operation(request, retry=retry, timeout=timeout)
return await self._get_operation(request, retry=retry, timeout=timeout, metadata=metadata)

async def list_operations(
self,
name,
filter_,
retry=gapic_v1.method_async.DEFAULT,
timeout=gapic_v1.method_async.DEFAULT,
metadata=None,
):
"""
Lists operations that match the specified filter in the request.
Expand Down Expand Up @@ -157,6 +164,8 @@ async def list_operations(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
Returns:
google.api_core.page_iterator.Iterator: An iterator that yields
Expand All @@ -174,7 +183,7 @@ async def list_operations(
request = operations_pb2.ListOperationsRequest(name=name, filter=filter_)

# Create the method used to fetch pages
method = functools.partial(self._list_operations, retry=retry, timeout=timeout)
method = functools.partial(self._list_operations, retry=retry, timeout=timeout, metadata=metadata)

iterator = page_iterator_async.AsyncGRPCIterator(
client=None,
Expand All @@ -188,7 +197,11 @@ async def list_operations(
return iterator

async def cancel_operation(
self, name, retry=gapic_v1.method_async.DEFAULT, timeout=gapic_v1.method_async.DEFAULT
self,
name,
retry=gapic_v1.method_async.DEFAULT,
timeout=gapic_v1.method_async.DEFAULT,
metadata=None,
):
"""Starts asynchronous cancellation on a long-running operation.
Expand Down Expand Up @@ -228,13 +241,19 @@ async def cancel_operation(
google.api_core.exceptions.GoogleAPICallError: If an error occurred
while invoking the RPC, the appropriate ``GoogleAPICallError``
subclass will be raised.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
"""
# Create the request object.
request = operations_pb2.CancelOperationRequest(name=name)
await self._cancel_operation(request, retry=retry, timeout=timeout)
await self._cancel_operation(request, retry=retry, timeout=timeout, metadata=metadata)

async def delete_operation(
self, name, retry=gapic_v1.method_async.DEFAULT, timeout=gapic_v1.method_async.DEFAULT
self,
name,
retry=gapic_v1.method_async.DEFAULT,
timeout=gapic_v1.method_async.DEFAULT,
metadata=None,
):
"""Deletes a long-running operation.
Expand All @@ -260,6 +279,8 @@ async def delete_operation(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
Raises:
google.api_core.exceptions.MethodNotImplemented: If the server
Expand All @@ -271,4 +292,4 @@ async def delete_operation(
"""
# Create the request object.
request = operations_pb2.DeleteOperationRequest(name=name)
await self._delete_operation(request, retry=retry, timeout=timeout)
await self._delete_operation(request, retry=retry, timeout=timeout, metadata=metadata)
35 changes: 28 additions & 7 deletions google/api_core/operations_v1/operations_client.py
Expand Up @@ -91,7 +91,11 @@ def __init__(self, channel, client_config=operations_client_config.config):

# Service calls
def get_operation(
self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT
self,
name,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
metadata=None,
):
"""Gets the latest state of a long-running operation.
Expand All @@ -117,6 +121,8 @@ def get_operation(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]):
Additional gRPC metadata.
Returns:
google.longrunning.operations_pb2.Operation: The state of the
Expand All @@ -128,14 +134,15 @@ def get_operation(
subclass will be raised.
"""
request = operations_pb2.GetOperationRequest(name=name)
return self._get_operation(request, retry=retry, timeout=timeout)
return self._get_operation(request, retry=retry, timeout=timeout, metadata=metadata)

def list_operations(
self,
name,
filter_,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
metadata=None,
):
"""
Lists operations that match the specified filter in the request.
Expand Down Expand Up @@ -171,6 +178,8 @@ def list_operations(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
Returns:
google.api_core.page_iterator.Iterator: An iterator that yields
Expand All @@ -188,7 +197,7 @@ def list_operations(
request = operations_pb2.ListOperationsRequest(name=name, filter=filter_)

# Create the method used to fetch pages
method = functools.partial(self._list_operations, retry=retry, timeout=timeout)
method = functools.partial(self._list_operations, retry=retry, timeout=timeout, metadata=metadata)

iterator = page_iterator.GRPCIterator(
client=None,
Expand All @@ -202,7 +211,11 @@ def list_operations(
return iterator

def cancel_operation(
self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT
self,
name,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
metadata=None,
):
"""Starts asynchronous cancellation on a long-running operation.
Expand Down Expand Up @@ -234,6 +247,8 @@ def cancel_operation(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
Raises:
google.api_core.exceptions.MethodNotImplemented: If the server
Expand All @@ -245,10 +260,14 @@ def cancel_operation(
"""
# Create the request object.
request = operations_pb2.CancelOperationRequest(name=name)
self._cancel_operation(request, retry=retry, timeout=timeout)
self._cancel_operation(request, retry=retry, timeout=timeout, metadata=metadata)

def delete_operation(
self, name, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT
self,
name,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
metadata=None,
):
"""Deletes a long-running operation.
Expand All @@ -274,6 +293,8 @@ def delete_operation(
unspecified, the the default timeout in the client
configuration is used. If ``None``, then the RPC method will
not time out.
metadata (Optional[List[Tuple[str, str]]]): Additional gRPC
metadata.
Raises:
google.api_core.exceptions.MethodNotImplemented: If the server
Expand All @@ -285,4 +306,4 @@ def delete_operation(
"""
# Create the request object.
request = operations_pb2.DeleteOperationRequest(name=name)
self._delete_operation(request, retry=retry, timeout=timeout)
self._delete_operation(request, retry=retry, timeout=timeout, metadata=metadata)
12 changes: 8 additions & 4 deletions tests/asyncio/operations_v1/test_operations_async_client.py
Expand Up @@ -36,9 +36,10 @@ async def test_get_operation():
operations_pb2.Operation(name="meep"))
client = operations_v1.OperationsAsyncClient(mocked_channel)

response = await client.get_operation("name")
response = await client.get_operation("name", metadata=[("x-goog-request-params", "foo")])
assert method.call_count == 1
assert tuple(method.call_args_list[0])[0][0].name == "name"
assert ("x-goog-request-params", "foo") in tuple(method.call_args_list[0])[1]["metadata"]
assert response == fake_call.response


Expand All @@ -53,7 +54,7 @@ async def test_list_operations():
mocked_channel, method, fake_call = _mock_grpc_objects(list_response)
client = operations_v1.OperationsAsyncClient(mocked_channel)

pager = await client.list_operations("name", "filter")
pager = await client.list_operations("name", "filter", metadata=[("x-goog-request-params", "foo")])

assert isinstance(pager, page_iterator_async.AsyncIterator)
responses = []
Expand All @@ -63,6 +64,7 @@ async def test_list_operations():
assert responses == operations

assert method.call_count == 1
assert ("x-goog-request-params", "foo") in tuple(method.call_args_list[0])[1]["metadata"]
request = tuple(method.call_args_list[0])[0][0]
assert isinstance(request, operations_pb2.ListOperationsRequest)
assert request.name == "name"
Expand All @@ -75,10 +77,11 @@ async def test_delete_operation():
empty_pb2.Empty())
client = operations_v1.OperationsAsyncClient(mocked_channel)

await client.delete_operation("name")
await client.delete_operation("name", metadata=[("x-goog-request-params", "foo")])

assert method.call_count == 1
assert tuple(method.call_args_list[0])[0][0].name == "name"
assert ("x-goog-request-params", "foo") in tuple(method.call_args_list[0])[1]["metadata"]


@pytest.mark.asyncio
Expand All @@ -87,7 +90,8 @@ async def test_cancel_operation():
empty_pb2.Empty())
client = operations_v1.OperationsAsyncClient(mocked_channel)

await client.cancel_operation("name")
await client.cancel_operation("name", metadata=[("x-goog-request-params", "foo")])

assert method.call_count == 1
assert tuple(method.call_args_list[0])[0][0].name == "name"
assert ("x-goog-request-params", "foo") in tuple(method.call_args_list[0])[1]["metadata"]
3 changes: 3 additions & 0 deletions tests/asyncio/test_operation_async.py
Expand Up @@ -177,12 +177,15 @@ def test_from_gapic():
operations_client,
struct_pb2.Struct,
metadata_type=struct_pb2.Struct,
grpc_metadata=[('x-goog-request-params', 'foo')]
)

assert future._result_type == struct_pb2.Struct
assert future._metadata_type == struct_pb2.Struct
assert future.operation.name == TEST_OPERATION_NAME
assert future.done
assert future._refresh.keywords["metadata"] == [('x-goog-request-params', 'foo')]
assert future._cancel.keywords["metadata"] == [('x-goog-request-params', 'foo')]


def test_deserialize():
Expand Down

0 comments on commit 73854e8

Please sign in to comment.