Skip to content

Commit

Permalink
fix: add mtls support
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Oct 21, 2020
1 parent ec8f5f2 commit 0d269d3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
22 changes: 14 additions & 8 deletions google/cloud/pubsub_v1/publisher/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,20 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs):
target=os.environ.get("PUBSUB_EMULATOR_HOST")
)

# The auto-generated layer's client has mTLS logic to determine the api
# endpoint and the ssl credentials to use. Here we create a client
# and extract the api endpoint and ssl credentials. The api endpoint
# will be used to set `self._target`, and ssl credentials will be
# passed to `grpc_helpers.create_channel` to establish a mTLS channel
# (if ssl credentials is not None).
client_options = kwargs.get("client_options", None)
if (
client_options
and "api_endpoint" in client_options
and isinstance(client_options["api_endpoint"], six.string_types)
):
self._target = client_options["api_endpoint"]
else:
self._target = publisher_client.PublisherClient.SERVICE_ADDRESS
credentials=kwargs.get("credentials", None)
publisher_client_instance = publisher_client.PublisherClient(
credentials=credentials,
client_options=client_options
)

self._target = publisher_client_instance._transport._host

# Use a custom channel.
# We need this in order to set appropriate default message size and
Expand All @@ -149,6 +154,7 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs):
channel = grpc_helpers.create_channel(
credentials=kwargs.pop("credentials", None),
target=self.target,
ssl_credentials=publisher_client_instance._transport._ssl_channel_credentials,
scopes=publisher_client.PublisherClient._DEFAULT_SCOPES,
options={
"grpc.max_send_message_length": -1,
Expand Down
23 changes: 14 additions & 9 deletions google/cloud/pubsub_v1/subscriber/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,20 @@ def __init__(self, **kwargs):
target=os.environ.get("PUBSUB_EMULATOR_HOST")
)

# api_endpoint wont be applied if 'transport' is passed in.
# The auto-generated layer's client has mTLS logic to determine the api
# endpoint and the ssl credentials to use. Here we create a client
# and extract the api endpoint and ssl credentials. The api endpoint
# will be used to set `self._target`, and ssl credentials will be
# passed to `grpc_helpers.create_channel` to establish a mTLS channel
# (if ssl credentials is not None).
client_options = kwargs.get("client_options", None)
if (
client_options
and "api_endpoint" in client_options
and isinstance(client_options["api_endpoint"], six.string_types)
):
self._target = client_options["api_endpoint"]
else:
self._target = subscriber_client.SubscriberClient.SERVICE_ADDRESS
credentials=kwargs.get("credentials", None)
subscriber_client_instance = subscriber_client.SubscriberClient(
credentials=credentials,
client_options=client_options
)

self._target = subscriber_client_instance._transport._host

# Use a custom channel.
# We need this in order to set appropriate default message size and
Expand All @@ -102,6 +106,7 @@ def __init__(self, **kwargs):
channel = grpc_helpers.create_channel(
credentials=kwargs.pop("credentials", None),
target=self.target,
ssl_credentials=subscriber_client_instance._transport._ssl_channel_credentials,
scopes=subscriber_client.SubscriberClient._DEFAULT_SCOPES,
options={
"grpc.max_send_message_length": -1,
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/pubsub_v1/publisher/test_publisher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_init_w_api_endpoint():
assert isinstance(client.api, publisher_client.PublisherClient)
assert (client.api._transport.grpc_channel._channel.target()).decode(
"utf-8"
) == "testendpoint.google.com"
) == "testendpoint.google.com:443"


def test_init_w_unicode_api_endpoint():
Expand All @@ -91,7 +91,7 @@ def test_init_w_unicode_api_endpoint():
assert isinstance(client.api, publisher_client.PublisherClient)
assert (client.api._transport.grpc_channel._channel.target()).decode(
"utf-8"
) == "testendpoint.google.com"
) == "testendpoint.google.com:443"


def test_init_w_empty_client_options():
Expand All @@ -106,6 +106,9 @@ def test_init_w_empty_client_options():
def test_init_client_options_pass_through():
def init(self, *args, **kwargs):
self.kwargs = kwargs
self._transport = mock.Mock()
self._transport._host = "testendpoint.google.com"
self._transport._ssl_channel_credentials = None

with mock.patch.object(publisher_client.PublisherClient, "__init__", init):
client = publisher.Client(
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/pubsub_v1/subscriber/test_subscriber_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_init_w_api_endpoint():
assert isinstance(client.api, subscriber_client.SubscriberClient)
assert (client.api._transport.grpc_channel._channel.target()).decode(
"utf-8"
) == "testendpoint.google.com"
) == "testendpoint.google.com:443"


def test_init_w_unicode_api_endpoint():
Expand All @@ -52,7 +52,7 @@ def test_init_w_unicode_api_endpoint():
assert isinstance(client.api, subscriber_client.SubscriberClient)
assert (client.api._transport.grpc_channel._channel.target()).decode(
"utf-8"
) == "testendpoint.google.com"
) == "testendpoint.google.com:443"


def test_init_w_empty_client_options():
Expand All @@ -67,6 +67,9 @@ def test_init_w_empty_client_options():
def test_init_client_options_pass_through():
def init(self, *args, **kwargs):
self.kwargs = kwargs
self._transport = mock.Mock()
self._transport._host = "testendpoint.google.com"
self._transport._ssl_channel_credentials = None

with mock.patch.object(subscriber_client.SubscriberClient, "__init__", init):
client = subscriber.Client(
Expand Down

0 comments on commit 0d269d3

Please sign in to comment.