From 0d2a88d8072154cfc9152afd6d26a60ddcdfbc73 Mon Sep 17 00:00:00 2001 From: HemangChothani <50404902+HemangChothani@users.noreply.github.com> Date: Mon, 13 Jul 2020 18:54:47 +0530 Subject: [PATCH] feat(bigquery): add __eq__ method for class PartitionRange and RangePartitioning (#162) * feat(bigquery): add __eq__ method for class PartitionRange and RangePartitioning * feat(bigquery): change class object to unhashable * feat(bigquery): change the assertion --- google/cloud/bigquery/table.py | 20 +++++++++ tests/unit/test_table.py | 82 ++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 5766f5fbe..f1575ffb2 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1891,10 +1891,20 @@ def interval(self, value): def _key(self): return tuple(sorted(self._properties.items())) + def __eq__(self, other): + if not isinstance(other, PartitionRange): + return NotImplemented + return self._key() == other._key() + + def __ne__(self, other): + return not self == other + def __repr__(self): key_vals = ["{}={}".format(key, val) for key, val in self._key()] return "PartitionRange({})".format(", ".join(key_vals)) + __hash__ = None + class RangePartitioning(object): """Range-based partitioning configuration for a table. @@ -1961,10 +1971,20 @@ def field(self, value): def _key(self): return (("field", self.field), ("range_", self.range_)) + def __eq__(self, other): + if not isinstance(other, RangePartitioning): + return NotImplemented + return self._key() == other._key() + + def __ne__(self, other): + return not self == other + def __repr__(self): key_vals = ["{}={}".format(key, repr(val)) for key, val in self._key()] return "RangePartitioning({})".format(", ".join(key_vals)) + __hash__ = None + class TimePartitioningType(object): """Specifies the type of time partitioning to perform.""" diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 94a326617..3aabebb77 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -3525,6 +3525,37 @@ def test_constructor_w_resource(self): assert object_under_test.end == 1234567890 assert object_under_test.interval == 1000000 + def test___eq___start_mismatch(self): + object_under_test = self._make_one(start=1, end=10, interval=2) + other = self._make_one(start=2, end=10, interval=2) + self.assertNotEqual(object_under_test, other) + + def test___eq___end__mismatch(self): + object_under_test = self._make_one(start=1, end=10, interval=2) + other = self._make_one(start=1, end=11, interval=2) + self.assertNotEqual(object_under_test, other) + + def test___eq___interval__mismatch(self): + object_under_test = self._make_one(start=1, end=10, interval=2) + other = self._make_one(start=1, end=11, interval=3) + self.assertNotEqual(object_under_test, other) + + def test___eq___hit(self): + object_under_test = self._make_one(start=1, end=10, interval=2) + other = self._make_one(start=1, end=10, interval=2) + self.assertEqual(object_under_test, other) + + def test__eq___type_mismatch(self): + object_under_test = self._make_one(start=1, end=10, interval=2) + self.assertNotEqual(object_under_test, object()) + self.assertEqual(object_under_test, mock.ANY) + + def test_unhashable_object(self): + object_under_test1 = self._make_one(start=1, end=10, interval=2) + + with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"): + hash(object_under_test1) + def test_repr(self): object_under_test = self._make_one(start=1, end=10, interval=2) assert repr(object_under_test) == "PartitionRange(end=10, interval=2, start=1)" @@ -3574,6 +3605,57 @@ def test_range_w_wrong_type(self): with pytest.raises(ValueError, match="PartitionRange"): object_under_test.range_ = object() + def test___eq___field_mismatch(self): + from google.cloud.bigquery.table import PartitionRange + + object_under_test = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + other = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="float_col" + ) + self.assertNotEqual(object_under_test, other) + + def test___eq___range__mismatch(self): + from google.cloud.bigquery.table import PartitionRange + + object_under_test = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + other = self._make_one( + range_=PartitionRange(start=2, end=20, interval=2), field="float_col" + ) + self.assertNotEqual(object_under_test, other) + + def test___eq___hit(self): + from google.cloud.bigquery.table import PartitionRange + + object_under_test = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + other = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + self.assertEqual(object_under_test, other) + + def test__eq___type_mismatch(self): + from google.cloud.bigquery.table import PartitionRange + + object_under_test = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + self.assertNotEqual(object_under_test, object()) + self.assertEqual(object_under_test, mock.ANY) + + def test_unhashable_object(self): + from google.cloud.bigquery.table import PartitionRange + + object_under_test1 = self._make_one( + range_=PartitionRange(start=1, end=10, interval=2), field="integer_col" + ) + with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"): + hash(object_under_test1) + def test_repr(self): from google.cloud.bigquery.table import PartitionRange