From ea83083c315d4a97c29df35955f9547e2f869114 Mon Sep 17 00:00:00 2001 From: Yoshi Automation Bot Date: Fri, 19 Jun 2020 18:56:11 -0700 Subject: [PATCH] feat: add async client (#26) This PR was generated using Autosynth. :rainbow: Synth log will be available here: https://source.cloud.google.com/results/invocations/6a02e8e8-a66f-42c9-b48a-8df33bd9ffe3/targets - [ ] To automatically regenerate this PR, check this box. PiperOrigin-RevId: 317199748 Source-Link: https://github.com/googleapis/googleapis/commit/ff1b4ff2cfb09b66e26510d2c659705607563c1b --- docs/documentai_v1beta2/services.rst | 6 +- docs/documentai_v1beta2/types.rst | 4 +- google/cloud/documentai/__init__.py | 5 +- google/cloud/documentai_v1beta2/__init__.py | 1 - .../__init__.py | 6 +- .../async_client.py | 254 +++++ .../document_understanding_service/client.py | 96 +- .../transports/__init__.py | 3 + .../transports/base.py | 15 +- .../transports/grpc.py | 21 +- .../transports/grpc_asyncio.py | 247 +++++ noxfile.py | 1 + setup.py | 2 +- synth.metadata | 4 +- .../test_document_understanding_service.py | 471 --------- .../test_document_understanding_service.py | 906 ++++++++++++++++++ 16 files changed, 1512 insertions(+), 530 deletions(-) create mode 100644 google/cloud/documentai_v1beta2/services/document_understanding_service/async_client.py create mode 100644 google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc_asyncio.py delete mode 100644 tests/unit/documentai_v1beta2/test_document_understanding_service.py create mode 100644 tests/unit/gapic/documentai_v1beta2/test_document_understanding_service.py diff --git a/docs/documentai_v1beta2/services.rst b/docs/documentai_v1beta2/services.rst index ea9bfbe4..b1f00952 100644 --- a/docs/documentai_v1beta2/services.rst +++ b/docs/documentai_v1beta2/services.rst @@ -1,6 +1,6 @@ -Client for Google Cloud Documentai API -====================================== +Services for Google Cloud Documentai v1beta2 API +================================================ -.. automodule:: google.cloud.documentai_v1beta2 +.. automodule:: google.cloud.documentai_v1beta2.services.document_understanding_service :members: :inherited-members: diff --git a/docs/documentai_v1beta2/types.rst b/docs/documentai_v1beta2/types.rst index d116ddab..2a437e9d 100644 --- a/docs/documentai_v1beta2/types.rst +++ b/docs/documentai_v1beta2/types.rst @@ -1,5 +1,5 @@ -Types for Google Cloud Documentai API -===================================== +Types for Google Cloud Documentai v1beta2 API +============================================= .. automodule:: google.cloud.documentai_v1beta2.types :members: diff --git a/google/cloud/documentai/__init__.py b/google/cloud/documentai/__init__.py index 137d58e6..a59e3b40 100644 --- a/google/cloud/documentai/__init__.py +++ b/google/cloud/documentai/__init__.py @@ -15,7 +15,9 @@ # limitations under the License. # - +from google.cloud.documentai_v1beta2.services.document_understanding_service.async_client import ( + DocumentUnderstandingServiceAsyncClient, +) from google.cloud.documentai_v1beta2.services.document_understanding_service.client import ( DocumentUnderstandingServiceClient, ) @@ -64,6 +66,7 @@ "BatchProcessDocumentsResponse", "BoundingPoly", "Document", + "DocumentUnderstandingServiceAsyncClient", "DocumentUnderstandingServiceClient", "EntityExtractionParams", "FormExtractionParams", diff --git a/google/cloud/documentai_v1beta2/__init__.py b/google/cloud/documentai_v1beta2/__init__.py index 63c78042..7a10da73 100644 --- a/google/cloud/documentai_v1beta2/__init__.py +++ b/google/cloud/documentai_v1beta2/__init__.py @@ -15,7 +15,6 @@ # limitations under the License. # - from .services.document_understanding_service import DocumentUnderstandingServiceClient from .types.document import Document from .types.document_understanding import AutoMlParams diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/__init__.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/__init__.py index 52e41717..be18edfd 100644 --- a/google/cloud/documentai_v1beta2/services/document_understanding_service/__init__.py +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/__init__.py @@ -16,5 +16,9 @@ # from .client import DocumentUnderstandingServiceClient +from .async_client import DocumentUnderstandingServiceAsyncClient -__all__ = ("DocumentUnderstandingServiceClient",) +__all__ = ( + "DocumentUnderstandingServiceClient", + "DocumentUnderstandingServiceAsyncClient", +) diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/async_client.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/async_client.py new file mode 100644 index 00000000..a5984e35 --- /dev/null +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/async_client.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation +from google.api_core import operation_async +from google.cloud.documentai_v1beta2.types import document +from google.cloud.documentai_v1beta2.types import document_understanding +from google.rpc import status_pb2 as status # type: ignore + +from .transports.base import DocumentUnderstandingServiceTransport +from .transports.grpc_asyncio import DocumentUnderstandingServiceGrpcAsyncIOTransport +from .client import DocumentUnderstandingServiceClient + + +class DocumentUnderstandingServiceAsyncClient: + """Service to parse structured information from unstructured or + semi-structured documents using state-of-the-art Google AI such + as natural language, computer vision, and translation. + """ + + _client: DocumentUnderstandingServiceClient + + DEFAULT_ENDPOINT = DocumentUnderstandingServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DocumentUnderstandingServiceClient.DEFAULT_MTLS_ENDPOINT + + from_service_account_file = ( + DocumentUnderstandingServiceClient.from_service_account_file + ) + from_service_account_json = from_service_account_file + + get_transport_class = functools.partial( + type(DocumentUnderstandingServiceClient).get_transport_class, + type(DocumentUnderstandingServiceClient), + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DocumentUnderstandingServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + ) -> None: + """Instantiate the document understanding service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DocumentUnderstandingServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint, this is the default value for + the environment variable) and "auto" (auto switch to the default + mTLS endpoint if client SSL credentials is present). However, + the ``api_endpoint`` property takes precedence if provided. + (2) The ``client_cert_source`` property is used to provide client + SSL credentials for mutual TLS transport. If not provided, the + default SSL credentials will be used if present. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = DocumentUnderstandingServiceClient( + credentials=credentials, transport=transport, client_options=client_options + ) + + async def batch_process_documents( + self, + request: document_understanding.BatchProcessDocumentsRequest = None, + *, + requests: Sequence[document_understanding.ProcessDocumentRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""LRO endpoint to batch process many documents. The output is + written to Cloud Storage as JSON in the [Document] format. + + Args: + request (:class:`~.document_understanding.BatchProcessDocumentsRequest`): + The request object. Request to batch process documents + as an asynchronous operation. The output is written to + Cloud Storage as JSON in the [Document] format. + requests (:class:`Sequence[~.document_understanding.ProcessDocumentRequest]`): + Required. Individual requests for + each document. + This corresponds to the ``requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:``~.document_understanding.BatchProcessDocumentsResponse``: + Response to an batch document processing request. This + is returned in the LRO Operation after the operation is + complete. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([requests]): + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = document_understanding.BatchProcessDocumentsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if requests is not None: + request.requests = requests + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_process_documents, + default_timeout=None, + client_info=_client_info, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + document_understanding.BatchProcessDocumentsResponse, + metadata_type=document_understanding.OperationMetadata, + ) + + # Done; return the response. + return response + + async def process_document( + self, + request: document_understanding.ProcessDocumentRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> document.Document: + r"""Processes a single document. + + Args: + request (:class:`~.document_understanding.ProcessDocumentRequest`): + The request object. Request to process one document. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.document.Document: + Document represents the canonical + document resource in Document + Understanding AI. It is an interchange + format that provides insights into + documents and allows for collaboration + between users and Document Understanding + AI to iterate and optimize for quality. + + """ + # Create or coerce a protobuf request object. + + request = document_understanding.ProcessDocumentRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.process_document, + default_timeout=None, + client_info=_client_info, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata) + + # Done; return the response. + return response + + +try: + _client_info = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution("google-cloud-documentai").version + ) +except pkg_resources.DistributionNotFound: + _client_info = gapic_v1.client_info.ClientInfo() + + +__all__ = ("DocumentUnderstandingServiceAsyncClient",) diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/client.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/client.py index e251fd78..1aefee2f 100644 --- a/google/cloud/documentai_v1beta2/services/document_understanding_service/client.py +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/client.py @@ -16,6 +16,7 @@ # from collections import OrderedDict +import os import re from typing import Callable, Dict, Sequence, Tuple, Type, Union import pkg_resources @@ -25,15 +26,19 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore from google.api_core import operation +from google.api_core import operation_async from google.cloud.documentai_v1beta2.types import document from google.cloud.documentai_v1beta2.types import document_understanding from google.rpc import status_pb2 as status # type: ignore from .transports.base import DocumentUnderstandingServiceTransport from .transports.grpc import DocumentUnderstandingServiceGrpcTransport +from .transports.grpc_asyncio import DocumentUnderstandingServiceGrpcAsyncIOTransport class DocumentUnderstandingServiceClientMeta(type): @@ -48,6 +53,9 @@ class DocumentUnderstandingServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[DocumentUnderstandingServiceTransport]] _transport_registry["grpc"] = DocumentUnderstandingServiceGrpcTransport + _transport_registry[ + "grpc_asyncio" + ] = DocumentUnderstandingServiceGrpcAsyncIOTransport def get_transport_class( cls, label: str = None @@ -150,21 +158,49 @@ def __init__( transport (Union[str, ~.DocumentUnderstandingServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. - (2) If ``transport`` argument is None, ``client_options`` can be - used to create a mutual TLS transport. If ``client_cert_source`` - is provided, mutual TLS transport will be created with the given - ``api_endpoint`` or the default mTLS endpoint, and the client - SSL credentials obtained from ``client_cert_source``. + default endpoint provided by the client. GOOGLE_API_USE_MTLS + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint, this is the default value for + the environment variable) and "auto" (auto switch to the default + mTLS endpoint if client SSL credentials is present). However, + the ``api_endpoint`` property takes precedence if provided. + (2) The ``client_cert_source`` property is used to provide client + SSL credentials for mutual TLS transport. If not provided, the + default SSL credentials will be used if present. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): client_options = ClientOptions.from_dict(client_options) + if client_options is None: + client_options = ClientOptions.ClientOptions() + + if client_options.api_endpoint is None: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + if use_mtls_env == "never": + client_options.api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + has_client_cert_source = ( + client_options.client_cert_source is not None + or mtls.has_default_client_cert_source() + ) + client_options.api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT + if has_client_cert_source + else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -177,38 +213,12 @@ def __init__( "provide its credentials directly." ) self._transport = transport - elif client_options is None or ( - client_options.api_endpoint is None - and client_options.client_cert_source is None - ): - # Don't trigger mTLS if we get an empty ClientOptions. + else: Transport = type(self).get_transport_class(transport) self._transport = Transport( - credentials=credentials, host=self.DEFAULT_ENDPOINT - ) - else: - # We have a non-empty ClientOptions. If client_cert_source is - # provided, trigger mTLS with user provided endpoint or the default - # mTLS endpoint. - if client_options.client_cert_source: - api_mtls_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_MTLS_ENDPOINT - ) - else: - api_mtls_endpoint = None - - api_endpoint = ( - client_options.api_endpoint - if client_options.api_endpoint - else self.DEFAULT_ENDPOINT - ) - - self._transport = DocumentUnderstandingServiceGrpcTransport( credentials=credentials, - host=api_endpoint, - api_mtls_endpoint=api_mtls_endpoint, + host=client_options.api_endpoint, + api_mtls_endpoint=client_options.api_endpoint, client_cert_source=client_options.client_cert_source, ) @@ -278,6 +288,12 @@ def batch_process_documents( client_info=_client_info, ) + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata) @@ -335,6 +351,12 @@ def process_document( client_info=_client_info, ) + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata) diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/__init__.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/__init__.py index 7030f9c6..ce42f2ab 100644 --- a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/__init__.py +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/__init__.py @@ -20,6 +20,7 @@ from .base import DocumentUnderstandingServiceTransport from .grpc import DocumentUnderstandingServiceGrpcTransport +from .grpc_asyncio import DocumentUnderstandingServiceGrpcAsyncIOTransport # Compile a registry of transports. @@ -27,9 +28,11 @@ OrderedDict() ) # type: Dict[str, Type[DocumentUnderstandingServiceTransport]] _transport_registry["grpc"] = DocumentUnderstandingServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DocumentUnderstandingServiceGrpcAsyncIOTransport __all__ = ( "DocumentUnderstandingServiceTransport", "DocumentUnderstandingServiceGrpcTransport", + "DocumentUnderstandingServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/base.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/base.py index 6809b5da..17cc0ba0 100644 --- a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/base.py +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/base.py @@ -27,7 +27,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class DocumentUnderstandingServiceTransport(metaclass=abc.ABCMeta): +class DocumentUnderstandingServiceTransport(abc.ABC): """Abstract transport class for DocumentUnderstandingService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -37,6 +37,7 @@ def __init__( *, host: str = "us-documentai.googleapis.com", credentials: credentials.Credentials = None, + **kwargs, ) -> None: """Instantiate the transport. @@ -64,23 +65,25 @@ def __init__( @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def batch_process_documents( self ) -> typing.Callable[ - [document_understanding.BatchProcessDocumentsRequest], operations.Operation + [document_understanding.BatchProcessDocumentsRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def process_document( self ) -> typing.Callable[ - [document_understanding.ProcessDocumentRequest], document.Document + [document_understanding.ProcessDocumentRequest], + typing.Union[document.Document, typing.Awaitable[document.Document]], ]: - raise NotImplementedError + raise NotImplementedError() __all__ = ("DocumentUnderstandingServiceTransport",) diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc.py index b021b9e5..e5acb31c 100644 --- a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc.py +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc.py @@ -15,10 +15,11 @@ # limitations under the License. # -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -47,6 +48,8 @@ class DocumentUnderstandingServiceGrpcTransport(DocumentUnderstandingServiceTran top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, @@ -78,8 +81,8 @@ def __init__( is None. Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. """ if channel: # Sanity check: Ensure that channel and credentials are not both @@ -95,6 +98,9 @@ def __init__( else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -106,7 +112,7 @@ def __init__( ssl_credentials = SslCredentials().ssl_credentials # create a new channel. The provided one is ignored. - self._grpc_channel = grpc_helpers.create_channel( + self._grpc_channel = type(self).create_channel( host, credentials=credentials, ssl_credentials=ssl_credentials, @@ -122,6 +128,7 @@ def create_channel( cls, host: str = "us-documentai.googleapis.com", credentials: credentials.Credentials = None, + scopes: Optional[Sequence[str]] = None, **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. @@ -132,13 +139,17 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, credentials=credentials, scopes=scopes, **kwargs ) @property diff --git a/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc_asyncio.py b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc_asyncio.py new file mode 100644 index 00000000..063370da --- /dev/null +++ b/google/cloud/documentai_v1beta2/services/document_understanding_service/transports/grpc_asyncio.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.documentai_v1beta2.types import document +from google.cloud.documentai_v1beta2.types import document_understanding +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import DocumentUnderstandingServiceTransport +from .grpc import DocumentUnderstandingServiceGrpcTransport + + +class DocumentUnderstandingServiceGrpcAsyncIOTransport( + DocumentUnderstandingServiceTransport +): + """gRPC AsyncIO backend transport for DocumentUnderstandingService. + + Service to parse structured information from unstructured or + semi-structured documents using state-of-the-art Google AI such + as natural language, computer vision, and translation. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "us-documentai.googleapis.com", + credentials: credentials.Credentials = None, + scopes: Optional[Sequence[str]] = None, + **kwargs + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, credentials=credentials, scopes=scopes, **kwargs + ) + + def __init__( + self, + *, + host: str = "us-documentai.googleapis.com", + credentials: credentials.Credentials = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If + provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A + callback to provide client SSL certificate bytes and private key + bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` + is None. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + ssl_credentials=ssl_credentials, + scopes=self.AUTH_SCOPES, + ) + + # Run the base constructor. + super().__init__(host=host, credentials=credentials) + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials + ) + + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def batch_process_documents( + self + ) -> Callable[ + [document_understanding.BatchProcessDocumentsRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the batch process documents method over gRPC. + + LRO endpoint to batch process many documents. The output is + written to Cloud Storage as JSON in the [Document] format. + + Returns: + Callable[[~.BatchProcessDocumentsRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_process_documents" not in self._stubs: + self._stubs["batch_process_documents"] = self.grpc_channel.unary_unary( + "/google.cloud.documentai.v1beta2.DocumentUnderstandingService/BatchProcessDocuments", + request_serializer=document_understanding.BatchProcessDocumentsRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["batch_process_documents"] + + @property + def process_document( + self + ) -> Callable[ + [document_understanding.ProcessDocumentRequest], Awaitable[document.Document] + ]: + r"""Return a callable for the process document method over gRPC. + + Processes a single document. + + Returns: + Callable[[~.ProcessDocumentRequest], + Awaitable[~.Document]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "process_document" not in self._stubs: + self._stubs["process_document"] = self.grpc_channel.unary_unary( + "/google.cloud.documentai.v1beta2.DocumentUnderstandingService/ProcessDocument", + request_serializer=document_understanding.ProcessDocumentRequest.serialize, + response_deserializer=document.Document.deserialize, + ) + return self._stubs["process_document"] + + +__all__ = ("DocumentUnderstandingServiceGrpcAsyncIOTransport",) diff --git a/noxfile.py b/noxfile.py index a2d177fa..2e726c77 100644 --- a/noxfile.py +++ b/noxfile.py @@ -66,6 +66,7 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. + session.install("asyncmock", "pytest-asyncio") session.install("mock", "pytest", "pytest-cov") session.install("-e", ".") diff --git a/setup.py b/setup.py index ab94cd88..1aec73d4 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ platforms="Posix; MacOS X; Windows", include_package_data=True, install_requires=( - "google-api-core[grpc] >= 1.17.0, < 2.0.0dev", + "google-api-core[grpc] >= 1.17.2, < 2.0.0dev", "proto-plus >= 0.4.0", ), python_requires=">=3.6", diff --git a/synth.metadata b/synth.metadata index 88a38fd2..3f79a2d8 100644 --- a/synth.metadata +++ b/synth.metadata @@ -11,8 +11,8 @@ "git": { "name": "googleapis", "remote": "https://github.com/googleapis/googleapis.git", - "sha": "d5fe42c39cd35f95131a0267314ae108ab1bef8d", - "internalRef": "314471006" + "sha": "ff1b4ff2cfb09b66e26510d2c659705607563c1b", + "internalRef": "317199748" } }, { diff --git a/tests/unit/documentai_v1beta2/test_document_understanding_service.py b/tests/unit/documentai_v1beta2/test_document_understanding_service.py deleted file mode 100644 index 2d9ff8f0..00000000 --- a/tests/unit/documentai_v1beta2/test_document_understanding_service.py +++ /dev/null @@ -1,471 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import grpc_helpers -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.documentai_v1beta2.services.document_understanding_service import ( - DocumentUnderstandingServiceClient, -) -from google.cloud.documentai_v1beta2.services.document_understanding_service import ( - transports, -) -from google.cloud.documentai_v1beta2.types import document -from google.cloud.documentai_v1beta2.types import document_understanding -from google.cloud.documentai_v1beta2.types import geometry -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.rpc import status_pb2 as status # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert DocumentUnderstandingServiceClient._get_default_mtls_endpoint(None) is None - assert ( - DocumentUnderstandingServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - DocumentUnderstandingServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - DocumentUnderstandingServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DocumentUnderstandingServiceClient._get_default_mtls_endpoint( - sandbox_mtls_endpoint - ) - == sandbox_mtls_endpoint - ) - assert ( - DocumentUnderstandingServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) - - -def test_document_understanding_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = DocumentUnderstandingServiceClient.from_service_account_file( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - client = DocumentUnderstandingServiceClient.from_service_account_json( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - assert client._transport._host == "us-documentai.googleapis.com:443" - - -def test_document_understanding_service_client_client_options(): - # Check that if channel is provided we won't create a new one. - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.DocumentUnderstandingServiceClient.get_transport_class" - ) as gtc: - transport = transports.DocumentUnderstandingServiceGrpcTransport( - credentials=credentials.AnonymousCredentials() - ) - client = DocumentUnderstandingServiceClient(transport=transport) - gtc.assert_not_called() - - # Check mTLS is not triggered with empty client options. - options = client_options.ClientOptions() - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.DocumentUnderstandingServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = DocumentUnderstandingServiceClient(client_options=options) - transport.assert_called_once_with( - credentials=None, host=client.DEFAULT_ENDPOINT - ) - - # Check mTLS is not triggered if api_endpoint is provided but - # client_cert_source is None. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.transports.DocumentUnderstandingServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = DocumentUnderstandingServiceClient(client_options=options) - grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, - client_cert_source=None, - credentials=None, - host="squid.clam.whelk", - ) - - # Check mTLS is triggered if client_cert_source is provided. - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.transports.DocumentUnderstandingServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = DocumentUnderstandingServiceClient(client_options=options) - grpc_transport.assert_called_once_with( - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - credentials=None, - host=client.DEFAULT_ENDPOINT, - ) - - # Check mTLS is triggered if api_endpoint and client_cert_source are provided. - options = client_options.ClientOptions( - api_endpoint="squid.clam.whelk", client_cert_source=client_cert_source_callback - ) - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.transports.DocumentUnderstandingServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = DocumentUnderstandingServiceClient(client_options=options) - grpc_transport.assert_called_once_with( - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=client_cert_source_callback, - credentials=None, - host="squid.clam.whelk", - ) - - -def test_document_understanding_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.documentai_v1beta2.services.document_understanding_service.transports.DocumentUnderstandingServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = DocumentUnderstandingServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - api_mtls_endpoint=None, - client_cert_source=None, - credentials=None, - host="squid.clam.whelk", - ) - - -def test_batch_process_documents(transport: str = "grpc"): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = document_understanding.BatchProcessDocumentsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.batch_process_documents), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.batch_process_documents(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_batch_process_documents_flattened(): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials() - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.batch_process_documents), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.batch_process_documents( - requests=[ - document_understanding.ProcessDocumentRequest(parent="parent_value") - ] - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].requests == [ - document_understanding.ProcessDocumentRequest(parent="parent_value") - ] - - -def test_batch_process_documents_flattened_error(): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials() - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.batch_process_documents( - document_understanding.BatchProcessDocumentsRequest(), - requests=[ - document_understanding.ProcessDocumentRequest(parent="parent_value") - ], - ) - - -def test_process_document(transport: str = "grpc"): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = document_understanding.ProcessDocumentRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.process_document), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = document.Document( - uri="uri_value", - content=b"content_blob", - mime_type="mime_type_value", - text="text_value", - ) - - response = client.process_document(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, document.Document) - assert response.uri == "uri_value" - assert response.content == b"content_blob" - assert response.mime_type == "mime_type_value" - assert response.text == "text_value" - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DocumentUnderstandingServiceGrpcTransport( - credentials=credentials.AnonymousCredentials() - ) - with pytest.raises(ValueError): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DocumentUnderstandingServiceGrpcTransport( - credentials=credentials.AnonymousCredentials() - ) - client = DocumentUnderstandingServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials() - ) - assert isinstance( - client._transport, transports.DocumentUnderstandingServiceGrpcTransport - ) - - -def test_document_understanding_service_base_transport(): - # Instantiate the base transport. - transport = transports.DocumentUnderstandingServiceTransport( - credentials=credentials.AnonymousCredentials() - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ("batch_process_documents", "process_document") - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_document_understanding_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - DocumentUnderstandingServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_document_understanding_service_host_no_port(): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="us-documentai.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "us-documentai.googleapis.com:443" - - -def test_document_understanding_service_host_with_port(): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="us-documentai.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "us-documentai.googleapis.com:8000" - - -def test_document_understanding_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() - transport = transports.DocumentUnderstandingServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_document_understanding_service_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.DocumentUnderstandingServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - ssl_credentials=mock_ssl_cred, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] -) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_document_understanding_service_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.DocumentUnderstandingServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - ssl_credentials=mock_ssl_cred, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_document_understanding_service_grpc_lro_client(): - client = DocumentUnderstandingServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc" - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client diff --git a/tests/unit/gapic/documentai_v1beta2/test_document_understanding_service.py b/tests/unit/gapic/documentai_v1beta2/test_document_understanding_service.py new file mode 100644 index 00000000..9036afdb --- /dev/null +++ b/tests/unit/gapic/documentai_v1beta2/test_document_understanding_service.py @@ -0,0 +1,906 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.documentai_v1beta2.services.document_understanding_service import ( + DocumentUnderstandingServiceAsyncClient, +) +from google.cloud.documentai_v1beta2.services.document_understanding_service import ( + DocumentUnderstandingServiceClient, +) +from google.cloud.documentai_v1beta2.services.document_understanding_service import ( + transports, +) +from google.cloud.documentai_v1beta2.types import document +from google.cloud.documentai_v1beta2.types import document_understanding +from google.cloud.documentai_v1beta2.types import geometry +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.rpc import status_pb2 as status # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DocumentUnderstandingServiceClient._get_default_mtls_endpoint(None) is None + assert ( + DocumentUnderstandingServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DocumentUnderstandingServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DocumentUnderstandingServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DocumentUnderstandingServiceClient._get_default_mtls_endpoint( + sandbox_mtls_endpoint + ) + == sandbox_mtls_endpoint + ) + assert ( + DocumentUnderstandingServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", + [DocumentUnderstandingServiceClient, DocumentUnderstandingServiceAsyncClient], +) +def test_document_understanding_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "us-documentai.googleapis.com:443" + + +def test_document_understanding_service_client_get_transport_class(): + transport = DocumentUnderstandingServiceClient.get_transport_class() + assert transport == transports.DocumentUnderstandingServiceGrpcTransport + + transport = DocumentUnderstandingServiceClient.get_transport_class("grpc") + assert transport == transports.DocumentUnderstandingServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + DocumentUnderstandingServiceClient, + transports.DocumentUnderstandingServiceGrpcTransport, + "grpc", + ), + ( + DocumentUnderstandingServiceAsyncClient, + transports.DocumentUnderstandingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_document_understanding_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object( + DocumentUnderstandingServiceClient, "get_transport_class" + ) as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object( + DocumentUnderstandingServiceClient, "get_transport_class" + ) as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + api_mtls_endpoint="squid.clam.whelk", + client_cert_source=None, + credentials=None, + host="squid.clam.whelk", + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "never". + os.environ["GOOGLE_API_USE_MTLS"] = "never" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_ENDPOINT, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # "always". + os.environ["GOOGLE_API_USE_MTLS"] = "always" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "auto", and client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "auto" + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=client_cert_source_callback, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "auto", and default_client_cert_source is provided. + os.environ["GOOGLE_API_USE_MTLS"] = "auto" + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_MTLS_ENDPOINT, + ) + + # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is + # "auto", but client_cert_source and default_client_cert_source are None. + os.environ["GOOGLE_API_USE_MTLS"] = "auto" + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + api_mtls_endpoint=client.DEFAULT_ENDPOINT, + client_cert_source=None, + credentials=None, + host=client.DEFAULT_ENDPOINT, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has + # unsupported value. + os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported" + with pytest.raises(MutualTLSChannelError): + client = client_class() + + del os.environ["GOOGLE_API_USE_MTLS"] + + +def test_document_understanding_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.documentai_v1beta2.services.document_understanding_service.transports.DocumentUnderstandingServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = DocumentUnderstandingServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + api_mtls_endpoint="squid.clam.whelk", + client_cert_source=None, + credentials=None, + host="squid.clam.whelk", + ) + + +def test_batch_process_documents(transport: str = "grpc"): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = document_understanding.BatchProcessDocumentsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.batch_process_documents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.batch_process_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_batch_process_documents_async(transport: str = "grpc_asyncio"): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = document_understanding.BatchProcessDocumentsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.batch_process_documents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.batch_process_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_process_documents_field_headers(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = document_understanding.BatchProcessDocumentsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.batch_process_documents), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.batch_process_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value") in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_process_documents_field_headers_async(): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = document_understanding.BatchProcessDocumentsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.batch_process_documents), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.batch_process_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value") in kw["metadata"] + + +def test_batch_process_documents_flattened(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials() + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.batch_process_documents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_process_documents( + requests=[ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ] + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].requests == [ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ] + + +def test_batch_process_documents_flattened_error(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials() + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_process_documents( + document_understanding.BatchProcessDocumentsRequest(), + requests=[ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ], + ) + + +@pytest.mark.asyncio +async def test_batch_process_documents_flattened_async(): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials() + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.batch_process_documents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_process_documents( + requests=[ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ] + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].requests == [ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ] + + +@pytest.mark.asyncio +async def test_batch_process_documents_flattened_error_async(): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials() + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_process_documents( + document_understanding.BatchProcessDocumentsRequest(), + requests=[ + document_understanding.ProcessDocumentRequest(parent="parent_value") + ], + ) + + +def test_process_document(transport: str = "grpc"): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = document_understanding.ProcessDocumentRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.process_document), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = document.Document( + uri="uri_value", + content=b"content_blob", + mime_type="mime_type_value", + text="text_value", + ) + + response = client.process_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, document.Document) + assert response.uri == "uri_value" + assert response.content == b"content_blob" + assert response.mime_type == "mime_type_value" + assert response.text == "text_value" + + +@pytest.mark.asyncio +async def test_process_document_async(transport: str = "grpc_asyncio"): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = document_understanding.ProcessDocumentRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.process_document), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + document.Document( + uri="uri_value", + content=b"content_blob", + mime_type="mime_type_value", + text="text_value", + ) + ) + + response = await client.process_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, document.Document) + assert response.uri == "uri_value" + assert response.content == b"content_blob" + assert response.mime_type == "mime_type_value" + assert response.text == "text_value" + + +def test_process_document_field_headers(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = document_understanding.ProcessDocumentRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.process_document), "__call__" + ) as call: + call.return_value = document.Document() + + client.process_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value") in kw["metadata"] + + +@pytest.mark.asyncio +async def test_process_document_field_headers_async(): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials() + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = document_understanding.ProcessDocumentRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._client._transport.process_document), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(document.Document()) + + await client.process_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value") in kw["metadata"] + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DocumentUnderstandingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials() + ) + with pytest.raises(ValueError): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DocumentUnderstandingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials() + ) + client = DocumentUnderstandingServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DocumentUnderstandingServiceGrpcTransport( + credentials=credentials.AnonymousCredentials() + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials() + ) + channel = transport.grpc_channel + assert channel + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials() + ) + assert isinstance( + client._transport, transports.DocumentUnderstandingServiceGrpcTransport + ) + + +def test_document_understanding_service_base_transport(): + # Instantiate the base transport. + transport = transports.DocumentUnderstandingServiceTransport( + credentials=credentials.AnonymousCredentials() + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ("batch_process_documents", "process_document") + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_document_understanding_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + DocumentUnderstandingServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_document_understanding_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.DocumentUnderstandingServiceGrpcTransport(host="squid.clam.whelk") + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_document_understanding_service_host_no_port(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="us-documentai.googleapis.com" + ), + ) + assert client._transport._host == "us-documentai.googleapis.com:443" + + +def test_document_understanding_service_host_with_port(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="us-documentai.googleapis.com:8000" + ), + ) + assert client._transport._host == "us-documentai.googleapis.com:8000" + + +def test_document_understanding_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that if channel is provided, mtls endpoint and client_cert_source + # won't be used. + callback = mock.MagicMock() + transport = transports.DocumentUnderstandingServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=callback, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert not callback.called + + +def test_document_understanding_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that if channel is provided, mtls endpoint and client_cert_source + # won't be used. + callback = mock.MagicMock() + transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=callback, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert not callback.called + + +@mock.patch("grpc.ssl_channel_credentials", autospec=True) +@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) +def test_document_understanding_service_grpc_transport_channel_mtls_with_client_cert_source( + grpc_create_channel, grpc_ssl_channel_cred +): + # Check that if channel is None, but api_mtls_endpoint and client_cert_source + # are provided, then a mTLS channel will be created. + mock_cred = mock.Mock() + + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + transport = transports.DocumentUnderstandingServiceGrpcTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ) + assert transport.grpc_channel == mock_grpc_channel + + +@mock.patch("grpc.ssl_channel_credentials", autospec=True) +@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) +def test_document_understanding_service_grpc_asyncio_transport_channel_mtls_with_client_cert_source( + grpc_create_channel, grpc_ssl_channel_cred +): + # Check that if channel is None, but api_mtls_endpoint and client_cert_source + # are provided, then a mTLS channel will be created. + mock_cred = mock.Mock() + + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize( + "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] +) +@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) +def test_document_understanding_service_grpc_transport_channel_mtls_with_adc( + grpc_create_channel, api_mtls_endpoint +): + # Check that if channel and client_cert_source are None, but api_mtls_endpoint + # is provided, then a mTLS channel will be created with SSL ADC. + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + # Mock google.auth.transport.grpc.SslCredentials class. + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + mock_cred = mock.Mock() + transport = transports.DocumentUnderstandingServiceGrpcTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint=api_mtls_endpoint, + client_cert_source=None, + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize( + "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] +) +@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) +def test_document_understanding_service_grpc_asyncio_transport_channel_mtls_with_adc( + grpc_create_channel, api_mtls_endpoint +): + # Check that if channel and client_cert_source are None, but api_mtls_endpoint + # is provided, then a mTLS channel will be created with SSL ADC. + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + # Mock google.auth.transport.grpc.SslCredentials class. + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + mock_cred = mock.Mock() + transport = transports.DocumentUnderstandingServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint=api_mtls_endpoint, + client_cert_source=None, + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + ssl_credentials=mock_ssl_cred, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_document_understanding_service_grpc_lro_client(): + client = DocumentUnderstandingServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc" + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_document_understanding_service_grpc_lro_async_client(): + client = DocumentUnderstandingServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio" + ) + transport = client._client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client