Skip to content

Commit

Permalink
fix: enable self signed jwt for grpc (#458)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 386504689

Source-Link: googleapis/googleapis@762094a

Source-Link: googleapis/googleapis-gen@6bfc480
  • Loading branch information
gcf-owl-bot[bot] committed Jul 24, 2021
1 parent 911829d commit c6e0ff6
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 44 deletions.
1 change: 0 additions & 1 deletion google/pubsub_v1/services/publisher/async_client.py
Expand Up @@ -29,7 +29,6 @@

from google.iam.v1 import iam_policy_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
from google.protobuf import duration_pb2 # type: ignore
from google.pubsub_v1.services.publisher import pagers
from google.pubsub_v1.types import pubsub
from google.pubsub_v1.types import TimeoutType
Expand Down
5 changes: 4 additions & 1 deletion google/pubsub_v1/services/publisher/client.py
Expand Up @@ -34,7 +34,6 @@

from google.iam.v1 import iam_policy_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
from google.protobuf import duration_pb2 # type: ignore
from google.pubsub_v1.services.publisher import pagers
from google.pubsub_v1.types import pubsub
from google.pubsub_v1.types import TimeoutType
Expand Down Expand Up @@ -399,6 +398,10 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
)

def create_topic(
Expand Down
4 changes: 4 additions & 0 deletions google/pubsub_v1/services/schema_service/client.py
Expand Up @@ -342,6 +342,10 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
)

def create_schema(
Expand Down
4 changes: 4 additions & 0 deletions google/pubsub_v1/services/subscriber/client.py
Expand Up @@ -411,6 +411,10 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
always_use_jwt_access=(
Transport == type(self).get_transport_class("grpc")
or Transport == type(self).get_transport_class("grpc_asyncio")
),
)

def create_subscription(
Expand Down
105 changes: 101 additions & 4 deletions owlbot.py
Expand Up @@ -299,14 +299,111 @@
),
)

# Add development feature `message_retention_duration` from pubsub_dev branch of googleapis
# See PR https://github.com/googleapis/python-pubsub/pull/456
count = s.replace(
library / f"google/pubsub_{library.name}/types/pubsub.py",
"""satisfies_pzs \(bool\):
Reserved for future use. This field is set
only in responses from the server; it is ignored
if it is set in any requests.""",
"""satisfies_pzs (bool):
Reserved for future use. This field is set
only in responses from the server; it is ignored
if it is set in any requests.
message_retention_duration (google.protobuf.duration_pb2.Duration):
Indicates the minimum duration to retain a message after it
is published to the topic. If this field is set, messages
published to the topic in the last
``message_retention_duration`` are always available to
subscribers. For instance, it allows any attached
subscription to `seek to a
timestamp <https://cloud.google.com/pubsub/docs/replay-overview#seek_to_a_time>`__
that is up to ``message_retention_duration`` in the past. If
this field is not set, message retention is controlled by
settings on individual subscriptions. Cannot be more than 7
days or less than 10 minutes."""
)

# Add development feature `message_retention_duration` from pubsub_dev branch of googleapis
# See PR https://github.com/googleapis/python-pubsub/pull/456
count += s.replace(
library / f"google/pubsub_{library.name}/types/pubsub.py",
"""satisfies_pzs = proto.Field\(
proto.BOOL,
number=7,
\)""",
"""satisfies_pzs = proto.Field(
proto.BOOL,
number=7,
)
message_retention_duration = proto.Field(
proto.MESSAGE, number=8, message=duration_pb2.Duration,
)"""
)

# Add development feature `topic_message_retention_duration` from pubsub_dev branch of googleapis
# See PR https://github.com/googleapis/python-pubsub/pull/456
count += s.replace(
library / f"google/pubsub_{library.name}/types/pubsub.py",
"""detached \(bool\):
Indicates whether the subscription is detached from its
topic. Detached subscriptions don't receive messages from
their topic and don't retain any backlog. ``Pull`` and
``StreamingPull`` requests will return FAILED_PRECONDITION.
If the subscription is a push subscription, pushes to the
endpoint will not be made.""",
"""detached (bool):
Indicates whether the subscription is detached from its
topic. Detached subscriptions don't receive messages from
their topic and don't retain any backlog. ``Pull`` and
``StreamingPull`` requests will return FAILED_PRECONDITION.
If the subscription is a push subscription, pushes to the
endpoint will not be made.
topic_message_retention_duration (google.protobuf.duration_pb2.Duration):
Output only. Indicates the minimum duration for which a
message is retained after it is published to the
subscription's topic. If this field is set, messages
published to the subscription's topic in the last
``topic_message_retention_duration`` are always available to
subscribers. See the ``message_retention_duration`` field in
``Topic``. This field is set only in responses from the
server; it is ignored if it is set in any requests."""
)

# Add development feature `topic_message_retention_duration` from pubsub_dev branch of googleapis
# See PR https://github.com/googleapis/python-pubsub/pull/456
count += s.replace(
library / f"google/pubsub_{library.name}/types/pubsub.py",
"""detached = proto.Field\(
proto.BOOL,
number=15,
\)""",
"""detached = proto.Field(
proto.BOOL,
number=15,
)
topic_message_retention_duration = proto.Field(
proto.MESSAGE, number=17, message=duration_pb2.Duration,
)
"""
)

if count != 4:
raise Exception("Pub/Sub topic retention feature not added")

