Skip to content

Commit

Permalink
fix: ensure enums are incomparable w other enum types (#248)
Browse files Browse the repository at this point in the history
Fixes #247.
  • Loading branch information
tseaver committed Sep 29, 2021
1 parent cb7e537 commit 5927c14
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
40 changes: 39 additions & 1 deletion proto/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,45 @@ def __new__(mcls, name, bases, attrs):
class Enum(enum.IntEnum, metaclass=ProtoEnumMeta):
"""A enum object that also builds a protobuf enum descriptor."""

pass
def _comparable(self, other):
# Avoid 'isinstance' to prevent other IntEnums from matching
return type(other) in (type(self), int)

def __eq__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value == int(other)

def __ne__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value != int(other)

def __lt__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value < int(other)

def __le__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value <= int(other)

def __ge__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value >= int(other)

def __gt__(self, other):
if not self._comparable(other):
return NotImplemented

return self.value > int(other)


class _EnumInfo:
Expand Down
28 changes: 28 additions & 0 deletions tests/enums_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import proto

__protobuf__ = proto.module(package="test.proto", manifest={"Enums",},)


class OneEnum(proto.Enum):
UNSPECIFIED = 0
SOME_VALUE = 1


class OtherEnum(proto.Enum):
UNSPECIFIED = 0
APPLE = 1
BANANA = 2
87 changes: 87 additions & 0 deletions tests/test_enum_total_ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2021, Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import enums_test


def test_total_ordering_w_same_enum_type():
to_compare = enums_test.OneEnum.SOME_VALUE

for item in enums_test.OneEnum:
if item.value < to_compare.value:
assert not to_compare == item
assert to_compare != item
assert not to_compare < item
assert not to_compare <= item
assert to_compare > item
assert to_compare >= item
elif item.value > to_compare.value:
assert not to_compare == item
assert to_compare != item
assert to_compare < item
assert to_compare <= item
assert not to_compare > item
assert not to_compare >= item
else: # item.value == to_compare.value:
assert to_compare == item
assert not to_compare != item
assert not to_compare < item
assert to_compare <= item
assert not to_compare > item
assert to_compare >= item


def test_total_ordering_w_other_enum_type():
to_compare = enums_test.OneEnum.SOME_VALUE

for item in enums_test.OtherEnum:
assert not to_compare == item
assert to_compare.SOME_VALUE != item
with pytest.raises(TypeError):
assert not to_compare < item
with pytest.raises(TypeError):
assert not to_compare <= item
with pytest.raises(TypeError):
assert not to_compare > item
with pytest.raises(TypeError):
assert not to_compare >= item


@pytest.mark.parametrize("int_val", range(-1, 3))
def test_total_ordering_w_int(int_val):
to_compare = enums_test.OneEnum.SOME_VALUE

if int_val < to_compare.value:
assert not to_compare == int_val
assert to_compare != int_val
assert not to_compare < int_val
assert not to_compare <= int_val
assert to_compare > int_val
assert to_compare >= int_val
elif int_val > to_compare.value:
assert not to_compare == int_val
assert to_compare != int_val
assert to_compare < int_val
assert to_compare <= int_val
assert not to_compare > int_val
assert not to_compare >= int_val
else: # int_val == to_compare.value:
assert to_compare == int_val
assert not to_compare != int_val
assert not to_compare < int_val
assert to_compare <= int_val
assert not to_compare > int_val
assert to_compare >= int_val

0 comments on commit 5927c14

Please sign in to comment.