From 6a99b125922b8fca7c997150b81b6925376e9d1d Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 20:48:26 +0000 Subject: [PATCH] fix: enable self signed jwt for grpc (#570) PiperOrigin-RevId: 386504689 Source-Link: https://github.com/googleapis/googleapis/commit/762094a99ac6e03a17516b13dfbef37927267a70 Source-Link: https://github.com/googleapis/googleapis-gen/commit/6bfc480e1a161d5de121c2bcc3745885d33b265a feat: Removes breaking change from v1 version of AI Platform protos feat: Adds a new API method FindMostStableBuild feat: Adds BigQuery output table field to batch prediction job output config feat: Adds Feature Store features feat: Adds CustomJob.web_access_uris field feat: Adds CustomJob.enable_web_access field feat: Adds Endpoint.network, Endpoint.private_endpoints fields and PrivateEndpoints message feat: Adds Execution.State constants: CACHED and CANCELLED feat: Adds IndexEndpoint.private_ip_ranges field feat: Adds IndexEndpointService.deployed_index_id field feat: Adds MetadataService.DeleteArtifact and DeleteExecution methods feat: Adds ModelMonitoringObjectConfig.explanation_config field feat: Adds ModelMonitoringObjectConfig.ExplanationConfig message field feat: Adds attribution_score_skew_thresholds field feat: Adds attribution_score_drift_threshold field feat: Adds fields to Study message --- .../services/dataset_service/client.py | 4 + .../dataset_service/transports/base.py | 42 +- .../dataset_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/endpoint_service/client.py | 4 + .../endpoint_service/transports/base.py | 42 +- .../endpoint_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/job_service/client.py | 4 + .../services/job_service/transports/base.py | 42 +- .../services/job_service/transports/grpc.py | 10 +- .../job_service/transports/grpc_asyncio.py | 10 +- .../services/migration_service/client.py | 4 + .../migration_service/transports/base.py | 42 +- .../migration_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/model_service/client.py | 4 + .../services/model_service/transports/base.py | 42 +- .../services/model_service/transports/grpc.py | 10 +- .../model_service/transports/grpc_asyncio.py | 10 +- .../services/pipeline_service/client.py | 4 + .../pipeline_service/transports/base.py | 42 +- .../pipeline_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../prediction_service/async_client.py | 26 +- .../services/prediction_service/client.py | 34 +- .../prediction_service/transports/base.py | 42 +- .../prediction_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../specialist_pool_service/client.py | 4 + .../transports/base.py | 42 +- .../transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- google/cloud/aiplatform_v1beta1/__init__.py | 24 + .../aiplatform_v1beta1/gapic_metadata.json | 50 + .../services/dataset_service/client.py | 4 + .../dataset_service/transports/base.py | 42 +- .../dataset_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/endpoint_service/async_client.py | 2 + .../services/endpoint_service/client.py | 19 + .../endpoint_service/transports/base.py | 42 +- .../endpoint_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../client.py | 4 + .../transports/base.py | 42 +- .../transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/featurestore_service/client.py | 4 + .../featurestore_service/transports/base.py | 42 +- .../featurestore_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/index_endpoint_service/client.py | 4 + .../index_endpoint_service/transports/base.py | 42 +- .../index_endpoint_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/index_service/client.py | 4 + .../services/index_service/transports/base.py | 42 +- .../services/index_service/transports/grpc.py | 10 +- .../index_service/transports/grpc_asyncio.py | 10 +- .../services/job_service/client.py | 4 + .../services/job_service/transports/base.py | 42 +- .../services/job_service/transports/grpc.py | 10 +- .../job_service/transports/grpc_asyncio.py | 10 +- .../services/metadata_service/async_client.py | 440 +- .../services/metadata_service/client.py | 444 +- .../metadata_service/transports/base.py | 102 +- .../metadata_service/transports/grpc.py | 140 +- .../transports/grpc_asyncio.py | 150 +- .../services/migration_service/client.py | 26 +- .../migration_service/transports/base.py | 42 +- .../migration_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/model_service/client.py | 4 + .../services/model_service/transports/base.py | 42 +- .../services/model_service/transports/grpc.py | 10 +- .../model_service/transports/grpc_asyncio.py | 10 +- .../services/pipeline_service/client.py | 4 + .../pipeline_service/transports/base.py | 42 +- .../pipeline_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/prediction_service/client.py | 4 + .../prediction_service/transports/base.py | 42 +- .../prediction_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../specialist_pool_service/client.py | 4 + .../transports/base.py | 42 +- .../transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/tensorboard_service/client.py | 4 + .../tensorboard_service/transports/base.py | 42 +- .../tensorboard_service/transports/grpc.py | 10 +- .../transports/grpc_asyncio.py | 10 +- .../services/vizier_service/client.py | 4 + .../vizier_service/transports/base.py | 42 +- .../vizier_service/transports/grpc.py | 10 +- .../vizier_service/transports/grpc_asyncio.py | 10 +- .../aiplatform_v1beta1/types/__init__.py | 24 + .../types/batch_prediction_job.py | 12 +- .../aiplatform_v1beta1/types/custom_job.py | 12 + .../aiplatform_v1beta1/types/endpoint.py | 48 +- .../aiplatform_v1beta1/types/execution.py | 2 + .../types/featurestore_monitoring.py | 11 + .../types/featurestore_online_service.py | 12 +- .../types/featurestore_service.py | 30 +- .../types/index_endpoint.py | 13 + .../types/index_endpoint_service.py | 3 + google/cloud/aiplatform_v1beta1/types/io.py | 1 + .../types/machine_resources.py | 17 +- .../types/metadata_service.py | 230 +- .../types/model_monitoring.py | 85 + .../types/pipeline_service.py | 44 +- .../cloud/aiplatform_v1beta1/types/study.py | 40 + owlbot.py | 26 +- schema/predict/instance/noxfile.py | 2 +- schema/predict/instance/setup.py | 3 +- schema/predict/params/noxfile.py | 2 +- schema/predict/params/setup.py | 3 +- schema/predict/prediction/noxfile.py | 2 +- schema/predict/prediction/setup.py | 3 +- setup.py | 2 +- testing/constraints-3.6.txt | 2 +- .../aiplatform_v1/test_dataset_service.py | 133 +- .../aiplatform_v1/test_endpoint_service.py | 133 +- .../gapic/aiplatform_v1/test_job_service.py | 133 +- .../aiplatform_v1/test_migration_service.py | 133 +- .../gapic/aiplatform_v1/test_model_service.py | 133 +- .../aiplatform_v1/test_pipeline_service.py | 133 +- .../aiplatform_v1/test_prediction_service.py | 196 +- .../test_specialist_pool_service.py | 133 +- .../test_dataset_service.py | 133 +- .../test_endpoint_service.py | 187 +- ...est_featurestore_online_serving_service.py | 142 +- .../test_featurestore_service.py | 133 +- .../test_index_endpoint_service.py | 133 +- .../aiplatform_v1beta1/test_index_service.py | 133 +- .../aiplatform_v1beta1/test_job_service.py | 133 +- .../test_metadata_service.py | 3666 +++++++++++------ .../test_migration_service.py | 161 +- .../aiplatform_v1beta1/test_model_service.py | 133 +- .../test_pipeline_service.py | 133 +- .../test_prediction_service.py | 259 +- .../test_specialist_pool_service.py | 133 +- .../test_tensorboard_service.py | 133 +- .../aiplatform_v1beta1/test_vizier_service.py | 133 +- tests/unit/gapic/definition_v1/__init__.py | 15 + .../unit/gapic/definition_v1beta1/__init__.py | 15 + tests/unit/gapic/instance_v1/__init__.py | 15 + tests/unit/gapic/instance_v1beta1/__init__.py | 15 + tests/unit/gapic/params_v1/__init__.py | 15 + tests/unit/gapic/params_v1beta1/__init__.py | 15 + tests/unit/gapic/prediction_v1/__init__.py | 15 + .../unit/gapic/prediction_v1beta1/__init__.py | 15 + 153 files changed, 6062 insertions(+), 4456 deletions(-) create mode 100644 tests/unit/gapic/definition_v1/__init__.py create mode 100644 tests/unit/gapic/definition_v1beta1/__init__.py create mode 100644 tests/unit/gapic/instance_v1/__init__.py create mode 100644 tests/unit/gapic/instance_v1beta1/__init__.py create mode 100644 tests/unit/gapic/params_v1/__init__.py create mode 100644 tests/unit/gapic/params_v1beta1/__init__.py create mode 100644 tests/unit/gapic/prediction_v1/__init__.py create mode 100644 tests/unit/gapic/prediction_v1beta1/__init__.py diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index 26bc6ef650..009561d163 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -418,6 +418,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_dataset( diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py index c049ed37ba..978f4cb517 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import annotation_spec from google.cloud.aiplatform_v1.types import dataset @@ -50,8 +51,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index 5d952d5674..4e0b73aa2e 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -60,6 +60,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -100,6 +101,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -153,6 +156,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -208,14 +212,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py index a42416c366..7e5222a73b 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -81,14 +81,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -106,6 +106,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -147,6 +148,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -199,6 +202,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 6b0ffb9a90..f9d4f6d071 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -369,6 +369,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_endpoint( diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py index a760eddfef..96e5abc877 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint @@ -49,8 +50,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" @@ -68,6 +67,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -91,6 +91,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -100,7 +102,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -119,13 +121,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -146,27 +155,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index 933465f7c8..69629abb49 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -99,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -152,6 +155,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -207,14 +211,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py index 9d1d1dd1e4..c698ca070a 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -80,14 +80,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -105,6 +105,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -146,6 +147,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -198,6 +201,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py index a6273cd93e..4a4986190e 100644 --- a/google/cloud/aiplatform_v1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -489,6 +489,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_custom_job( diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py index c8b47f54c6..10622dedbe 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import batch_prediction_job from google.cloud.aiplatform_v1.types import ( @@ -60,8 +61,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" @@ -79,6 +78,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -102,6 +102,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -111,7 +113,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -130,13 +132,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -157,27 +166,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index 5187314007..9025fab621 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -72,6 +72,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -112,6 +113,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -165,6 +168,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -220,14 +224,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py index 2772069ea9..24c7aa72c9 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -93,14 +93,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -118,6 +118,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -159,6 +160,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -211,6 +214,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 162818246f..48d73db89d 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -444,6 +444,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def search_migratable_resources( diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py index 6dc3d69c17..24daacbec2 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import migration_service from google.longrunning import operations_pb2 # type: ignore @@ -47,8 +48,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" @@ -66,6 +65,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -89,6 +89,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -98,7 +100,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -117,13 +119,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -144,27 +153,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index 3eba0a40c6..c5a765d169 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -60,6 +60,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -100,6 +101,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -153,6 +156,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -208,14 +212,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py index 516f66d2c5..71e00310db 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -81,14 +81,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -106,6 +106,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -147,6 +148,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -199,6 +202,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py index 0011d2541c..5ea2f91b18 100644 --- a/google/cloud/aiplatform_v1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -429,6 +429,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def upload_model( diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py index 70481a015a..303a08b7f0 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import model from google.cloud.aiplatform_v1.types import model as gca_model @@ -51,8 +52,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" @@ -70,6 +69,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -93,6 +93,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -102,7 +104,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -121,13 +123,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -148,27 +157,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index 294c83b547..00a4d2cd8a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -63,6 +63,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -103,6 +104,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -156,6 +159,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -211,14 +215,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py index 58f52db314..7dd1ac6b44 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -84,14 +84,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -109,6 +109,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -150,6 +151,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -202,6 +205,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py index 3266a00e97..ec3fd4ccf8 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -506,6 +506,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_training_pipeline( diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py index 395b266fdd..04cef84b9d 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import pipeline_job from google.cloud.aiplatform_v1.types import pipeline_job as gca_pipeline_job @@ -52,8 +53,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" @@ -71,6 +70,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -94,6 +94,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -103,7 +105,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -122,13 +124,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -149,27 +158,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index e9fafa4613..0f47c72e8b 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -67,6 +67,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -107,6 +108,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -160,6 +163,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -215,14 +219,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py index 6a0ccb7545..a8b26c1a93 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -88,14 +88,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -113,6 +113,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -154,6 +155,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -206,6 +209,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index b5cde52017..fe69ba8c58 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -166,8 +166,8 @@ async def predict( request: prediction_service.PredictRequest = None, *, endpoint: str = None, - parameters: struct_pb2.Value = None, instances: Sequence[struct_pb2.Value] = None, + parameters: struct_pb2.Value = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), @@ -186,17 +186,6 @@ async def predict( This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (:class:`google.protobuf.struct_pb2.Value`): - The parameters that govern the prediction. The schema of - the parameters may be specified via Endpoint's - DeployedModels' [Model's - ][google.cloud.aiplatform.v1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] - [parameters_schema_uri][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri]. - - This corresponds to the ``parameters`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. instances (:class:`Sequence[google.protobuf.struct_pb2.Value]`): Required. The instances that are the input to the prediction call. A DeployedModel may have an upper limit @@ -213,6 +202,17 @@ async def predict( This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. + parameters (:class:`google.protobuf.struct_pb2.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + [parameters_schema_uri][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri]. + + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -228,7 +228,7 @@ async def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, parameters, instances]) + has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index d1a1198435..1140b50be9 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -346,6 +346,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def predict( @@ -353,8 +357,8 @@ def predict( request: prediction_service.PredictRequest = None, *, endpoint: str = None, - parameters: struct_pb2.Value = None, instances: Sequence[struct_pb2.Value] = None, + parameters: struct_pb2.Value = None, retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), @@ -373,17 +377,6 @@ def predict( This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (google.protobuf.struct_pb2.Value): - The parameters that govern the prediction. The schema of - the parameters may be specified via Endpoint's - DeployedModels' [Model's - ][google.cloud.aiplatform.v1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] - [parameters_schema_uri][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri]. - - This corresponds to the ``parameters`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. instances (Sequence[google.protobuf.struct_pb2.Value]): Required. The instances that are the input to the prediction call. A DeployedModel may have an upper limit @@ -400,6 +393,17 @@ def predict( This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. + parameters (google.protobuf.struct_pb2.Value): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + [parameters_schema_uri][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri]. + + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. @@ -415,7 +419,7 @@ def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, parameters, instances]) + has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -432,10 +436,10 @@ def predict( # request, apply these. if endpoint is not None: request.endpoint = endpoint - if parameters is not None: - request.parameters = parameters if instances is not None: request.instances.extend(instances) + if parameters is not None: + request.parameters = parameters # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py index 804afddd08..5aa4c006bd 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import prediction_service @@ -45,8 +46,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" @@ -64,6 +63,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -87,6 +87,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,7 +98,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -115,13 +117,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -142,27 +151,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index e005541e66..3d5d68a608 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -57,6 +57,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -97,6 +98,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -149,6 +152,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -204,14 +208,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py index 1cc41c6669..38a9173c23 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -78,14 +78,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -103,6 +103,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -144,6 +145,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -195,6 +198,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py index 96ce2a35cb..022afe5dc4 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -359,6 +359,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_specialist_pool( diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py index b4edeb5a73..99cbc55279 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1.types import specialist_pool from google.cloud.aiplatform_v1.types import specialist_pool_service @@ -48,8 +49,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" @@ -67,6 +66,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -90,6 +90,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -99,7 +101,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -118,13 +120,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -145,27 +154,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index 0f1bad07d6..2e849673c7 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -65,6 +65,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -105,6 +106,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -158,6 +161,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -213,14 +217,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py index ec9a57206c..b0ac09dcef 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -86,14 +86,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -111,6 +111,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -152,6 +153,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -204,6 +207,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index a9c7df2b17..6481a4c0bd 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -93,6 +93,7 @@ from .types.encryption_spec import EncryptionSpec from .types.endpoint import DeployedModel from .types.endpoint import Endpoint +from .types.endpoint import PrivateEndpoints from .types.endpoint_service import CreateEndpointOperationMetadata from .types.endpoint_service import CreateEndpointRequest from .types.endpoint_service import DeleteEndpointRequest @@ -267,7 +268,9 @@ from .types.metadata_service import CreateMetadataSchemaRequest from .types.metadata_service import CreateMetadataStoreOperationMetadata from .types.metadata_service import CreateMetadataStoreRequest +from .types.metadata_service import DeleteArtifactRequest from .types.metadata_service import DeleteContextRequest +from .types.metadata_service import DeleteExecutionRequest from .types.metadata_service import DeleteMetadataStoreOperationMetadata from .types.metadata_service import DeleteMetadataStoreRequest from .types.metadata_service import GetArtifactRequest @@ -285,6 +288,15 @@ from .types.metadata_service import ListMetadataSchemasResponse from .types.metadata_service import ListMetadataStoresRequest from .types.metadata_service import ListMetadataStoresResponse +from .types.metadata_service import PurgeArtifactsMetadata +from .types.metadata_service import PurgeArtifactsRequest +from .types.metadata_service import PurgeArtifactsResponse +from .types.metadata_service import PurgeContextsMetadata +from .types.metadata_service import PurgeContextsRequest +from .types.metadata_service import PurgeContextsResponse +from .types.metadata_service import PurgeExecutionsMetadata +from .types.metadata_service import PurgeExecutionsRequest +from .types.metadata_service import PurgeExecutionsResponse from .types.metadata_service import QueryArtifactLineageSubgraphRequest from .types.metadata_service import QueryContextLineageSubgraphRequest from .types.metadata_service import QueryExecutionInputsAndOutputsRequest @@ -561,6 +573,7 @@ "Dataset", "DatasetServiceClient", "DedicatedResources", + "DeleteArtifactRequest", "DeleteBatchPredictionJobRequest", "DeleteContextRequest", "DeleteCustomJobRequest", @@ -568,6 +581,7 @@ "DeleteDatasetRequest", "DeleteEndpointRequest", "DeleteEntityTypeRequest", + "DeleteExecutionRequest", "DeleteFeatureRequest", "DeleteFeaturestoreRequest", "DeleteHyperparameterTuningJobRequest", @@ -797,6 +811,16 @@ "PredictResponse", "PredictSchemata", "PredictionServiceClient", + "PrivateEndpoints", + "PurgeArtifactsMetadata", + "PurgeArtifactsRequest", + "PurgeArtifactsResponse", + "PurgeContextsMetadata", + "PurgeContextsRequest", + "PurgeContextsResponse", + "PurgeExecutionsMetadata", + "PurgeExecutionsRequest", + "PurgeExecutionsResponse", "PythonPackageSpec", "QueryArtifactLineageSubgraphRequest", "QueryContextLineageSubgraphRequest", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 605e95582d..55544714a5 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -938,11 +938,21 @@ "create_metadata_store" ] }, + "DeleteArtifact": { + "methods": [ + "delete_artifact" + ] + }, "DeleteContext": { "methods": [ "delete_context" ] }, + "DeleteExecution": { + "methods": [ + "delete_execution" + ] + }, "DeleteMetadataStore": { "methods": [ "delete_metadata_store" @@ -998,6 +1008,21 @@ "list_metadata_stores" ] }, + "PurgeArtifacts": { + "methods": [ + "purge_artifacts" + ] + }, + "PurgeContexts": { + "methods": [ + "purge_contexts" + ] + }, + "PurgeExecutions": { + "methods": [ + "purge_executions" + ] + }, "QueryArtifactLineageSubgraph": { "methods": [ "query_artifact_lineage_subgraph" @@ -1073,11 +1098,21 @@ "create_metadata_store" ] }, + "DeleteArtifact": { + "methods": [ + "delete_artifact" + ] + }, "DeleteContext": { "methods": [ "delete_context" ] }, + "DeleteExecution": { + "methods": [ + "delete_execution" + ] + }, "DeleteMetadataStore": { "methods": [ "delete_metadata_store" @@ -1133,6 +1168,21 @@ "list_metadata_stores" ] }, + "PurgeArtifacts": { + "methods": [ + "purge_artifacts" + ] + }, + "PurgeContexts": { + "methods": [ + "purge_contexts" + ] + }, + "PurgeExecutions": { + "methods": [ + "purge_executions" + ] + }, "QueryArtifactLineageSubgraph": { "methods": [ "query_artifact_lineage_subgraph" diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 8582d349eb..a03c0fd0b6 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -418,6 +418,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_dataset( diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index f7998abdf1..b043f64b6a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import annotation_spec from google.cloud.aiplatform_v1beta1.types import dataset @@ -50,8 +51,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 089c59004a..18d63c2988 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -60,6 +60,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -100,6 +101,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -153,6 +156,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -208,14 +212,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index 7ce9afb46c..36d5a65e9e 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -81,14 +81,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -106,6 +106,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -147,6 +148,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -199,6 +202,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 77d2abab83..1daadb7439 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -54,6 +54,8 @@ class EndpointServiceAsyncClient: parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path) model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) + network_path = staticmethod(EndpointServiceClient.network_path) + parse_network_path = staticmethod(EndpointServiceClient.parse_network_path) common_billing_account_path = staticmethod( EndpointServiceClient.common_billing_account_path ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 6a5371b98b..d7660217b8 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -196,6 +196,21 @@ def parse_model_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def network_path(project: str, network: str,) -> str: + """Returns a fully-qualified network string.""" + return "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + + @staticmethod + def parse_network_path(path: str) -> Dict[str, str]: + """Parses a network path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/global/networks/(?P.+?)$", path + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path(billing_account: str,) -> str: """Returns a fully-qualified billing_account string.""" @@ -369,6 +384,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_endpoint( diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 12d5e3d32b..c6a2bbf5d8 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint @@ -49,8 +50,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" @@ -68,6 +67,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -91,6 +91,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -100,7 +102,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -119,13 +121,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -146,27 +155,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 5f4f6274f4..877a79b7e6 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -99,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -152,6 +155,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -207,14 +211,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index ffc37c4b14..ecdaab95b0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -80,14 +80,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -105,6 +105,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -146,6 +147,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -198,6 +201,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py index 322e19e043..71d6c6193d 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/client.py @@ -359,6 +359,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def read_feature_values( diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py index b4e26a18c0..349146d525 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import featurestore_online_service @@ -45,8 +46,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class FeaturestoreOnlineServingServiceTransport(abc.ABC): """Abstract transport class for FeaturestoreOnlineServingService.""" @@ -64,6 +63,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -87,6 +87,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,7 +98,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -115,13 +117,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -142,27 +151,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py index b8ee74ba5a..b102a11b6e 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -99,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -151,6 +154,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -206,14 +210,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py index 1ead473e29..d3a8f66bb9 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_online_serving_service/transports/grpc_asyncio.py @@ -80,14 +80,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -105,6 +105,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -146,6 +147,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -197,6 +200,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py index 5e1597a5b0..0dfe99140f 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/client.py @@ -405,6 +405,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_featurestore( diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py index 7d9162a2fa..c2c03a28de 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import entity_type from google.cloud.aiplatform_v1beta1.types import entity_type as gca_entity_type @@ -52,8 +53,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class FeaturestoreServiceTransport(abc.ABC): """Abstract transport class for FeaturestoreService.""" @@ -71,6 +70,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -94,6 +94,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -103,7 +105,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -122,13 +124,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -149,27 +158,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py index 6b645a9720..4b96f1b547 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc.py @@ -65,6 +65,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -105,6 +106,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -158,6 +161,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -213,14 +217,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py index edb1c8c51d..fdd73956b2 100644 --- a/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/featurestore_service/transports/grpc_asyncio.py @@ -86,14 +86,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -111,6 +111,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -152,6 +153,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -204,6 +207,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py index 2952bbd94b..ca692ee019 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/client.py @@ -370,6 +370,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_index_endpoint( diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py index cbc16d1221..91d6e5783d 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import index_endpoint from google.cloud.aiplatform_v1beta1.types import index_endpoint as gca_index_endpoint @@ -49,8 +50,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class IndexEndpointServiceTransport(abc.ABC): """Abstract transport class for IndexEndpointService.""" @@ -68,6 +67,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -91,6 +91,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -100,7 +102,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -119,13 +121,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -146,27 +155,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py index dcfa1050ba..0b2cc3c865 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc.py @@ -61,6 +61,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -101,6 +102,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -154,6 +157,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -209,14 +213,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py index 6edd0afe21..92fdafcd70 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_endpoint_service/transports/grpc_asyncio.py @@ -82,14 +82,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -107,6 +107,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -148,6 +149,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -200,6 +203,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/client.py b/google/cloud/aiplatform_v1beta1/services/index_service/client.py index b5894ff8d6..67d8a97c6c 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/client.py @@ -370,6 +370,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_index( diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py index 731cd90024..98cd4c01ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import index from google.cloud.aiplatform_v1beta1.types import index_service @@ -48,8 +49,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class IndexServiceTransport(abc.ABC): """Abstract transport class for IndexService.""" @@ -67,6 +66,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -90,6 +90,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -99,7 +101,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -118,13 +120,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -145,27 +154,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py index 4cc2ac2cd7..83a4e4a1c0 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc.py @@ -61,6 +61,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -101,6 +102,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -154,6 +157,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -209,14 +213,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py index e48a839b3c..fd527ccf2c 100644 --- a/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/index_service/transports/grpc_asyncio.py @@ -82,14 +82,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -107,6 +107,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -148,6 +149,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -200,6 +203,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 53ba235acd..4dd5f9fcea 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -552,6 +552,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_custom_job( diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index f4adf20483..8698bed1e9 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import ( @@ -66,8 +67,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" @@ -85,6 +84,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -108,6 +108,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -117,7 +119,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -136,13 +138,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -163,27 +172,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index c4c837fbc3..9007ac7f32 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -78,6 +78,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -118,6 +119,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -171,6 +174,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -226,14 +230,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index a747f8ff05..df12c26e38 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -99,14 +99,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -124,6 +124,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -165,6 +166,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -217,6 +220,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py index c886cbda34..8a1f529dfc 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/async_client.py @@ -885,6 +885,183 @@ async def update_artifact( # Done; return the response. return response + async def delete_artifact( + self, + request: metadata_service.DeleteArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Artifact. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteArtifactRequest`): + The request object. Request message for + [MetadataService.DeleteArtifact][google.cloud.aiplatform.v1beta1.MetadataService.DeleteArtifact]. + name (:class:`str`): + Required. The resource name of the + Artifact to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = metadata_service.DeleteArtifactRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_artifact, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def purge_artifacts( + self, + request: metadata_service.PurgeArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Purges Artifacts. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.PurgeArtifactsRequest`): + The request object. Request message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + parent (:class:`str`): + Required. The metadata store to purge + Artifacts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeArtifactsResponse` + Response message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = metadata_service.PurgeArtifactsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.purge_artifacts, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + metadata_service.PurgeArtifactsResponse, + metadata_type=metadata_service.PurgeArtifactsMetadata, + ) + + # Done; return the response. + return response + async def create_context( self, request: metadata_service.CreateContextRequest = None, @@ -1234,7 +1411,7 @@ async def delete_context( [MetadataService.DeleteContext][google.cloud.aiplatform.v1beta1.MetadataService.DeleteContext]. name (:class:`str`): Required. The resource name of the - Context to retrieve. Format: + Context to delete. Format: projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} This corresponds to the ``name`` field @@ -1310,6 +1487,90 @@ async def delete_context( # Done; return the response. return response + async def purge_contexts( + self, + request: metadata_service.PurgeContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Purges Contexts. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.PurgeContextsRequest`): + The request object. Request message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + parent (:class:`str`): + Required. The metadata store to purge + Contexts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeContextsResponse` + Response message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = metadata_service.PurgeContextsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.purge_contexts, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + metadata_service.PurgeContextsResponse, + metadata_type=metadata_service.PurgeContextsMetadata, + ) + + # Done; return the response. + return response + async def add_context_artifacts_and_executions( self, request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, @@ -1911,6 +2172,183 @@ async def update_execution( # Done; return the response. return response + async def delete_execution( + self, + request: metadata_service.DeleteExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Execution. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteExecutionRequest`): + The request object. Request message for + [MetadataService.DeleteExecution][google.cloud.aiplatform.v1beta1.MetadataService.DeleteExecution]. + name (:class:`str`): + Required. The resource name of the + Execution to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = metadata_service.DeleteExecutionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_execution, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def purge_executions( + self, + request: metadata_service.PurgeExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Purges Executions. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.PurgeExecutionsRequest`): + The request object. Request message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + parent (:class:`str`): + Required. The metadata store to purge + Executions from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeExecutionsResponse` + Response message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = metadata_service.PurgeExecutionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.purge_executions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + metadata_service.PurgeExecutionsResponse, + metadata_type=metadata_service.PurgeExecutionsMetadata, + ) + + # Done; return the response. + return response + async def add_execution_events( self, request: metadata_service.AddExecutionEventsRequest = None, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py index 4f4ed9360e..1b813d80a8 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/client.py @@ -448,6 +448,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_metadata_store( @@ -1142,6 +1146,183 @@ def update_artifact( # Done; return the response. return response + def delete_artifact( + self, + request: metadata_service.DeleteArtifactRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes an Artifact. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteArtifactRequest): + The request object. Request message for + [MetadataService.DeleteArtifact][google.cloud.aiplatform.v1beta1.MetadataService.DeleteArtifact]. + name (str): + Required. The resource name of the + Artifact to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.DeleteArtifactRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.DeleteArtifactRequest): + request = metadata_service.DeleteArtifactRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_artifact] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def purge_artifacts( + self, + request: metadata_service.PurgeArtifactsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Purges Artifacts. + + Args: + request (google.cloud.aiplatform_v1beta1.types.PurgeArtifactsRequest): + The request object. Request message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + parent (str): + Required. The metadata store to purge + Artifacts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeArtifactsResponse` + Response message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.PurgeArtifactsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.PurgeArtifactsRequest): + request = metadata_service.PurgeArtifactsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.purge_artifacts] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + metadata_service.PurgeArtifactsResponse, + metadata_type=metadata_service.PurgeArtifactsMetadata, + ) + + # Done; return the response. + return response + def create_context( self, request: metadata_service.CreateContextRequest = None, @@ -1491,7 +1672,7 @@ def delete_context( [MetadataService.DeleteContext][google.cloud.aiplatform.v1beta1.MetadataService.DeleteContext]. name (str): Required. The resource name of the - Context to retrieve. Format: + Context to delete. Format: projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} This corresponds to the ``name`` field @@ -1567,6 +1748,90 @@ def delete_context( # Done; return the response. return response + def purge_contexts( + self, + request: metadata_service.PurgeContextsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Purges Contexts. + + Args: + request (google.cloud.aiplatform_v1beta1.types.PurgeContextsRequest): + The request object. Request message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + parent (str): + Required. The metadata store to purge + Contexts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeContextsResponse` + Response message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.PurgeContextsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.PurgeContextsRequest): + request = metadata_service.PurgeContextsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.purge_contexts] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + metadata_service.PurgeContextsResponse, + metadata_type=metadata_service.PurgeContextsMetadata, + ) + + # Done; return the response. + return response + def add_context_artifacts_and_executions( self, request: metadata_service.AddContextArtifactsAndExecutionsRequest = None, @@ -2174,6 +2439,183 @@ def update_execution( # Done; return the response. return response + def delete_execution( + self, + request: metadata_service.DeleteExecutionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Deletes an Execution. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteExecutionRequest): + The request object. Request message for + [MetadataService.DeleteExecution][google.cloud.aiplatform.v1beta1.MetadataService.DeleteExecution]. + name (str): + Required. The resource name of the + Execution to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.DeleteExecutionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.DeleteExecutionRequest): + request = metadata_service.DeleteExecutionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_execution] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + empty_pb2.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def purge_executions( + self, + request: metadata_service.PurgeExecutionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gac_operation.Operation: + r"""Purges Executions. + + Args: + request (google.cloud.aiplatform_v1beta1.types.PurgeExecutionsRequest): + The request object. Request message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + parent (str): + Required. The metadata store to purge + Executions from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.PurgeExecutionsResponse` + Response message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a metadata_service.PurgeExecutionsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, metadata_service.PurgeExecutionsRequest): + request = metadata_service.PurgeExecutionsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.purge_executions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + metadata_service.PurgeExecutionsResponse, + metadata_type=metadata_service.PurgeExecutionsMetadata, + ) + + # Done; return the response. + return response + def add_execution_events( self, request: metadata_service.AddExecutionEventsRequest = None, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py index 284a96558e..5b74127d6b 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import artifact from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact @@ -57,8 +58,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class MetadataServiceTransport(abc.ABC): """Abstract transport class for MetadataService.""" @@ -76,6 +75,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -99,6 +99,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -108,7 +110,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -127,13 +129,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -154,27 +163,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -206,6 +194,12 @@ def _prep_wrapped_messages(self, client_info): self.update_artifact: gapic_v1.method.wrap_method( self.update_artifact, default_timeout=5.0, client_info=client_info, ), + self.delete_artifact: gapic_v1.method.wrap_method( + self.delete_artifact, default_timeout=None, client_info=client_info, + ), + self.purge_artifacts: gapic_v1.method.wrap_method( + self.purge_artifacts, default_timeout=None, client_info=client_info, + ), self.create_context: gapic_v1.method.wrap_method( self.create_context, default_timeout=5.0, client_info=client_info, ), @@ -221,6 +215,9 @@ def _prep_wrapped_messages(self, client_info): self.delete_context: gapic_v1.method.wrap_method( self.delete_context, default_timeout=5.0, client_info=client_info, ), + self.purge_contexts: gapic_v1.method.wrap_method( + self.purge_contexts, default_timeout=None, client_info=client_info, + ), self.add_context_artifacts_and_executions: gapic_v1.method.wrap_method( self.add_context_artifacts_and_executions, default_timeout=5.0, @@ -246,6 +243,12 @@ def _prep_wrapped_messages(self, client_info): self.update_execution: gapic_v1.method.wrap_method( self.update_execution, default_timeout=5.0, client_info=client_info, ), + self.delete_execution: gapic_v1.method.wrap_method( + self.delete_execution, default_timeout=None, client_info=client_info, + ), + self.purge_executions: gapic_v1.method.wrap_method( + self.purge_executions, default_timeout=None, client_info=client_info, + ), self.add_execution_events: gapic_v1.method.wrap_method( self.add_execution_events, default_timeout=5.0, client_info=client_info, ), @@ -357,6 +360,24 @@ def update_artifact( ]: raise NotImplementedError() + @property + def delete_artifact( + self, + ) -> Callable[ + [metadata_service.DeleteArtifactRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def purge_artifacts( + self, + ) -> Callable[ + [metadata_service.PurgeArtifactsRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def create_context( self, @@ -405,6 +426,15 @@ def delete_context( ]: raise NotImplementedError() + @property + def purge_contexts( + self, + ) -> Callable[ + [metadata_service.PurgeContextsRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def add_context_artifacts_and_executions( self, @@ -480,6 +510,24 @@ def update_execution( ]: raise NotImplementedError() + @property + def delete_execution( + self, + ) -> Callable[ + [metadata_service.DeleteExecutionRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def purge_executions( + self, + ) -> Callable[ + [metadata_service.PurgeExecutionsRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def add_execution_events( self, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py index e123bb59b4..1ce2ff9986 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc.py @@ -69,6 +69,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -109,6 +110,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -162,6 +165,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -217,14 +221,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -468,6 +472,58 @@ def update_artifact( ) return self._stubs["update_artifact"] + @property + def delete_artifact( + self, + ) -> Callable[[metadata_service.DeleteArtifactRequest], operations_pb2.Operation]: + r"""Return a callable for the delete artifact method over gRPC. + + Deletes an Artifact. + + Returns: + Callable[[~.DeleteArtifactRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_artifact" not in self._stubs: + self._stubs["delete_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteArtifact", + request_serializer=metadata_service.DeleteArtifactRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_artifact"] + + @property + def purge_artifacts( + self, + ) -> Callable[[metadata_service.PurgeArtifactsRequest], operations_pb2.Operation]: + r"""Return a callable for the purge artifacts method over gRPC. + + Purges Artifacts. + + Returns: + Callable[[~.PurgeArtifactsRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_artifacts" not in self._stubs: + self._stubs["purge_artifacts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeArtifacts", + request_serializer=metadata_service.PurgeArtifactsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_artifacts"] + @property def create_context( self, @@ -600,6 +656,32 @@ def delete_context( ) return self._stubs["delete_context"] + @property + def purge_contexts( + self, + ) -> Callable[[metadata_service.PurgeContextsRequest], operations_pb2.Operation]: + r"""Return a callable for the purge contexts method over gRPC. + + Purges Contexts. + + Returns: + Callable[[~.PurgeContextsRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_contexts" not in self._stubs: + self._stubs["purge_contexts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeContexts", + request_serializer=metadata_service.PurgeContextsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_contexts"] + @property def add_context_artifacts_and_executions( self, @@ -807,6 +889,58 @@ def update_execution( ) return self._stubs["update_execution"] + @property + def delete_execution( + self, + ) -> Callable[[metadata_service.DeleteExecutionRequest], operations_pb2.Operation]: + r"""Return a callable for the delete execution method over gRPC. + + Deletes an Execution. + + Returns: + Callable[[~.DeleteExecutionRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_execution" not in self._stubs: + self._stubs["delete_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteExecution", + request_serializer=metadata_service.DeleteExecutionRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_execution"] + + @property + def purge_executions( + self, + ) -> Callable[[metadata_service.PurgeExecutionsRequest], operations_pb2.Operation]: + r"""Return a callable for the purge executions method over gRPC. + + Purges Executions. + + Returns: + Callable[[~.PurgeExecutionsRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_executions" not in self._stubs: + self._stubs["purge_executions"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeExecutions", + request_serializer=metadata_service.PurgeExecutionsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_executions"] + @property def add_execution_events( self, diff --git a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py index a9fa784a46..19646a3071 100644 --- a/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/metadata_service/transports/grpc_asyncio.py @@ -90,14 +90,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -115,6 +115,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -156,6 +157,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -208,6 +211,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -481,6 +485,62 @@ def update_artifact( ) return self._stubs["update_artifact"] + @property + def delete_artifact( + self, + ) -> Callable[ + [metadata_service.DeleteArtifactRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the delete artifact method over gRPC. + + Deletes an Artifact. + + Returns: + Callable[[~.DeleteArtifactRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_artifact" not in self._stubs: + self._stubs["delete_artifact"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteArtifact", + request_serializer=metadata_service.DeleteArtifactRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_artifact"] + + @property + def purge_artifacts( + self, + ) -> Callable[ + [metadata_service.PurgeArtifactsRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the purge artifacts method over gRPC. + + Purges Artifacts. + + Returns: + Callable[[~.PurgeArtifactsRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_artifacts" not in self._stubs: + self._stubs["purge_artifacts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeArtifacts", + request_serializer=metadata_service.PurgeArtifactsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_artifacts"] + @property def create_context( self, @@ -620,6 +680,34 @@ def delete_context( ) return self._stubs["delete_context"] + @property + def purge_contexts( + self, + ) -> Callable[ + [metadata_service.PurgeContextsRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the purge contexts method over gRPC. + + Purges Contexts. + + Returns: + Callable[[~.PurgeContextsRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_contexts" not in self._stubs: + self._stubs["purge_contexts"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeContexts", + request_serializer=metadata_service.PurgeContextsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_contexts"] + @property def add_context_artifacts_and_executions( self, @@ -833,6 +921,62 @@ def update_execution( ) return self._stubs["update_execution"] + @property + def delete_execution( + self, + ) -> Callable[ + [metadata_service.DeleteExecutionRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the delete execution method over gRPC. + + Deletes an Execution. + + Returns: + Callable[[~.DeleteExecutionRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_execution" not in self._stubs: + self._stubs["delete_execution"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/DeleteExecution", + request_serializer=metadata_service.DeleteExecutionRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_execution"] + + @property + def purge_executions( + self, + ) -> Callable[ + [metadata_service.PurgeExecutionsRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the purge executions method over gRPC. + + Purges Executions. + + Returns: + Callable[[~.PurgeExecutionsRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "purge_executions" not in self._stubs: + self._stubs["purge_executions"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MetadataService/PurgeExecutions", + request_serializer=metadata_service.PurgeExecutionsRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["purge_executions"] + @property def add_execution_events( self, diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 47d85e73f2..c357f45d02 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -179,32 +179,32 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod @@ -444,6 +444,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def search_migratable_resources( diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index d2b54f2c66..3405e5f216 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import migration_service from google.longrunning import operations_pb2 # type: ignore @@ -47,8 +48,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" @@ -66,6 +65,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -89,6 +89,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -98,7 +100,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -117,13 +119,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -144,27 +153,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 3148709fae..31e0b81fdd 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -60,6 +60,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -100,6 +101,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -153,6 +156,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -208,14 +212,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 1777eddaaf..70c630c53c 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -81,14 +81,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -106,6 +106,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -147,6 +148,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -199,6 +202,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index c8474b7e64..b883c31e0f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -430,6 +430,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def upload_model( diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 3dd6e890c7..e96cb0c0de 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model as gca_model @@ -51,8 +52,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" @@ -70,6 +69,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -93,6 +93,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -102,7 +104,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -121,13 +123,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -148,27 +157,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 4814f5c5bc..6b8a181480 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -63,6 +63,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -103,6 +104,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -156,6 +159,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -211,14 +215,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index c530f93914..81be2706a6 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -84,14 +84,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -109,6 +109,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -150,6 +151,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -202,6 +205,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index ac589a4a24..06eba87b60 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -508,6 +508,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_training_pipeline( diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index cc09aa7551..d862d4e558 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_job from google.cloud.aiplatform_v1beta1.types import pipeline_job as gca_pipeline_job @@ -54,8 +55,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" @@ -73,6 +72,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -96,6 +96,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -105,7 +107,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -124,13 +126,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -151,27 +160,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 35957d464d..77bbb88dcb 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -69,6 +69,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -109,6 +110,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -162,6 +165,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -217,14 +221,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index 081f2a2f55..dc777ae601 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -90,14 +90,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -115,6 +115,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -156,6 +157,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -208,6 +211,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 29b849b09c..7ef1133b40 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -347,6 +347,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def predict( diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 9006c8335a..ad5ea5f387 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -45,8 +46,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" @@ -64,6 +63,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -87,6 +87,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,7 +98,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -115,13 +117,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -142,27 +151,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 6776380a8f..d901a40452 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -57,6 +57,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -97,6 +98,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -149,6 +152,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -204,14 +208,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 727333f09e..e12a2caf5e 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -78,14 +78,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -103,6 +103,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -144,6 +145,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -195,6 +198,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index c26bb7b712..31cb4142ee 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -359,6 +359,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_specialist_pool( diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index cbb0dee1cf..1a42154dca 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -48,8 +49,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" @@ -67,6 +66,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -90,6 +90,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -99,7 +101,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -118,13 +120,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -145,27 +154,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index bdc5719d97..47e194de7f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -65,6 +65,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -105,6 +106,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -158,6 +161,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -213,14 +217,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index 3d8af32950..b50396a5ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -86,14 +86,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -111,6 +111,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -152,6 +153,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -204,6 +207,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py index 045c69b7d0..80cd787912 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/client.py @@ -437,6 +437,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_tensorboard( diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py index 71b612167c..68b3f921a9 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import tensorboard from google.cloud.aiplatform_v1beta1.types import tensorboard_experiment @@ -58,8 +59,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class TensorboardServiceTransport(abc.ABC): """Abstract transport class for TensorboardService.""" @@ -77,6 +76,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -100,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -109,7 +111,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -128,13 +130,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -155,27 +164,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py index ad3913ec39..b956c014d5 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc.py @@ -70,6 +70,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -110,6 +111,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -163,6 +166,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -218,14 +222,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py index 4af4514351..5526bb09d5 100644 --- a/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/tensorboard_service/transports/grpc_asyncio.py @@ -91,14 +91,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -116,6 +116,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -157,6 +158,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -209,6 +212,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py index 4b3611e6c4..5b458087a6 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -383,6 +383,10 @@ def __init__( client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, + always_use_jwt_access=( + Transport == type(self).get_transport_class("grpc") + or Transport == type(self).get_transport_class("grpc_asyncio") + ), ) def create_study( diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py index 45cd82c0ab..144c60512c 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py @@ -25,6 +25,7 @@ from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import study from google.cloud.aiplatform_v1beta1.types import study as gca_study @@ -50,8 +51,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class VizierServiceTransport(abc.ABC): """Abstract transport class for VizierService.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py index 85cbd69259..03ee800123 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -65,6 +65,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -105,6 +106,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -158,6 +161,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -213,14 +217,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py index 376561514b..29c92b49ec 100644 --- a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -86,14 +86,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -111,6 +111,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -152,6 +153,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -204,6 +207,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 7b54b4f06d..932c8be0b1 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -65,6 +65,7 @@ from .endpoint import ( DeployedModel, Endpoint, + PrivateEndpoints, ) from .endpoint_service import ( CreateEndpointOperationMetadata, @@ -262,7 +263,9 @@ CreateMetadataSchemaRequest, CreateMetadataStoreOperationMetadata, CreateMetadataStoreRequest, + DeleteArtifactRequest, DeleteContextRequest, + DeleteExecutionRequest, DeleteMetadataStoreOperationMetadata, DeleteMetadataStoreRequest, GetArtifactRequest, @@ -280,6 +283,15 @@ ListMetadataSchemasResponse, ListMetadataStoresRequest, ListMetadataStoresResponse, + PurgeArtifactsMetadata, + PurgeArtifactsRequest, + PurgeArtifactsResponse, + PurgeContextsMetadata, + PurgeContextsRequest, + PurgeContextsResponse, + PurgeExecutionsMetadata, + PurgeExecutionsRequest, + PurgeExecutionsResponse, QueryArtifactLineageSubgraphRequest, QueryContextLineageSubgraphRequest, QueryExecutionInputsAndOutputsRequest, @@ -520,6 +532,7 @@ "EncryptionSpec", "DeployedModel", "Endpoint", + "PrivateEndpoints", "CreateEndpointOperationMetadata", "CreateEndpointRequest", "DeleteEndpointRequest", @@ -694,7 +707,9 @@ "CreateMetadataSchemaRequest", "CreateMetadataStoreOperationMetadata", "CreateMetadataStoreRequest", + "DeleteArtifactRequest", "DeleteContextRequest", + "DeleteExecutionRequest", "DeleteMetadataStoreOperationMetadata", "DeleteMetadataStoreRequest", "GetArtifactRequest", @@ -712,6 +727,15 @@ "ListMetadataSchemasResponse", "ListMetadataStoresRequest", "ListMetadataStoresResponse", + "PurgeArtifactsMetadata", + "PurgeArtifactsRequest", + "PurgeArtifactsResponse", + "PurgeContextsMetadata", + "PurgeContextsRequest", + "PurgeContextsResponse", + "PurgeExecutionsMetadata", + "PurgeExecutionsRequest", + "PurgeExecutionsResponse", "QueryArtifactLineageSubgraphRequest", "QueryContextLineageSubgraphRequest", "QueryExecutionInputsAndOutputsRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index cc0d88d68b..4fbb04c5ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -256,7 +256,7 @@ class OutputConfig(proto.Message): files are created (N depends on total number of failed predictions). These files contain the failed instances, as per their schema, followed by an additional ``error`` field - which as value has ```google.rpc.Status`` `__ + which as value has [google.rpc.Status][google.rpc.Status] containing only ``code`` and ``message`` fields. bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): The BigQuery project or dataset location where the output is @@ -278,8 +278,8 @@ class OutputConfig(proto.Message): ``errors`` table contains rows for which the prediction has failed, it has instance columns, as per the instance schema, followed by a single "errors" column, which as values has - ```google.rpc.Status`` `__ represented as a STRUCT, - and containing only ``code`` and ``message``. + [google.rpc.Status][google.rpc.Status] represented as a + STRUCT, and containing only ``code`` and ``message``. predictions_format (str): Required. The format in which Vertex AI gives the predictions, must be one of the @@ -311,6 +311,11 @@ class OutputInfo(proto.Message): Output only. The path of the BigQuery dataset created, in ``bq://projectId.bqDatasetId`` format, into which the prediction output is written. + bigquery_output_table (str): + Output only. The name of the BigQuery table created, in + ``predictions_`` format, into which the + prediction output is written. Can be used by UI to generate + the BigQuery output path, for example. """ gcs_output_directory = proto.Field( @@ -319,6 +324,7 @@ class OutputInfo(proto.Message): bigquery_output_dataset = proto.Field( proto.STRING, number=2, oneof="output_location", ) + bigquery_output_table = proto.Field(proto.STRING, number=4,) name = proto.Field(proto.STRING, number=1,) display_name = proto.Field(proto.STRING, number=2,) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 9932a3c9ac..4a810ce781 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -86,6 +86,12 @@ class CustomJob(proto.Message): CustomJob. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. + web_access_uris (Sequence[google.cloud.aiplatform_v1beta1.types.CustomJob.WebAccessUrisEntry]): + Output only. The web access URIs for the + training job. The keys are the node names in the + training jobs, e.g. workerpool0-0. The values + are the URIs for each node's web portal in the + job. """ name = proto.Field(proto.STRING, number=1,) @@ -101,6 +107,7 @@ class CustomJob(proto.Message): encryption_spec = proto.Field( proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, ) + web_access_uris = proto.MapField(proto.STRING, proto.STRING, number=16,) class CustomJobSpec(proto.Message): @@ -167,6 +174,10 @@ class CustomJobSpec(proto.Message): resource to which this CustomJob will upload Tensorboard logs. Format: ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + enable_web_access (bool): + Optional. Vertex AI will enable web portal access to the + containers. The portals can be accessed on web via the URLs + given by [web_access_uris][]. """ worker_pool_specs = proto.RepeatedField( @@ -179,6 +190,7 @@ class CustomJobSpec(proto.Message): proto.MESSAGE, number=6, message=io.GcsDestination, ) tensorboard = proto.Field(proto.STRING, number=7,) + enable_web_access = proto.Field(proto.BOOL, number=10,) class WorkerPoolSpec(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 7b35657bcc..a9f2a9e1f5 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -22,7 +22,8 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"Endpoint", "DeployedModel",}, + package="google.cloud.aiplatform.v1beta1", + manifest={"Endpoint", "DeployedModel", "PrivateEndpoints",}, ) @@ -82,6 +83,19 @@ class Endpoint(proto.Message): Endpoint. If set, this Endpoint and all sub- resources of this Endpoint will be secured by this key. + network (str): + The full name of the Google Compute Engine + `network `__ + to which the Endpoint should be peered. + + Private services access must already be configured for the + network. If left unspecified, the Endpoint is not peered + with any network. + + `Format `__: + projects/{project}/global/networks/{network}. Where + {project} is a project number, as in '12345', and {network} + is network name. """ name = proto.Field(proto.STRING, number=1,) @@ -98,6 +112,7 @@ class Endpoint(proto.Message): encryption_spec = proto.Field( proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, ) + network = proto.Field(proto.STRING, number=13,) class DeployedModel(proto.Message): @@ -170,6 +185,13 @@ class DeployedModel(proto.Message): requests at a high queries per second rate (QPS). Estimate your costs before enabling this option. + private_endpoints (google.cloud.aiplatform_v1beta1.types.PrivateEndpoints): + Output only. Provide paths for users to send + predict/explain/health requests directly to the deployed + model services running on Cloud via private services access. + This field is populated if + [network][google.cloud.aiplatform.v1beta1.Endpoint.network] + is configured. """ dedicated_resources = proto.Field( @@ -194,6 +216,30 @@ class DeployedModel(proto.Message): service_account = proto.Field(proto.STRING, number=11,) enable_container_logging = proto.Field(proto.BOOL, number=12,) enable_access_logging = proto.Field(proto.BOOL, number=13,) + private_endpoints = proto.Field( + proto.MESSAGE, number=14, message="PrivateEndpoints", + ) + + +class PrivateEndpoints(proto.Message): + r"""PrivateEndpoints is used to provide paths for users to send + requests via private services access. + + Attributes: + predict_http_uri (str): + Output only. Http(s) path to send prediction + requests. + explain_http_uri (str): + Output only. Http(s) path to send explain + requests. + health_http_uri (str): + Output only. Http(s) path to send health + check requests. + """ + + predict_http_uri = proto.Field(proto.STRING, number=1,) + explain_http_uri = proto.Field(proto.STRING, number=2,) + health_http_uri = proto.Field(proto.STRING, number=3,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/execution.py b/google/cloud/aiplatform_v1beta1/types/execution.py index a2762f1a1c..1ac953399c 100644 --- a/google/cloud/aiplatform_v1beta1/types/execution.py +++ b/google/cloud/aiplatform_v1beta1/types/execution.py @@ -88,6 +88,8 @@ class State(proto.Enum): RUNNING = 2 COMPLETE = 3 FAILED = 4 + CACHED = 5 + CANCELLED = 6 name = proto.Field(proto.STRING, number=1,) display_name = proto.Field(proto.STRING, number=2,) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py index 76e6f6b5a4..2ac554c0f2 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_monitoring.py @@ -55,12 +55,23 @@ class SnapshotAnalysis(proto.Message): Configuration of the snapshot analysis based monitoring pipeline running interval. The value is rolled up to full day. + monitoring_interval_days (int): + Configuration of the snapshot analysis based monitoring + pipeline running interval. The value indicates number of + days. If both + [FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval_days][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval_days] + and + [FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval] + are set when creating/updating EntityTypes/Features, + [FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval_days][google.cloud.aiplatform.v1beta1.FeaturestoreMonitoringConfig.SnapshotAnalysis.monitoring_interval_days] + will be used. """ disabled = proto.Field(proto.BOOL, number=1,) monitoring_interval = proto.Field( proto.MESSAGE, number=2, message=duration_pb2.Duration, ) + monitoring_interval_days = proto.Field(proto.INT32, number=3,) snapshot_analysis = proto.Field(proto.MESSAGE, number=1, message=SnapshotAnalysis,) diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py index 87688e56e5..e139820aa6 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_online_service.py @@ -167,12 +167,14 @@ class StreamingReadFeatureValuesRequest(proto.Message): For example, for a machine learning model predicting user clicks on a website, an EntityType ID could be "user". entity_ids (Sequence[str]): - Required. IDs of entities to read Feature values of. For - example, for a machine learning model predicting user clicks - on a website, an entity ID could be "user_123". + Required. IDs of entities to read Feature values of. The + maximum number of IDs is 100. For example, for a machine + learning model predicting user clicks on a website, an + entity ID could be "user_123". feature_selector (google.cloud.aiplatform_v1beta1.types.FeatureSelector): Required. Selector choosing Features of the - target EntityType. + target EntityType. Feature IDs will be + deduplicated. """ entity_type = proto.Field(proto.STRING, number=1,) @@ -206,7 +208,7 @@ class FeatureValue(proto.Message): bytes_value (bytes): Bytes feature value. metadata (google.cloud.aiplatform_v1beta1.types.FeatureValue.Metadata): - Output only. Metadata of feature value. + Metadata of feature value. """ class Metadata(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py index fcbe2516e6..08863af44a 100644 --- a/google/cloud/aiplatform_v1beta1/types/featurestore_service.py +++ b/google/cloud/aiplatform_v1beta1/types/featurestore_service.py @@ -396,6 +396,8 @@ class BatchReadFeatureValuesRequest(proto.Message): Values in the timestamp column must use the RFC 3339 format, e.g. ``2012-07-30T10:43:17.123Z``. + bigquery_read_instances (google.cloud.aiplatform_v1beta1.types.BigQuerySource): + Similar to csv_read_instances, but from BigQuery source. featurestore (str): Required. The resource name of the Featurestore from which to query Feature values. Format: @@ -403,14 +405,34 @@ class BatchReadFeatureValuesRequest(proto.Message): destination (google.cloud.aiplatform_v1beta1.types.FeatureValueDestination): Required. Specifies output location and format. + pass_through_fields (Sequence[google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesRequest.PassThroughField]): + When not empty, the specified fields in the + \*_read_instances source will be joined as-is in the output, + in addition to those fields from the Featurestore Entity. + + For BigQuery source, the type of the pass-through values + will be automatically inferred. For CSV source, the + pass-through values will be passed as opaque bytes. entity_type_specs (Sequence[google.cloud.aiplatform_v1beta1.types.BatchReadFeatureValuesRequest.EntityTypeSpec]): Required. Specifies EntityType grouping Features to read values of and settings. Each EntityType referenced in [BatchReadFeatureValuesRequest.entity_type_specs] must have - a column specifying entity IDs in tha EntityType in + a column specifying entity IDs in the EntityType in [BatchReadFeatureValuesRequest.request][] . """ + class PassThroughField(proto.Message): + r"""Describe pass-through fields in read_instance source. + Attributes: + field_name (str): + Required. The name of the field in the CSV header or the + name of the column in BigQuery table. The naming restriction + is the same as + [Feature.name][google.cloud.aiplatform.v1beta1.Feature.name]. + """ + + field_name = proto.Field(proto.STRING, number=1,) + class EntityTypeSpec(proto.Message): r"""Selects Features of an EntityType to read values of and specifies read settings. @@ -439,10 +461,16 @@ class EntityTypeSpec(proto.Message): csv_read_instances = proto.Field( proto.MESSAGE, number=3, oneof="read_option", message=io.CsvSource, ) + bigquery_read_instances = proto.Field( + proto.MESSAGE, number=5, oneof="read_option", message=io.BigQuerySource, + ) featurestore = proto.Field(proto.STRING, number=1,) destination = proto.Field( proto.MESSAGE, number=4, message="FeatureValueDestination", ) + pass_through_fields = proto.RepeatedField( + proto.MESSAGE, number=8, message=PassThroughField, + ) entity_type_specs = proto.RepeatedField( proto.MESSAGE, number=7, message=EntityTypeSpec, ) diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py index 1f7a7be1b5..c04ed7dcba 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint.py @@ -171,6 +171,18 @@ class DeployedIndex(proto.Message): deployed_index_auth_config (google.cloud.aiplatform_v1beta1.types.DeployedIndexAuthConfig): Optional. If set, the authentication is enabled for the private endpoint. + reserved_ip_ranges (Sequence[str]): + Optional. A list of reserved ip ranges under + the VPC network that can be used for this + DeployedIndex. + If set, we will deploy the index within the + provided ip ranges. Otherwise, the index might + be deployed to any ip ranges under the provided + VPC network. + + The value sohuld be the name of the address + (https://cloud.google.com/compute/docs/reference/rest/v1/addresses) + Example: 'vertex-ai-ip-range'. """ id = proto.Field(proto.STRING, number=1,) @@ -190,6 +202,7 @@ class DeployedIndex(proto.Message): deployed_index_auth_config = proto.Field( proto.MESSAGE, number=9, message="DeployedIndexAuthConfig", ) + reserved_ip_ranges = proto.RepeatedField(proto.STRING, number=10,) class DeployedIndexAuthConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py index 591071752a..b4049d4826 100644 --- a/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/index_endpoint_service.py @@ -240,11 +240,14 @@ class DeployIndexOperationMetadata(proto.Message): Attributes: generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. + deployed_index_id (str): + The unique index id specified by user """ generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) + deployed_index_id = proto.Field(proto.STRING, number=2,) class UndeployIndexRequest(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 6d456e7b6b..0b7cb85666 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -108,6 +108,7 @@ class BigQueryDestination(proto.Message): Accepted forms: - BigQuery path. For example: ``bq://projectId`` or + ``bq://projectId.bqDatasetId`` or ``bq://projectId.bqDatasetId.bqTableId``. """ diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 329b22213b..896b29ab16 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -80,15 +80,14 @@ class DedicatedResources(proto.Message): Required. Immutable. The specification of a single machine used by the prediction. min_replica_count (int): - Required. Immutable. The minimum number of machine replicas - this DeployedModel will be always deployed on. If traffic - against it increases, it may dynamically be deployed onto - more replicas, and as traffic decreases, some of these extra - replicas may be freed. Note: if - [machine_spec.accelerator_count][google.cloud.aiplatform.v1beta1.MachineSpec.accelerator_count] - is above 0, currently the model will be always deployed - precisely on - [min_replica_count][google.cloud.aiplatform.v1beta1.DedicatedResources.min_replica_count]. + Required. Immutable. The minimum number of + machine replicas this DeployedModel will be + always deployed on. This value must be greater + than or equal to 1. + If traffic against the DeployedModel increases, + it may dynamically be deployed onto more + replicas, and as traffic decreases, some of + these extra replicas may be freed. max_replica_count (int): Immutable. The maximum number of replicas this DeployedModel may be deployed on when the traffic against it increases. If diff --git a/google/cloud/aiplatform_v1beta1/types/metadata_service.py b/google/cloud/aiplatform_v1beta1/types/metadata_service.py index 12b3af0b0f..5d6ff87acd 100644 --- a/google/cloud/aiplatform_v1beta1/types/metadata_service.py +++ b/google/cloud/aiplatform_v1beta1/types/metadata_service.py @@ -40,12 +40,19 @@ "ListArtifactsRequest", "ListArtifactsResponse", "UpdateArtifactRequest", + "DeleteArtifactRequest", + "PurgeArtifactsRequest", + "PurgeArtifactsResponse", + "PurgeArtifactsMetadata", "CreateContextRequest", "GetContextRequest", "ListContextsRequest", "ListContextsResponse", "UpdateContextRequest", "DeleteContextRequest", + "PurgeContextsRequest", + "PurgeContextsResponse", + "PurgeContextsMetadata", "AddContextArtifactsAndExecutionsRequest", "AddContextArtifactsAndExecutionsResponse", "AddContextChildrenRequest", @@ -56,6 +63,10 @@ "ListExecutionsRequest", "ListExecutionsResponse", "UpdateExecutionRequest", + "DeleteExecutionRequest", + "PurgeExecutionsRequest", + "PurgeExecutionsResponse", + "PurgeExecutionsMetadata", "AddExecutionEventsRequest", "AddExecutionEventsResponse", "QueryExecutionInputsAndOutputsRequest", @@ -369,6 +380,81 @@ class UpdateArtifactRequest(proto.Message): allow_missing = proto.Field(proto.BOOL, number=3,) +class DeleteArtifactRequest(proto.Message): + r"""Request message for + [MetadataService.DeleteArtifact][google.cloud.aiplatform.v1beta1.MetadataService.DeleteArtifact]. + + Attributes: + name (str): + Required. The resource name of the Artifact + to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/artifacts/{artifact} + etag (str): + Optional. The etag of the Artifact to delete. If this is + provided, it must match the server's etag. Otherwise, the + request will fail with a FAILED_PRECONDITION. + """ + + name = proto.Field(proto.STRING, number=1,) + etag = proto.Field(proto.STRING, number=2,) + + +class PurgeArtifactsRequest(proto.Message): + r"""Request message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + + Attributes: + parent (str): + Required. The metadata store to purge + Artifacts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + filter (str): + Required. A required filter matching the Artifacts to be + purged. E.g., update_time <= 2020-11-19T11:30:00-04:00. + force (bool): + Optional. Flag to indicate to actually perform the purge. If + ``force`` is set to false, the method will return a sample + of Artifact names that would be deleted. + """ + + parent = proto.Field(proto.STRING, number=1,) + filter = proto.Field(proto.STRING, number=2,) + force = proto.Field(proto.BOOL, number=3,) + + +class PurgeArtifactsResponse(proto.Message): + r"""Response message for + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + + Attributes: + purge_count (int): + The number of Artifacts that this request deleted (or, if + ``force`` is false, the number of Artifacts that will be + deleted). This can be an estimate. + purge_sample (Sequence[str]): + A sample of the Artifact names that will be deleted. Only + populated if ``force`` is set to false. The maximum number + of samples is 100 (it is possible to return fewer). + """ + + purge_count = proto.Field(proto.INT64, number=1,) + purge_sample = proto.RepeatedField(proto.STRING, number=2,) + + +class PurgeArtifactsMetadata(proto.Message): + r"""Details of operations that perform + [MetadataService.PurgeArtifacts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeArtifacts]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for purging Artifacts. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + class CreateContextRequest(proto.Message): r"""Request message for [MetadataService.CreateContext][google.cloud.aiplatform.v1beta1.MetadataService.CreateContext]. @@ -535,17 +621,76 @@ class DeleteContextRequest(proto.Message): Attributes: name (str): Required. The resource name of the Context to - retrieve. Format: + delete. Format: projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context} force (bool): - If set to true, any child resources of this Context will be - deleted. (Otherwise, the request will fail with a - FAILED_PRECONDITION error if the Context has any child - resources, such as another Context, Artifact, or Execution). + The force deletion semantics is still + undefined. Users should not use this field. + etag (str): + Optional. The etag of the Context to delete. If this is + provided, it must match the server's etag. Otherwise, the + request will fail with a FAILED_PRECONDITION. """ name = proto.Field(proto.STRING, number=1,) force = proto.Field(proto.BOOL, number=2,) + etag = proto.Field(proto.STRING, number=3,) + + +class PurgeContextsRequest(proto.Message): + r"""Request message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + + Attributes: + parent (str): + Required. The metadata store to purge + Contexts from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + filter (str): + Required. A required filter matching the Contexts to be + purged. E.g., update_time <= 2020-11-19T11:30:00-04:00. + force (bool): + Optional. Flag to indicate to actually perform the purge. If + ``force`` is set to false, the method will return a sample + of Context names that would be deleted. + """ + + parent = proto.Field(proto.STRING, number=1,) + filter = proto.Field(proto.STRING, number=2,) + force = proto.Field(proto.BOOL, number=3,) + + +class PurgeContextsResponse(proto.Message): + r"""Response message for + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + + Attributes: + purge_count (int): + The number of Contexts that this request deleted (or, if + ``force`` is false, the number of Contexts that will be + deleted). This can be an estimate. + purge_sample (Sequence[str]): + A sample of the Context names that will be deleted. Only + populated if ``force`` is set to false. The maximum number + of samples is 100 (it is possible to return fewer). + """ + + purge_count = proto.Field(proto.INT64, number=1,) + purge_sample = proto.RepeatedField(proto.STRING, number=2,) + + +class PurgeContextsMetadata(proto.Message): + r"""Details of operations that perform + [MetadataService.PurgeContexts][google.cloud.aiplatform.v1beta1.MetadataService.PurgeContexts]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for purging Contexts. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) class AddContextArtifactsAndExecutionsRequest(proto.Message): @@ -778,6 +923,81 @@ class UpdateExecutionRequest(proto.Message): allow_missing = proto.Field(proto.BOOL, number=3,) +class DeleteExecutionRequest(proto.Message): + r"""Request message for + [MetadataService.DeleteExecution][google.cloud.aiplatform.v1beta1.MetadataService.DeleteExecution]. + + Attributes: + name (str): + Required. The resource name of the Execution + to delete. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution} + etag (str): + Optional. The etag of the Execution to delete. If this is + provided, it must match the server's etag. Otherwise, the + request will fail with a FAILED_PRECONDITION. + """ + + name = proto.Field(proto.STRING, number=1,) + etag = proto.Field(proto.STRING, number=2,) + + +class PurgeExecutionsRequest(proto.Message): + r"""Request message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + + Attributes: + parent (str): + Required. The metadata store to purge + Executions from. Format: + projects/{project}/locations/{location}/metadataStores/{metadatastore} + filter (str): + Required. A required filter matching the Executions to be + purged. E.g., update_time <= 2020-11-19T11:30:00-04:00. + force (bool): + Optional. Flag to indicate to actually perform the purge. If + ``force`` is set to false, the method will return a sample + of Execution names that would be deleted. + """ + + parent = proto.Field(proto.STRING, number=1,) + filter = proto.Field(proto.STRING, number=2,) + force = proto.Field(proto.BOOL, number=3,) + + +class PurgeExecutionsResponse(proto.Message): + r"""Response message for + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + + Attributes: + purge_count (int): + The number of Executions that this request deleted (or, if + ``force`` is false, the number of Executions that will be + deleted). This can be an estimate. + purge_sample (Sequence[str]): + A sample of the Execution names that will be deleted. Only + populated if ``force`` is set to false. The maximum number + of samples is 100 (it is possible to return fewer). + """ + + purge_count = proto.Field(proto.INT64, number=1,) + purge_sample = proto.RepeatedField(proto.STRING, number=2,) + + +class PurgeExecutionsMetadata(proto.Message): + r"""Details of operations that perform + [MetadataService.PurgeExecutions][google.cloud.aiplatform.v1beta1.MetadataService.PurgeExecutions]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for purging Executions. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + class AddExecutionEventsRequest(proto.Message): r"""Request message for [MetadataService.AddExecutionEvents][google.cloud.aiplatform.v1beta1.MetadataService.AddExecutionEvents]. diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py index baa837c6eb..94ea052f35 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_monitoring.py +++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring.py @@ -42,6 +42,9 @@ class ModelMonitoringObjectiveConfig(proto.Message): prediction data. prediction_drift_detection_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig): The config for drift of prediction data. + explanation_config (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.ExplanationConfig): + The config for integrated with Explainable + AI. """ class TrainingDataset(proto.Message): @@ -102,11 +105,19 @@ class TrainingPredictionSkewDetectionConfig(proto.Message): for that feature. The threshold here is against feature distribution distance between the training and prediction feature. + attribution_score_skew_thresholds (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig.AttributionScoreSkewThresholdsEntry]): + Key is the feature name and value is the + threshold. The threshold here is against + attribution score distance between the training + and prediction feature. """ skew_thresholds = proto.MapField( proto.STRING, proto.MESSAGE, number=1, message="ThresholdConfig", ) + attribution_score_skew_thresholds = proto.MapField( + proto.STRING, proto.MESSAGE, number=2, message="ThresholdConfig", + ) class PredictionDriftDetectionConfig(proto.Message): r"""The config for Prediction data drift detection. @@ -118,11 +129,82 @@ class PredictionDriftDetectionConfig(proto.Message): for that feature. The threshold here is against feature distribution distance between different time windws. + attribution_score_drift_thresholds (Sequence[google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig.AttributionScoreDriftThresholdsEntry]): + Key is the feature name and value is the + threshold. The threshold here is against + attribution score distance between different + time windows. """ drift_thresholds = proto.MapField( proto.STRING, proto.MESSAGE, number=1, message="ThresholdConfig", ) + attribution_score_drift_thresholds = proto.MapField( + proto.STRING, proto.MESSAGE, number=2, message="ThresholdConfig", + ) + + class ExplanationConfig(proto.Message): + r"""The config for integrated with Explainable AI. Only applicable if + the Model has explanation_spec populated. + + Attributes: + enable_feature_attributes (bool): + If want to analyze the Explainable AI feature + attribute scores or not. If set to true, Vertex + AI will log the feature attributions from + explain response and do the skew/drift detection + for them. + explanation_baseline (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.ExplanationConfig.ExplanationBaseline): + Predictions generated by the + BatchPredictionJob using baseline dataset. + """ + + class ExplanationBaseline(proto.Message): + r"""Output from + [BatchPredictionJob][google.cloud.aiplatform.v1beta1.BatchPredictionJob] + for Model Monitoring baseline dataset, which can be used to generate + baseline attribution scores. + + Attributes: + gcs (google.cloud.aiplatform_v1beta1.types.GcsDestination): + Cloud Storage location for BatchExplain + output. + bigquery (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): + BigQuery location for BatchExplain output. + prediction_format (google.cloud.aiplatform_v1beta1.types.ModelMonitoringObjectiveConfig.ExplanationConfig.ExplanationBaseline.PredictionFormat): + The storage format of the predictions + generated BatchPrediction job. + """ + + class PredictionFormat(proto.Enum): + r"""The storage format of the predictions generated + BatchPrediction job. + """ + PREDICTION_FORMAT_UNSPECIFIED = 0 + JSONL = 2 + BIGQUERY = 3 + + gcs = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, + ) + bigquery = proto.Field( + proto.MESSAGE, + number=3, + oneof="destination", + message=io.BigQueryDestination, + ) + prediction_format = proto.Field( + proto.ENUM, + number=1, + enum="ModelMonitoringObjectiveConfig.ExplanationConfig.ExplanationBaseline.PredictionFormat", + ) + + enable_feature_attributes = proto.Field(proto.BOOL, number=1,) + explanation_baseline = proto.Field( + proto.MESSAGE, + number=2, + message="ModelMonitoringObjectiveConfig.ExplanationConfig.ExplanationBaseline", + ) training_dataset = proto.Field(proto.MESSAGE, number=1, message=TrainingDataset,) training_prediction_skew_detection_config = proto.Field( @@ -131,6 +213,9 @@ class PredictionDriftDetectionConfig(proto.Message): prediction_drift_detection_config = proto.Field( proto.MESSAGE, number=3, message=PredictionDriftDetectionConfig, ) + explanation_config = proto.Field( + proto.MESSAGE, number=5, message=ExplanationConfig, + ) class ModelMonitoringAlertConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index a4c868011d..31792d0f9a 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -83,21 +83,35 @@ class ListTrainingPipelinesRequest(proto.Message): TrainingPipelines from. Format: ``projects/{project}/locations/{location}`` filter (str): - The standard list filter. Supported fields: - - - ``display_name`` supports = and !=. - - - ``state`` supports = and !=. - - Some examples of using the filter are: - - - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` - - - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` - - - ``NOT display_name="my_pipeline"`` - - - ``state="PIPELINE_STATE_FAILED"`` + Lists the PipelineJobs that match the filter expression. The + following fields are supported: + + - ``pipeline_name``: Supports ``=`` and ``!=`` comparisons. + - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, + ``<=``, and ``>=`` comparisons. Values must be in RFC + 3339 format. + - ``labels``: Supports key-value equality and key presence. + + Filter expressions can be combined together using logical + operators (``AND`` & ``OR``). For example: + ``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``. + + The syntax to define filter expression is based on + https://google.aip.dev/160. + + Examples: + + - ``create_time>"2021-05-18T00:00:00Z" OR update_time>"2020-05-18T00:00:00Z"`` + PipelineJobs created or updated after 2020-05-18 00:00:00 + UTC. + - ``labels.env = "prod"`` PipelineJobs with label "env" set + to "prod". page_size (int): The standard list page size. page_token (str): diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 63e58e5f42..f43eae65f2 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -111,6 +111,12 @@ class Trial(proto.Message): Output only. The CustomJob name linked to the Trial. It's set for a HyperparameterTuningJob's Trial. + web_access_uris (Sequence[google.cloud.aiplatform_v1beta1.types.Trial.WebAccessUrisEntry]): + Output only. The web access URIs for the + training job. The keys are the node names in the + training jobs, e.g. workerpool0-0. The values + are the URIs for each node's web portal in the + job. """ class State(proto.Enum): @@ -151,6 +157,7 @@ class Parameter(proto.Message): client_id = proto.Field(proto.STRING, number=9,) infeasible_reason = proto.Field(proto.STRING, number=10,) custom_job = proto.Field(proto.STRING, number=11,) + web_access_uris = proto.MapField(proto.STRING, proto.STRING, number=12,) class StudySpec(proto.Message): @@ -277,10 +284,18 @@ class DoubleValueSpec(proto.Message): max_value (float): Required. Inclusive maximum value of the parameter. + default_value (float): + A default value for a ``DOUBLE`` parameter that is assumed + to be a relatively good starting point. Unset value signals + that there is no offered starting point. + + Currently only supported by the Vizier service. Not + supported by HyperparamterTuningJob or TrainingPipeline. """ min_value = proto.Field(proto.DOUBLE, number=1,) max_value = proto.Field(proto.DOUBLE, number=2,) + default_value = proto.Field(proto.DOUBLE, number=4, optional=True,) class IntegerValueSpec(proto.Message): r"""Value specification for a parameter in ``INTEGER`` type. @@ -291,19 +306,35 @@ class IntegerValueSpec(proto.Message): max_value (int): Required. Inclusive maximum value of the parameter. + default_value (int): + A default value for an ``INTEGER`` parameter that is assumed + to be a relatively good starting point. Unset value signals + that there is no offered starting point. + + Currently only supported by the Vizier service. Not + supported by HyperparamterTuningJob or TrainingPipeline. """ min_value = proto.Field(proto.INT64, number=1,) max_value = proto.Field(proto.INT64, number=2,) + default_value = proto.Field(proto.INT64, number=4, optional=True,) class CategoricalValueSpec(proto.Message): r"""Value specification for a parameter in ``CATEGORICAL`` type. Attributes: values (Sequence[str]): Required. The list of possible categories. + default_value (str): + A default value for a ``CATEGORICAL`` parameter that is + assumed to be a relatively good starting point. Unset value + signals that there is no offered starting point. + + Currently only supported by the Vizier service. Not + supported by HyperparamterTuningJob or TrainingPipeline. """ values = proto.RepeatedField(proto.STRING, number=1,) + default_value = proto.Field(proto.STRING, number=3, optional=True,) class DiscreteValueSpec(proto.Message): r"""Value specification for a parameter in ``DISCRETE`` type. @@ -315,9 +346,18 @@ class DiscreteValueSpec(proto.Message): might have possible settings of 1.5, 2.5, and 4.0. This list should not contain more than 1,000 values. + default_value (float): + A default value for a ``DISCRETE`` parameter that is assumed + to be a relatively good starting point. Unset value signals + that there is no offered starting point. It automatically + rounds to the nearest feasible discrete point. + + Currently only supported by the Vizier service. Not + supported by HyperparamterTuningJob or TrainingPipeline. """ values = proto.RepeatedField(proto.DOUBLE, number=1,) + default_value = proto.Field(proto.DOUBLE, number=3, optional=True,) class ConditionalParameterSpec(proto.Message): r"""Represents a parameter spec with condition from its parent diff --git a/owlbot.py b/owlbot.py index 04ada1d6a4..e8df4c990b 100644 --- a/owlbot.py +++ b/owlbot.py @@ -14,7 +14,7 @@ """This script is used to synthesize generated parts of this library.""" -import os +import re import synthtool as s import synthtool.gcp as gcp @@ -36,11 +36,22 @@ "request.instances.extend(instances)", ) - # https://github.com/googleapis/gapic-generator-python/issues/672 + # Remove test_predict_flattened/test_predict_flattened_async due to gapic generator bug + # https://github.com/googleapis/gapic-generator-python/issues/414 s.replace( - library / f"google/cloud/aiplatform_{library.name}/services/endpoint_service/client.py", - "request.traffic_split.extend\(traffic_split\)", - "request.traffic_split = traffic_split", + library / f"tests/unit/gapic/aiplatform_{library.name}/test_prediction_service.py", + """def test_predict_flattened.*?def test_predict_flattened_error""", + "def test_predict_flattened_error", + flags=re.MULTILINE | re.DOTALL, + ) + + # Remove test_explain_flattened/test_explain_flattened_async due to gapic generator bug + # https://github.com/googleapis/gapic-generator-python/issues/414 + s.replace( + library / f"tests/unit/gapic/aiplatform_{library.name}/test_prediction_service.py", + """def test_explain_flattened.*?def test_explain_flattened_error""", + "def test_explain_flattened_error", + flags=re.MULTILINE | re.DOTALL, ) s.move( @@ -61,11 +72,6 @@ f"scripts/fixup_prediction_{library.name}_keywords.py", "google/cloud/aiplatform/__init__.py", f"google/cloud/aiplatform/{library.name}/schema/**/services/", - f"tests/unit/gapic/aiplatform_{library.name}/test_prediction_service.py", - f"tests/unit/gapic/definition_{library.name}/", - f"tests/unit/gapic/instance_{library.name}/", - f"tests/unit/gapic/params_{library.name}/", - f"tests/unit/gapic/prediction_{library.name}/", ], ) diff --git a/schema/predict/instance/noxfile.py b/schema/predict/instance/noxfile.py index f9e24efc02..c8d234927c 100644 --- a/schema/predict/instance/noxfile.py +++ b/schema/predict/instance/noxfile.py @@ -70,7 +70,7 @@ def cover(session): @nox.session(python=['3.6', '3.7']) def mypy(session): """Run the type checker.""" - session.install('mypy') + session.install('mypy', 'types-pkg_resources') session.install('.') session.run( 'mypy', diff --git a/schema/predict/instance/setup.py b/schema/predict/instance/setup.py index 42a274ed13..c6f5e50d64 100644 --- a/schema/predict/instance/setup.py +++ b/schema/predict/instance/setup.py @@ -34,7 +34,7 @@ platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( - 'google-api-core[grpc] >= 1.22.2, < 2.0.0dev', + 'google-api-core[grpc] >= 1.27.0, < 3.0.0dev', 'libcst >= 0.2.5', 'proto-plus >= 1.15.0', 'packaging >= 14.3', ), @@ -46,6 +46,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', 'Topic :: Internet', 'Topic :: Software Development :: Libraries :: Python Modules', ], diff --git a/schema/predict/params/noxfile.py b/schema/predict/params/noxfile.py index 90e4bfa1c9..e886ab4759 100644 --- a/schema/predict/params/noxfile.py +++ b/schema/predict/params/noxfile.py @@ -70,7 +70,7 @@ def cover(session): @nox.session(python=['3.6', '3.7']) def mypy(session): """Run the type checker.""" - session.install('mypy') + session.install('mypy', 'types-pkg_resources') session.install('.') session.run( 'mypy', diff --git a/schema/predict/params/setup.py b/schema/predict/params/setup.py index befe8ab40c..e382f09fe5 100644 --- a/schema/predict/params/setup.py +++ b/schema/predict/params/setup.py @@ -34,7 +34,7 @@ platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( - 'google-api-core[grpc] >= 1.22.2, < 2.0.0dev', + 'google-api-core[grpc] >= 1.27.0, < 3.0.0dev', 'libcst >= 0.2.5', 'proto-plus >= 1.15.0', 'packaging >= 14.3', ), @@ -46,6 +46,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', 'Topic :: Internet', 'Topic :: Software Development :: Libraries :: Python Modules', ], diff --git a/schema/predict/prediction/noxfile.py b/schema/predict/prediction/noxfile.py index dba807db56..8ea8ef006a 100644 --- a/schema/predict/prediction/noxfile.py +++ b/schema/predict/prediction/noxfile.py @@ -70,7 +70,7 @@ def cover(session): @nox.session(python=['3.6', '3.7']) def mypy(session): """Run the type checker.""" - session.install('mypy') + session.install('mypy', 'types-pkg_resources') session.install('.') session.run( 'mypy', diff --git a/schema/predict/prediction/setup.py b/schema/predict/prediction/setup.py index 99756ad5a4..f4ab906809 100644 --- a/schema/predict/prediction/setup.py +++ b/schema/predict/prediction/setup.py @@ -34,7 +34,7 @@ platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( - 'google-api-core[grpc] >= 1.22.2, < 2.0.0dev', + 'google-api-core[grpc] >= 1.27.0, < 3.0.0dev', 'libcst >= 0.2.5', 'proto-plus >= 1.15.0', 'packaging >= 14.3', ), @@ -46,6 +46,7 @@ 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', 'Topic :: Internet', 'Topic :: Software Development :: Libraries :: Python Modules', ], diff --git a/setup.py b/setup.py index 3cbf6c309a..b16e2ca8f3 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ # NOTE: Maintainers, please do not require google-api-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 - "google-api-core[grpc] >= 1.22.2, <3.0.0dev", + "google-api-core[grpc] >= 1.26.0, <3.0.0dev", "proto-plus >= 1.10.1", "packaging >= 14.3", "google-cloud-storage >= 1.32.0, < 2.0.0dev", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 9631174f70..fc7d641771 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,7 +5,7 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-api-core==1.22.2 +google-api-core==1.26.0 libcst==0.2.5 proto-plus==1.10.1 mock==4.0.2 diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index 6ebf22cd26..fa699fb54a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceClient from google.cloud.aiplatform_v1.services.dataset_service import pagers from google.cloud.aiplatform_v1.services.dataset_service import transports -from google.cloud.aiplatform_v1.services.dataset_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.dataset_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -63,8 +60,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -75,16 +73,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -147,6 +135,31 @@ def test_dataset_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DatasetServiceGrpcTransport, "grpc"), + (transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_dataset_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] ) @@ -226,6 +239,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -242,6 +256,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -258,6 +273,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -286,6 +302,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -352,6 +369,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -385,6 +403,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -406,6 +425,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -436,6 +456,7 @@ def test_dataset_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -466,6 +487,7 @@ def test_dataset_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -485,6 +507,7 @@ def test_dataset_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3302,7 +3325,6 @@ def test_dataset_service_transport_auth_adc_old_google_auth(transport_class): (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_dataset_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3331,79 +3353,6 @@ def test_dataset_service_transport_create_channel(transport_class, grpc_helpers) ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.DatasetServiceGrpcTransport, grpc_helpers), - (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_dataset_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.DatasetServiceGrpcTransport, grpc_helpers), - (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_dataset_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -3426,7 +3375,7 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3535,7 +3484,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3582,7 +3531,7 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 707ddf6fc0..03906b6fd9 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceClient from google.cloud.aiplatform_v1.services.endpoint_service import pagers from google.cloud.aiplatform_v1.services.endpoint_service import transports -from google.cloud.aiplatform_v1.services.endpoint_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.endpoint_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -60,8 +57,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -72,16 +70,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -144,6 +132,31 @@ def test_endpoint_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.EndpointServiceGrpcTransport, "grpc"), + (transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_endpoint_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] ) @@ -223,6 +236,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -239,6 +253,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -255,6 +270,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -283,6 +299,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -359,6 +376,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -392,6 +410,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -413,6 +432,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -443,6 +463,7 @@ def test_endpoint_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -473,6 +494,7 @@ def test_endpoint_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -492,6 +514,7 @@ def test_endpoint_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -2421,7 +2444,6 @@ def test_endpoint_service_transport_auth_adc_old_google_auth(transport_class): (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_endpoint_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -2450,79 +2472,6 @@ def test_endpoint_service_transport_create_channel(transport_class, grpc_helpers ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.EndpointServiceGrpcTransport, grpc_helpers), - (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_endpoint_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.EndpointServiceGrpcTransport, grpc_helpers), - (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_endpoint_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2545,7 +2494,7 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2654,7 +2603,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2701,7 +2650,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index a4e3d484e3..4f744eb39c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -38,9 +38,6 @@ from google.cloud.aiplatform_v1.services.job_service import JobServiceClient from google.cloud.aiplatform_v1.services.job_service import pagers from google.cloud.aiplatform_v1.services.job_service import transports -from google.cloud.aiplatform_v1.services.job_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.job_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -79,8 +76,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -91,16 +89,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -158,6 +146,31 @@ def test_job_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.JobServiceGrpcTransport, "grpc"), + (transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_job_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -233,6 +246,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -249,6 +263,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -265,6 +280,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -293,6 +309,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -357,6 +374,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -390,6 +408,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -411,6 +430,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -441,6 +461,7 @@ def test_job_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -471,6 +492,7 @@ def test_job_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -488,6 +510,7 @@ def test_job_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -5827,7 +5850,6 @@ def test_job_service_transport_auth_adc_old_google_auth(transport_class): (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_job_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -5856,79 +5878,6 @@ def test_job_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.JobServiceGrpcTransport, grpc_helpers), - (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_job_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.JobServiceGrpcTransport, grpc_helpers), - (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_job_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], @@ -5948,7 +5897,7 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6052,7 +6001,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6096,7 +6045,7 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 5645588f27..95535bb8b3 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1.services.migration_service import pagers from google.cloud.aiplatform_v1.services.migration_service import transports -from google.cloud.aiplatform_v1.services.migration_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.migration_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -53,8 +50,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -65,16 +63,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -138,6 +126,31 @@ def test_migration_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.MigrationServiceGrpcTransport, "grpc"), + (transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_migration_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] ) @@ -217,6 +230,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -233,6 +247,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -249,6 +264,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -277,6 +293,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -353,6 +370,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -386,6 +404,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -407,6 +426,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -437,6 +457,7 @@ def test_migration_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -467,6 +488,7 @@ def test_migration_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -486,6 +508,7 @@ def test_migration_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -1387,7 +1410,6 @@ def test_migration_service_transport_auth_adc_old_google_auth(transport_class): (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_migration_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1416,79 +1438,6 @@ def test_migration_service_transport_create_channel(transport_class, grpc_helper ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MigrationServiceGrpcTransport, grpc_helpers), - (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_migration_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MigrationServiceGrpcTransport, grpc_helpers), - (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_migration_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -1511,7 +1460,7 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1620,7 +1569,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1667,7 +1616,7 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index bfc0d94a16..5ae709d4d2 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -38,9 +38,6 @@ from google.cloud.aiplatform_v1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1.services.model_service import pagers from google.cloud.aiplatform_v1.services.model_service import transports -from google.cloud.aiplatform_v1.services.model_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.model_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -62,8 +59,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -74,16 +72,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -141,6 +129,31 @@ def test_model_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_model_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -216,6 +229,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -232,6 +246,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -248,6 +263,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -276,6 +292,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -340,6 +357,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -373,6 +391,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -394,6 +413,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -424,6 +444,7 @@ def test_model_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -454,6 +475,7 @@ def test_model_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -471,6 +493,7 @@ def test_model_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3384,7 +3407,6 @@ def test_model_service_transport_auth_adc_old_google_auth(transport_class): (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_model_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3413,79 +3435,6 @@ def test_model_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_model_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_model_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], @@ -3505,7 +3454,7 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3609,7 +3558,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3653,7 +3602,7 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 0a6dec9583..3a09d7a0c0 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceClient from google.cloud.aiplatform_v1.services.pipeline_service import pagers from google.cloud.aiplatform_v1.services.pipeline_service import transports -from google.cloud.aiplatform_v1.services.pipeline_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.pipeline_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -72,8 +69,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -84,16 +82,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -156,6 +144,31 @@ def test_pipeline_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PipelineServiceGrpcTransport, "grpc"), + (transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_pipeline_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] ) @@ -235,6 +248,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -251,6 +265,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -267,6 +282,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -295,6 +311,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -371,6 +388,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -404,6 +422,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -425,6 +444,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -455,6 +475,7 @@ def test_pipeline_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -485,6 +506,7 @@ def test_pipeline_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -504,6 +526,7 @@ def test_pipeline_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3331,7 +3354,6 @@ def test_pipeline_service_transport_auth_adc_old_google_auth(transport_class): (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_pipeline_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3360,79 +3382,6 @@ def test_pipeline_service_transport_create_channel(transport_class, grpc_helpers ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PipelineServiceGrpcTransport, grpc_helpers), - (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_pipeline_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PipelineServiceGrpcTransport, grpc_helpers), - (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_pipeline_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -3455,7 +3404,7 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3564,7 +3513,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3611,7 +3560,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py index 60deba2f05..ef42abe57e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_prediction_service.py @@ -38,9 +38,6 @@ PredictionServiceClient, ) from google.cloud.aiplatform_v1.services.prediction_service import transports -from google.cloud.aiplatform_v1.services.prediction_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.prediction_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -50,8 +47,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -62,16 +60,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -135,6 +123,31 @@ def test_prediction_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PredictionServiceGrpcTransport, "grpc"), + (transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_prediction_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,] ) @@ -214,6 +227,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -230,6 +244,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -246,6 +261,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -274,6 +290,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -350,6 +367,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -383,6 +401,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -404,6 +423,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -434,6 +454,7 @@ def test_prediction_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -464,6 +485,7 @@ def test_prediction_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -483,6 +505,7 @@ def test_prediction_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -624,33 +647,6 @@ async def test_predict_field_headers_async(): assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] -def test_predict_flattened(): - client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.predict( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - - def test_predict_flattened_error(): client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) @@ -659,45 +655,11 @@ def test_predict_flattened_error(): with pytest.raises(ValueError): client.predict( prediction_service.PredictRequest(), - endpoint="endpoint_value", - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - ) - - -@pytest.mark.asyncio -async def test_predict_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.PredictResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.predict( endpoint="endpoint_value", instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), ) - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - @pytest.mark.asyncio async def test_predict_flattened_error_async(): @@ -711,8 +673,8 @@ async def test_predict_flattened_error_async(): await client.predict( prediction_service.PredictRequest(), endpoint="endpoint_value", - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), ) @@ -943,7 +905,6 @@ def test_prediction_service_transport_auth_adc_old_google_auth(transport_class): (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_prediction_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -972,79 +933,6 @@ def test_prediction_service_transport_create_channel(transport_class, grpc_helpe ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PredictionServiceGrpcTransport, grpc_helpers), - (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_prediction_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PredictionServiceGrpcTransport, grpc_helpers), - (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_prediction_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -1067,7 +955,7 @@ def test_prediction_service_grpc_transport_client_cert_source_for_mtls(transport "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1176,7 +1064,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1223,7 +1111,7 @@ def test_prediction_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index 2f54d4b68c..a0658f77af 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1.services.specialist_pool_service import transports -from google.cloud.aiplatform_v1.services.specialist_pool_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1.services.specialist_pool_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -58,8 +55,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -70,16 +68,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -143,6 +131,31 @@ def test_specialist_pool_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_specialist_pool_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] ) @@ -226,6 +239,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -242,6 +256,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -258,6 +273,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -286,6 +302,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -362,6 +379,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -395,6 +413,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -416,6 +435,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -450,6 +470,7 @@ def test_specialist_pool_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -484,6 +505,7 @@ def test_specialist_pool_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -503,6 +525,7 @@ def test_specialist_pool_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -2100,7 +2123,6 @@ def test_specialist_pool_service_transport_auth_adc_old_google_auth(transport_cl (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_specialist_pool_service_transport_create_channel( transport_class, grpc_helpers ): @@ -2131,79 +2153,6 @@ def test_specialist_pool_service_transport_create_channel( ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.SpecialistPoolServiceGrpcTransport, grpc_helpers), - (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_specialist_pool_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.SpecialistPoolServiceGrpcTransport, grpc_helpers), - (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_specialist_pool_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2228,7 +2177,7 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2337,7 +2286,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2384,7 +2333,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 03b3c97547..0d91446271 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports -from google.cloud.aiplatform_v1beta1.services.dataset_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.dataset_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -65,8 +62,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -77,16 +75,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -149,6 +137,31 @@ def test_dataset_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DatasetServiceGrpcTransport, "grpc"), + (transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_dataset_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] ) @@ -228,6 +241,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -244,6 +258,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -260,6 +275,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -288,6 +304,7 @@ def test_dataset_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -354,6 +371,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -387,6 +405,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -408,6 +427,7 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -438,6 +458,7 @@ def test_dataset_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -468,6 +489,7 @@ def test_dataset_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -487,6 +509,7 @@ def test_dataset_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3304,7 +3327,6 @@ def test_dataset_service_transport_auth_adc_old_google_auth(transport_class): (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_dataset_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3333,79 +3355,6 @@ def test_dataset_service_transport_create_channel(transport_class, grpc_helpers) ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.DatasetServiceGrpcTransport, grpc_helpers), - (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_dataset_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.DatasetServiceGrpcTransport, grpc_helpers), - (transports.DatasetServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_dataset_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -3428,7 +3377,7 @@ def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_cl "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3537,7 +3486,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3584,7 +3533,7 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index de53fedc34..4fdc429249 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports -from google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -65,8 +62,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -77,16 +75,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -149,6 +137,31 @@ def test_endpoint_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.EndpointServiceGrpcTransport, "grpc"), + (transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_endpoint_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] ) @@ -228,6 +241,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -244,6 +258,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -260,6 +275,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -288,6 +304,7 @@ def test_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -364,6 +381,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -397,6 +415,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -418,6 +437,7 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -448,6 +468,7 @@ def test_endpoint_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -478,6 +499,7 @@ def test_endpoint_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -497,6 +519,7 @@ def test_endpoint_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -729,6 +752,7 @@ def test_get_endpoint( display_name="display_name_value", description="description_value", etag="etag_value", + network="network_value", ) response = client.get_endpoint(request) @@ -743,6 +767,7 @@ def test_get_endpoint( assert response.display_name == "display_name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.network == "network_value" def test_get_endpoint_from_dict(): @@ -785,6 +810,7 @@ async def test_get_endpoint_async( display_name="display_name_value", description="description_value", etag="etag_value", + network="network_value", ) ) response = await client.get_endpoint(request) @@ -800,6 +826,7 @@ async def test_get_endpoint_async( assert response.display_name == "display_name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.network == "network_value" @pytest.mark.asyncio @@ -1299,6 +1326,7 @@ def test_update_endpoint( display_name="display_name_value", description="description_value", etag="etag_value", + network="network_value", ) response = client.update_endpoint(request) @@ -1313,6 +1341,7 @@ def test_update_endpoint( assert response.display_name == "display_name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.network == "network_value" def test_update_endpoint_from_dict(): @@ -1355,6 +1384,7 @@ async def test_update_endpoint_async( display_name="display_name_value", description="description_value", etag="etag_value", + network="network_value", ) ) response = await client.update_endpoint(request) @@ -1370,6 +1400,7 @@ async def test_update_endpoint_async( assert response.display_name == "display_name_value" assert response.description == "description_value" assert response.etag == "etag_value" + assert response.network == "network_value" @pytest.mark.asyncio @@ -2426,7 +2457,6 @@ def test_endpoint_service_transport_auth_adc_old_google_auth(transport_class): (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_endpoint_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -2455,79 +2485,6 @@ def test_endpoint_service_transport_create_channel(transport_class, grpc_helpers ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.EndpointServiceGrpcTransport, grpc_helpers), - (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_endpoint_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.EndpointServiceGrpcTransport, grpc_helpers), - (transports.EndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_endpoint_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2550,7 +2507,7 @@ def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_c "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2659,7 +2616,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2706,7 +2663,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2791,8 +2748,30 @@ def test_parse_model_path(): assert expected == actual +def test_network_path(): + project = "squid" + network = "clam" + expected = "projects/{project}/global/networks/{network}".format( + project=project, network=network, + ) + actual = EndpointServiceClient.network_path(project, network) + assert expected == actual + + +def test_parse_network_path(): + expected = { + "project": "whelk", + "network": "octopus", + } + path = EndpointServiceClient.network_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_network_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "oyster" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -2802,7 +2781,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "nudibranch", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2812,7 +2791,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "cuttlefish" expected = "folders/{folder}".format(folder=folder,) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual @@ -2820,7 +2799,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "mussel", } path = EndpointServiceClient.common_folder_path(**expected) @@ -2830,7 +2809,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "winkle" expected = "organizations/{organization}".format(organization=organization,) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual @@ -2838,7 +2817,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nautilus", } path = EndpointServiceClient.common_organization_path(**expected) @@ -2848,7 +2827,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "scallop" expected = "projects/{project}".format(project=project,) actual = EndpointServiceClient.common_project_path(project) assert expected == actual @@ -2856,7 +2835,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "abalone", } path = EndpointServiceClient.common_project_path(**expected) @@ -2866,8 +2845,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "squid" + location = "clam" expected = "projects/{project}/locations/{location}".format( project=project, location=location, ) @@ -2877,8 +2856,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "whelk", + "location": "octopus", } path = EndpointServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py index 6645e05944..8a4859b834 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_online_serving_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import ( transports, ) -from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -52,8 +49,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -64,16 +62,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -151,6 +139,34 @@ def test_featurestore_online_serving_service_client_from_service_account_info( assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.FeaturestoreOnlineServingServiceGrpcTransport, "grpc"), + ( + transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_featurestore_online_serving_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [ @@ -244,6 +260,7 @@ def test_featurestore_online_serving_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -260,6 +277,7 @@ def test_featurestore_online_serving_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -276,6 +294,7 @@ def test_featurestore_online_serving_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -304,6 +323,7 @@ def test_featurestore_online_serving_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -380,6 +400,7 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -413,6 +434,7 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -434,6 +456,7 @@ def test_featurestore_online_serving_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -468,6 +491,7 @@ def test_featurestore_online_serving_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -502,6 +526,7 @@ def test_featurestore_online_serving_service_client_client_options_credentials_f client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -521,6 +546,7 @@ def test_featurestore_online_serving_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -1231,7 +1257,6 @@ def test_featurestore_online_serving_service_transport_auth_adc_old_google_auth( ), ], ) -@requires_api_core_gte_1_26_0 def test_featurestore_online_serving_service_transport_create_channel( transport_class, grpc_helpers ): @@ -1262,85 +1287,6 @@ def test_featurestore_online_serving_service_transport_create_channel( ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.FeaturestoreOnlineServingServiceGrpcTransport, grpc_helpers), - ( - transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, - grpc_helpers_async, - ), - ], -) -@requires_api_core_lt_1_26_0 -def test_featurestore_online_serving_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.FeaturestoreOnlineServingServiceGrpcTransport, grpc_helpers), - ( - transports.FeaturestoreOnlineServingServiceGrpcAsyncIOTransport, - grpc_helpers_async, - ), - ], -) -@requires_api_core_lt_1_26_0 -def test_featurestore_online_serving_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -1365,7 +1311,7 @@ def test_featurestore_online_serving_service_grpc_transport_client_cert_source_f "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1474,7 +1420,7 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_client_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1523,7 +1469,7 @@ def test_featurestore_online_serving_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py index 5721f569ac..869271f31f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_featurestore_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.featurestore_service import pagers from google.cloud.aiplatform_v1beta1.services.featurestore_service import transports -from google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.featurestore_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -68,8 +65,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -80,16 +78,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -153,6 +141,31 @@ def test_featurestore_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.FeaturestoreServiceGrpcTransport, "grpc"), + (transports.FeaturestoreServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_featurestore_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [FeaturestoreServiceClient, FeaturestoreServiceAsyncClient,] ) @@ -236,6 +249,7 @@ def test_featurestore_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -252,6 +266,7 @@ def test_featurestore_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -268,6 +283,7 @@ def test_featurestore_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -296,6 +312,7 @@ def test_featurestore_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -372,6 +389,7 @@ def test_featurestore_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -405,6 +423,7 @@ def test_featurestore_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -426,6 +445,7 @@ def test_featurestore_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -460,6 +480,7 @@ def test_featurestore_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -494,6 +515,7 @@ def test_featurestore_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -513,6 +535,7 @@ def test_featurestore_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -5876,7 +5899,6 @@ def test_featurestore_service_transport_auth_adc_old_google_auth(transport_class (transports.FeaturestoreServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_featurestore_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -5905,79 +5927,6 @@ def test_featurestore_service_transport_create_channel(transport_class, grpc_hel ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.FeaturestoreServiceGrpcTransport, grpc_helpers), - (transports.FeaturestoreServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_featurestore_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.FeaturestoreServiceGrpcTransport, grpc_helpers), - (transports.FeaturestoreServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_featurestore_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -6002,7 +5951,7 @@ def test_featurestore_service_grpc_transport_client_cert_source_for_mtls( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -6111,7 +6060,7 @@ def test_featurestore_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -6158,7 +6107,7 @@ def test_featurestore_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py index 8387d2a3b1..5d88a0353f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_endpoint_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import transports -from google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.index_endpoint_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -60,8 +57,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -72,16 +70,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -145,6 +133,31 @@ def test_index_endpoint_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.IndexEndpointServiceGrpcTransport, "grpc"), + (transports.IndexEndpointServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_index_endpoint_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [IndexEndpointServiceClient, IndexEndpointServiceAsyncClient,] ) @@ -228,6 +241,7 @@ def test_index_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -244,6 +258,7 @@ def test_index_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -260,6 +275,7 @@ def test_index_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -288,6 +304,7 @@ def test_index_endpoint_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -364,6 +381,7 @@ def test_index_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -397,6 +415,7 @@ def test_index_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -418,6 +437,7 @@ def test_index_endpoint_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -452,6 +472,7 @@ def test_index_endpoint_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -486,6 +507,7 @@ def test_index_endpoint_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -505,6 +527,7 @@ def test_index_endpoint_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -2571,7 +2594,6 @@ def test_index_endpoint_service_transport_auth_adc_old_google_auth(transport_cla (transports.IndexEndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_index_endpoint_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -2600,79 +2622,6 @@ def test_index_endpoint_service_transport_create_channel(transport_class, grpc_h ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.IndexEndpointServiceGrpcTransport, grpc_helpers), - (transports.IndexEndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_index_endpoint_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.IndexEndpointServiceGrpcTransport, grpc_helpers), - (transports.IndexEndpointServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_index_endpoint_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2697,7 +2646,7 @@ def test_index_endpoint_service_grpc_transport_client_cert_source_for_mtls( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2806,7 +2755,7 @@ def test_index_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2853,7 +2802,7 @@ def test_index_endpoint_service_transport_channel_mtls_with_adc(transport_class) "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py index 4996d1a173..e9c885be1e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_index_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1beta1.services.index_service import IndexServiceClient from google.cloud.aiplatform_v1beta1.services.index_service import pagers from google.cloud.aiplatform_v1beta1.services.index_service import transports -from google.cloud.aiplatform_v1beta1.services.index_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.index_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -59,8 +56,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -71,16 +69,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -138,6 +126,31 @@ def test_index_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.IndexServiceGrpcTransport, "grpc"), + (transports.IndexServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_index_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize("client_class", [IndexServiceClient, IndexServiceAsyncClient,]) def test_index_service_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -213,6 +226,7 @@ def test_index_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -229,6 +243,7 @@ def test_index_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -245,6 +260,7 @@ def test_index_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -273,6 +289,7 @@ def test_index_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -337,6 +354,7 @@ def test_index_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -370,6 +388,7 @@ def test_index_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -391,6 +410,7 @@ def test_index_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -421,6 +441,7 @@ def test_index_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -451,6 +472,7 @@ def test_index_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -468,6 +490,7 @@ def test_index_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -1835,7 +1858,6 @@ def test_index_service_transport_auth_adc_old_google_auth(transport_class): (transports.IndexServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_index_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1864,79 +1886,6 @@ def test_index_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.IndexServiceGrpcTransport, grpc_helpers), - (transports.IndexServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_index_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.IndexServiceGrpcTransport, grpc_helpers), - (transports.IndexServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_index_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.IndexServiceGrpcTransport, transports.IndexServiceGrpcAsyncIOTransport], @@ -1956,7 +1905,7 @@ def test_index_service_grpc_transport_client_cert_source_for_mtls(transport_clas "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2060,7 +2009,7 @@ def test_index_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2104,7 +2053,7 @@ def test_index_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 700ebee54b..9c13ce8f9e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -38,9 +38,6 @@ from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.services.job_service import transports -from google.cloud.aiplatform_v1beta1.services.job_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.job_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -87,8 +84,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -99,16 +97,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -166,6 +154,31 @@ def test_job_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.JobServiceGrpcTransport, "grpc"), + (transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_job_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -241,6 +254,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -257,6 +271,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -273,6 +288,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -301,6 +317,7 @@ def test_job_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -365,6 +382,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -398,6 +416,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -419,6 +438,7 @@ def test_job_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -449,6 +469,7 @@ def test_job_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -479,6 +500,7 @@ def test_job_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -496,6 +518,7 @@ def test_job_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -8137,7 +8160,6 @@ def test_job_service_transport_auth_adc_old_google_auth(transport_class): (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_job_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -8166,79 +8188,6 @@ def test_job_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.JobServiceGrpcTransport, grpc_helpers), - (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_job_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.JobServiceGrpcTransport, grpc_helpers), - (transports.JobServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_job_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], @@ -8258,7 +8207,7 @@ def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class) "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -8362,7 +8311,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -8406,7 +8355,7 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py index 2b2ce80c5c..c7f2bc84c7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_metadata_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.metadata_service import pagers from google.cloud.aiplatform_v1beta1.services.metadata_service import transports -from google.cloud.aiplatform_v1beta1.services.metadata_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.metadata_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -71,8 +68,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -83,16 +81,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -155,6 +143,31 @@ def test_metadata_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.MetadataServiceGrpcTransport, "grpc"), + (transports.MetadataServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_metadata_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [MetadataServiceClient, MetadataServiceAsyncClient,] ) @@ -234,6 +247,7 @@ def test_metadata_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -250,6 +264,7 @@ def test_metadata_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -266,6 +281,7 @@ def test_metadata_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -294,6 +310,7 @@ def test_metadata_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -370,6 +387,7 @@ def test_metadata_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -403,6 +421,7 @@ def test_metadata_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -424,6 +443,7 @@ def test_metadata_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -454,6 +474,7 @@ def test_metadata_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -484,6 +505,7 @@ def test_metadata_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -503,6 +525,7 @@ def test_metadata_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -2664,8 +2687,8 @@ async def test_update_artifact_flattened_error_async(): ) -def test_create_context( - transport: str = "grpc", request_type=metadata_service.CreateContextRequest +def test_delete_artifact( + transport: str = "grpc", request_type=metadata_service.DeleteArtifactRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2676,40 +2699,25 @@ def test_create_context( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_context.Context( - name="name_value", - display_name="display_name_value", - etag="etag_value", - parent_contexts=["parent_contexts_value"], - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) - response = client.create_context(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_artifact(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateContextRequest() + assert args[0] == metadata_service.DeleteArtifactRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_context.Context) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) -def test_create_context_from_dict(): - test_create_context(request_type=dict) +def test_delete_artifact_from_dict(): + test_delete_artifact(request_type=dict) -def test_create_context_empty_call(): +def test_delete_artifact_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -2717,16 +2725,16 @@ def test_create_context_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: - client.create_context() + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + client.delete_artifact() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateContextRequest() + assert args[0] == metadata_service.DeleteArtifactRequest() @pytest.mark.asyncio -async def test_create_context_async( - transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest +async def test_delete_artifact_async( + transport: str = "grpc_asyncio", request_type=metadata_service.DeleteArtifactRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2737,55 +2745,40 @@ async def test_create_context_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_context.Context( - name="name_value", - display_name="display_name_value", - etag="etag_value", - parent_contexts=["parent_contexts_value"], - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.create_context(request) + response = await client.delete_artifact(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateContextRequest() + assert args[0] == metadata_service.DeleteArtifactRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_context.Context) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_create_context_async_from_dict(): - await test_create_context_async(request_type=dict) +async def test_delete_artifact_async_from_dict(): + await test_delete_artifact_async(request_type=dict) -def test_create_context_field_headers(): +def test_delete_artifact_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.CreateContextRequest() + request = metadata_service.DeleteArtifactRequest() - request.parent = "parent/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: - call.return_value = gca_context.Context() - client.create_context(request) + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_artifact(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -2794,25 +2787,27 @@ def test_create_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_create_context_field_headers_async(): +async def test_delete_artifact_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.CreateContextRequest() + request = metadata_service.DeleteArtifactRequest() - request.parent = "parent/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) - await client.create_context(request) + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_artifact(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -2821,78 +2816,65 @@ async def test_create_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_create_context_flattened(): +def test_delete_artifact_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_context.Context() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.create_context( - parent="parent_value", - context=gca_context.Context(name="name_value"), - context_id="context_id_value", - ) + client.delete_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == "context_id_value" + assert args[0].name == "name_value" -def test_create_context_flattened_error(): +def test_delete_artifact_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_context( - metadata_service.CreateContextRequest(), - parent="parent_value", - context=gca_context.Context(name="name_value"), - context_id="context_id_value", + client.delete_artifact( + metadata_service.DeleteArtifactRequest(), name="name_value", ) @pytest.mark.asyncio -async def test_create_context_flattened_async(): +async def test_delete_artifact_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_context), "__call__") as call: + with mock.patch.object(type(client.transport.delete_artifact), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_context.Context() + call.return_value = operations_pb2.Operation(name="operations/op") - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.create_context( - parent="parent_value", - context=gca_context.Context(name="name_value"), - context_id="context_id_value", - ) + response = await client.delete_artifact(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].context == gca_context.Context(name="name_value") - assert args[0].context_id == "context_id_value" + assert args[0].name == "name_value" @pytest.mark.asyncio -async def test_create_context_flattened_error_async(): +async def test_delete_artifact_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -2900,16 +2882,13 @@ async def test_create_context_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.create_context( - metadata_service.CreateContextRequest(), - parent="parent_value", - context=gca_context.Context(name="name_value"), - context_id="context_id_value", + await client.delete_artifact( + metadata_service.DeleteArtifactRequest(), name="name_value", ) -def test_get_context( - transport: str = "grpc", request_type=metadata_service.GetContextRequest +def test_purge_artifacts( + transport: str = "grpc", request_type=metadata_service.PurgeArtifactsRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2920,40 +2899,25 @@ def test_get_context( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = context.Context( - name="name_value", - display_name="display_name_value", - etag="etag_value", - parent_contexts=["parent_contexts_value"], - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) - response = client.get_context(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.purge_artifacts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetContextRequest() + assert args[0] == metadata_service.PurgeArtifactsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, context.Context) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) -def test_get_context_from_dict(): - test_get_context(request_type=dict) +def test_purge_artifacts_from_dict(): + test_purge_artifacts(request_type=dict) -def test_get_context_empty_call(): +def test_purge_artifacts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -2961,16 +2925,16 @@ def test_get_context_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: - client.get_context() + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + client.purge_artifacts() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetContextRequest() + assert args[0] == metadata_service.PurgeArtifactsRequest() @pytest.mark.asyncio -async def test_get_context_async( - transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest +async def test_purge_artifacts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.PurgeArtifactsRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2981,55 +2945,40 @@ async def test_get_context_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - context.Context( - name="name_value", - display_name="display_name_value", - etag="etag_value", - parent_contexts=["parent_contexts_value"], - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.get_context(request) + response = await client.purge_artifacts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetContextRequest() + assert args[0] == metadata_service.PurgeArtifactsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, context.Context) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_get_context_async_from_dict(): - await test_get_context_async(request_type=dict) +async def test_purge_artifacts_async_from_dict(): + await test_purge_artifacts_async(request_type=dict) -def test_get_context_field_headers(): +def test_purge_artifacts_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.GetContextRequest() + request = metadata_service.PurgeArtifactsRequest() - request.name = "name/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: - call.return_value = context.Context() - client.get_context(request) + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.purge_artifacts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3038,25 +2987,27 @@ def test_get_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_get_context_field_headers_async(): +async def test_purge_artifacts_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.GetContextRequest() + request = metadata_service.PurgeArtifactsRequest() - request.name = "name/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) - await client.get_context(request) + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.purge_artifacts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3065,63 +3016,65 @@ async def test_get_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_get_context_flattened(): +def test_purge_artifacts_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = context.Context() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_context(name="name_value",) + client.purge_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].parent == "parent_value" -def test_get_context_flattened_error(): +def test_purge_artifacts_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_context( - metadata_service.GetContextRequest(), name="name_value", + client.purge_artifacts( + metadata_service.PurgeArtifactsRequest(), parent="parent_value", ) @pytest.mark.asyncio -async def test_get_context_flattened_async(): +async def test_purge_artifacts_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_context), "__call__") as call: + with mock.patch.object(type(client.transport.purge_artifacts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = context.Context() + call.return_value = operations_pb2.Operation(name="operations/op") - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_context(name="name_value",) + response = await client.purge_artifacts(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].parent == "parent_value" @pytest.mark.asyncio -async def test_get_context_flattened_error_async(): +async def test_purge_artifacts_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3129,13 +3082,13 @@ async def test_get_context_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.get_context( - metadata_service.GetContextRequest(), name="name_value", + await client.purge_artifacts( + metadata_service.PurgeArtifactsRequest(), parent="parent_value", ) -def test_list_contexts( - transport: str = "grpc", request_type=metadata_service.ListContextsRequest +def test_create_context( + transport: str = "grpc", request_type=metadata_service.CreateContextRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3146,28 +3099,40 @@ def test_list_contexts( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListContextsResponse( - next_page_token="next_page_token_value", + call.return_value = gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) - response = client.list_contexts(request) + response = client.create_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListContextsRequest() + assert args[0] == metadata_service.CreateContextRequest() # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListContextsPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, gca_context.Context) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + assert response.parent_contexts == ["parent_contexts_value"] + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" -def test_list_contexts_from_dict(): - test_list_contexts(request_type=dict) +def test_create_context_from_dict(): + test_create_context(request_type=dict) -def test_list_contexts_empty_call(): +def test_create_context_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -3175,16 +3140,16 @@ def test_list_contexts_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: - client.list_contexts() + with mock.patch.object(type(client.transport.create_context), "__call__") as call: + client.create_context() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListContextsRequest() + assert args[0] == metadata_service.CreateContextRequest() @pytest.mark.asyncio -async def test_list_contexts_async( - transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest +async def test_create_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.CreateContextRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3195,43 +3160,55 @@ async def test_list_contexts_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListContextsResponse( - next_page_token="next_page_token_value", + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) ) - response = await client.list_contexts(request) + response = await client.create_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListContextsRequest() + assert args[0] == metadata_service.CreateContextRequest() # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListContextsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, gca_context.Context) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + assert response.parent_contexts == ["parent_contexts_value"] + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" @pytest.mark.asyncio -async def test_list_contexts_async_from_dict(): - await test_list_contexts_async(request_type=dict) +async def test_create_context_async_from_dict(): + await test_create_context_async(request_type=dict) -def test_list_contexts_field_headers(): +def test_create_context_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.ListContextsRequest() + request = metadata_service.CreateContextRequest() request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: - call.return_value = metadata_service.ListContextsResponse() - client.list_contexts(request) + with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value = gca_context.Context() + client.create_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3244,23 +3221,21 @@ def test_list_contexts_field_headers(): @pytest.mark.asyncio -async def test_list_contexts_field_headers_async(): +async def test_create_context_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.ListContextsRequest() + request = metadata_service.CreateContextRequest() request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListContextsResponse() - ) - await client.list_contexts(request) + with mock.patch.object(type(client.transport.create_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + await client.create_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3272,62 +3247,75 @@ async def test_list_contexts_field_headers_async(): assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_list_contexts_flattened(): +def test_create_context_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListContextsResponse() + call.return_value = gca_context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_contexts(parent="parent_value",) + client.create_context( + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] assert args[0].parent == "parent_value" + assert args[0].context == gca_context.Context(name="name_value") + assert args[0].context_id == "context_id_value" -def test_list_contexts_flattened_error(): +def test_create_context_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_contexts( - metadata_service.ListContextsRequest(), parent="parent_value", + client.create_context( + metadata_service.CreateContextRequest(), + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) @pytest.mark.asyncio -async def test_list_contexts_flattened_async(): +async def test_create_context_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + with mock.patch.object(type(client.transport.create_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListContextsResponse() + call.return_value = gca_context.Context() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListContextsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_contexts(parent="parent_value",) + response = await client.create_context( + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] assert args[0].parent == "parent_value" + assert args[0].context == gca_context.Context(name="name_value") + assert args[0].context_id == "context_id_value" @pytest.mark.asyncio -async def test_list_contexts_flattened_error_async(): +async def test_create_context_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3335,214 +3323,91 @@ async def test_list_contexts_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.list_contexts( - metadata_service.ListContextsRequest(), parent="parent_value", + await client.create_context( + metadata_service.CreateContextRequest(), + parent="parent_value", + context=gca_context.Context(name="name_value"), + context_id="context_id_value", ) -def test_list_contexts_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_get_context( + transport: str = "grpc", request_type=metadata_service.GetContextRequest +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(), context.Context(),], - next_page_token="abc", - ), - metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), - metadata_service.ListContextsResponse( - contexts=[context.Context(),], next_page_token="ghi", - ), - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(),], - ), - RuntimeError, + with mock.patch.object(type(client.transport.get_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", ) + response = client.get_context(request) - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_contexts(request={}) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.GetContextRequest() - assert pager._metadata == metadata + # Establish that the response is the type that we expect. + assert isinstance(response, context.Context) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + assert response.parent_contexts == ["parent_contexts_value"] + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, context.Context) for i in results) + +def test_get_context_from_dict(): + test_get_context(request_type=dict) -def test_list_contexts_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) +def test_get_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(), context.Context(),], - next_page_token="abc", - ), - metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), - metadata_service.ListContextsResponse( - contexts=[context.Context(),], next_page_token="ghi", - ), - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(),], - ), - RuntimeError, - ) - pages = list(client.list_contexts(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + with mock.patch.object(type(client.transport.get_context), "__call__") as call: + client.get_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.GetContextRequest() @pytest.mark.asyncio -async def test_list_contexts_async_pager(): +async def test_get_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetContextRequest +): client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(), context.Context(),], - next_page_token="abc", - ), - metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), - metadata_service.ListContextsResponse( - contexts=[context.Context(),], next_page_token="ghi", - ), - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(),], - ), - RuntimeError, - ) - async_pager = await client.list_contexts(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, context.Context) for i in responses) - - -@pytest.mark.asyncio -async def test_list_contexts_async_pages(): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(), context.Context(),], - next_page_token="abc", - ), - metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), - metadata_service.ListContextsResponse( - contexts=[context.Context(),], next_page_token="ghi", - ), - metadata_service.ListContextsResponse( - contexts=[context.Context(), context.Context(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_contexts(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_update_context( - transport: str = "grpc", request_type=metadata_service.UpdateContextRequest -): - client = MetadataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_context.Context( - name="name_value", - display_name="display_name_value", - etag="etag_value", - parent_contexts=["parent_contexts_value"], - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) - response = client.update_context(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateContextRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_context.Context) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - assert response.parent_contexts == ["parent_contexts_value"] - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" - - -def test_update_context_from_dict(): - test_update_context(request_type=dict) - - -def test_update_context_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = MetadataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), transport="grpc", - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: - client.update_context() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateContextRequest() - - -@pytest.mark.asyncio -async def test_update_context_async( - transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest -): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_context.Context( + context.Context( name="name_value", display_name="display_name_value", etag="etag_value", @@ -3552,15 +3417,15 @@ async def test_update_context_async( description="description_value", ) ) - response = await client.update_context(request) + response = await client.get_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateContextRequest() + assert args[0] == metadata_service.GetContextRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_context.Context) + assert isinstance(response, context.Context) assert response.name == "name_value" assert response.display_name == "display_name_value" assert response.etag == "etag_value" @@ -3571,23 +3436,23 @@ async def test_update_context_async( @pytest.mark.asyncio -async def test_update_context_async_from_dict(): - await test_update_context_async(request_type=dict) +async def test_get_context_async_from_dict(): + await test_get_context_async(request_type=dict) -def test_update_context_field_headers(): +def test_get_context_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.UpdateContextRequest() + request = metadata_service.GetContextRequest() - request.context.name = "context.name/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: - call.return_value = gca_context.Context() - client.update_context(request) + with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value = context.Context() + client.get_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3596,27 +3461,25 @@ def test_update_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ - "metadata" - ] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_update_context_field_headers_async(): +async def test_get_context_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.UpdateContextRequest() + request = metadata_service.GetContextRequest() - request.context.name = "context.name/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) - await client.update_context(request) + with mock.patch.object(type(client.transport.get_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) + await client.get_context(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3625,75 +3488,63 @@ async def test_update_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ - "metadata" - ] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_update_context_flattened(): +def test_get_context_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_context.Context() + call.return_value = context.Context() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.update_context( - context=gca_context.Context(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), - ) + client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + assert args[0].name == "name_value" -def test_update_context_flattened_error(): +def test_get_context_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.update_context( - metadata_service.UpdateContextRequest(), - context=gca_context.Context(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + client.get_context( + metadata_service.GetContextRequest(), name="name_value", ) @pytest.mark.asyncio -async def test_update_context_flattened_async(): +async def test_get_context_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_context), "__call__") as call: + with mock.patch.object(type(client.transport.get_context), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_context.Context() + call.return_value = context.Context() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(context.Context()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.update_context( - context=gca_context.Context(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), - ) + response = await client.get_context(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == gca_context.Context(name="name_value") - assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + assert args[0].name == "name_value" @pytest.mark.asyncio -async def test_update_context_flattened_error_async(): +async def test_get_context_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3701,15 +3552,13 @@ async def test_update_context_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.update_context( - metadata_service.UpdateContextRequest(), - context=gca_context.Context(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + await client.get_context( + metadata_service.GetContextRequest(), name="name_value", ) -def test_delete_context( - transport: str = "grpc", request_type=metadata_service.DeleteContextRequest +def test_list_contexts( + transport: str = "grpc", request_type=metadata_service.ListContextsRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3720,25 +3569,28 @@ def test_delete_context( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - response = client.delete_context(request) + call.return_value = metadata_service.ListContextsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_contexts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.DeleteContextRequest() + assert args[0] == metadata_service.ListContextsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListContextsPager) + assert response.next_page_token == "next_page_token_value" -def test_delete_context_from_dict(): - test_delete_context(request_type=dict) +def test_list_contexts_from_dict(): + test_list_contexts(request_type=dict) -def test_delete_context_empty_call(): +def test_list_contexts_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -3746,16 +3598,16 @@ def test_delete_context_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: - client.delete_context() + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + client.list_contexts() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.DeleteContextRequest() + assert args[0] == metadata_service.ListContextsRequest() @pytest.mark.asyncio -async def test_delete_context_async( - transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest +async def test_list_contexts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListContextsRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3766,40 +3618,1258 @@ async def test_delete_context_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + metadata_service.ListContextsResponse( + next_page_token="next_page_token_value", + ) ) - response = await client.delete_context(request) + response = await client.list_contexts(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.DeleteContextRequest() + assert args[0] == metadata_service.ListContextsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) + assert isinstance(response, pagers.ListContextsAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_delete_context_async_from_dict(): - await test_delete_context_async(request_type=dict) +async def test_list_contexts_async_from_dict(): + await test_list_contexts_async(request_type=dict) -def test_delete_context_field_headers(): +def test_list_contexts_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListContextsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value = metadata_service.ListContextsResponse() + client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_contexts_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.ListContextsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) + await client.list_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_contexts_flattened(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListContextsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_contexts(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_contexts_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_contexts( + metadata_service.ListContextsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_contexts_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListContextsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListContextsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_contexts(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_contexts_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_contexts( + metadata_service.ListContextsRequest(), parent="parent_value", + ) + + +def test_list_contexts_pager(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", + ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), + metadata_service.ListContextsResponse( + contexts=[context.Context(),], next_page_token="ghi", + ), + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_contexts(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, context.Context) for i in results) + + +def test_list_contexts_pages(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_contexts), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", + ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), + metadata_service.ListContextsResponse( + contexts=[context.Context(),], next_page_token="ghi", + ), + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(),], + ), + RuntimeError, + ) + pages = list(client.list_contexts(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_contexts_async_pager(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", + ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), + metadata_service.ListContextsResponse( + contexts=[context.Context(),], next_page_token="ghi", + ), + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(),], + ), + RuntimeError, + ) + async_pager = await client.list_contexts(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, context.Context) for i in responses) + + +@pytest.mark.asyncio +async def test_list_contexts_async_pages(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_contexts), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(), context.Context(),], + next_page_token="abc", + ), + metadata_service.ListContextsResponse(contexts=[], next_page_token="def",), + metadata_service.ListContextsResponse( + contexts=[context.Context(),], next_page_token="ghi", + ), + metadata_service.ListContextsResponse( + contexts=[context.Context(), context.Context(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_contexts(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_context( + transport: str = "grpc", request_type=metadata_service.UpdateContextRequest +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + response = client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.UpdateContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_context.Context) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + assert response.parent_contexts == ["parent_contexts_value"] + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" + + +def test_update_context_from_dict(): + test_update_context(request_type=dict) + + +def test_update_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + client.update_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.UpdateContextRequest() + + +@pytest.mark.asyncio +async def test_update_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.UpdateContextRequest +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_context.Context( + name="name_value", + display_name="display_name_value", + etag="etag_value", + parent_contexts=["parent_contexts_value"], + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + ) + response = await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.UpdateContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_context.Context) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + assert response.parent_contexts == ["parent_contexts_value"] + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" + + +@pytest.mark.asyncio +async def test_update_context_async_from_dict(): + await test_update_context_async(request_type=dict) + + +def test_update_context_field_headers(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateContextRequest() + + request.context.name = "context.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value = gca_context.Context() + client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.UpdateContextRequest() + + request.context.name = "context.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + await client.update_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "context.name=context.name/value",) in kw[ + "metadata" + ] + + +def test_update_context_flattened(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_context( + context=gca_context.Context(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].context == gca_context.Context(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + + +def test_update_context_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_context( + metadata_service.UpdateContextRequest(), + context=gca_context.Context(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_context.Context() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_context.Context()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_context( + context=gca_context.Context(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].context == gca_context.Context(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_context( + metadata_service.UpdateContextRequest(), + context=gca_context.Context(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_context( + transport: str = "grpc", request_type=metadata_service.DeleteContextRequest +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.DeleteContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_context_from_dict(): + test_delete_context(request_type=dict) + + +def test_delete_context_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + client.delete_context() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.DeleteContextRequest() + + +@pytest.mark.asyncio +async def test_delete_context_async( + transport: str = "grpc_asyncio", request_type=metadata_service.DeleteContextRequest +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.DeleteContextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_context_async_from_dict(): + await test_delete_context_async(request_type=dict) + + +def test_delete_context_field_headers(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.DeleteContextRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_context_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = metadata_service.DeleteContextRequest() - request.name = "name/value" + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_context(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_context_flattened(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_context(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_context_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_context( + metadata_service.DeleteContextRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_context_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_context(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_context_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_context( + metadata_service.DeleteContextRequest(), name="name_value", + ) + + +def test_purge_contexts( + transport: str = "grpc", request_type=metadata_service.PurgeContextsRequest +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.PurgeContextsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_purge_contexts_from_dict(): + test_purge_contexts(request_type=dict) + + +def test_purge_contexts_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + client.purge_contexts() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.PurgeContextsRequest() + + +@pytest.mark.asyncio +async def test_purge_contexts_async( + transport: str = "grpc_asyncio", request_type=metadata_service.PurgeContextsRequest +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.PurgeContextsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_purge_contexts_async_from_dict(): + await test_purge_contexts_async(request_type=dict) + + +def test_purge_contexts_field_headers(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.PurgeContextsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_purge_contexts_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.PurgeContextsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.purge_contexts(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_purge_contexts_flattened(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.purge_contexts(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_purge_contexts_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.purge_contexts( + metadata_service.PurgeContextsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_purge_contexts_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.purge_contexts), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.purge_contexts(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_purge_contexts_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.purge_contexts( + metadata_service.PurgeContextsRequest(), parent="parent_value", + ) + + +def test_add_context_artifacts_and_executions( + transport: str = "grpc", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + response = client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) + + +def test_add_context_artifacts_and_executions_from_dict(): + test_add_context_artifacts_and_executions(request_type=dict) + + +def test_add_context_artifacts_and_executions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + client.add_context_artifacts_and_executions() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) + response = await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance( + response, metadata_service.AddContextArtifactsAndExecutionsResponse + ) + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_async_from_dict(): + await test_add_context_artifacts_and_executions_async(request_type=dict) + + +def test_add_context_artifacts_and_executions_field_headers(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextArtifactsAndExecutionsRequest() + + request.context = "context/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_field_headers_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextArtifactsAndExecutionsRequest() + + request.context = "context/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) + await client.add_context_artifacts_and_executions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + + +def test_add_context_artifacts_and_executions_flattened(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.add_context_artifacts_and_executions( + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].context == "context_value" + assert args[0].artifacts == ["artifacts_value"] + assert args[0].executions == ["executions_value"] + + +def test_add_context_artifacts_and_executions_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.add_context_artifacts_and_executions( + metadata_service.AddContextArtifactsAndExecutionsRequest(), + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], + ) + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_artifacts_and_executions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextArtifactsAndExecutionsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.add_context_artifacts_and_executions( + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].context == "context_value" + assert args[0].artifacts == ["artifacts_value"] + assert args[0].executions == ["executions_value"] + + +@pytest.mark.asyncio +async def test_add_context_artifacts_and_executions_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.add_context_artifacts_and_executions( + metadata_service.AddContextArtifactsAndExecutionsRequest(), + context="context_value", + artifacts=["artifacts_value"], + executions=["executions_value"], + ) + + +def test_add_context_children( + transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest +): + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.AddContextChildrenResponse() + response = client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextChildrenRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +def test_add_context_children_from_dict(): + test_add_context_children(request_type=dict) + + +def test_add_context_children_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MetadataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: + client.add_context_children() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextChildrenRequest() + + +@pytest.mark.asyncio +async def test_add_context_children_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.AddContextChildrenRequest, +): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - client.delete_context(request) + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.AddContextChildrenResponse() + ) + response = await client.add_context_children(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == metadata_service.AddContextChildrenRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, metadata_service.AddContextChildrenResponse) + + +@pytest.mark.asyncio +async def test_add_context_children_async_from_dict(): + await test_add_context_children_async(request_type=dict) + + +def test_add_context_children_field_headers(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = metadata_service.AddContextChildrenRequest() + + request.context = "context/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: + call.return_value = metadata_service.AddContextChildrenResponse() + client.add_context_children(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -3808,27 +4878,29 @@ def test_delete_context_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_delete_context_field_headers_async(): +async def test_add_context_children_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.DeleteContextRequest() + request = metadata_service.AddContextChildrenRequest() - request.name = "name/value" + request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") + metadata_service.AddContextChildrenResponse() ) - await client.delete_context(request) + await client.add_context_children(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -3837,65 +4909,77 @@ async def test_delete_context_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] -def test_delete_context_flattened(): +def test_add_context_children_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = metadata_service.AddContextChildrenResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_context(name="name_value",) + client.add_context_children( + context="context_value", child_contexts=["child_contexts_value"], + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].context == "context_value" + assert args[0].child_contexts == ["child_contexts_value"] -def test_delete_context_flattened_error(): +def test_add_context_children_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_context( - metadata_service.DeleteContextRequest(), name="name_value", + client.add_context_children( + metadata_service.AddContextChildrenRequest(), + context="context_value", + child_contexts=["child_contexts_value"], ) @pytest.mark.asyncio -async def test_delete_context_flattened_async(): +async def test_add_context_children_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_context), "__call__") as call: + with mock.patch.object( + type(client.transport.add_context_children), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = metadata_service.AddContextChildrenResponse() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + metadata_service.AddContextChildrenResponse() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_context(name="name_value",) + response = await client.add_context_children( + context="context_value", child_contexts=["child_contexts_value"], + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].context == "context_value" + assert args[0].child_contexts == ["child_contexts_value"] @pytest.mark.asyncio -async def test_delete_context_flattened_error_async(): +async def test_add_context_children_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -3903,14 +4987,16 @@ async def test_delete_context_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.delete_context( - metadata_service.DeleteContextRequest(), name="name_value", + await client.add_context_children( + metadata_service.AddContextChildrenRequest(), + context="context_value", + child_contexts=["child_contexts_value"], ) -def test_add_context_artifacts_and_executions( +def test_query_context_lineage_subgraph( transport: str = "grpc", - request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, + request_type=metadata_service.QueryContextLineageSubgraphRequest, ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3922,28 +5008,26 @@ def test_add_context_artifacts_and_executions( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() - response = client.add_context_artifacts_and_executions(request) + call.return_value = lineage_subgraph.LineageSubgraph() + response = client.query_context_lineage_subgraph(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() # Establish that the response is the type that we expect. - assert isinstance( - response, metadata_service.AddContextArtifactsAndExecutionsResponse - ) + assert isinstance(response, lineage_subgraph.LineageSubgraph) -def test_add_context_artifacts_and_executions_from_dict(): - test_add_context_artifacts_and_executions(request_type=dict) +def test_query_context_lineage_subgraph_from_dict(): + test_query_context_lineage_subgraph(request_type=dict) -def test_add_context_artifacts_and_executions_empty_call(): +def test_query_context_lineage_subgraph_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -3952,18 +5036,18 @@ def test_add_context_artifacts_and_executions_empty_call(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: - client.add_context_artifacts_and_executions() + client.query_context_lineage_subgraph() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_async( +async def test_query_context_lineage_subgraph_async( transport: str = "grpc_asyncio", - request_type=metadata_service.AddContextArtifactsAndExecutionsRequest, + request_type=metadata_service.QueryContextLineageSubgraphRequest, ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3975,45 +5059,43 @@ async def test_add_context_artifacts_and_executions_async( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextArtifactsAndExecutionsResponse() + lineage_subgraph.LineageSubgraph() ) - response = await client.add_context_artifacts_and_executions(request) + response = await client.query_context_lineage_subgraph(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextArtifactsAndExecutionsRequest() + assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() # Establish that the response is the type that we expect. - assert isinstance( - response, metadata_service.AddContextArtifactsAndExecutionsResponse - ) + assert isinstance(response, lineage_subgraph.LineageSubgraph) @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_async_from_dict(): - await test_add_context_artifacts_and_executions_async(request_type=dict) +async def test_query_context_lineage_subgraph_async_from_dict(): + await test_query_context_lineage_subgraph_async(request_type=dict) -def test_add_context_artifacts_and_executions_field_headers(): +def test_query_context_lineage_subgraph_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.AddContextArtifactsAndExecutionsRequest() + request = metadata_service.QueryContextLineageSubgraphRequest() request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() - client.add_context_artifacts_and_executions(request) + call.return_value = lineage_subgraph.LineageSubgraph() + client.query_context_lineage_subgraph(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4026,25 +5108,25 @@ def test_add_context_artifacts_and_executions_field_headers(): @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_field_headers_async(): +async def test_query_context_lineage_subgraph_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.AddContextArtifactsAndExecutionsRequest() + request = metadata_service.QueryContextLineageSubgraphRequest() request.context = "context/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextArtifactsAndExecutionsResponse() + lineage_subgraph.LineageSubgraph() ) - await client.add_context_artifacts_and_executions(request) + await client.query_context_lineage_subgraph(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4056,81 +5138,67 @@ async def test_add_context_artifacts_and_executions_field_headers_async(): assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] -def test_add_context_artifacts_and_executions_flattened(): +def test_query_context_lineage_subgraph_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + call.return_value = lineage_subgraph.LineageSubgraph() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.add_context_artifacts_and_executions( - context="context_value", - artifacts=["artifacts_value"], - executions=["executions_value"], - ) + client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] assert args[0].context == "context_value" - assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ["executions_value"] -def test_add_context_artifacts_and_executions_flattened_error(): +def test_query_context_lineage_subgraph_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.add_context_artifacts_and_executions( - metadata_service.AddContextArtifactsAndExecutionsRequest(), + client.query_context_lineage_subgraph( + metadata_service.QueryContextLineageSubgraphRequest(), context="context_value", - artifacts=["artifacts_value"], - executions=["executions_value"], ) @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_flattened_async(): +async def test_query_context_lineage_subgraph_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.add_context_artifacts_and_executions), "__call__" + type(client.transport.query_context_lineage_subgraph), "__call__" ) as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextArtifactsAndExecutionsResponse() + call.return_value = lineage_subgraph.LineageSubgraph() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextArtifactsAndExecutionsResponse() + lineage_subgraph.LineageSubgraph() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.add_context_artifacts_and_executions( - context="context_value", - artifacts=["artifacts_value"], - executions=["executions_value"], - ) + response = await client.query_context_lineage_subgraph(context="context_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] assert args[0].context == "context_value" - assert args[0].artifacts == ["artifacts_value"] - assert args[0].executions == ["executions_value"] @pytest.mark.asyncio -async def test_add_context_artifacts_and_executions_flattened_error_async(): +async def test_query_context_lineage_subgraph_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4138,16 +5206,14 @@ async def test_add_context_artifacts_and_executions_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.add_context_artifacts_and_executions( - metadata_service.AddContextArtifactsAndExecutionsRequest(), + await client.query_context_lineage_subgraph( + metadata_service.QueryContextLineageSubgraphRequest(), context="context_value", - artifacts=["artifacts_value"], - executions=["executions_value"], ) -def test_add_context_children( - transport: str = "grpc", request_type=metadata_service.AddContextChildrenRequest +def test_create_execution( + transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4158,27 +5224,40 @@ def test_add_context_children( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextChildrenResponse() - response = client.add_context_children(request) + call.return_value = gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + response = client.create_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextChildrenRequest() + assert args[0] == metadata_service.CreateExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextChildrenResponse) + assert isinstance(response, gca_execution.Execution) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == gca_execution.Execution.State.NEW + assert response.etag == "etag_value" + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" -def test_add_context_children_from_dict(): - test_add_context_children(request_type=dict) +def test_create_execution_from_dict(): + test_create_execution(request_type=dict) -def test_add_context_children_empty_call(): +def test_create_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -4186,19 +5265,17 @@ def test_add_context_children_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: - client.add_context_children() + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + client.create_execution() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextChildrenRequest() + assert args[0] == metadata_service.CreateExecutionRequest() @pytest.mark.asyncio -async def test_add_context_children_async( +async def test_create_execution_async( transport: str = "grpc_asyncio", - request_type=metadata_service.AddContextChildrenRequest, + request_type=metadata_service.CreateExecutionRequest, ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4209,44 +5286,55 @@ async def test_add_context_children_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextChildrenResponse() + gca_execution.Execution( + name="name_value", + display_name="display_name_value", + state=gca_execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) ) - response = await client.add_context_children(request) + response = await client.create_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.AddContextChildrenRequest() + assert args[0] == metadata_service.CreateExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, metadata_service.AddContextChildrenResponse) + assert isinstance(response, gca_execution.Execution) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == gca_execution.Execution.State.NEW + assert response.etag == "etag_value" + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" @pytest.mark.asyncio -async def test_add_context_children_async_from_dict(): - await test_add_context_children_async(request_type=dict) +async def test_create_execution_async_from_dict(): + await test_create_execution_async(request_type=dict) -def test_add_context_children_field_headers(): +def test_create_execution_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.AddContextChildrenRequest() + request = metadata_service.CreateExecutionRequest() - request.context = "context/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: - call.return_value = metadata_service.AddContextChildrenResponse() - client.add_context_children(request) + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + call.return_value = gca_execution.Execution() + client.create_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4255,29 +5343,27 @@ def test_add_context_children_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_add_context_children_field_headers_async(): +async def test_create_execution_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.AddContextChildrenRequest() + request = metadata_service.CreateExecutionRequest() - request.context = "context/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextChildrenResponse() + gca_execution.Execution() ) - await client.add_context_children(request) + await client.create_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4286,77 +5372,80 @@ async def test_add_context_children_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_add_context_children_flattened(): +def test_create_execution_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextChildrenResponse() + call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.add_context_children( - context="context_value", child_contexts=["child_contexts_value"], + client.create_execution( + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == "context_value" - assert args[0].child_contexts == ["child_contexts_value"] + assert args[0].parent == "parent_value" + assert args[0].execution == gca_execution.Execution(name="name_value") + assert args[0].execution_id == "execution_id_value" -def test_add_context_children_flattened_error(): +def test_create_execution_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.add_context_children( - metadata_service.AddContextChildrenRequest(), - context="context_value", - child_contexts=["child_contexts_value"], + client.create_execution( + metadata_service.CreateExecutionRequest(), + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) @pytest.mark.asyncio -async def test_add_context_children_flattened_async(): +async def test_create_execution_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.add_context_children), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.AddContextChildrenResponse() + call.return_value = gca_execution.Execution() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.AddContextChildrenResponse() + gca_execution.Execution() ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.add_context_children( - context="context_value", child_contexts=["child_contexts_value"], + response = await client.create_execution( + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == "context_value" - assert args[0].child_contexts == ["child_contexts_value"] + assert args[0].parent == "parent_value" + assert args[0].execution == gca_execution.Execution(name="name_value") + assert args[0].execution_id == "execution_id_value" @pytest.mark.asyncio -async def test_add_context_children_flattened_error_async(): +async def test_create_execution_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4364,16 +5453,16 @@ async def test_add_context_children_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.add_context_children( - metadata_service.AddContextChildrenRequest(), - context="context_value", - child_contexts=["child_contexts_value"], + await client.create_execution( + metadata_service.CreateExecutionRequest(), + parent="parent_value", + execution=gca_execution.Execution(name="name_value"), + execution_id="execution_id_value", ) -def test_query_context_lineage_subgraph( - transport: str = "grpc", - request_type=metadata_service.QueryContextLineageSubgraphRequest, +def test_get_execution( + transport: str = "grpc", request_type=metadata_service.GetExecutionRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4384,27 +5473,40 @@ def test_query_context_lineage_subgraph( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph() - response = client.query_context_lineage_subgraph(request) + call.return_value = execution.Execution( + name="name_value", + display_name="display_name_value", + state=execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) + response = client.get_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + assert args[0] == metadata_service.GetExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, lineage_subgraph.LineageSubgraph) + assert isinstance(response, execution.Execution) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == execution.Execution.State.NEW + assert response.etag == "etag_value" + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" -def test_query_context_lineage_subgraph_from_dict(): - test_query_context_lineage_subgraph(request_type=dict) +def test_get_execution_from_dict(): + test_get_execution(request_type=dict) -def test_query_context_lineage_subgraph_empty_call(): +def test_get_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -4412,19 +5514,16 @@ def test_query_context_lineage_subgraph_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: - client.query_context_lineage_subgraph() + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + client.get_execution() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + assert args[0] == metadata_service.GetExecutionRequest() @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_async( - transport: str = "grpc_asyncio", - request_type=metadata_service.QueryContextLineageSubgraphRequest, +async def test_get_execution_async( + transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4435,44 +5534,55 @@ async def test_query_context_lineage_subgraph_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - lineage_subgraph.LineageSubgraph() + execution.Execution( + name="name_value", + display_name="display_name_value", + state=execution.Execution.State.NEW, + etag="etag_value", + schema_title="schema_title_value", + schema_version="schema_version_value", + description="description_value", + ) ) - response = await client.query_context_lineage_subgraph(request) + response = await client.get_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.QueryContextLineageSubgraphRequest() + assert args[0] == metadata_service.GetExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, lineage_subgraph.LineageSubgraph) + assert isinstance(response, execution.Execution) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == execution.Execution.State.NEW + assert response.etag == "etag_value" + assert response.schema_title == "schema_title_value" + assert response.schema_version == "schema_version_value" + assert response.description == "description_value" @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_async_from_dict(): - await test_query_context_lineage_subgraph_async(request_type=dict) +async def test_get_execution_async_from_dict(): + await test_get_execution_async(request_type=dict) -def test_query_context_lineage_subgraph_field_headers(): +def test_get_execution_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.QueryContextLineageSubgraphRequest() - - request.context = "context/value" + request = metadata_service.GetExecutionRequest() - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: - call.return_value = lineage_subgraph.LineageSubgraph() - client.query_context_lineage_subgraph(request) + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value = execution.Execution() + client.get_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4481,29 +5591,25 @@ def test_query_context_lineage_subgraph_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_field_headers_async(): +async def test_get_execution_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.QueryContextLineageSubgraphRequest() + request = metadata_service.GetExecutionRequest() - request.context = "context/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - lineage_subgraph.LineageSubgraph() - ) - await client.query_context_lineage_subgraph(request) + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) + await client.get_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4512,70 +5618,63 @@ async def test_query_context_lineage_subgraph_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "context=context/value",) in kw["metadata"] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_query_context_lineage_subgraph_flattened(): +def test_get_execution_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph() + call.return_value = execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.query_context_lineage_subgraph(context="context_value",) + client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].context == "context_value" + assert args[0].name == "name_value" -def test_query_context_lineage_subgraph_flattened_error(): +def test_get_execution_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.query_context_lineage_subgraph( - metadata_service.QueryContextLineageSubgraphRequest(), - context="context_value", + client.get_execution( + metadata_service.GetExecutionRequest(), name="name_value", ) @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_flattened_async(): +async def test_get_execution_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.query_context_lineage_subgraph), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = lineage_subgraph.LineageSubgraph() + call.return_value = execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - lineage_subgraph.LineageSubgraph() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.query_context_lineage_subgraph(context="context_value",) + response = await client.get_execution(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].context == "context_value" + assert args[0].name == "name_value" @pytest.mark.asyncio -async def test_query_context_lineage_subgraph_flattened_error_async(): +async def test_get_execution_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -4583,14 +5682,13 @@ async def test_query_context_lineage_subgraph_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.query_context_lineage_subgraph( - metadata_service.QueryContextLineageSubgraphRequest(), - context="context_value", + await client.get_execution( + metadata_service.GetExecutionRequest(), name="name_value", ) -def test_create_execution( - transport: str = "grpc", request_type=metadata_service.CreateExecutionRequest +def test_list_executions( + transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4601,40 +5699,28 @@ def test_create_execution( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution( - name="name_value", - display_name="display_name_value", - state=gca_execution.Execution.State.NEW, - etag="etag_value", - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", + call.return_value = metadata_service.ListExecutionsResponse( + next_page_token="next_page_token_value", ) - response = client.create_execution(request) + response = client.list_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateExecutionRequest() + assert args[0] == metadata_service.ListExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_execution.Execution) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == gca_execution.Execution.State.NEW - assert response.etag == "etag_value" - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, pagers.ListExecutionsPager) + assert response.next_page_token == "next_page_token_value" -def test_create_execution_from_dict(): - test_create_execution(request_type=dict) +def test_list_executions_from_dict(): + test_list_executions(request_type=dict) -def test_create_execution_empty_call(): +def test_list_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -4642,17 +5728,16 @@ def test_create_execution_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: - client.create_execution() + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + client.list_executions() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateExecutionRequest() + assert args[0] == metadata_service.ListExecutionsRequest() @pytest.mark.asyncio -async def test_create_execution_async( - transport: str = "grpc_asyncio", - request_type=metadata_service.CreateExecutionRequest, +async def test_list_executions_async( + transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4663,55 +5748,43 @@ async def test_create_execution_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution( - name="name_value", - display_name="display_name_value", - state=gca_execution.Execution.State.NEW, - etag="etag_value", - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", + metadata_service.ListExecutionsResponse( + next_page_token="next_page_token_value", ) ) - response = await client.create_execution(request) + response = await client.list_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.CreateExecutionRequest() + assert args[0] == metadata_service.ListExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_execution.Execution) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == gca_execution.Execution.State.NEW - assert response.etag == "etag_value" - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, pagers.ListExecutionsAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio -async def test_create_execution_async_from_dict(): - await test_create_execution_async(request_type=dict) +async def test_list_executions_async_from_dict(): + await test_list_executions_async(request_type=dict) -def test_create_execution_field_headers(): +def test_list_executions_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.CreateExecutionRequest() + request = metadata_service.ListExecutionsRequest() request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: - call.return_value = gca_execution.Execution() - client.create_execution(request) + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + call.return_value = metadata_service.ListExecutionsResponse() + client.list_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4724,23 +5797,23 @@ def test_create_execution_field_headers(): @pytest.mark.asyncio -async def test_create_execution_field_headers_async(): +async def test_list_executions_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.CreateExecutionRequest() + request = metadata_service.ListExecutionsRequest() request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution() + metadata_service.ListExecutionsResponse() ) - await client.create_execution(request) + await client.list_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4752,94 +5825,226 @@ async def test_create_execution_field_headers_async(): assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_create_execution_flattened(): +def test_list_executions_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution() + call.return_value = metadata_service.ListExecutionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.create_execution( - parent="parent_value", - execution=gca_execution.Execution(name="name_value"), - execution_id="execution_id_value", + client.list_executions(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_executions_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_executions( + metadata_service.ListExecutionsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_executions_flattened_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = metadata_service.ListExecutionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + metadata_service.ListExecutionsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_executions(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_executions_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_executions( + metadata_service.ListExecutionsRequest(), parent="parent_value", + ) + + +def test_list_executions_pager(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token="abc", + ), + metadata_service.ListExecutionsResponse( + executions=[], next_page_token="def", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(),], next_page_token="ghi", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(), execution.Execution(),], + ), + RuntimeError, ) - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == "execution_id_value" + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_executions(request={}) + + assert pager._metadata == metadata + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, execution.Execution) for i in results) -def test_create_execution_flattened_error(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_execution( - metadata_service.CreateExecutionRequest(), - parent="parent_value", - execution=gca_execution.Execution(name="name_value"), - execution_id="execution_id_value", +def test_list_executions_pages(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token="abc", + ), + metadata_service.ListExecutionsResponse( + executions=[], next_page_token="def", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(),], next_page_token="ghi", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(), execution.Execution(),], + ), + RuntimeError, ) + pages = list(client.list_executions(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio -async def test_create_execution_flattened_async(): +async def test_list_executions_async_pager(): client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials, ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_execution), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_execution( - parent="parent_value", - execution=gca_execution.Execution(name="name_value"), - execution_id="execution_id_value", + with mock.patch.object( + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token="abc", + ), + metadata_service.ListExecutionsResponse( + executions=[], next_page_token="def", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(),], next_page_token="ghi", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(), execution.Execution(),], + ), + RuntimeError, ) + async_pager = await client.list_executions(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].execution_id == "execution_id_value" + assert len(responses) == 6 + assert all(isinstance(i, execution.Execution) for i in responses) @pytest.mark.asyncio -async def test_create_execution_flattened_error_async(): +async def test_list_executions_async_pages(): client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), + credentials=ga_credentials.AnonymousCredentials, ) - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_execution( - metadata_service.CreateExecutionRequest(), - parent="parent_value", - execution=gca_execution.Execution(name="name_value"), - execution_id="execution_id_value", + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + metadata_service.ListExecutionsResponse( + executions=[ + execution.Execution(), + execution.Execution(), + execution.Execution(), + ], + next_page_token="abc", + ), + metadata_service.ListExecutionsResponse( + executions=[], next_page_token="def", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(),], next_page_token="ghi", + ), + metadata_service.ListExecutionsResponse( + executions=[execution.Execution(), execution.Execution(),], + ), + RuntimeError, ) + pages = [] + async for page_ in (await client.list_executions(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token -def test_get_execution( - transport: str = "grpc", request_type=metadata_service.GetExecutionRequest +def test_update_execution( + transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4850,40 +6055,40 @@ def test_get_execution( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = execution.Execution( + call.return_value = gca_execution.Execution( name="name_value", display_name="display_name_value", - state=execution.Execution.State.NEW, + state=gca_execution.Execution.State.NEW, etag="etag_value", schema_title="schema_title_value", schema_version="schema_version_value", description="description_value", ) - response = client.get_execution(request) + response = client.update_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetExecutionRequest() + assert args[0] == metadata_service.UpdateExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, execution.Execution) + assert isinstance(response, gca_execution.Execution) assert response.name == "name_value" assert response.display_name == "display_name_value" - assert response.state == execution.Execution.State.NEW + assert response.state == gca_execution.Execution.State.NEW assert response.etag == "etag_value" assert response.schema_title == "schema_title_value" assert response.schema_version == "schema_version_value" assert response.description == "description_value" -def test_get_execution_from_dict(): - test_get_execution(request_type=dict) +def test_update_execution_from_dict(): + test_update_execution(request_type=dict) -def test_get_execution_empty_call(): +def test_update_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -4891,16 +6096,17 @@ def test_get_execution_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: - client.get_execution() + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + client.update_execution() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetExecutionRequest() + assert args[0] == metadata_service.UpdateExecutionRequest() @pytest.mark.asyncio -async def test_get_execution_async( - transport: str = "grpc_asyncio", request_type=metadata_service.GetExecutionRequest +async def test_update_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.UpdateExecutionRequest, ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4911,31 +6117,31 @@ async def test_get_execution_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - execution.Execution( + gca_execution.Execution( name="name_value", display_name="display_name_value", - state=execution.Execution.State.NEW, + state=gca_execution.Execution.State.NEW, etag="etag_value", schema_title="schema_title_value", schema_version="schema_version_value", description="description_value", ) ) - response = await client.get_execution(request) + response = await client.update_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.GetExecutionRequest() + assert args[0] == metadata_service.UpdateExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, execution.Execution) + assert isinstance(response, gca_execution.Execution) assert response.name == "name_value" assert response.display_name == "display_name_value" - assert response.state == execution.Execution.State.NEW + assert response.state == gca_execution.Execution.State.NEW assert response.etag == "etag_value" assert response.schema_title == "schema_title_value" assert response.schema_version == "schema_version_value" @@ -4943,23 +6149,23 @@ async def test_get_execution_async( @pytest.mark.asyncio -async def test_get_execution_async_from_dict(): - await test_get_execution_async(request_type=dict) +async def test_update_execution_async_from_dict(): + await test_update_execution_async(request_type=dict) -def test_get_execution_field_headers(): +def test_update_execution_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.GetExecutionRequest() + request = metadata_service.UpdateExecutionRequest() - request.name = "name/value" + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: - call.return_value = execution.Execution() - client.get_execution(request) + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value = gca_execution.Execution() + client.update_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -4968,25 +6174,29 @@ def test_get_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio -async def test_get_execution_field_headers_async(): +async def test_update_execution_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.GetExecutionRequest() + request = metadata_service.UpdateExecutionRequest() - request.name = "name/value" + request.execution.name = "execution.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) - await client.get_execution(request) + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) + await client.update_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -4995,63 +6205,77 @@ async def test_get_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ + "metadata" + ] -def test_get_execution_flattened(): +def test_update_execution_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = execution.Execution() + call.return_value = gca_execution.Execution() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_execution(name="name_value",) + client.update_execution( + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].execution == gca_execution.Execution(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) -def test_get_execution_flattened_error(): +def test_update_execution_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_execution( - metadata_service.GetExecutionRequest(), name="name_value", + client.update_execution( + metadata_service.UpdateExecutionRequest(), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio -async def test_get_execution_flattened_async(): +async def test_update_execution_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_execution), "__call__") as call: + with mock.patch.object(type(client.transport.update_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = execution.Execution() + call.return_value = gca_execution.Execution() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(execution.Execution()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_execution.Execution() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_execution(name="name_value",) + response = await client.update_execution( + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].execution == gca_execution.Execution(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio -async def test_get_execution_flattened_error_async(): +async def test_update_execution_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5059,13 +6283,15 @@ async def test_get_execution_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.get_execution( - metadata_service.GetExecutionRequest(), name="name_value", + await client.update_execution( + metadata_service.UpdateExecutionRequest(), + execution=gca_execution.Execution(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) -def test_list_executions( - transport: str = "grpc", request_type=metadata_service.ListExecutionsRequest +def test_delete_execution( + transport: str = "grpc", request_type=metadata_service.DeleteExecutionRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5076,28 +6302,25 @@ def test_list_executions( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListExecutionsResponse( - next_page_token="next_page_token_value", - ) - response = client.list_executions(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListExecutionsRequest() + assert args[0] == metadata_service.DeleteExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListExecutionsPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, future.Future) -def test_list_executions_from_dict(): - test_list_executions(request_type=dict) +def test_delete_execution_from_dict(): + test_delete_execution(request_type=dict) -def test_list_executions_empty_call(): +def test_delete_execution_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -5105,16 +6328,17 @@ def test_list_executions_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: - client.list_executions() + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + client.delete_execution() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListExecutionsRequest() + assert args[0] == metadata_service.DeleteExecutionRequest() @pytest.mark.asyncio -async def test_list_executions_async( - transport: str = "grpc_asyncio", request_type=metadata_service.ListExecutionsRequest +async def test_delete_execution_async( + transport: str = "grpc_asyncio", + request_type=metadata_service.DeleteExecutionRequest, ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5125,43 +6349,40 @@ async def test_list_executions_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListExecutionsResponse( - next_page_token="next_page_token_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.list_executions(request) + response = await client.delete_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.ListExecutionsRequest() + assert args[0] == metadata_service.DeleteExecutionRequest() # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListExecutionsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_list_executions_async_from_dict(): - await test_list_executions_async(request_type=dict) +async def test_delete_execution_async_from_dict(): + await test_delete_execution_async(request_type=dict) -def test_list_executions_field_headers(): +def test_delete_execution_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.ListExecutionsRequest() + request = metadata_service.DeleteExecutionRequest() - request.parent = "parent/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: - call.return_value = metadata_service.ListExecutionsResponse() - client.list_executions(request) + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -5170,27 +6391,27 @@ def test_list_executions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_list_executions_field_headers_async(): +async def test_delete_execution_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.ListExecutionsRequest() + request = metadata_service.DeleteExecutionRequest() - request.parent = "parent/value" + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListExecutionsResponse() + operations_pb2.Operation(name="operations/op") ) - await client.list_executions(request) + await client.delete_execution(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5199,229 +6420,79 @@ async def test_list_executions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_executions_flattened(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListExecutionsResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_executions(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_list_executions_flattened_error(): +def test_delete_execution_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_executions( - metadata_service.ListExecutionsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_executions_flattened_async(): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = metadata_service.ListExecutionsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - metadata_service.ListExecutionsResponse() - ) + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_executions(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_executions_flattened_error_async(): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_executions( - metadata_service.ListExecutionsRequest(), parent="parent_value", - ) - - -def test_list_executions_pager(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - execution.Execution(), - ], - next_page_token="abc", - ), - metadata_service.ListExecutionsResponse( - executions=[], next_page_token="def", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(),], next_page_token="ghi", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(), execution.Execution(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_executions(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, execution.Execution) for i in results) - - -def test_list_executions_pages(): - client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_executions), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - execution.Execution(), - ], - next_page_token="abc", - ), - metadata_service.ListExecutionsResponse( - executions=[], next_page_token="def", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(),], next_page_token="ghi", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(), execution.Execution(),], - ), - RuntimeError, - ) - pages = list(client.list_executions(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_executions_async_pager(): - client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - execution.Execution(), - ], - next_page_token="abc", - ), - metadata_service.ListExecutionsResponse( - executions=[], next_page_token="def", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(),], next_page_token="ghi", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(), execution.Execution(),], - ), - RuntimeError, - ) - async_pager = await client.list_executions(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) + # using the keyword arguments to the method. + client.delete_execution(name="name_value",) - assert len(responses) == 6 - assert all(isinstance(i, execution.Execution) for i in responses) + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_execution_flattened_error(): + client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_execution( + metadata_service.DeleteExecutionRequest(), name="name_value", + ) @pytest.mark.asyncio -async def test_list_executions_async_pages(): +async def test_delete_execution_flattened_async(): client = MetadataServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, + credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_executions), "__call__", new_callable=mock.AsyncMock - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - metadata_service.ListExecutionsResponse( - executions=[ - execution.Execution(), - execution.Execution(), - execution.Execution(), - ], - next_page_token="abc", - ), - metadata_service.ListExecutionsResponse( - executions=[], next_page_token="def", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(),], next_page_token="ghi", - ), - metadata_service.ListExecutionsResponse( - executions=[execution.Execution(), execution.Execution(),], - ), - RuntimeError, + with mock.patch.object(type(client.transport.delete_execution), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") ) - pages = [] - async for page_ in (await client.list_executions(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_execution(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" -def test_update_execution( - transport: str = "grpc", request_type=metadata_service.UpdateExecutionRequest +@pytest.mark.asyncio +async def test_delete_execution_flattened_error_async(): + client = MetadataServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_execution( + metadata_service.DeleteExecutionRequest(), name="name_value", + ) + + +def test_purge_executions( + transport: str = "grpc", request_type=metadata_service.PurgeExecutionsRequest ): client = MetadataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5432,40 +6503,25 @@ def test_update_execution( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution( - name="name_value", - display_name="display_name_value", - state=gca_execution.Execution.State.NEW, - etag="etag_value", - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) - response = client.update_execution(request) + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.purge_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateExecutionRequest() + assert args[0] == metadata_service.PurgeExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_execution.Execution) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == gca_execution.Execution.State.NEW - assert response.etag == "etag_value" - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) -def test_update_execution_from_dict(): - test_update_execution(request_type=dict) +def test_purge_executions_from_dict(): + test_purge_executions(request_type=dict) -def test_update_execution_empty_call(): +def test_purge_executions_empty_call(): # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. client = MetadataServiceClient( @@ -5473,17 +6529,17 @@ def test_update_execution_empty_call(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: - client.update_execution() + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + client.purge_executions() call.assert_called() _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateExecutionRequest() + assert args[0] == metadata_service.PurgeExecutionsRequest() @pytest.mark.asyncio -async def test_update_execution_async( +async def test_purge_executions_async( transport: str = "grpc_asyncio", - request_type=metadata_service.UpdateExecutionRequest, + request_type=metadata_service.PurgeExecutionsRequest, ): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5494,55 +6550,40 @@ async def test_update_execution_async( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution( - name="name_value", - display_name="display_name_value", - state=gca_execution.Execution.State.NEW, - etag="etag_value", - schema_title="schema_title_value", - schema_version="schema_version_value", - description="description_value", - ) + operations_pb2.Operation(name="operations/spam") ) - response = await client.update_execution(request) + response = await client.purge_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == metadata_service.UpdateExecutionRequest() + assert args[0] == metadata_service.PurgeExecutionsRequest() # Establish that the response is the type that we expect. - assert isinstance(response, gca_execution.Execution) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == gca_execution.Execution.State.NEW - assert response.etag == "etag_value" - assert response.schema_title == "schema_title_value" - assert response.schema_version == "schema_version_value" - assert response.description == "description_value" + assert isinstance(response, future.Future) @pytest.mark.asyncio -async def test_update_execution_async_from_dict(): - await test_update_execution_async(request_type=dict) +async def test_purge_executions_async_from_dict(): + await test_purge_executions_async(request_type=dict) -def test_update_execution_field_headers(): +def test_purge_executions_field_headers(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.UpdateExecutionRequest() + request = metadata_service.PurgeExecutionsRequest() - request.execution.name = "execution.name/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: - call.return_value = gca_execution.Execution() - client.update_execution(request) + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.purge_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) == 1 @@ -5551,29 +6592,27 @@ def test_update_execution_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ - "metadata" - ] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio -async def test_update_execution_field_headers_async(): +async def test_purge_executions_field_headers_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. - request = metadata_service.UpdateExecutionRequest() + request = metadata_service.PurgeExecutionsRequest() - request.execution.name = "execution.name/value" + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution() + operations_pb2.Operation(name="operations/op") ) - await client.update_execution(request) + await client.purge_executions(request) # Establish that the underlying gRPC stub method was called. assert len(call.mock_calls) @@ -5582,77 +6621,65 @@ async def test_update_execution_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "execution.name=execution.name/value",) in kw[ - "metadata" - ] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] -def test_update_execution_flattened(): +def test_purge_executions_flattened(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution() + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.update_execution( - execution=gca_execution.Execution(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), - ) + client.purge_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + assert args[0].parent == "parent_value" -def test_update_execution_flattened_error(): +def test_purge_executions_flattened_error(): client = MetadataServiceClient(credentials=ga_credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.update_execution( - metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + client.purge_executions( + metadata_service.PurgeExecutionsRequest(), parent="parent_value", ) @pytest.mark.asyncio -async def test_update_execution_flattened_async(): +async def test_purge_executions_flattened_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_execution), "__call__") as call: + with mock.patch.object(type(client.transport.purge_executions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = gca_execution.Execution() + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_execution.Execution() + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.update_execution( - execution=gca_execution.Execution(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), - ) + response = await client.purge_executions(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].execution == gca_execution.Execution(name="name_value") - assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + assert args[0].parent == "parent_value" @pytest.mark.asyncio -async def test_update_execution_flattened_error_async(): +async def test_purge_executions_flattened_error_async(): client = MetadataServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), ) @@ -5660,10 +6687,8 @@ async def test_update_execution_flattened_error_async(): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - await client.update_execution( - metadata_service.UpdateExecutionRequest(), - execution=gca_execution.Execution(name="name_value"), - update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + await client.purge_executions( + metadata_service.PurgeExecutionsRequest(), parent="parent_value", ) @@ -7342,11 +8367,14 @@ def test_metadata_service_base_transport(): "get_artifact", "list_artifacts", "update_artifact", + "delete_artifact", + "purge_artifacts", "create_context", "get_context", "list_contexts", "update_context", "delete_context", + "purge_contexts", "add_context_artifacts_and_executions", "add_context_children", "query_context_lineage_subgraph", @@ -7354,6 +8382,8 @@ def test_metadata_service_base_transport(): "get_execution", "list_executions", "update_execution", + "delete_execution", + "purge_executions", "add_execution_events", "query_execution_inputs_and_outputs", "create_metadata_schema", @@ -7496,7 +8526,6 @@ def test_metadata_service_transport_auth_adc_old_google_auth(transport_class): (transports.MetadataServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_metadata_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -7525,79 +8554,6 @@ def test_metadata_service_transport_create_channel(transport_class, grpc_helpers ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MetadataServiceGrpcTransport, grpc_helpers), - (transports.MetadataServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_metadata_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MetadataServiceGrpcTransport, grpc_helpers), - (transports.MetadataServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_metadata_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -7620,7 +8576,7 @@ def test_metadata_service_grpc_transport_client_cert_source_for_mtls(transport_c "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -7729,7 +8685,7 @@ def test_metadata_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -7776,7 +8732,7 @@ def test_metadata_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 279de3296e..2158facfaf 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports -from google.cloud.aiplatform_v1beta1.services.migration_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.migration_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -55,8 +52,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -67,16 +65,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -140,6 +128,31 @@ def test_migration_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.MigrationServiceGrpcTransport, "grpc"), + (transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_migration_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] ) @@ -219,6 +232,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -235,6 +249,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -251,6 +266,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -279,6 +295,7 @@ def test_migration_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -355,6 +372,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -388,6 +406,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -409,6 +428,7 @@ def test_migration_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -439,6 +459,7 @@ def test_migration_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -469,6 +490,7 @@ def test_migration_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -488,6 +510,7 @@ def test_migration_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -1389,7 +1412,6 @@ def test_migration_service_transport_auth_adc_old_google_auth(transport_class): (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_migration_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1418,79 +1440,6 @@ def test_migration_service_transport_create_channel(transport_class, grpc_helper ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MigrationServiceGrpcTransport, grpc_helpers), - (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_migration_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.MigrationServiceGrpcTransport, grpc_helpers), - (transports.MigrationServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_migration_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -1513,7 +1462,7 @@ def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_ "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1622,7 +1571,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1669,7 +1618,7 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1734,18 +1683,20 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + location = "mussel" + dataset = "winkle" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1755,20 +1706,18 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "scallop" - location = "abalone" - dataset = "squid" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + project = "squid" + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "clam", - "location": "whelk", + "project": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index d28fff4307..1a0d64c322 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports -from google.cloud.aiplatform_v1beta1.services.model_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.model_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -66,8 +63,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -78,16 +76,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -145,6 +133,31 @@ def test_model_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_model_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -220,6 +233,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -236,6 +250,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -252,6 +267,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -280,6 +296,7 @@ def test_model_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -344,6 +361,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -377,6 +395,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -398,6 +417,7 @@ def test_model_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -428,6 +448,7 @@ def test_model_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -458,6 +479,7 @@ def test_model_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -475,6 +497,7 @@ def test_model_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3388,7 +3411,6 @@ def test_model_service_transport_auth_adc_old_google_auth(transport_class): (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_model_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3417,79 +3439,6 @@ def test_model_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_model_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_model_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], @@ -3509,7 +3458,7 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_clas "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3613,7 +3562,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3657,7 +3606,7 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 7c78097cd6..74bfdce09b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports -from google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -78,8 +75,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -90,16 +88,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -162,6 +150,31 @@ def test_pipeline_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PipelineServiceGrpcTransport, "grpc"), + (transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_pipeline_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] ) @@ -241,6 +254,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -257,6 +271,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -273,6 +288,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -301,6 +317,7 @@ def test_pipeline_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -377,6 +394,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -410,6 +428,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -431,6 +450,7 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -461,6 +481,7 @@ def test_pipeline_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -491,6 +512,7 @@ def test_pipeline_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -510,6 +532,7 @@ def test_pipeline_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3337,7 +3360,6 @@ def test_pipeline_service_transport_auth_adc_old_google_auth(transport_class): (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_pipeline_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3366,79 +3388,6 @@ def test_pipeline_service_transport_create_channel(transport_class, grpc_helpers ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PipelineServiceGrpcTransport, grpc_helpers), - (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_pipeline_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PipelineServiceGrpcTransport, grpc_helpers), - (transports.PipelineServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_pipeline_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -3461,7 +3410,7 @@ def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_c "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -3570,7 +3519,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -3617,7 +3566,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index a36714bb87..559531c0e1 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -38,9 +38,6 @@ PredictionServiceClient, ) from google.cloud.aiplatform_v1beta1.services.prediction_service import transports -from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.prediction_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -51,8 +48,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -63,16 +61,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -136,6 +124,31 @@ def test_prediction_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PredictionServiceGrpcTransport, "grpc"), + (transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_prediction_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,] ) @@ -215,6 +228,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -231,6 +245,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -247,6 +262,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -275,6 +291,7 @@ def test_prediction_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -351,6 +368,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -384,6 +402,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -405,6 +424,7 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -435,6 +455,7 @@ def test_prediction_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -465,6 +486,7 @@ def test_prediction_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -484,6 +506,7 @@ def test_prediction_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -625,33 +648,6 @@ async def test_predict_field_headers_async(): assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] -def test_predict_flattened(): - client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.predict( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - - def test_predict_flattened_error(): client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) @@ -666,40 +662,6 @@ def test_predict_flattened_error(): ) -@pytest.mark.asyncio -async def test_predict_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.PredictResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.predict( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - - @pytest.mark.asyncio async def test_predict_flattened_error_async(): client = PredictionServiceAsyncClient( @@ -855,35 +817,6 @@ async def test_explain_field_headers_async(): assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] -def test_explain_flattened(): - client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.explain( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - assert args[0].deployed_model_id == "deployed_model_id_value" - - def test_explain_flattened_error(): client = PredictionServiceClient(credentials=ga_credentials.AnonymousCredentials(),) @@ -899,42 +832,6 @@ def test_explain_flattened_error(): ) -@pytest.mark.asyncio -async def test_explain_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.ExplainResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.explain( - endpoint="endpoint_value", - instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], - parameters=struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - assert args[0].deployed_model_id == "deployed_model_id_value" - - @pytest.mark.asyncio async def test_explain_flattened_error_async(): client = PredictionServiceAsyncClient( @@ -1183,7 +1080,6 @@ def test_prediction_service_transport_auth_adc_old_google_auth(transport_class): (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_prediction_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1212,79 +1108,6 @@ def test_prediction_service_transport_create_channel(transport_class, grpc_helpe ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PredictionServiceGrpcTransport, grpc_helpers), - (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_prediction_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.PredictionServiceGrpcTransport, grpc_helpers), - (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_prediction_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -1307,7 +1130,7 @@ def test_prediction_service_grpc_transport_client_cert_source_for_mtls(transport "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1416,7 +1239,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1463,7 +1286,7 @@ def test_prediction_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index d9f0d11522..bc1acf4603 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -58,8 +55,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -70,16 +68,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -143,6 +131,31 @@ def test_specialist_pool_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_specialist_pool_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] ) @@ -226,6 +239,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -242,6 +256,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -258,6 +273,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -286,6 +302,7 @@ def test_specialist_pool_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -362,6 +379,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -395,6 +413,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -416,6 +435,7 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -450,6 +470,7 @@ def test_specialist_pool_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -484,6 +505,7 @@ def test_specialist_pool_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -503,6 +525,7 @@ def test_specialist_pool_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -2100,7 +2123,6 @@ def test_specialist_pool_service_transport_auth_adc_old_google_auth(transport_cl (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_specialist_pool_service_transport_create_channel( transport_class, grpc_helpers ): @@ -2131,79 +2153,6 @@ def test_specialist_pool_service_transport_create_channel( ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.SpecialistPoolServiceGrpcTransport, grpc_helpers), - (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_specialist_pool_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.SpecialistPoolServiceGrpcTransport, grpc_helpers), - (transports.SpecialistPoolServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_specialist_pool_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2228,7 +2177,7 @@ def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2337,7 +2286,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2384,7 +2333,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py index aab827e031..15caab1a47 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_tensorboard_service.py @@ -42,9 +42,6 @@ ) from google.cloud.aiplatform_v1beta1.services.tensorboard_service import pagers from google.cloud.aiplatform_v1beta1.services.tensorboard_service import transports -from google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.tensorboard_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -71,8 +68,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -83,16 +81,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -156,6 +144,31 @@ def test_tensorboard_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.TensorboardServiceGrpcTransport, "grpc"), + (transports.TensorboardServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_tensorboard_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [TensorboardServiceClient, TensorboardServiceAsyncClient,] ) @@ -235,6 +248,7 @@ def test_tensorboard_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -251,6 +265,7 @@ def test_tensorboard_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -267,6 +282,7 @@ def test_tensorboard_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -295,6 +311,7 @@ def test_tensorboard_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -371,6 +388,7 @@ def test_tensorboard_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -404,6 +422,7 @@ def test_tensorboard_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -425,6 +444,7 @@ def test_tensorboard_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -455,6 +475,7 @@ def test_tensorboard_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -485,6 +506,7 @@ def test_tensorboard_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -504,6 +526,7 @@ def test_tensorboard_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -7542,7 +7565,6 @@ def test_tensorboard_service_transport_auth_adc_old_google_auth(transport_class) (transports.TensorboardServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_tensorboard_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -7571,79 +7593,6 @@ def test_tensorboard_service_transport_create_channel(transport_class, grpc_help ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.TensorboardServiceGrpcTransport, grpc_helpers), - (transports.TensorboardServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_tensorboard_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.TensorboardServiceGrpcTransport, grpc_helpers), - (transports.TensorboardServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_tensorboard_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -7668,7 +7617,7 @@ def test_tensorboard_service_grpc_transport_client_cert_source_for_mtls( "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -7777,7 +7726,7 @@ def test_tensorboard_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -7824,7 +7773,7 @@ def test_tensorboard_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py index e4df5b0517..e00fabb437 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -40,9 +40,6 @@ from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceClient from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers from google.cloud.aiplatform_v1beta1.services.vizier_service import transports -from google.cloud.aiplatform_v1beta1.services.vizier_service.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.aiplatform_v1beta1.services.vizier_service.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -57,8 +54,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -69,16 +67,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -141,6 +129,31 @@ def test_vizier_service_client_from_service_account_info(client_class): assert client.transport._host == "aiplatform.googleapis.com:443" +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.VizierServiceGrpcTransport, "grpc"), + (transports.VizierServiceGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_vizier_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + @pytest.mark.parametrize( "client_class", [VizierServiceClient, VizierServiceAsyncClient,] ) @@ -220,6 +233,7 @@ def test_vizier_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -236,6 +250,7 @@ def test_vizier_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is @@ -252,6 +267,7 @@ def test_vizier_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has @@ -280,6 +296,7 @@ def test_vizier_service_client_client_options( client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -346,6 +363,7 @@ def test_vizier_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case ADC client cert is provided. Whether client cert is used depends on @@ -379,6 +397,7 @@ def test_vizier_service_client_mtls_env_auto( client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) # Check the case client_cert_source and ADC client cert are not provided. @@ -400,6 +419,7 @@ def test_vizier_service_client_mtls_env_auto( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -430,6 +450,7 @@ def test_vizier_service_client_client_options_scopes( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -460,6 +481,7 @@ def test_vizier_service_client_client_options_credentials_file( client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -479,6 +501,7 @@ def test_vizier_service_client_client_options_from_dict(): client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, ) @@ -3838,7 +3861,6 @@ def test_vizier_service_transport_auth_adc_old_google_auth(transport_class): (transports.VizierServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_vizier_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -3867,79 +3889,6 @@ def test_vizier_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.VizierServiceGrpcTransport, grpc_helpers), - (transports.VizierServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_vizier_service_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.VizierServiceGrpcTransport, grpc_helpers), - (transports.VizierServiceGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_vizier_service_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "aiplatform.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -3962,7 +3911,7 @@ def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_cla "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -4069,7 +4018,7 @@ def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -4116,7 +4065,7 @@ def test_vizier_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/definition_v1/__init__.py b/tests/unit/gapic/definition_v1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/definition_v1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/definition_v1beta1/__init__.py b/tests/unit/gapic/definition_v1beta1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/definition_v1beta1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/instance_v1/__init__.py b/tests/unit/gapic/instance_v1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/instance_v1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/instance_v1beta1/__init__.py b/tests/unit/gapic/instance_v1beta1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/instance_v1beta1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/params_v1/__init__.py b/tests/unit/gapic/params_v1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/params_v1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/params_v1beta1/__init__.py b/tests/unit/gapic/params_v1beta1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/params_v1beta1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/prediction_v1/__init__.py b/tests/unit/gapic/prediction_v1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/prediction_v1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/prediction_v1beta1/__init__.py b/tests/unit/gapic/prediction_v1beta1/__init__.py new file mode 100644 index 0000000000..4de65971c2 --- /dev/null +++ b/tests/unit/gapic/prediction_v1beta1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#