Skip to content

Commit

Permalink
fix: Modify synth.py to update grpc transport options.
Browse files Browse the repository at this point in the history
  • Loading branch information
dpcollins-google committed Dec 23, 2020
1 parent 3d6a29d commit f316bb5
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 93 deletions.
46 changes: 4 additions & 42 deletions google/cloud/pubsub_v1/publisher/client.py
Expand Up @@ -25,7 +25,6 @@
import six

from google.api_core import gapic_v1
from google.api_core import grpc_helpers
from google.auth.credentials import AnonymousCredentials
from google.oauth2 import service_account

Expand All @@ -39,9 +38,6 @@
from google.cloud.pubsub_v1.publisher.flow_controller import FlowController
from google.pubsub_v1 import types as gapic_types
from google.pubsub_v1.services.publisher import client as publisher_client
from google.pubsub_v1.services.publisher.transports import (
grpc as publisher_grpc_transport,
)

__version__ = pkg_resources.get_distribution("google-cloud-pubsub").version

Expand Down Expand Up @@ -127,53 +123,19 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs):
# If so, create a grpc insecure channel with the emulator host
# as the target.
if os.environ.get("PUBSUB_EMULATOR_HOST"):
kwargs["channel"] = grpc.insecure_channel(
target=os.environ.get("PUBSUB_EMULATOR_HOST")
)
kwargs["client_options"] = {
"api_endpoint": os.environ.get("PUBSUB_EMULATOR_HOST")
}
kwargs["credentials"] = AnonymousCredentials()

# The GAPIC client has mTLS logic to determine the api endpoint and the
# ssl credentials to use. Here we create a GAPIC client to help compute the
# api endpoint and ssl credentials. The api endpoint will be used to set
# `self._target`, and ssl credentials will be passed to
# `grpc_helpers.create_channel` to establish a mTLS channel (if ssl
# credentials is not None).
client_options = kwargs.get("client_options", None)
credentials = kwargs.get("credentials", None)
client_for_mtls_info = publisher_client.PublisherClient(
credentials=credentials, client_options=client_options
)

self._target = client_for_mtls_info._transport._host

# Use a custom channel.
# We need this in order to set appropriate default message size and
# keepalive options.
if "transport" not in kwargs:
channel = kwargs.pop("channel", None)
if channel is None:
channel = grpc_helpers.create_channel(
credentials=kwargs.pop("credentials", None),
target=self.target,
ssl_credentials=client_for_mtls_info._transport._ssl_channel_credentials,
scopes=publisher_client.PublisherClient._DEFAULT_SCOPES,
options={
"grpc.max_send_message_length": -1,
"grpc.max_receive_message_length": -1,
}.items(),
)
# cannot pass both 'channel' and 'credentials'
kwargs.pop("credentials", None)
transport = publisher_grpc_transport.PublisherGrpcTransport(channel=channel)
kwargs["transport"] = transport

# For a transient failure, retry publishing the message infinitely.
self.publisher_options = types.PublisherOptions(*publisher_options)
self._enable_message_ordering = self.publisher_options[0]

# Add the metrics headers, and instantiate the underlying GAPIC
# client.
self.api = publisher_client.PublisherClient(**kwargs)
self._target = self.api._transport._host
self._batch_class = thread.Batch
self.batch_settings = types.BatchSettings(*batch_settings)

Expand Down
52 changes: 5 additions & 47 deletions google/cloud/pubsub_v1/subscriber/client.py
Expand Up @@ -19,7 +19,6 @@

import grpc

from google.api_core import grpc_helpers
from google.auth.credentials import AnonymousCredentials
from google.oauth2 import service_account

Expand All @@ -28,9 +27,6 @@
from google.cloud.pubsub_v1.subscriber import futures
from google.cloud.pubsub_v1.subscriber._protocol import streaming_pull_manager
from google.pubsub_v1.services.subscriber import client as subscriber_client
from google.pubsub_v1.services.subscriber.transports import (
grpc as subscriber_grpc_transport,
)


