Skip to content

Commit

Permalink
Backend registration happens at schema initialization (#1548)
Browse files Browse the repository at this point in the history
* make backend registration more robust

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* register backends on schema init

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* add joblib to dev environment

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix pandas unit test

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* update backend registration for pyspark

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

---------

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed Apr 1, 2024
1 parent e22db33 commit 58c5e45
Show file tree
Hide file tree
Showing 45 changed files with 244 additions and 118 deletions.
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt
Expand Up @@ -192,6 +192,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt
Expand Up @@ -194,6 +194,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas2.0.3-pydantic1.10.11.txt
Expand Up @@ -192,6 +192,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas2.0.3-pydantic2.3.0.txt
Expand Up @@ -194,6 +194,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas2.2.0-pydantic1.10.11.txt
Expand Up @@ -190,6 +190,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.10-pandas2.2.0-pydantic2.3.0.txt
Expand Up @@ -192,6 +192,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt
Expand Up @@ -186,6 +186,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt
Expand Up @@ -188,6 +188,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas2.0.3-pydantic1.10.11.txt
Expand Up @@ -186,6 +186,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas2.0.3-pydantic2.3.0.txt
Expand Up @@ -188,6 +188,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas2.2.0-pydantic1.10.11.txt
Expand Up @@ -184,6 +184,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.11-pandas2.2.0-pydantic2.3.0.txt
Expand Up @@ -186,6 +186,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.8-pandas1.5.3-pydantic1.10.11.txt
Expand Up @@ -207,6 +207,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.8-pandas1.5.3-pydantic2.3.0.txt
Expand Up @@ -209,6 +209,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.8-pandas2.0.3-pydantic1.10.11.txt
Expand Up @@ -207,6 +207,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.8-pandas2.0.3-pydantic2.3.0.txt
Expand Up @@ -209,6 +209,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt
Expand Up @@ -199,6 +199,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt
Expand Up @@ -201,6 +201,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas2.0.3-pydantic1.10.11.txt
Expand Up @@ -199,6 +199,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas2.0.3-pydantic2.3.0.txt
Expand Up @@ -201,6 +201,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas2.2.0-pydantic1.10.11.txt
Expand Up @@ -197,6 +197,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py3.9-pandas2.2.0-pydantic2.3.0.txt
Expand Up @@ -199,6 +199,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.17
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions dev/requirements-3.10.txt
Expand Up @@ -192,6 +192,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions dev/requirements-3.11.txt
Expand Up @@ -186,6 +186,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions dev/requirements-3.8.txt
Expand Up @@ -207,6 +207,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions dev/requirements-3.9.txt
Expand Up @@ -199,6 +199,7 @@ jinja2==3.1.3
# myst-parser
# nbconvert
# sphinx
joblib==1.3.2
json5==0.9.14
# via
# asv
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Expand Up @@ -50,6 +50,7 @@ dependencies:

# testing
- isort >= 5.7.0
- joblib
- mypy = 0.982
- pylint <= 2.17.3
- pytest
Expand Down
3 changes: 2 additions & 1 deletion pandera/api/base/checks.py
Expand Up @@ -183,7 +183,8 @@ def from_builtin_check_name(
@classmethod
def register_backend(cls, type_: Type, backend: Type[BaseCheckBackend]):
"""Register a backend for the specified type."""
cls.BACKEND_REGISTRY[(cls, type_)] = backend
if (cls, type_) not in cls.BACKEND_REGISTRY:
cls.BACKEND_REGISTRY[(cls, type_)] = backend

@classmethod
def get_backend(cls, check_obj: Any) -> Type[BaseCheckBackend]:
Expand Down
12 changes: 11 additions & 1 deletion pandera/api/base/schema.py
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.description = description
self.metadata = metadata
self.drop_invalid_rows = drop_invalid_rows
self._register_default_backends()

def validate(
self,
Expand Down Expand Up @@ -94,7 +95,8 @@ def properties(self):
@classmethod
def register_backend(cls, type_: Type, backend: Type[BaseSchemaBackend]):
"""Register a schema backend for this class."""
cls.BACKEND_REGISTRY[(cls, type_)] = backend
if (cls, type_) not in cls.BACKEND_REGISTRY:
cls.BACKEND_REGISTRY[(cls, type_)] = backend

@classmethod
def get_backend(
Expand Down Expand Up @@ -122,6 +124,14 @@ def get_backend(
f"Looked up the following base classes: {classes}"
)

def _register_default_backends(self):
"""Register default backends.
This method is invoked in the `__init__` method for subclasses that
implement the API for a specific dataframe object, and should be
overridden in those subclasses.
"""


def inferred_schema_guard(method):
"""
Expand Down
4 changes: 4 additions & 0 deletions pandera/api/pandas/array.py
Expand Up @@ -12,6 +12,7 @@
from pandera.api.checks import Check
from pandera.api.hypotheses import Hypothesis
from pandera.api.pandas.types import PandasDtypeInputTypes, is_field
from pandera.backends.pandas.register import register_pandas_backends
from pandera.config import get_config_context
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine, PYDANTIC_V2
Expand Down Expand Up @@ -111,6 +112,9 @@ def __init__(
"DataFrameSchema dtype."
)

def _register_default_backends(self):
register_pandas_backends()

# the _is_inferred getter and setter methods are not public
@property
def _is_inferred(self):
Expand Down
4 changes: 4 additions & 0 deletions pandera/api/pandas/container.py
Expand Up @@ -18,6 +18,7 @@
from pandera.api.checks import Check
from pandera.api.hypotheses import Hypothesis
from pandera.api.pandas.types import PandasDtypeInputTypes
from pandera.backends.pandas.register import register_pandas_backends
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine, PYDANTIC_V2

Expand Down Expand Up @@ -171,6 +172,9 @@ def _validate_attributes(self):
"or `'filter'`."
)

def _register_default_backends(self):
register_pandas_backends()

@property
def coerce(self) -> bool:
"""Whether to coerce series to specified type."""
Expand Down
Empty file removed pandera/api/polars/array.py
Empty file.
4 changes: 4 additions & 0 deletions pandera/api/polars/components.py
Expand Up @@ -8,6 +8,7 @@
from pandera.api.base.types import CheckList
from pandera.api.pandas.components import Column as _Column
from pandera.api.polars.types import PolarsDtypeInputTypes, PolarsCheckObjects
from pandera.backends.polars.register import register_polars_backends
from pandera.config import config_context, get_config_context
from pandera.engines import polars_engine
from pandera.utils import is_regex
Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
)
self.set_regex()

def _register_default_backends(self):
register_polars_backends()

def validate(
self,
check_obj: PolarsCheckObjects,
Expand Down
4 changes: 4 additions & 0 deletions pandera/api/polars/container.py
Expand Up @@ -8,6 +8,7 @@
from pandera.api.pandas.container import DataFrameSchema as _DataFrameSchema
from pandera.api.polars.types import PolarsCheckObjects
from pandera.api.polars.utils import get_validation_depth
from pandera.backends.polars.register import register_polars_backends
from pandera.config import config_context
from pandera.dtypes import DataType
from pandera.engines import polars_engine
Expand All @@ -33,6 +34,9 @@ def _validate_attributes(self):
"polars backend, all duplicate values will be reported."
)

def _register_default_backends(self):
register_polars_backends()

def validate(
self,
check_obj: PolarsCheckObjects,
Expand Down
4 changes: 4 additions & 0 deletions pandera/api/pyspark/column_schema.py
Expand Up @@ -9,6 +9,7 @@
from pandera.api.checks import Check
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.pyspark.types import CheckList, PySparkDtypeInputTypes
from pandera.backends.pyspark.register import register_pyspark_backends
from pandera.dtypes import DataType
from pandera.engines import pyspark_engine

Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(
self.description = description
self.metadata = metadata

def _register_default_backends(self):
register_pyspark_backends()

@property
def dtype(self) -> DataType:
"""Get the pyspark dtype"""
Expand Down
6 changes: 5 additions & 1 deletion pandera/api/pyspark/container.py
Expand Up @@ -11,12 +11,13 @@
from pyspark.sql import DataFrame

from pandera import errors
from pandera.config import get_config_context
from pandera.api.base.schema import BaseSchema
from pandera.api.base.types import StrictType
from pandera.api.checks import Check
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.pyspark.types import CheckList, PySparkDtypeInputTypes
from pandera.backends.pyspark.register import register_pyspark_backends
from pandera.config import get_config_context
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pyspark_engine

Expand Down Expand Up @@ -153,6 +154,9 @@ def __init__(
self._IS_INFERRED = False
self.metadata = metadata

def _register_default_backends(self):
register_pyspark_backends()

@property
def coerce(self) -> bool:
"""Whether to coerce series to specified type."""
Expand Down

0 comments on commit 58c5e45

Please sign in to comment.