Skip to content

Commit

Permalink
bugfix: nullable check float dtype handles nan and null (#1627)
Browse files Browse the repository at this point in the history
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed May 8, 2024
1 parent 0faae07 commit 63140c9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
5 changes: 2 additions & 3 deletions pandera/backends/polars/components.py
Expand Up @@ -207,10 +207,9 @@ def check_nullable(
)
]

expr = pl.col(schema.selector).is_not_null()
if is_float_dtype(check_obj, schema.selector):
expr = pl.col(schema.selector).is_not_nan()
else:
expr = pl.col(schema.selector).is_not_null()
expr = expr & pl.col(schema.selector).is_not_nan()

isna = check_obj.select(expr)
passed = isna.select([pl.col("*").all()]).collect()
Expand Down
8 changes: 3 additions & 5 deletions tests/polars/test_polars_components.py
Expand Up @@ -129,7 +129,7 @@ def test_coerce_dtype(data, from_dtype, to_dtype, exception_cls):
NULLABLE_DTYPES_AND_DATA = [
[pl.Int64, [1, 2, 3, None]],
[pl.Utf8, ["foo", "bar", "baz", None]],
[pl.Float64, [1.0, 2.0, 3.0, float("nan")]],
[pl.Float64, [1.0, 2.0, 3.0, float("nan"), None]],
[pl.Boolean, [True, False, True, None]],
]

Expand All @@ -138,7 +138,7 @@ def test_coerce_dtype(data, from_dtype, to_dtype, exception_cls):
@pytest.mark.parametrize("nullable", [True, False])
def test_check_nullable(dtype, data, nullable):
data = pl.LazyFrame({"column": pl.Series(data, dtype=dtype)})
column_schema = pa.Column(pl.Int64, nullable=nullable, name="column")
column_schema = pa.Column(dtype, nullable=nullable, name="column")
backend = ColumnBackend()
check_results: List[CoreCheckResult] = backend.check_nullable(
data, column_schema
Expand All @@ -153,9 +153,7 @@ def test_check_nullable_regex(dtype, data, nullable):
data = pl.LazyFrame(
{f"column_{i}": pl.Series(data, dtype=dtype) for i in range(3)}
)
column_schema = pa.Column(
pl.Int64, nullable=nullable, name=r"^column_\d+$"
)
column_schema = pa.Column(dtype, nullable=nullable, name=r"^column_\d+$")
backend = ColumnBackend()
check_results = backend.check_nullable(data, column_schema)
for result in check_results:
Expand Down

0 comments on commit 63140c9

Please sign in to comment.