Skip to content

Commit

Permalink
Refector the polars type engine to use LazyFrames, add support for Ne…
Browse files Browse the repository at this point in the history
…sted types (#1526)

* refactor polars type engine for LazyFrames, add nested types

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* add polars engine in the container/component backends

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* add tests and docs for nested types

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix mypy

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix pylint, mypy

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* handle annotated types

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* update docs

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* update nested dtype docstrings

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

---------

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed Mar 15, 2024
1 parent a6270df commit 7d1b1ba
Show file tree
Hide file tree
Showing 17 changed files with 571 additions and 369 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Expand Up @@ -199,7 +199,7 @@

# this is a workaround to filter out forward reference issue in
# sphinx_autodoc_typehints
class FilterPandasTypeAnnotationWarning(pylogging.Filter):
class FilterTypeAnnotationWarnings(pylogging.Filter):
def filter(self, record: pylogging.LogRecord) -> bool:
# You probably should make this check more specific by checking
# that dataclass name is in the message, so that you don't filter out
Expand All @@ -225,7 +225,7 @@ def filter(self, record: pylogging.LogRecord) -> bool:


logging.getLogger("sphinx_autodoc_typehints").logger.addFilter(
FilterPandasTypeAnnotationWarning()
FilterTypeAnnotationWarnings()
)


Expand Down
51 changes: 47 additions & 4 deletions docs/source/polars.rst
Expand Up @@ -322,10 +322,10 @@ present in the data.
Supported Data Types
--------------------

``pandera`` currently supports all the `scalar data types <https://docs.pola.rs/py-polars/html/reference/datatypes.html>`__.
`Nested data types <https://docs.pola.rs/py-polars/html/reference/datatypes.html#nested>`__
are not yet supported. Built-in python types like ``str``, ``int``, ``float``,
and ``bool`` will be handled in the same way that ``polars`` handles them:
``pandera`` currently supports all of the
`polars data types <https://docs.pola.rs/py-polars/html/reference/datatypes.html>`__.
Built-in python types like ``str``, ``int``, ``float``, and ``bool`` will be
handled in the same way that ``polars`` handles them:

.. testcode:: polars

Expand All @@ -351,6 +351,49 @@ So the following schemas are equivalent:

assert schema1 == schema2

Nested Types
^^^^^^^^^^^^

Polars nested datetypes are also supported via :ref:`parameterized data types <parameterized dtypes>`.
See the examples below for the different ways to specify this through the
object-based and class-based APIs:

.. tabbed:: DataFrameSchema

.. testcode:: polars

schema = pa.DataFrameSchema(
{
"list_col": pa.Column(pl.List(pl.Int64())),
"array_col": pa.Column(pl.Array(pl.Int64(), 3)),
"struct_col": pa.Column(pl.Struct({"a": pl.Utf8(), "b": pl.Float64()})),
},
)

.. tabbed:: DataFrameModel (Annotated)

.. testcode:: polars

try:
from typing import Annotated # python 3.9+
except ImportError:
from typing_extensions import Annotated

class ModelWithAnnotated(pa.DataFrameModel):
list_col: Annotated[pl.List, pl.Int64()]
array_col: Annotated[pl.Array, pl.Int64(), 3]
struct_col: Annotated[pl.Struct, {"a": pl.Utf8(), "b": pl.Float64()}]

.. tabbed:: DataFrameModel (Field)

.. testcode:: polars

class ModelWithDtypeKwargs(pa.DataFrameModel):
list_col: pl.List = pa.Field(dtype_kwargs={"inner": pl.Int64()})
array_col: pl.Array = pa.Field(dtype_kwargs={"inner": pl.Int64(), "width": 3})
struct_col: pl.Struct = pa.Field(dtype_kwargs={"fields": {"a": pl.Utf8(), "b": pl.Float64()}})


Custom checks
-------------

Expand Down
35 changes: 35 additions & 0 deletions docs/source/reference/dtypes.rst
Expand Up @@ -92,6 +92,41 @@ Pydantic Dtypes

pandera.engines.pandas_engine.PydanticModel

Polars Dtypes
-------------

*new in 0.19.0*

.. autosummary::
:toctree: generated
:template: dtype.rst
:nosignatures:

pandera.engines.polars_engine.Int8
pandera.engines.polars_engine.Int16
pandera.engines.polars_engine.Int32
pandera.engines.polars_engine.Int64
pandera.engines.polars_engine.UInt8
pandera.engines.polars_engine.UInt16
pandera.engines.polars_engine.UInt32
pandera.engines.polars_engine.UInt64
pandera.engines.polars_engine.Float32
pandera.engines.polars_engine.Float64
pandera.engines.polars_engine.Decimal
pandera.engines.polars_engine.Date
pandera.engines.polars_engine.DateTime
pandera.engines.polars_engine.Time
pandera.engines.polars_engine.Timedelta
pandera.engines.polars_engine.Array
pandera.engines.polars_engine.List
pandera.engines.polars_engine.Struct
pandera.engines.polars_engine.Bool
pandera.engines.polars_engine.String
pandera.engines.polars_engine.Categorical
pandera.engines.polars_engine.Category
pandera.engines.polars_engine.Null
pandera.engines.polars_engine.Object


Utility functions
-----------------
Expand Down
3 changes: 2 additions & 1 deletion pandera/api/pandas/model.py
Expand Up @@ -95,7 +95,8 @@ def _build_columns_index( # pylint:disable=too-many-locals
dtype = None if dtype is Any else dtype

if (
annotation.origin is None
annotation.is_annotated_type
or annotation.origin is None
or annotation.origin in SERIES_TYPES
or annotation.raw_annotation in SERIES_TYPES
):
Expand Down
4 changes: 3 additions & 1 deletion pandera/api/polars/model.py
Expand Up @@ -68,7 +68,9 @@ def _build_columns( # pylint:disable=too-many-locals

dtype = None if dtype is Any else dtype

if annotation.origin is None:
if annotation.origin is None or isinstance(
annotation.origin, pl.datatypes.DataTypeClass
):
if check_name is False:
raise SchemaInitError(
f"'check_name' is not supported for {field_name}."
Expand Down
11 changes: 4 additions & 7 deletions pandera/backends/polars/components.py
Expand Up @@ -11,6 +11,7 @@
from pandera.backends.polars.base import PolarsSchemaBackend, is_float_dtype
from pandera.config import ValidationScope
from pandera.errors import (
ParserError,
SchemaDefinitionError,
SchemaError,
SchemaErrors,
Expand Down Expand Up @@ -153,12 +154,8 @@ def coerce_dtype(
return check_obj

try:
return (
check_obj.cast({schema.selector: schema.dtype.type})
.collect()
.lazy()
)
except (pl.ComputeError, pl.InvalidOperationError) as exc:
return schema.dtype.try_coerce(check_obj)
except ParserError as exc:
raise SchemaError(
schema=schema,
data=check_obj,
Expand Down Expand Up @@ -299,7 +296,7 @@ def check_dtype(
obj_dtype = check_obj_subset.schema[column]
results.append(
CoreCheckResult(
passed=obj_dtype.is_(schema.dtype.type),
passed=schema.dtype.check(obj_dtype),
check=f"dtype('{schema.dtype}')",
reason_code=SchemaErrorReason.WRONG_DATATYPE,
message=(
Expand Down
19 changes: 10 additions & 9 deletions pandera/backends/polars/container.py
Expand Up @@ -8,10 +8,12 @@

from pandera.api.base.error_handler import ErrorHandler
from pandera.api.polars.container import DataFrameSchema
from pandera.api.polars.types import PolarsData
from pandera.backends.base import CoreCheckResult, ColumnInfo
from pandera.backends.polars.base import PolarsSchemaBackend
from pandera.config import ValidationScope
from pandera.errors import (
ParserError,
SchemaError,
SchemaErrors,
SchemaErrorReason,
Expand Down Expand Up @@ -388,16 +390,15 @@ def _coerce_dtype_helper(
"""
error_handler = ErrorHandler(lazy=True)

if schema.dtype is not None:
obj = obj.cast(schema.dtype.type)
else:
obj = obj.cast(
{k: v.dtype.type for k, v in schema.columns.items()}
)

try:
obj = obj.collect().lazy()
except pl.exceptions.ComputeError as exc:
if schema.dtype is not None:
obj = schema.dtype.try_coerce(obj)
else:
for col_schema in schema.columns.values():
obj = col_schema.dtype.try_coerce(
PolarsData(obj, col_schema.selector)
)
except (ParserError, pl.ComputeError) as exc:
error_handler.collect_error(
validation_type(SchemaErrorReason.DATATYPE_COERCION),
SchemaErrorReason.DATATYPE_COERCION,
Expand Down
2 changes: 1 addition & 1 deletion pandera/dtypes.py
Expand Up @@ -413,7 +413,7 @@ class Decimal(_Number):
"""The number of digits after the decimal point."""

# pylint: disable=line-too-long
rounding: str = dataclasses.field(
rounding: Optional[str] = dataclasses.field(
default_factory=lambda: decimal.getcontext().rounding
)
"""
Expand Down
10 changes: 9 additions & 1 deletion pandera/engines/__init__.py
@@ -1,6 +1,14 @@
"""Pandera type engines."""

from pandera.engines.utils import pydantic_version
import pydantic

from packaging import version


def pydantic_version():
"""Return the pydantic version."""

return version.parse(pydantic.__version__)


PYDANTIC_V2 = pydantic_version().release >= (2, 0, 0)

0 comments on commit 7d1b1ba

Please sign in to comment.