Skip to content

Commit

Permalink
add pandas pyarrow backend support (#1628)
Browse files Browse the repository at this point in the history
* feat: add basic pyarrow datatypes to pandas engine

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* test: exclude arrow types from strategies

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* fix: pandas 2 plus check & remove pyarrow string equivalent

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* chore: add missing imports

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* test: ✅ exclude string[pyarrow] from string to type parse test

`string[pyarrow]` gets parsed to type `string` by pandas

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* refactor: add more equivalents to pyarrow dtypes

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

* fix: type linting for python 3.8

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>

---------

Signed-off-by: Ajith Aravind <ajith.aravind100@gmail.com>
  • Loading branch information
aaravind100 committed May 11, 2024
1 parent 95e412f commit 45f9d4a
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pandera/__init__.py
Expand Up @@ -62,6 +62,7 @@
INT64,
PANDAS_1_2_0_PLUS,
PANDAS_1_3_0_PLUS,
PANDAS_2_0_0_PLUS,
STRING,
UINT8,
UINT16,
Expand Down Expand Up @@ -136,7 +137,9 @@
"INT16",
"INT32",
"INT64",
"PANDAS_1_2_0_PLUS",
"PANDAS_1_3_0_PLUS",
"PANDAS_2_0_0_PLUS",
"STRING",
"UINT8",
"UINT16",
Expand Down
261 changes: 261 additions & 0 deletions pandera/engines/pandas_engine.py
Expand Up @@ -75,6 +75,7 @@

PANDAS_1_2_0_PLUS = pandas_version().release >= (1, 2, 0)
PANDAS_1_3_0_PLUS = pandas_version().release >= (1, 3, 0)
PANDAS_2_0_0_PLUS = pandas_version().release >= (2, 0, 0)


# register different TypedDict type depending on python version
Expand All @@ -101,6 +102,16 @@ def is_extension_dtype(
)


def is_pyarrow_dtype(
pd_dtype: PandasDataType,
) -> Union[bool, Iterable[bool]]:
"""Check if a value is a pandas pyarrow type or instance of one."""
if not PYARROW_INSTALLED:
raise TypeError("pyarrow must be installed to use pyarrow dtypes.")

return isinstance(pd_dtype, pd.ArrowDtype)


@immutable(init=True)
class DataType(dtypes.DataType):
"""Base `DataType` for boxing Pandas data types."""
Expand Down Expand Up @@ -220,6 +231,8 @@ def dtype(cls, data_type: Any) -> dtypes.DataType:
"Usage Tip: Use an instance or a string "
"representation."
) from None
elif is_pyarrow_dtype(data_type):
np_or_pd_dtype = data_type.pyarrow_dtype
else:
# let pandas transform any acceptable value
# into a numpy or pandas dtype.
Expand Down Expand Up @@ -1570,3 +1583,251 @@ def __init__( # pylint:disable=super-init-not-called

def __str__(self) -> str:
return str(NamedTuple.__name__)


###############################################################################
# pyarrow types
###############################################################################

if PYARROW_INSTALLED and PANDAS_2_0_0_PLUS:

@Engine.register_dtype(
equivalents=[
"bool[pyarrow]",
pyarrow.bool_,
pd.ArrowDtype(pyarrow.bool_()),
]
)
@immutable
class ArrowBool(BOOL):
"""Semantic representation of a :class:`pyarrow.bool_`."""

type = pd.ArrowDtype(pyarrow.bool_())

@Engine.register_dtype(
equivalents=[
"int64[pyarrow]",
pyarrow.int64,
pd.ArrowDtype(pyarrow.int64()),
]
)
@immutable
class ArrowInt64(DataType, dtypes.Int):
"""Semantic representation of a :class:`pyarrow.int64`."""

type = pd.ArrowDtype(pyarrow.int64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"int32[pyarrow]",
pyarrow.int32,
pd.ArrowDtype(pyarrow.int32()),
]
)
@immutable
class ArrowInt32(ArrowInt64):
"""Semantic representation of a :class:`pyarrow.int32`."""

type = pd.ArrowDtype(pyarrow.int32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[
"int16[pyarrow]",
pyarrow.int16,
pd.ArrowDtype(pyarrow.int16()),
]
)
@immutable
class ArrowInt16(ArrowInt32):
"""Semantic representation of a :class:`pyarrow.int16`."""

type = pd.ArrowDtype(pyarrow.int16())
bit_width: int = 16

@Engine.register_dtype(
equivalents=[
"int8[pyarrow]",
pyarrow.int8,
pd.ArrowDtype(pyarrow.int8()),
]
)
@immutable
class ArrowInt8(ArrowInt16):
"""Semantic representation of a :class:`pyarrow.int8`."""

type = pd.ArrowDtype(pyarrow.int8())
bit_width: int = 8

@Engine.register_dtype(equivalents=[pyarrow.string])
@immutable
class ArrowString(DataType, dtypes.String):
"""Semantic representation of a :class:`pyarrow.string`."""

type = pd.ArrowDtype(pyarrow.string())

@Engine.register_dtype(
equivalents=[
"uint64[pyarrow]",
pyarrow.uint64,
pd.ArrowDtype(pyarrow.uint64()),
]
)
@immutable
class ArrowUInt64(DataType, dtypes.UInt):
"""Semantic representation of a :class:`pyarrow.uint64`."""

type = pd.ArrowDtype(pyarrow.uint64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"uint32[pyarrow]",
pyarrow.uint32,
pd.ArrowDtype(pyarrow.uint32()),
]
)
@immutable
class ArrowUInt32(ArrowUInt64):
"""Semantic representation of a :class:`pyarrow.uint32`."""

