diff --git a/google/cloud/pubsub_v1/publisher/client.py b/google/cloud/pubsub_v1/publisher/client.py index 9284420f5..caa784407 100644 --- a/google/cloud/pubsub_v1/publisher/client.py +++ b/google/cloud/pubsub_v1/publisher/client.py @@ -133,6 +133,16 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs): target=os.environ.get("PUBSUB_EMULATOR_HOST") ) + client_options = kwargs.pop("client_options", None) + if ( + client_options + and "api_endpoint" in client_options + and isinstance(client_options["api_endpoint"], six.string_types) + ): + self._target = client_options["api_endpoint"] + else: + self._target = publisher_client.PublisherClient.SERVICE_ADDRESS + # Use a custom channel. # We need this in order to set appropriate default message size and # keepalive options. @@ -217,7 +227,7 @@ def target(self): Returns: str: The location of the API. """ - return publisher_client.PublisherClient.SERVICE_ADDRESS + return self._target def _get_or_create_sequencer(self, topic, ordering_key): """ Get an existing sequencer or create a new one given the (topic, diff --git a/google/cloud/pubsub_v1/subscriber/client.py b/google/cloud/pubsub_v1/subscriber/client.py index 00d97231e..718e69083 100644 --- a/google/cloud/pubsub_v1/subscriber/client.py +++ b/google/cloud/pubsub_v1/subscriber/client.py @@ -16,6 +16,7 @@ import os import pkg_resources +import six import grpc @@ -79,6 +80,17 @@ def __init__(self, **kwargs): target=os.environ.get("PUBSUB_EMULATOR_HOST") ) + # api_endpoint wont be applied if 'transport' is passed in. + client_options = kwargs.pop("client_options", None) + if ( + client_options + and "api_endpoint" in client_options + and isinstance(client_options["api_endpoint"], six.string_types) + ): + self._target = client_options["api_endpoint"] + else: + self._target = subscriber_client.SubscriberClient.SERVICE_ADDRESS + # Use a custom channel. # We need this in order to set appropriate default message size and # keepalive options. @@ -133,7 +145,7 @@ def target(self): Returns: str: The location of the API. """ - return subscriber_client.SubscriberClient.SERVICE_ADDRESS + return self._target @property def api(self): diff --git a/tests/unit/pubsub_v1/publisher/test_publisher_client.py b/tests/unit/pubsub_v1/publisher/test_publisher_client.py index 69c854b47..4ca979892 100644 --- a/tests/unit/pubsub_v1/publisher/test_publisher_client.py +++ b/tests/unit/pubsub_v1/publisher/test_publisher_client.py @@ -53,6 +53,35 @@ def test_init_w_custom_transport(): assert client.batch_settings.max_messages == 100 +def test_init_w_api_endpoint(): + client_options = {"api_endpoint": "testendpoint.google.com"} + client = publisher.Client(client_options=client_options) + + assert isinstance(client.api, publisher_client.PublisherClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == "testendpoint.google.com" + + +def test_init_w_unicode_api_endpoint(): + client_options = {"api_endpoint": u"testendpoint.google.com"} + client = publisher.Client(client_options=client_options) + + assert isinstance(client.api, publisher_client.PublisherClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == "testendpoint.google.com" + + +def test_init_w_empty_client_options(): + client = publisher.Client(client_options={}) + + assert isinstance(client.api, publisher_client.PublisherClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == publisher_client.PublisherClient.SERVICE_ADDRESS + + def test_init_emulator(monkeypatch): monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/foo/bar/") # NOTE: When the emulator host is set, a custom channel will be used, so diff --git a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py index 19ec194ce..d8f671157 100644 --- a/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py +++ b/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py @@ -34,6 +34,35 @@ def test_init_w_custom_transport(): assert client.api.transport is transport +def test_init_w_api_endpoint(): + client_options = {"api_endpoint": "testendpoint.google.com"} + client = subscriber.Client(client_options=client_options) + + assert isinstance(client.api, subscriber_client.SubscriberClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == "testendpoint.google.com" + + +def test_init_w_unicode_api_endpoint(): + client_options = {"api_endpoint": u"testendpoint.google.com"} + client = subscriber.Client(client_options=client_options) + + assert isinstance(client.api, subscriber_client.SubscriberClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == "testendpoint.google.com" + + +def test_init_w_empty_client_options(): + client = subscriber.Client(client_options={}) + + assert isinstance(client.api, subscriber_client.SubscriberClient) + assert (client.api.transport._channel._channel.target()).decode( + "utf-8" + ) == subscriber_client.SubscriberClient.SERVICE_ADDRESS + + def test_init_emulator(monkeypatch): monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/baz/bacon/") # NOTE: When the emulator host is set, a custom channel will be used, so