Skip to content

Commit

Permalink
feat(bigquery): add __eq__ method for class PartitionRange and RangeP…
Browse files Browse the repository at this point in the history
…artitioning (#162)

* feat(bigquery): add __eq__ method for class PartitionRange and RangePartitioning

* feat(bigquery): change class object to unhashable

* feat(bigquery): change the assertion
  • Loading branch information
HemangChothani committed Jul 13, 2020
1 parent dbaf3bd commit 0d2a88d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
20 changes: 20 additions & 0 deletions google/cloud/bigquery/table.py
Expand Up @@ -1891,10 +1891,20 @@ def interval(self, value):
def _key(self):
return tuple(sorted(self._properties.items()))

def __eq__(self, other):
if not isinstance(other, PartitionRange):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, val) for key, val in self._key()]
return "PartitionRange({})".format(", ".join(key_vals))

__hash__ = None


class RangePartitioning(object):
"""Range-based partitioning configuration for a table.
Expand Down Expand Up @@ -1961,10 +1971,20 @@ def field(self, value):
def _key(self):
return (("field", self.field), ("range_", self.range_))

def __eq__(self, other):
if not isinstance(other, RangePartitioning):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, repr(val)) for key, val in self._key()]
return "RangePartitioning({})".format(", ".join(key_vals))

__hash__ = None


class TimePartitioningType(object):
"""Specifies the type of time partitioning to perform."""
Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_table.py
Expand Up @@ -3525,6 +3525,37 @@ def test_constructor_w_resource(self):
assert object_under_test.end == 1234567890
assert object_under_test.interval == 1000000

def test___eq___start_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=2, end=10, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___end__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___interval__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=3)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=10, interval=2)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
object_under_test1 = self._make_one(start=1, end=10, interval=2)

with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
assert repr(object_under_test) == "PartitionRange(end=10, interval=2, start=1)"
Expand Down Expand Up @@ -3574,6 +3605,57 @@ def test_range_w_wrong_type(self):
with pytest.raises(ValueError, match="PartitionRange"):
object_under_test.range_ = object()

def test___eq___field_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___range__mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=2, end=20, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test1 = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
from google.cloud.bigquery.table import PartitionRange

Expand Down

0 comments on commit 0d2a88d

Please sign in to comment.