diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index ba3ab41185..64aa67213b 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -166,6 +166,22 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endif %}{# service.has_lro #} {% for method in service.methods.values() %} {%- if method.http_options and not (method.server_streaming or method.client_streaming) %} + + {% if method.input.required_fields %} + __{{ method.name | snake_case }}_required_fields_default_values = { + {% for req_field in method.input.required_fields if req_field.is_primitive %} + "{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %}{# default is str #} + {% endfor %} + } + + + @staticmethod + def _{{ method.name | snake_case }}_get_unset_required_fields(message_dict): + return {k: v for k, v in {{service.name}}RestTransport.__{{ method.name | snake_case }}_required_fields_default_values.items() if k not in message_dict} + + + {% endif %}{# required fields #} + def _{{method.name | snake_case}}(self, request: {{method.input.ident}}, *, retry: OptionalRetry=gapic_v1.method.DEFAULT, @@ -206,21 +222,6 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endfor %} ] - {% if method.input.required_fields %} - required_fields = [ - # (snake_case_name, camel_case_name) - {% for req_field in method.input.required_fields %} - {% if req_field.is_primitive %} - ( - "{{ req_field.name | snake_case }}", - "{{ req_field.name | camel_case }}" - ), - {% endif %}{# is primitive #} - {% endfor %}{# required fields #} - ] - - {% endif %} - request_kwargs = {{method.input.ident}}.to_dict(request) transcoded_request = path_template.transcode( http_options, **request_kwargs) @@ -254,16 +255,8 @@ class {{service.name}}RestTransport({{service.name}}Transport): )) {% if method.input.required_fields %} - # Ensure required fields have values in query_params. - # If a required field has a default value, it can get lost - # by the to_json call above. - orig_query_params = transcoded_request["query_params"] - for snake_case_name, camel_case_name in required_fields: - if snake_case_name in orig_query_params: - if camel_case_name not in query_params: - query_params[camel_case_name] = orig_query_params[snake_case_name] - - {% endif %} + query_params.update(self._{{ method.name | snake_case }}_get_unset_required_fields(query_params)) + {% endif %}{# required fields #} # Send the request headers = dict(metadata) diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index a6d1bd2562..b31ff4df9b 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -7,6 +7,7 @@ import mock import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule @@ -1187,6 +1188,7 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method. {% if "next_page_token" in method.output.fields.values()|map(attribute='name') and not method.paged_result_field %} {# Cheeser assertion to force code coverage for bad paginated methods #} assert response.raw_page is response + {% endif %} # Establish that the response is the type that we expect. @@ -1210,6 +1212,130 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method. {% endif %} + {% if method.input.required_fields %} +def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ident }}): + transport_class = transports.{{ service.rest_transport_name }} + + request_init = {} + {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% if req_field.field_pb.default_value is string %} + request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}" + {% else %} + request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }} + {% endif %}{# default is str #} + {% endfor %} + request = request_type(request_init) + jsonified_request = json.loads(request_type.to_json( + request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% set field_name = req_field.name | camel_case %} + assert "{{ field_name }}" not in jsonified_request + {% endfor %} + + unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% set field_name = req_field.name | camel_case %} + assert "{{ field_name }}" in jsonified_request + assert jsonified_request["{{ field_name }}"] == request_init["{{ req_field.name }}"] + {% endfor %} + + {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% set field_name = req_field.name | camel_case %} + {% set mock_value = req_field.primitive_mock_as_str() %} + jsonified_request["{{ field_name }}"] = {{ mock_value }} + {% endfor %} + + unset_fields = transport_class._{{ method.name | snake_case }}_get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + {% for req_field in method.input.required_fields if req_field.is_primitive %} + {% set field_name = req_field.name | camel_case %} + {% set mock_value = req_field.primitive_mock_as_str() %} + assert "{{ field_name }}" in jsonified_request + assert jsonified_request["{{ field_name }}"] == {{ mock_value }} + {% endfor %} + + + client = {{ service.client_name }}( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(request_init) + + # Designate an appropriate value for the returned response. + {% if method.void %} + return_value = None + {% elif method.lro %} + return_value = operations_pb2.Operation(name='operations/spam') + {% elif method.server_streaming %} + return_value = iter([{{ method.output.ident }}()]) + {% else %} + return_value = {{ method.output.ident }}() + {% endif %} + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "{{ method.http_options[0].method }}", + 'query_params': request_init, + } + {% if method.http_options[0].body %} + transcode_result['body'] = {} + {% endif %} + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + {% if method.void %} + json_return_value = '' + {% elif method.lro %} + json_return_value = json_format.MessageToJson(return_value) + {% else %} + json_return_value = {{ method.output.ident }}.to_json(return_value) + {% endif %} + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + {% if method.client_streaming %} + response = client.{{ method.name|snake_case }}(iter(requests)) + {% else %} + response = client.{{ method_name }}(request) + {% endif %} + + expected_params = [ + {% for req_field in method.input.required_fields if req_field.is_primitive %} + ( + "{{ req_field.name }}", + {% if req_field.field_pb.default_value is string %} + "{{ req_field.field_pb.default_value }}" + {% else %} + {{ req_field.field_pb.default_value }} + {% endif %}{# default is str #} + ) + {% endfor %} + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + + {% endif %}{# required_fields #} + + def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}): client = {{ service.client_name }}( credentials=ga_credentials.AnonymousCredentials(), @@ -1325,9 +1451,10 @@ def test_{{ method_name }}_rest_flattened_error(transport: str = 'rest'): {% if method.paged_result_field %} -def test_{{ method_name }}_rest_pager(): +def test_{{ method_name }}_rest_pager(transport: str = 'rest'): client = {{ service.client_name }}( credentials=ga_credentials.AnonymousCredentials(), + transport=transport, ) # Mock the http request call within the method and fake a response. @@ -1446,25 +1573,35 @@ def test_{{ method_name }}_rest_error(): credentials=ga_credentials.AnonymousCredentials(), transport='rest' ) - {%- if not method.http_options %} - # Since a `google.api.http` annotation is required for using a rest transport - # method, this should error. - with pytest.raises(RuntimeError) as runtime_error: - client.{{ method_name }}({}) - assert ('Cannot define a method without a valid `google.api.http` annotation.' - in str(runtime_error.value)) - {%- else %} # TODO(yon-mg): Remove when this method has a working implementation # or testing straegy with pytest.raises(NotImplementedError): client.{{ method_name }}({}) - {%- endif %} -{% endif %}{% endwith %}{# method_name #} +{% endif %}{# not streaming #}{% endwith %}{# method_name #} {% endfor -%} {#- method in methods for rest #} + +{% for method in service.methods.values() if 'rest' in opts.transport and + not method.http_options %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %} +def test_{{ method_name }}_rest_error(): + client = {{ service.client_name }}( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + # Since a `google.api.http` annotation is required for using a rest transport + # method, this should error. + with pytest.raises(RuntimeError) as runtime_error: + client.{{ method_name }}({}) + assert ("Cannot define a method without a valid 'google.api.http' annotation." + in str(runtime_error.value)) + + +{% endwith %}{# method_name #} +{% endfor %}{# for methods without http_options #} + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( @@ -1758,8 +1895,7 @@ def test_{{ service.name|snake_case }}_http_transport_client_cert_source_for_mtl mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) -{# TODO(kbandes): re-enable this code when LRO is implmented for rest #} -{% if False and service.has_lro -%} +{% if service.has_lro -%} def test_{{ service.name|snake_case }}_rest_lro_client(): client = {{ service.client_name }}( credentials=ga_credentials.AnonymousCredentials(), @@ -1770,7 +1906,7 @@ def test_{{ service.name|snake_case }}_rest_lro_client(): # Ensure that we have a api-core operations client. assert isinstance( transport.operations_client, - operations_v1.OperationsClient, + operations_v1.AbstractOperationsClient, ) # Ensure that subsequent calls to the property send the exact same object. diff --git a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py index 8b0f15491d..95115e9d14 100644 --- a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py +++ b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule diff --git a/tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py b/tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py index 54bebbee75..2bd38142f6 100644 --- a/tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py +++ b/tests/integration/goldens/credentials/tests/unit/gapic/credentials_v1/test_iam_credentials.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule diff --git a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_config_service_v2.py b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_config_service_v2.py index f8c963c4af..e8bd895fb8 100644 --- a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_config_service_v2.py +++ b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_config_service_v2.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule diff --git a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_logging_service_v2.py b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_logging_service_v2.py index 47cc2177ef..84911953c1 100644 --- a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_logging_service_v2.py +++ b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_logging_service_v2.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule diff --git a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_metrics_service_v2.py b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_metrics_service_v2.py index 7d8951e95a..9c97b20308 100644 --- a/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_metrics_service_v2.py +++ b/tests/integration/goldens/logging/tests/unit/gapic/logging_v2/test_metrics_service_v2.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule diff --git a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py index d6a487d8dc..8e60bc092e 100644 --- a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py +++ b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py @@ -18,6 +18,7 @@ import grpc from grpc.experimental import aio +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule