Skip to content

Commit

Permalink
feat: add client_cert_source_for_mtls argument to transports (#135)
Browse files Browse the repository at this point in the history
This PR was generated using Autosynth. 🌈

Synth log will be available here:
https://source.cloud.google.com/results/invocations/3b4457c8-4080-407a-9a6d-4a48ddcea154/targets

- [ ] To automatically regenerate this PR, check this box.

PiperOrigin-RevId: 354996675
Source-Link: googleapis/googleapis@20712b8
PiperOrigin-RevId: 352816749
Source-Link: googleapis/googleapis@ceaaf31
  • Loading branch information
yoshi-automation committed Feb 3, 2021
1 parent bc94422 commit 072850d
Show file tree
Hide file tree
Showing 13 changed files with 454 additions and 305 deletions.
18 changes: 7 additions & 11 deletions google/cloud/bigquery_storage_v1/services/big_query_read/client.py
Expand Up @@ -322,21 +322,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)

ssl_credentials = None
client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
import grpc # type: ignore

cert, key = client_options.client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
is_mtls = True
client_cert_source_func = client_options.client_cert_source
else:
creds = SslCredentials()
is_mtls = creds.is_mtls
ssl_credentials = creds.ssl_credentials if is_mtls else None
is_mtls = mtls.has_default_client_cert_source()
client_cert_source_func = (
mtls.default_client_cert_source() if is_mtls else None
)

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
Expand Down Expand Up @@ -379,7 +375,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
ssl_channel_credentials=ssl_credentials,
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
Expand Down
Expand Up @@ -59,6 +59,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -89,6 +90,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -105,6 +110,11 @@ def __init__(
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -114,11 +124,6 @@ def __init__(
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
DeprecationWarning,
)

host = (
api_mtls_endpoint
if ":" in api_mtls_endpoint
Expand Down Expand Up @@ -162,12 +167,18 @@ def __init__(
scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Expand Up @@ -103,6 +103,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -134,6 +135,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -150,6 +155,11 @@ def __init__(
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -159,11 +169,6 @@ def __init__(
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
DeprecationWarning,
)

host = (
api_mtls_endpoint
if ":" in api_mtls_endpoint
Expand Down Expand Up @@ -207,12 +212,18 @@ def __init__(
scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Expand Up @@ -324,21 +324,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)

ssl_credentials = None
client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
import grpc # type: ignore

cert, key = client_options.client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
is_mtls = True
client_cert_source_func = client_options.client_cert_source
else:
creds = SslCredentials()
is_mtls = creds.is_mtls
ssl_credentials = creds.ssl_credentials if is_mtls else None
is_mtls = mtls.has_default_client_cert_source()
client_cert_source_func = (
mtls.default_client_cert_source() if is_mtls else None
)

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
Expand Down Expand Up @@ -381,7 +377,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
ssl_channel_credentials=ssl_credentials,
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
Expand Down
Expand Up @@ -61,6 +61,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -91,6 +92,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -107,6 +112,11 @@ def __init__(
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -116,11 +126,6 @@ def __init__(
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
DeprecationWarning,
)

host = (
api_mtls_endpoint
if ":" in api_mtls_endpoint
Expand Down Expand Up @@ -164,12 +169,18 @@ def __init__(
scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Expand Up @@ -105,6 +105,7 @@ def __init__(
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -136,6 +137,10 @@ def __init__(
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -152,6 +157,11 @@ def __init__(
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -161,11 +171,6 @@ def __init__(
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn(
"api_mtls_endpoint and client_cert_source are deprecated",
DeprecationWarning,
)

host = (
api_mtls_endpoint
if ":" in api_mtls_endpoint
Expand Down Expand Up @@ -209,12 +214,18 @@ def __init__(
scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id
)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Expand Up @@ -314,21 +314,17 @@ def __init__(
util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
)

ssl_credentials = None
client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
import grpc # type: ignore

cert, key = client_options.client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
is_mtls = True
client_cert_source_func = client_options.client_cert_source
else:
creds = SslCredentials()
is_mtls = creds.is_mtls
ssl_credentials = creds.ssl_credentials if is_mtls else None
is_mtls = mtls.has_default_client_cert_source()
client_cert_source_func = (
mtls.default_client_cert_source() if is_mtls else None
)

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
Expand Down Expand Up @@ -371,7 +367,7 @@ def __init__(
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
ssl_channel_credentials=ssl_credentials,
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
Expand Down

0 comments on commit 072850d

Please sign in to comment.