Skip to content

Commit

Permalink
bugfix: add index validation to SeriesSchema (#1524)
Browse files Browse the repository at this point in the history
* bugfix: add index validation to SeriesSchema

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix tests

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

---------

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed Mar 11, 2024
1 parent 6c11fbb commit 17c558f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 23 deletions.
10 changes: 10 additions & 0 deletions pandera/api/pandas/array.py
Expand Up @@ -469,6 +469,16 @@ def validate( # type: ignore [override]
lazy=lazy,
inplace=inplace,
)
if self.index is not None:
validated_obj = self.index.validate(
check_obj,
head=head,
tail=tail,
sample=sample,
random_state=random_state,
lazy=lazy,
inplace=inplace,
)
return cast(pd.Series, validated_obj)

def example(self, size=None) -> pd.Series:
Expand Down
1 change: 1 addition & 0 deletions pandera/backends/pandas/__init__.py
Expand Up @@ -72,6 +72,7 @@
for t in series_datatypes:
SeriesSchema.register_backend(t, SeriesSchemaBackend)
Column.register_backend(t, ColumnBackend)
MultiIndex.register_backend(t, MultiIndexBackend)
Index.register_backend(t, IndexBackend)

for t in index_datatypes:
Expand Down
44 changes: 31 additions & 13 deletions pandera/backends/pandas/components.py
Expand Up @@ -275,28 +275,46 @@ def validate(
reason_code=SchemaErrorReason.MISMATCH_INDEX,
)

error_handler = ErrorHandler(lazy)

if schema.coerce:
check_obj.index = schema.coerce_dtype(check_obj.index)
obj_to_validate = schema.dtype.coerce(
check_obj.index.to_series().reset_index(drop=True)
)
else:
obj_to_validate = check_obj.index.to_series().reset_index(
drop=True
)
try:
check_obj.index = schema.coerce_dtype(check_obj.index)
except SchemaError as exc:
error_handler.collect_error(
validation_type(exc.reason_code),
exc.reason_code,
exc,
)

assert is_field(
super().validate(
obj_to_validate,
try:
_validated_obj = super().validate(
check_obj.index.to_series().reset_index(drop=True),
schema,
head=head,
tail=tail,
sample=sample,
random_state=random_state,
lazy=lazy,
inplace=inplace,
),
)
)
assert is_field(_validated_obj)
except SchemaError as exc:
error_handler.collect_error(
validation_type(exc.reason_code),
exc.reason_code,
exc,
)
except SchemaErrors as exc:
error_handler.collect_errors(exc.schema_errors, exc)

if lazy and error_handler.collected_errors:
raise SchemaErrors(
schema=schema,
schema_errors=error_handler.schema_errors,
data=check_obj,
)

return check_obj


Expand Down
39 changes: 29 additions & 10 deletions tests/core/test_schemas.py
Expand Up @@ -667,6 +667,27 @@ def test_series_schema_with_index(coerce: bool) -> None:
assert (validated_series_multiindex.index == multi_index).all()


def test_series_schema_with_index_errors() -> None:
"""Test that SeriesSchema raises errors for invalid index."""
schema_with_index = SeriesSchema(dtype=int, index=Index(int))
data = pd.Series([1, 2, 3], index=[1.0, 2.0, 3.0])
with pytest.raises(errors.SchemaError):
schema_with_index(data)

schema_with_index_check = SeriesSchema(
dtype=int, index=Index(float, Check(lambda x: x == 1.0))
)
with pytest.raises(errors.SchemaError):
schema_with_index_check(data)

schema_with_index_coerce = SeriesSchema(
dtype=int, index=Index(int, coerce=True)
)
expected = pd.Series([1, 2, 3], index=[1, 2, 3])
schema_with_index_coerce(data)
assert schema_with_index_coerce(data).equals(expected)


class SeriesGreaterCheck:
# pylint: disable=too-few-public-methods
"""Class creating callable objects to check if series elements exceed a
Expand Down Expand Up @@ -1622,9 +1643,9 @@ def test_lazy_dataframe_unique() -> None:
Index(str, checks=Check.isin(["a", "b", "c"])),
pd.DataFrame({"col": [1, 2, 3]}, index=["a", "b", "d"]),
{
# expect that the data in the SchemaError is the pd.Index cast
# into a Series
"data": pd.Series(["a", "b", "d"]),
"data": pd.DataFrame(
{"col": [1, 2, 3]}, index=["a", "b", "d"]
),
"schema_errors": {
"Index": {"isin(['a', 'b', 'c'])": ["d"]},
},
Expand All @@ -1645,8 +1666,6 @@ def test_lazy_dataframe_unique() -> None:
),
),
{
# expect that the data in the SchemaError is the pd.MultiIndex
# cast into a DataFrame
"data": pd.DataFrame(
{"column": [1, 2, 3]},
index=pd.MultiIndex.from_arrays(
Expand Down Expand Up @@ -1724,12 +1743,12 @@ def fail_without_msg(data):
@pytest.mark.parametrize(
"from_dtype,to_dtype",
[
# [float, int],
# [int, float],
# [object, int],
# [object, float],
[float, int],
[int, float],
[object, int],
[object, float],
[int, object],
# [float, object],
[float, object],
],
)
def test_schema_coerce_inplace_validation(
Expand Down

0 comments on commit 17c558f

Please sign in to comment.