Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ensure rest unit tests have complete coverage #1098

Merged
merged 6 commits into from Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -166,6 +166,31 @@ class {{service.name}}RestTransport({{service.name}}Transport):
{% endif %}{# service.has_lro #}
{% for method in service.methods.values() %}
{%- if method.http_options and not (method.server_streaming or method.client_streaming) %}

{% if method.input.required_fields %}
@staticmethod
def _{{ method.name | snake_case }}_populate_required_fields(message_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer it if we moved mutation to the caller, something like

__{{ method.name | snake_case }}_required_fields_default_values =  {
    {% for req_field in method.input.required_fields if req_field.is_primitive %}
        "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %}{# default is str #}
    {% endfor %}
}


def _{{ method.name | snake_case }}_get_unset_required_fields(message_dict):
    return {k, v for k, v in __{{ method.name | snake_case }}_required_fields_default_values.items() if k not in message_dict}

    ....
    query_params.update(self._{{ method.name | snake_case }}_get_unset_required_fields(query_params)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

required_fields = [
{% for req_field in method.input.required_fields %}
{% if req_field.is_primitive %}
(
"{{ req_field.name | camel_case }}",
{% if req_field.field_pb.default_value is string %}
"{{ req_field.field_pb.default_value }}"
{% else %}
{{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
),
{% endif %}{# is primitive #}
{% endfor %}{# required fields #}
]

for field_name, default_value in required_fields:
if field_name not in message_dict:
message_dict[field_name] = default_value

{% endif %}{# required fields #}

def _{{method.name | snake_case}}(self,
request: {{method.input.ident}}, *,
retry: OptionalRetry=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -254,16 +279,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
))

{% if method.input.required_fields %}
# Ensure required fields have values in query_params.
# If a required field has a default value, it can get lost
# by the to_json call above.
orig_query_params = transcoded_request["query_params"]
for snake_case_name, camel_case_name in required_fields:
if snake_case_name in orig_query_params:
if camel_case_name not in query_params:
query_params[camel_case_name] = orig_query_params[snake_case_name]

{% endif %}
self._{{ method.name | snake_case }}_populate_required_fields(query_params)
{% endif %}{# required fields #}

# Send the request
headers = dict(metadata)
Expand Down
Expand Up @@ -7,6 +7,7 @@ import mock

import grpc
from grpc.experimental import aio
import json
import math
import pytest
from proto.marshal.rules.dates import DurationRule, TimestampRule
Expand Down Expand Up @@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %}
{# Cheeser assertion to force code coverage for bad paginated methods #}
assert response.raw_page is response

{% endif %}

# Establish that the response is the type that we expect.
Expand All @@ -1210,6 +1212,61 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
{% endif %}


{% if method.input.required_fields %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to have a test or two just for the hidden required fields update method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean. Which method is hidden?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, once we refactor as you describe above.
This is quite tricky to test, which is a lot of motivation for these changes. The problem is that typically, though not always, a required field is going to have an expected template in the http rule, so the default value will cause the transcoding to fail. This can be worked around by mocking the transcoding function, but it gets convoluted and ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what to do here. No code is un-covered. I would argue that testing this logic in the context of the api method itself doesn't necessarily add anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've added logic to test the actual api method with default-valued required fields.

def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}):
transport_class = transports.{{ service.rest_transport_name }}

request_init = {}
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% if req_field.field_pb.default_value is string %}
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
{% else %}
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
{% endif %}{# default is str #}
{% endfor %}
request = request_type(request_init)
jsonified_request = json.loads(request_type.to_json(
request,
including_default_value_fields=False,
use_integers_for_enums=False
))

# verify fields with default values are dropped
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" not in jsonified_request
{% endfor %}

transport_class._{{ method.name | snake_case }}_populate_required_fields(jsonified_request)

# verify required fields with default values are now present
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"]
{% endfor %}

{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
jsonified_request["{{ field_name }}"] = {{ mock_value }}
{% endfor %}

transport_class._{{ method.name | snake_case }}_populate_required_fields(jsonified_request)

# verify required fields with non-default values are left alone
{% for req_field in method.input.required_fields if req_field.is_primitive %}
{% set field_name = req_field.name | camel_case %}
{% set mock_value = req_field.primitive_mock_as_str() %}
assert "{{ field_name }}" in jsonified_request
assert jsonified_request["{{ field_name }}"] == {{ mock_value }}
{% endfor %}



{% endif %}{# required_fields #}


def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand Down Expand Up @@ -1325,9 +1382,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'):


{% if method.paged_result_field %}
def test_{{ method_name }}_rest_pager():
def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Mock the http request call within the method and fake a response.
Expand Down Expand Up @@ -1446,25 +1504,35 @@ def test_{{ method_name }}_rest_error():
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
{%- if not method.http_options %}
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ('Cannot define a method without a valid `google.api.http` annotation.'
in str(runtime_error.value))
{%- else %}

# TODO(yon-mg): Remove when this method has a working implementation
# or testing straegy
with pytest.raises(NotImplementedError):
client.{{ method_name }}({})

{%- endif %}

{% endif %}{% endwith %}{# method_name #}
{% endif %}{# not streaming #}{% endwith %}{# method_name #}

{% endfor -%} {#- method in methods for rest #}

{% for method in service.methods.values() if 'rest' in opts.transport and
not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}
def test_{{ method_name }}_rest_error():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
transport='rest'
)
# Since a `google.api.http` annotation is required for using a rest transport
# method, this should error.
with pytest.raises(RuntimeError) as runtime_error:
client.{{ method_name }}({})
assert ("Cannot define a method without a valid 'google.api.http' annotation."
in str(runtime_error.value))


{% endwith %}{# method_name #}
{% endfor %}{# for methods without http_options #}

def test_credentials_transport_error():
# It is an error to provide credentials and a transport instance.
transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport(
Expand Down Expand Up @@ -1758,8 +1826,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl
mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback)


{# TODO(kbandes): re-enable this code when LRO is implmented for rest #}
{% if False and service.has_lro -%}
{% if service.has_lro -%}
def test_{{ service.name|snake_case }}_rest_lro_client():
client = {{ service.client_name }}(
credentials=ga_credentials.AnonymousCredentials(),
Expand All @@ -1770,7 +1837,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
# Ensure that we have a api-core operations client.
assert isinstance(
transport.operations_client,
operations_v1.OperationsClient,
operations_v1.AbstractOperationsClient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed?

Copy link
Contributor Author

@kbandes kbandes Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was at the request of software-dov. The change was made to api-core in a prior PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding related PRs: I'm not sure. I believe other people are making these changes and releasing them, I'm not sure of the timeline or PR numbers. The change to noxfile.py is small, but I won't know exactly how to do it until I know what the actual release numbers for showcase and api-core are.

)

# Ensure that subsequent calls to the property send the exact same object.
Expand Down