type = pd.ArrowDtype(pyarrow.uint32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[
"uint16[pyarrow]",
pyarrow.uint16,
pd.ArrowDtype(pyarrow.uint16()),
]
)
@immutable
class ArrowUInt16(ArrowUInt32):
"""Semantic representation of a :class:`pyarrow.uint16`."""

type = pd.ArrowDtype(pyarrow.uint16())
bit_width: int = 16

@Engine.register_dtype(
equivalents=[
"uint8[pyarrow]",
pyarrow.uint8,
pd.ArrowDtype(pyarrow.uint8()),
]
)
@immutable
class ArrowUInt8(ArrowUInt16):
"""Semantic representation of a :class:`pyarrow.uint8`."""

type = pd.ArrowDtype(pyarrow.uint8())
bit_width: int = 8

@Engine.register_dtype(
equivalents=[
"double[pyarrow]",
pyarrow.float64,
pd.ArrowDtype(pyarrow.float64()),
]
)
@immutable
class ArrowFloat64(DataType, dtypes.Float):
"""Semantic representation of a :class:`pyarrow.float64`."""

type = pd.ArrowDtype(pyarrow.float64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"float[pyarrow]",
pyarrow.float32,
pd.ArrowDtype(pyarrow.float32()),
]
)
@immutable
class ArrowFloat32(ArrowFloat64):
"""Semantic representation of a :class:`pyarrow.float32`."""

type = pd.ArrowDtype(pyarrow.float32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[pyarrow.decimal128, pyarrow.Decimal128Type]
)
@immutable(init=True)
class ArrowDecimal128(DataType, dtypes.Decimal):
"""Semantic representation of a :class:`pyarrow.decimal128`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
precision: int = 28
scale: int = 0

def __post_init__(self) -> None:
type_ = pd.ArrowDtype(
pyarrow.decimal128(self.precision, self.scale)
)
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(
cls,
pyarrow_dtype: pyarrow.Decimal128Type,
):
return cls(precision=pyarrow_dtype.precision, scale=pyarrow_dtype.scale) # type: ignore

@Engine.register_dtype(
equivalents=[pyarrow.timestamp, pyarrow.TimestampType]
)
@immutable(init=True)
class ArrowTimestamp(DataType, dtypes.Timestamp):
"""Semantic representation of a :class:`pyarrow.timestamp`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
unit: Optional[str] = "ns"
tz: Optional[datetime.tzinfo] = None

def __post_init__(self):
type_ = pd.ArrowDtype(pyarrow.timestamp(self.unit, self.tz))
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
return cls(unit=pyarrow_dtype.unit, tz=pyarrow_dtype.tz) # type: ignore

@Engine.register_dtype(
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(DataType, dtypes.Category):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
index_type: Optional[pyarrow.DataType] = pyarrow.int64()
value_type: Optional[pyarrow.DataType] = pyarrow.int64()
ordered: bool = False

def __post_init__(self):
type_ = pd.ArrowDtype(
pyarrow.dictionary(
self.index_type,
self.value_type,
self.ordered,
)
)
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(
cls, pyarrow_dtype: pyarrow.DictionaryType
):
return cls(
index_type=pyarrow_dtype.index_type, # type: ignore
value_type=pyarrow_dtype.value_type, # type: ignore
ordered=pyarrow_dtype.ordered, # type: ignore
)
14 changes: 13 additions & 1 deletion tests/core/test_pandas_engine.py
@@ -1,6 +1,7 @@
"""Test pandas engine."""

from datetime import date
from typing import Any, Set

import hypothesis
import hypothesis.extra.pandas as pd_st
Expand All @@ -14,9 +15,20 @@
from pandera.engines import pandas_engine
from pandera.errors import ParserError

UNSUPPORTED_DTYPE_CLS: Set[Any] = set()

# `string[pyarrow]` gets parsed to type `string` by pandas
if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS:
UNSUPPORTED_DTYPE_CLS.add(pandas_engine.ArrowString)


@pytest.mark.parametrize(
"data_type", list(pandas_engine.Engine.get_registered_dtypes())
"data_type",
[
data_type
for data_type in pandas_engine.Engine.get_registered_dtypes()
if data_type not in UNSUPPORTED_DTYPE_CLS
],
)
def test_pandas_data_type(data_type):
"""Test numpy engine DataType base class."""
Expand Down
22 changes: 22 additions & 0 deletions tests/strategies/test_strategies.py
Expand Up @@ -45,6 +45,28 @@
pandas_engine.PythonNamedTuple,
]
)

if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS:
UNSUPPORTED_DTYPE_CLS.update(
[
pandas_engine.ArrowBool,
pandas_engine.ArrowDecimal128,
pandas_engine.ArrowDictionary,
pandas_engine.ArrowFloat32,
pandas_engine.ArrowFloat64,
pandas_engine.ArrowInt8,
pandas_engine.ArrowInt16,
pandas_engine.ArrowInt32,
pandas_engine.ArrowInt64,
pandas_engine.ArrowString,
pandas_engine.ArrowTimestamp,
pandas_engine.ArrowUInt8,
pandas_engine.ArrowUInt16,
pandas_engine.ArrowUInt32,
pandas_engine.ArrowUInt64,
]
)

SUPPORTED_DTYPES = set()
for data_type in pandas_engine.Engine.get_registered_dtypes():
if (
Expand Down

0 comments on commit 45f9d4a

Please sign in to comment.