# The namespace package declaration in google/cloud/__init__.py should be excluded
# from coverage.
s.replace(
".coveragerc",
r"((?P<indent>[^\n\S]+)google/pubsub/__init__\.py)",
"\g<indent>google/cloud/__init__.py\n\g<0>",
count = s.replace(
library / ".coveragerc",
"google/pubsub/__init__.py",
"""google/cloud/__init__.py
google/pubsub/__init__.py""",
)

if count < 1:
raise Exception(".coveragerc replacement failed.")

s.move(
library,
excludes=[
Expand Down
4 changes: 2 additions & 2 deletions scripts/fixup_pubsub_v1_keywords.py
Expand Up @@ -42,8 +42,8 @@ class pubsubCallTransformer(cst.CSTTransformer):
'acknowledge': ('subscription', 'ack_ids', ),
'create_schema': ('parent', 'schema', 'schema_id', ),
'create_snapshot': ('name', 'subscription', 'labels', ),
'create_subscription': ('name', 'topic', 'push_config', 'ack_deadline_seconds', 'retain_acked_messages', 'message_retention_duration', 'labels', 'enable_message_ordering', 'expiration_policy', 'filter', 'dead_letter_policy', 'retry_policy', 'detached', 'topic_message_retention_duration', ),
'create_topic': ('name', 'labels', 'message_storage_policy', 'kms_key_name', 'schema_settings', 'satisfies_pzs', 'message_retention_duration', ),
'create_subscription': ('name', 'topic', 'push_config', 'ack_deadline_seconds', 'retain_acked_messages', 'message_retention_duration', 'labels', 'enable_message_ordering', 'expiration_policy', 'filter', 'dead_letter_policy', 'retry_policy', 'detached', ),
'create_topic': ('name', 'labels', 'message_storage_policy', 'kms_key_name', 'schema_settings', 'satisfies_pzs', ),
'delete_schema': ('name', ),
'delete_snapshot': ('snapshot', ),
'delete_subscription': ('subscription', ),
Expand Down
30 changes: 18 additions & 12 deletions tests/unit/gapic/pubsub_v1/test_publisher.py
Expand Up @@ -35,7 +35,6 @@
from google.iam.v1 import options_pb2 # type: ignore
from google.iam.v1 import policy_pb2 # type: ignore
from google.oauth2 import service_account
from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
from google.pubsub_v1.services.publisher import PublisherAsyncClient
Expand Down Expand Up @@ -116,24 +115,14 @@ def test_publisher_client_from_service_account_info(client_class):
assert client.transport._host == "pubsub.googleapis.com:443"


@pytest.mark.parametrize("client_class", [PublisherClient, PublisherAsyncClient,])
def test_publisher_client_service_account_always_use_jwt(client_class):
with mock.patch.object(
service_account.Credentials, "with_always_use_jwt_access", create=True
) as use_jwt:
creds = service_account.Credentials(None, None, None)
client = client_class(credentials=creds)
use_jwt.assert_not_called()


@pytest.mark.parametrize(
"transport_class,transport_name",
[
(transports.PublisherGrpcTransport, "grpc"),
(transports.PublisherGrpcAsyncIOTransport, "grpc_asyncio"),
],
)
def test_publisher_client_service_account_always_use_jwt_true(
def test_publisher_client_service_account_always_use_jwt(
transport_class, transport_name
):
with mock.patch.object(
Expand All @@ -143,6 +132,13 @@ def test_publisher_client_service_account_always_use_jwt_true(
transport = transport_class(credentials=creds, always_use_jwt_access=True)
use_jwt.assert_called_once_with(True)

with mock.patch.object(
service_account.Credentials, "with_always_use_jwt_access", create=True
) as use_jwt:
creds = service_account.Credentials(None, None, None)
transport = transport_class(credentials=creds, always_use_jwt_access=False)
use_jwt.assert_not_called()


@pytest.mark.parametrize("client_class", [PublisherClient, PublisherAsyncClient,])
def test_publisher_client_from_service_account_file(client_class):
Expand Down Expand Up @@ -217,6 +213,7 @@ def test_publisher_client_client_options(client_class, transport_class, transpor
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
Expand All @@ -233,6 +230,7 @@ def test_publisher_client_client_options(client_class, transport_class, transpor
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
Expand All @@ -249,6 +247,7 @@ def test_publisher_client_client_options(client_class, transport_class, transpor
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)

# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
Expand Down Expand Up @@ -277,6 +276,7 @@ def test_publisher_client_client_options(client_class, transport_class, transpor
client_cert_source_for_mtls=None,
quota_project_id="octopus",
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)


Expand Down Expand Up @@ -341,6 +341,7 @@ def test_publisher_client_mtls_env_auto(
client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)

# Check the case ADC client cert is provided. Whether client cert is used depends on
Expand Down Expand Up @@ -374,6 +375,7 @@ def test_publisher_client_mtls_env_auto(
client_cert_source_for_mtls=expected_client_cert_source,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)

# Check the case client_cert_source and ADC client cert are not provided.
Expand All @@ -395,6 +397,7 @@ def test_publisher_client_mtls_env_auto(
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)


Expand Down Expand Up @@ -425,6 +428,7 @@ def test_publisher_client_client_options_scopes(
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)


Expand Down Expand Up @@ -455,6 +459,7 @@ def test_publisher_client_client_options_credentials_file(
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)


Expand All @@ -472,6 +477,7 @@ def test_publisher_client_client_options_from_dict():
client_cert_source_for_mtls=None,
quota_project_id=None,
client_info=transports.base.DEFAULT_CLIENT_INFO,
always_use_jwt_access=True,
)


Expand Down

0 comments on commit c6e0ff6

Please sign in to comment.