__version__ = pkg_resources.get_distribution("google-cloud-pubsub").version
Expand Down Expand Up @@ -78,52 +74,14 @@ def __init__(self, **kwargs):
# If so, create a grpc insecure channel with the emulator host
# as the target.
if os.environ.get("PUBSUB_EMULATOR_HOST"):
kwargs["channel"] = grpc.insecure_channel(
target=os.environ.get("PUBSUB_EMULATOR_HOST")
)
kwargs["client_options"] = {
"api_endpoint": os.environ.get("PUBSUB_EMULATOR_HOST")
}
kwargs["credentials"] = AnonymousCredentials()

# The GAPIC client has mTLS logic to determine the api endpoint and the
# ssl credentials to use. Here we create a GAPIC client to help compute the
# api endpoint and ssl credentials. The api endpoint will be used to set
# `self._target`, and ssl credentials will be passed to
# `grpc_helpers.create_channel` to establish a mTLS channel (if ssl
# credentials is not None).
client_options = kwargs.get("client_options", None)
credentials = kwargs.get("credentials", None)
client_for_mtls_info = subscriber_client.SubscriberClient(
credentials=credentials, client_options=client_options
)

self._target = client_for_mtls_info._transport._host

# Use a custom channel.
# We need this in order to set appropriate default message size and
# keepalive options.
if "transport" not in kwargs:
channel = kwargs.pop("channel", None)
if channel is None:
channel = grpc_helpers.create_channel(
credentials=kwargs.pop("credentials", None),
target=self.target,
ssl_credentials=client_for_mtls_info._transport._ssl_channel_credentials,
scopes=subscriber_client.SubscriberClient._DEFAULT_SCOPES,
options={
"grpc.max_send_message_length": -1,
"grpc.max_receive_message_length": -1,
"grpc.keepalive_time_ms": 30000,
}.items(),
)
# cannot pass both 'channel' and 'credentials'
kwargs.pop("credentials", None)
transport = subscriber_grpc_transport.SubscriberGrpcTransport(
channel=channel
)
kwargs["transport"] = transport

# Add the metrics headers, and instantiate the underlying GAPIC
# client.
# Instantiate the underlying GAPIC client.
self._api = subscriber_client.SubscriberClient(**kwargs)
self._target = self._api._transport._host

@classmethod
def from_service_account_file(cls, filename, **kwargs):
Expand Down
12 changes: 12 additions & 0 deletions synth.py
Expand Up @@ -63,6 +63,18 @@
\g<0>""",
)

# Modify GRPC options in transports.
s.replace(
["google/pubsub_v1/services/*/transports/grpc*",
"tests/unit/gapic/pubsub_v1/*"],
"options=[.*]",
"""options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.keepalive_time_ms": 30000),
]"""
)

# Monkey patch the streaming_pull() GAPIC method to disable pre-fetching stream
# results.
s.replace(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/pubsub_v1/publisher/test_publisher_client.py
Expand Up @@ -130,7 +130,7 @@ def init(self, *args, **kwargs):


def test_init_emulator(monkeypatch):
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/foo/bar/")
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/foo/bar:123")
# NOTE: When the emulator host is set, a custom channel will be used, so
# no credentials (mock ot otherwise) can be passed in.
client = publisher.Client()
Expand All @@ -140,7 +140,7 @@ def test_init_emulator(monkeypatch):
# Sadly, there seems to be no good way to do this without poking at
# the private API of gRPC.
channel = client.api._transport.publish._channel
assert channel.target().decode("utf8") == "/foo/bar/"
assert channel.target().decode("utf8") == "/foo/bar:123"


def test_message_ordering_enabled():
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/pubsub_v1/subscriber/test_subscriber_client.py
Expand Up @@ -91,7 +91,7 @@ def init(self, *args, **kwargs):


def test_init_emulator(monkeypatch):
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/baz/bacon/")
monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/baz/bacon:123")
# NOTE: When the emulator host is set, a custom channel will be used, so
# no credentials (mock ot otherwise) can be passed in.
client = subscriber.Client()
Expand All @@ -101,7 +101,7 @@ def test_init_emulator(monkeypatch):
# Sadly, there seems to be no good way to do this without poking at
# the private API of gRPC.
channel = client.api._transport.pull._channel
assert channel.target().decode("utf8") == "/baz/bacon/"
assert channel.target().decode("utf8") == "/baz/bacon:123"


def test_class_method_factory():
Expand Down

0 comments on commit f316bb5

Please sign in to comment.