diff --git a/docs/messages.rst b/docs/messages.rst index 18dd4db6..af23ea9d 100644 --- a/docs/messages.rst +++ b/docs/messages.rst @@ -165,3 +165,13 @@ via the :meth:`~.Message.to_json` and :meth:`~.Message.from_json` methods. new_song = Song.from_json(json) +Similarly, messages can be converted into dictionaries via the +:meth:`~.Message.to_dict` helper method. +There is no :meth:`~.Message.from_dict` method because the Message constructor +already allows construction from mapping types. + +.. code-block:: python + + song_dict = Song.to_dict(song) + + new_song = Song(song_dict) diff --git a/docs/reference/message.rst b/docs/reference/message.rst index d9752b5d..34da8376 100644 --- a/docs/reference/message.rst +++ b/docs/reference/message.rst @@ -10,6 +10,7 @@ Message and Field .. automethod:: deserialize .. automethod:: to_json .. automethod:: from_json + .. automethod:: to_dict .. automodule:: proto.fields diff --git a/noxfile.py b/noxfile.py index 2f323155..3a2fff28 100644 --- a/noxfile.py +++ b/noxfile.py @@ -48,7 +48,7 @@ def unitcpp(session): return unit(session, proto="cpp") -@nox.session(python="3.6") +@nox.session(python="3.7") def docs(session): """Build the docs.""" diff --git a/proto/message.py b/proto/message.py index 9de5630a..f1c3bbbe 100644 --- a/proto/message.py +++ b/proto/message.py @@ -20,7 +20,7 @@ from google.protobuf import descriptor_pb2 from google.protobuf import message -from google.protobuf.json_format import MessageToJson, Parse +from google.protobuf.json_format import MessageToDict, MessageToJson, Parse from proto import _file_info from proto import _package_info @@ -347,21 +347,45 @@ def to_json(cls, instance, *, use_integers_for_enums=True) -> str: including_default_value_fields=True, ) - def from_json(cls, payload) -> "Message": + def from_json(cls, payload, *, ignore_unknown_fields=False) -> "Message": """Given a json string representing an instance, parse it into a message. Args: paylod: A json string representing a message. + ignore_unknown_fields (Optional(bool)): If True, do not raise errors + for unknown fields. Returns: ~.Message: An instance of the message class against which this method was called. """ instance = cls() - Parse(payload, instance._pb) + Parse(payload, instance._pb, ignore_unknown_fields=ignore_unknown_fields) return instance + def to_dict(cls, instance, *, use_integers_for_enums=True) -> "Message": + """Given a message instance, return its representation as a python dict. + + Args: + instance: An instance of this message type, or something + compatible (accepted by the type's constructor). + use_integers_for_enums (Optional(bool)): An option that determines whether enum + values should be represented by strings (False) or integers (True). + Default is True. + + Returns: + dict: A representation of the protocol buffer using pythonic data structures. + Messages and map fields are represented as dicts, + repeated fields are represented as lists. + """ + return MessageToDict( + cls.pb(instance), + including_default_value_fields=True, + preserving_proto_field_name=True, + use_integers_for_enums=use_integers_for_enums, + ) + class Message(metaclass=MessageMeta): """The abstract base class for a message. @@ -369,17 +393,19 @@ class Message(metaclass=MessageMeta): Args: mapping (Union[dict, ~.Message]): A dictionary or message to be used to determine the values for this message. + ignore_unknown_fields (Optional(bool)): If True, do not raise errors for + unknown fields. Only applied if `mapping` is a mapping type or there + are keyword parameters. kwargs (dict): Keys and values corresponding to the fields of the message. """ - def __init__(self, mapping=None, **kwargs): + def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): # We accept several things for `mapping`: # * An instance of this class. # * An instance of the underlying protobuf descriptor class. # * A dict # * Nothing (keyword arguments only). - if mapping is None: if not kwargs: # Special fast path for empty construction. @@ -405,24 +431,33 @@ def __init__(self, mapping=None, **kwargs): # Just use the above logic on mapping's underlying pb. self.__init__(mapping=mapping._pb, **kwargs) return - elif not isinstance(mapping, collections.abc.Mapping): + elif isinstance(mapping, collections.abc.Mapping): + # Can't have side effects on mapping. + mapping = copy.copy(mapping) + # kwargs entries take priority for duplicate keys. + mapping.update(kwargs) + else: # Sanity check: Did we get something not a map? Error if so. raise TypeError( "Invalid constructor input for %s: %r" % (self.__class__.__name__, mapping,) ) - else: - # Can't have side effects on mapping. - mapping = copy.copy(mapping) - # kwargs entries take priority for duplicate keys. - mapping.update(kwargs) params = {} # Update the mapping to address any values that need to be # coerced. marshal = self._meta.marshal for key, value in mapping.items(): - pb_type = self._meta.fields[key].pb_type + try: + pb_type = self._meta.fields[key].pb_type + except KeyError: + if ignore_unknown_fields: + continue + + raise ValueError( + "Unknown field for {}: {}".format(self.__class__.__name__, key) + ) + pb_value = marshal.to_proto(pb_type, value) if pb_value is not None: params[key] = pb_value diff --git a/tests/test_json.py b/tests/test_json.py index 58195e89..34c2d438 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -15,7 +15,7 @@ import pytest import proto -from google.protobuf.json_format import MessageToJson, Parse +from google.protobuf.json_format import MessageToJson, Parse, ParseError def test_message_to_json(): @@ -34,7 +34,7 @@ class Squid(proto.Message): json = """{ "massKg": 100 - } + } """ s = Squid.from_json(json) @@ -95,3 +95,21 @@ class Zone(proto.Enum): .replace("\n", "") ) assert json2 == '{"zone":"EPIPELAGIC"}' + + +def test_json_unknown_field(): + # Note that 'lengthCm' is unknown in the local definition. + # This could happen if the client is using an older proto definition + # than the server. + json_str = '{\n "massKg": 20,\n "lengthCm": 100\n}' + + class Octopus(proto.Message): + mass_kg = proto.Field(proto.INT32, number=1) + + o = Octopus.from_json(json_str, ignore_unknown_fields=True) + assert not hasattr(o, "length_cm") + assert not hasattr(o, "lengthCm") + + # Don't permit unknown fields by default + with pytest.raises(ParseError): + o = Octopus.from_json(json_str) diff --git a/tests/test_message.py b/tests/test_message.py index b5f59857..32a61802 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import pytest import proto @@ -228,3 +229,69 @@ class Squid(proto.Message): s1._pb = s2._pb assert s1.mass_kg == 20 + + +def test_serialize_to_dict(): + class Squid(proto.Message): + # Test primitives, enums, and repeated fields. + class Chromatophore(proto.Message): + class Color(proto.Enum): + UNKNOWN = 0 + RED = 1 + BROWN = 2 + WHITE = 3 + BLUE = 4 + + color = proto.Field(Color, number=1) + + mass_kg = proto.Field(proto.INT32, number=1) + chromatophores = proto.RepeatedField(Chromatophore, number=2) + + s = Squid(mass_kg=20) + colors = ["RED", "BROWN", "WHITE", "BLUE"] + s.chromatophores = [ + {"color": c} for c in itertools.islice(itertools.cycle(colors), 10) + ] + + s_dict = Squid.to_dict(s) + assert s_dict["chromatophores"][0]["color"] == 1 + + new_s = Squid(s_dict) + assert new_s == s + + s_dict = Squid.to_dict(s, use_integers_for_enums=False) + assert s_dict["chromatophores"][0]["color"] == "RED" + + new_s = Squid(s_dict) + assert new_s == s + + +def test_unknown_field_deserialize(): + # This is a somewhat common setup: a client uses an older proto definition, + # while the server sends the newer definition. The client still needs to be + # able to interact with the protos it receives from the server. + + class Octopus_Old(proto.Message): + mass_kg = proto.Field(proto.INT32, number=1) + + class Octopus_New(proto.Message): + mass_kg = proto.Field(proto.INT32, number=1) + length_cm = proto.Field(proto.INT32, number=2) + + o_new = Octopus_New(mass_kg=20, length_cm=100) + o_ser = Octopus_New.serialize(o_new) + + o_old = Octopus_Old.deserialize(o_ser) + assert not hasattr(o_old, "length_cm") + + +def test_unknown_field_from_dict(): + class Squid(proto.Message): + mass_kg = proto.Field(proto.INT32, number=1) + + # By default we don't permit unknown fields + with pytest.raises(ValueError): + s = Squid({"mass_kg": 20, "length_cm": 100}) + + s = Squid({"mass_kg": 20, "length_cm": 100}, ignore_unknown_fields=True) + assert not hasattr(s, "length_cm")