Skip to content
This repository has been archived by the owner on Sep 20, 2023. It is now read-only.

feat: update ReviewDocumentRequest to allow set priority and enable validation #172

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -25,6 +25,7 @@
from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.documentai_v1.types import document_processor_service
from google.longrunning import operations_pb2 # type: ignore
Expand All @@ -47,8 +48,6 @@
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class DocumentProcessorServiceTransport(abc.ABC):
"""Abstract transport class for DocumentProcessorService."""
Expand All @@ -66,6 +65,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -89,6 +89,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
Expand All @@ -98,7 +100,7 @@ def __init__(
scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)

# Save the scopes.
self._scopes = scopes or self.AUTH_SCOPES
self._scopes = scopes

# If no credentials are provided, then determine the appropriate
# defaults.
Expand All @@ -117,13 +119,20 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

# TODO(busunkim): These two class methods are in the base transport
# TODO(busunkim): This method is in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-api-core
# and google-auth are increased.
# should be deleted once the minimum required versions of google-auth is increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
Expand All @@ -144,27 +153,6 @@ def _get_scopes_kwargs(

return scopes_kwargs

# TODO: Remove this function once google-api-core >= 1.26.0 is required
@classmethod
def _get_self_signed_jwt_kwargs(
cls, host: str, scopes: Optional[Sequence[str]]
) -> Dict[str, Union[Optional[Sequence[str]], str]]:
"""Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""

self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}

if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return self_signed_jwt_kwargs

def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down
Expand Up @@ -63,6 +63,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -103,6 +104,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.

Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
Expand Down Expand Up @@ -156,6 +159,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=always_use_jwt_access,
)

if not self._grpc_channel:
Expand Down Expand Up @@ -211,14 +215,14 @@ def create_channel(
and ``credentials_file`` are passed.
"""

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
default_scopes=cls.AUTH_SCOPES,
scopes=scopes,
default_host=cls.DEFAULT_HOST,
**kwargs,
)

Expand Down
Expand Up @@ -84,14 +84,14 @@ def create_channel(
aio.Channel: A gRPC AsyncIO channel object.
"""

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
default_scopes=cls.AUTH_SCOPES,
scopes=scopes,
default_host=cls.DEFAULT_HOST,
**kwargs,
)

Expand All @@ -109,6 +109,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.

Expand Down Expand Up @@ -150,6 +151,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.

Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
Expand Down Expand Up @@ -202,6 +205,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=always_use_jwt_access,
)

if not self._grpc_channel:
Expand Down
19 changes: 18 additions & 1 deletion google/cloud/documentai_v1/types/document.py
Expand Up @@ -212,6 +212,8 @@ class Page(proto.Message):
form_fields (Sequence[google.cloud.documentai_v1.types.Document.Page.FormField]):
A list of visually detected form fields on
the page.
provenance (google.cloud.documentai_v1.types.Document.Provenance):
The history of this page.
"""

class Dimension(proto.Message):
Expand Down Expand Up @@ -552,6 +554,8 @@ class FormField(proto.Message):
- blank (this indicates the field_value is normal text)
- "unfilled_checkbox"
- "filled_checkbox".
provenance (google.cloud.documentai_v1.types.Document.Provenance):
The history of this annotation.
"""

field_name = proto.Field(
Expand All @@ -567,6 +571,9 @@ class FormField(proto.Message):
proto.MESSAGE, number=4, message="Document.Page.DetectedLanguage",
)
value_type = proto.Field(proto.STRING, number=5,)
provenance = proto.Field(
proto.MESSAGE, number=8, message="Document.Provenance",
)

class DetectedLanguage(proto.Message):
r"""Detected language for a structural component.
Expand Down Expand Up @@ -615,6 +622,9 @@ class DetectedLanguage(proto.Message):
form_fields = proto.RepeatedField(
proto.MESSAGE, number=11, message="Document.Page.FormField",
)
provenance = proto.Field(
proto.MESSAGE, number=16, message="Document.Provenance",
)

class Entity(proto.Message):
r"""A phrase in the text that is a known entity type, such as a
Expand Down Expand Up @@ -819,7 +829,9 @@ class PageRef(proto.Message):
Required. Index into the
[Document.pages][google.cloud.documentai.v1.Document.pages]
element, for example using [Document.pages][page_refs.page]
to locate the related page element.
to locate the related page element. This field is skipped
when its value is the default 0. See
https://developers.google.com/protocol-buffers/docs/proto3#json.
layout_type (google.cloud.documentai_v1.types.Document.PageAnchor.PageRef.LayoutType):
Optional. The type of the layout element that
is being referenced if any.
Expand Down Expand Up @@ -899,11 +911,16 @@ class Parent(proto.Message):
revision (int):
The index of the [Document.revisions] identifying the parent
revision.
index (int):
The index of the parent revisions
corresponding collection of items (eg. list of
entities, properties within entities, etc.)
id (int):
The id of the parent provenance.
"""

revision = proto.Field(proto.INT32, number=1,)
index = proto.Field(proto.INT32, number=3,)
id = proto.Field(proto.INT32, number=2,)

revision = proto.Field(proto.INT32, number=1,)
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/documentai_v1/types/document_processor_service.py
Expand Up @@ -212,12 +212,24 @@ class ReviewDocumentRequest(proto.Message):
Required. The resource name of the
HumanReviewConfig that the document will be
reviewed with.
enable_schema_validation (bool):
Whether the validation should be performed on
the ad-hoc review request.
priority (google.cloud.documentai_v1.types.ReviewDocumentRequest.Priority):
The priority of the human review task.
"""

class Priority(proto.Enum):
r"""The priority level of the human review task."""
DEFAULT = 0
URGENT = 1

inline_document = proto.Field(
proto.MESSAGE, number=4, oneof="source", message=gcd_document.Document,
)
human_review_config = proto.Field(proto.STRING, number=1,)
enable_schema_validation = proto.Field(proto.BOOL, number=3,)
priority = proto.Field(proto.ENUM, number=5, enum=Priority,)


class ReviewDocumentResponse(proto.Message):
Expand Down
Expand Up @@ -25,6 +25,7 @@
from google.api_core import retry as retries # type: ignore
from google.api_core import operations_v1 # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.documentai_v1beta2.types import document
from google.cloud.documentai_v1beta2.types import document_understanding
Expand All @@ -48,8 +49,6 @@
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class DocumentUnderstandingServiceTransport(abc.ABC):
"""Abstract transport class for DocumentUnderstandingService."""
Expand All @@ -67,6 +66,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -90,6 +90,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT should
be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
Expand All @@ -99,7 +101,7 @@ def __init__(
scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)

# Save the scopes.
self._scopes = scopes or self.AUTH_SCOPES
self._scopes = scopes

# If no credentials are provided, then determine the appropriate
# defaults.
Expand All @@ -118,13 +120,20 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

# TODO(busunkim): These two class methods are in the base transport
# TODO(busunkim): This method is in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-api-core
# and google-auth are increased.
# should be deleted once the minimum required versions of google-auth is increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
Expand All @@ -139,33 +148,20 @@ def _get_scopes_kwargs(
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
# Documentai uses a regional host (us-documentai.googleapis.com) as the default
# so self-signed JWT cannot be used.
# Intentionally pass default scopes as user scopes so the auth library
# does not use the self-signed JWT flow.
# https://github.com/googleapis/python-documentai/issues/174
scopes_kwargs = {
"scopes": scopes or cls.AUTH_SCOPES,
"default_scopes": cls.AUTH_SCOPES,
}
else:
scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}

return scopes_kwargs

# TODO: Remove this function once google-api-core >= 1.26.0 is required
@classmethod
def _get_self_signed_jwt_kwargs(
cls, host: str, scopes: Optional[Sequence[str]]
) -> Dict[str, Union[Optional[Sequence[str]], str]]:
"""Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""

self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}

if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return self_signed_jwt_kwargs

def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down