Skip to content

Commit

Permalink
implement timezone agnostic polars_engine.DateTime type (#1589)
Browse files Browse the repository at this point in the history
Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy committed Apr 19, 2024
1 parent 249cab2 commit c1e7c06
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 3 deletions.
43 changes: 40 additions & 3 deletions pandera/engines/polars_engine.py
Expand Up @@ -5,7 +5,16 @@
import decimal
import inspect
import warnings
from typing import Any, Union, Optional, Iterable, Literal, Sequence, Tuple
from typing import (
Any,
Union,
Optional,
Iterable,
Literal,
Sequence,
Tuple,
Type,
)


import polars as pl
Expand Down Expand Up @@ -416,16 +425,26 @@ class Date(DataType, dtypes.Date):
class DateTime(DataType, dtypes.DateTime):
"""Polars datetime data type."""

type = pl.Datetime
type: Type[pl.Datetime] = pl.Datetime
time_zone_agnostic: bool = False

def __init__( # pylint:disable=super-init-not-called
self,
time_zone: Optional[str] = None,
time_unit: Optional[str] = None,
time_zone_agnostic: bool = False,
) -> None:

_kwargs = {}
if time_unit is not None:
# avoid deprecated warning when initializing pl.Datetime:
# passing time_unit=None is deprecated.
_kwargs["time_unit"] = time_unit

object.__setattr__(
self, "type", pl.Datetime(time_zone=time_zone, time_unit=time_unit)
self, "type", pl.Datetime(time_zone=time_zone, **_kwargs)
)
object.__setattr__(self, "time_zone_agnostic", time_zone_agnostic)

@classmethod
def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
Expand All @@ -435,6 +454,24 @@ def from_parametrized_dtype(cls, polars_dtype: pl.Datetime):
time_zone=polars_dtype.time_zone, time_unit=polars_dtype.time_unit
)

def check(
self,
pandera_dtype: dtypes.DataType,
data_container: Optional[PolarsDataContainer] = None,
) -> Union[bool, Iterable[bool]]:
try:
pandera_dtype = Engine.dtype(pandera_dtype)
except TypeError:
return False

if self.time_zone_agnostic:
return (
isinstance(pandera_dtype.type, pl.Datetime)
and pandera_dtype.type.time_unit == self.type.time_unit
)

return self.type == pandera_dtype.type and super().check(pandera_dtype)


@Engine.register_dtype(
equivalents=[
Expand Down
36 changes: 36 additions & 0 deletions tests/polars/test_polars_container.py
Expand Up @@ -11,9 +11,14 @@
import polars as pl

import pytest
from hypothesis import given
from hypothesis import strategies as st
from polars.testing.parametric import dataframes, column

import pandera as pa
from pandera import Check as C
from pandera.api.polars.types import PolarsData
from pandera.engines import polars_engine as pe
from pandera.polars import Column, DataFrameSchema, DataFrameModel


Expand Down Expand Up @@ -528,3 +533,34 @@ class Config:
lf_with_nested_types, lazy=True
)
assert validated_lf.collect().equals(validated_lf.collect())


@pytest.mark.parametrize(
"time_zone",
[
None,
"UTC",
"GMT",
"EST",
],
)
@given(st.data())
def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data):
strategy = dataframes(
column("datetime_col", dtype=pl.Datetime()),
lazy=True,
size=10,
)
lf = data.draw(strategy)
lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)})
schema_tz_agnostic = DataFrameSchema(
{"datetime_col": Column(pe.DateTime(time_zone_agnostic=True))}
)
schema_tz_agnostic.validate(lf)

schema_tz_sensitive = DataFrameSchema(
{"datetime_col": Column(pe.DateTime(time_zone_agnostic=False))}
)
if time_zone:
with pytest.raises(pa.errors.SchemaError):
schema_tz_sensitive.validate(lf)
40 changes: 40 additions & 0 deletions tests/polars/test_polars_dtypes.py
@@ -1,4 +1,6 @@
"""Polars dtype tests."""

import datetime
import decimal
from decimal import Decimal
from typing import Union, Tuple, Sequence
Expand Down Expand Up @@ -403,3 +405,41 @@ def test_polars_nested_dtypes_try_coercion(
pe.Engine.dtype(noncoercible_dtype).try_coerce(PolarsData(data))
except pandera.errors.ParserError as exc:
assert exc.failure_cases.equals(data.collect())


@pytest.mark.parametrize(
"dtype",
[
"datetime",
datetime.datetime,
pl.Datetime,
pl.Datetime(),
pl.Datetime(time_unit="ns"),
pl.Datetime(time_unit="us"),
pl.Datetime(time_unit="ms"),
pl.Datetime(time_zone="UTC"),
],
)
def test_datetime_time_zone_agnostic(dtype):

tz_agnostic = pe.DateTime(time_zone_agnostic=True)
dtype = pe.Engine.dtype(dtype)

if tz_agnostic.type.time_unit == getattr(dtype.type, "time_unit", "us"):
# timezone agnostic pandera dtype should pass regardless of timezone
assert tz_agnostic.check(dtype)
else:
# but fail if the time units don't match
assert not tz_agnostic.check(dtype)

tz_sensitive = pe.DateTime()
if getattr(dtype.type, "time_zone", None) is not None:
assert not tz_sensitive.check(dtype)

tz_sensitive_utc = pe.DateTime(time_zone="UTC")
if getattr(
dtype.type, "time_zone", None
) is None and tz_sensitive_utc.type.time_zone != getattr(
dtype.type, "time_zone", None
):
assert not tz_sensitive_utc.check(dtype)

0 comments on commit c1e7c06

Please sign in to comment.