From 5df0f0965c541ca546d3851be1ab7782dc80a11b Mon Sep 17 00:00:00 2001 From: Yoshi Automation Bot Date: Wed, 3 Feb 2021 09:12:28 -0800 Subject: [PATCH] feat: add `client_cert_source_for_mtls` argument to transports (#78) * changes without context autosynth cannot find the source of changes triggered by earlier changes in this repository, or by version upgrades to tools such as linters. * chore: update Go generator, rules_go, and protobuf PiperOrigin-RevId: 352816749 Source-Author: Google APIs Source-Date: Wed Jan 20 10:06:23 2021 -0800 Source-Repo: googleapis/googleapis Source-Sha: ceaaf31b3d13badab7cf9d3b570f5639db5593d9 Source-Link: https://github.com/googleapis/googleapis/commit/ceaaf31b3d13badab7cf9d3b570f5639db5593d9 * chore: upgrade gapic-generator-python to 0.40.5 PiperOrigin-RevId: 354996675 Source-Author: Google APIs Source-Date: Mon Feb 1 12:11:49 2021 -0800 Source-Repo: googleapis/googleapis Source-Sha: 20712b8fe95001b312f62c6c5f33e3e3ec92cfaf Source-Link: https://github.com/googleapis/googleapis/commit/20712b8fe95001b312f62c6c5f33e3e3ec92cfaf * revert unblacken Co-authored-by: Tim Swast --- .../services/reservation_service/client.py | 18 +- .../reservation_service/transports/grpc.py | 23 ++- .../transports/grpc_asyncio.py | 23 ++- synth.metadata | 6 +- .../test_reservation_service.py | 187 +++++++++++------- 5 files changed, 156 insertions(+), 101 deletions(-) diff --git a/google/cloud/bigquery_reservation_v1/services/reservation_service/client.py b/google/cloud/bigquery_reservation_v1/services/reservation_service/client.py index f6c0490a..ad850098 100644 --- a/google/cloud/bigquery_reservation_v1/services/reservation_service/client.py +++ b/google/cloud/bigquery_reservation_v1/services/reservation_service/client.py @@ -363,21 +363,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -420,7 +416,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) diff --git a/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc.py b/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc.py index 1f3f9a28..95c1117f 100644 --- a/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc.py +++ b/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc.py @@ -75,6 +75,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -105,6 +106,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -121,6 +126,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -130,11 +140,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -178,12 +183,18 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ diff --git a/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc_asyncio.py b/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc_asyncio.py index 216a7845..66365d18 100644 --- a/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_reservation_v1/services/reservation_service/transports/grpc_asyncio.py @@ -119,6 +119,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -150,6 +151,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -166,6 +171,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -175,11 +185,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -223,12 +228,18 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, options=[ diff --git a/synth.metadata b/synth.metadata index 70085e2c..a91989c3 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,15 +4,15 @@ "git": { "name": ".", "remote": "https://github.com/googleapis/python-bigquery-reservation.git", - "sha": "555c4eeb6daa09b713632e92fbc8dc2a7aa37495" + "sha": "c30f0d02e4818306ee9f00a40a9e5b17657f70f2" } }, { "git": { "name": "googleapis", "remote": "https://github.com/googleapis/googleapis.git", - "sha": "520682435235d9c503983a360a2090025aa47cd1", - "internalRef": "350246057" + "sha": "20712b8fe95001b312f62c6c5f33e3e3ec92cfaf", + "internalRef": "354996675" } }, { diff --git a/tests/unit/gapic/bigquery_reservation_v1/test_reservation_service.py b/tests/unit/gapic/bigquery_reservation_v1/test_reservation_service.py index 58e735ee..faba4a02 100644 --- a/tests/unit/gapic/bigquery_reservation_v1/test_reservation_service.py +++ b/tests/unit/gapic/bigquery_reservation_v1/test_reservation_service.py @@ -181,7 +181,7 @@ def test_reservation_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -197,7 +197,7 @@ def test_reservation_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -213,7 +213,7 @@ def test_reservation_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -241,7 +241,7 @@ def test_reservation_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -302,29 +302,25 @@ def test_reservation_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -333,66 +329,53 @@ def test_reservation_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -418,7 +401,7 @@ def test_reservation_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -448,7 +431,7 @@ def test_reservation_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -467,7 +450,7 @@ def test_reservation_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -5578,6 +5561,56 @@ def test_reservation_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.ReservationServiceGrpcTransport, + transports.ReservationServiceGrpcAsyncIOTransport, + ], +) +def test_reservation_service_grpc_transport_client_cert_source_for_mtls( + transport_class, +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", + ), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_reservation_service_host_no_port(): client = ReservationServiceClient( credentials=credentials.AnonymousCredentials(), @@ -5622,6 +5655,8 @@ def test_reservation_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -5677,6 +5712,8 @@ def test_reservation_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [