Skip to content

Commit

Permalink
[SPARK-36438][PYTHON] Support list-like Python objects for Series com…
Browse files Browse the repository at this point in the history
…parison

### What changes were proposed in this pull request?

This PR proposes to implement `Series` comparison with list-like Python objects.

Currently `Series` doesn't support the comparison to list-like Python objects such as `list`, `tuple`, `dict`, `set`.

**Before**

```python
>>> psser
0    1
1    2
2    3
dtype: int64

>>> psser == [3, 2, 1]
Traceback (most recent call last):
...
TypeError: The operation can not be applied to list.
...
```

**After**

```python
>>> psser
0    1
1    2
2    3
dtype: int64

>>> psser == [3, 2, 1]
0    False
1     True
2    False
dtype: bool
```

This was originally proposed in databricks/koalas#2022, and all reviews in origin PR has been resolved.

### Why are the changes needed?

To follow pandas' behavior.

### Does this PR introduce _any_ user-facing change?

Yes, the `Series` comparison with list-like Python objects now possible.

### How was this patch tested?

Unittests

Closes #34114 from itholic/SPARK-36438.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
itholic authored and HyukjinKwon committed Oct 13, 2021
1 parent f678c75 commit 46bcef7
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 6 deletions.
6 changes: 5 additions & 1 deletion python/pyspark/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,11 @@ def __abs__(self: IndexOpsLike) -> IndexOpsLike:

# comparison operators
def __eq__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
return self._dtype_op.eq(self, other)
# pandas always returns False for all items with dict and set.
if isinstance(other, (dict, set)):
return self != self
else:
return self._dtype_op.eq(self, other)

def __ne__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
return self._dtype_op.ne(self, other)
Expand Down
95 changes: 91 additions & 4 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,98 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
raise TypeError(">= can not be applied to %s." % self.pretty_name)

def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)
if isinstance(right, (list, tuple)):
from pyspark.pandas.series import first_series, scol_for
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField

len_right = len(right)
if len(left) != len(right):
raise ValueError("Lengths must be equal")

sdf = left._internal.spark_frame
structed_scol = F.struct(
sdf[NATURAL_ORDER_COLUMN_NAME],
*left._internal.index_spark_columns,
left.spark.column
)
# The size of the list is expected to be small.
collected_structed_scol = F.collect_list(structed_scol)
# Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order.
collected_structed_scol = F.array_sort(collected_structed_scol)
right_values_scol = F.array([F.lit(x) for x in right]) # type: ignore
index_scol_names = left._internal.index_spark_column_names
scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0])
# Compare the values of left and right by using zip_with function.
cond = F.zip_with(
collected_structed_scol,
right_values_scol,
lambda x, y: F.struct(
*[
x[index_scol_name].alias(index_scol_name)
for index_scol_name in index_scol_names
],
F.when(x[scol_name].isNull() | y.isNull(), False)
.otherwise(
x[scol_name] == y,
)
.alias(scol_name)
),
).alias(scol_name)
# 1. `sdf_new` here looks like the below (the first field of each set is Index):
# +----------------------------------------------------------+
# |0 |
# +----------------------------------------------------------+
# |[{0, false}, {1, true}, {2, false}, {3, true}, {4, false}]|
# +----------------------------------------------------------+
sdf_new = sdf.select(cond)
# 2. `sdf_new` after the explode looks like the below:
# +----------+
# | col|
# +----------+
# |{0, false}|
# | {1, true}|
# |{2, false}|
# | {3, true}|
# |{4, false}|
# +----------+
sdf_new = sdf_new.select(F.explode(scol_name))
# 3. Here, the final `sdf_new` looks like the below:
# +-----------------+-----+
# |__index_level_0__| 0|
# +-----------------+-----+
# | 0|false|
# | 1| true|
# | 2|false|
# | 3| true|
# | 4|false|
# +-----------------+-----+
sdf_new = sdf_new.select("col.*")

index_spark_columns = [
scol_for(sdf_new, index_scol_name) for index_scol_name in index_scol_names
]
data_spark_columns = [scol_for(sdf_new, scol_name)]

internal = left._internal.copy(
spark_frame=sdf_new,
index_spark_columns=index_spark_columns,
data_spark_columns=data_spark_columns,
index_fields=[
InternalField.from_struct_field(index_field)
for index_field in sdf_new.select(index_spark_columns).schema.fields
],
data_fields=[
InternalField.from_struct_field(
sdf_new.select(data_spark_columns).schema.fields[0]
)
],
)
return first_series(DataFrame(internal))
else:
from pyspark.pandas.base import column_op

return column_op(Column.__eq__)(left, right)
return column_op(Column.__eq__)(left, right)

def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def rfloordiv(self, other: Any) -> "Series":
koalas = CachedAccessor("koalas", PandasOnSparkSeriesMethods)

# Comparison Operators
def eq(self, other: Any) -> bool:
def eq(self, other: Any) -> "Series":
"""
Compare if the current value is equal to the other.
Expand Down
37 changes: 37 additions & 0 deletions python/pyspark/pandas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,29 @@ def _test_cov(self, pser1, pser2):
pscov = psser1.cov(psser2, min_periods=3)
self.assert_eq(pcov, pscov, almost=True)

def test_series_eq(self):
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
psser = ps.from_pandas(pser)

# other = Series
pandas_other = pd.Series([np.nan, 1, 3, 4, np.nan, 6], name="x")
pandas_on_spark_other = ps.from_pandas(pandas_other)
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())

# other = Series with different Index
pandas_other = pd.Series(
[np.nan, 1, 3, 4, np.nan, 6], index=[10, 20, 30, 40, 50, 60], name="x"
)
pandas_on_spark_other = ps.from_pandas(pandas_other)
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())

# other = Index
pandas_other = pd.Index([np.nan, 1, 3, 4, np.nan, 6], name="x")
pandas_on_spark_other = ps.from_pandas(pandas_other)
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())


class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
Expand Down Expand Up @@ -2039,6 +2062,20 @@ def test_combine_first(self):
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psdf1.combine_first(psdf2)

def test_series_eq(self):
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
psser = ps.from_pandas(pser)

others = (
ps.Series([np.nan, 1, 3, 4, np.nan, 6], name="x"),
ps.Index([np.nan, 1, 3, 4, np.nan, 6], name="x"),
)
for other in others:
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psser.eq(other)
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psser == other


if __name__ == "__main__":
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/pandas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,56 @@ def _test_cov(self, pdf):
pscov = psdf["s1"].cov(psdf["s2"], min_periods=4)
self.assert_eq(pcov, pscov, almost=True)

def test_eq(self):
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
psser = ps.from_pandas(pser)

# other = Series
self.assert_eq(pser.eq(pser), psser.eq(psser))
self.assert_eq(pser == pser, psser == psser)

# other = dict
other = {1: None, 2: None, 3: None, 4: None, np.nan: None, 6: None}
self.assert_eq(pser.eq(other), psser.eq(other))
self.assert_eq(pser == other, psser == other)

# other = set
other = {1, 2, 3, 4, np.nan, 6}
self.assert_eq(pser.eq(other), psser.eq(other))
self.assert_eq(pser == other, psser == other)

# other = list
other = [np.nan, 1, 3, 4, np.nan, 6]
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
self.assert_eq(pser == other, (psser == other).sort_index())
else:
self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())

# other = tuple
other = (np.nan, 1, 3, 4, np.nan, 6)
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
self.assert_eq(pser == other, (psser == other).sort_index())
else:
self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())

# other = list with the different length
other = [np.nan, 1, 3, 4, np.nan]
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
psser.eq(other)
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
psser == other

# other = tuple with the different length
other = (np.nan, 1, 3, 4, np.nan)
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
psser.eq(other)
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
psser == other


if __name__ == "__main__":
from pyspark.pandas.tests.test_series import * # noqa: F401
Expand Down

0 comments on commit 46bcef7

Please sign in to comment.