Skip to content

Commit

Permalink
Merge pull request #3623 from Cheukting/pd_types
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Apr 26, 2023
2 parents 154577c + 9a939ff commit 6ed5054
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 20 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,5 @@
RELEASE_TYPE: minor

This release adds support for `nullable pandas dtypes <https://pandas.pydata.org/docs/user_guide/integer_na.html>`__
in :func:`~hypothesis.extra.pandas` (:issue:`3604`).
Thanks to Cheuk Ting Ho for implementing this at the PyCon sprints!
58 changes: 45 additions & 13 deletions hypothesis-python/src/hypothesis/extra/pandas/impl.py
Expand Up @@ -44,6 +44,12 @@ def is_categorical_dtype(dt):
return dt == "category"


try:
from pandas.core.arrays.integer import IntegerDtype
except ImportError:
IntegerDtype = ()


def dtype_for_elements_strategy(s):
return st.shared(
s.map(lambda x: pandas.Series([x]).dtype),
Expand Down Expand Up @@ -79,6 +85,12 @@ def elements_and_dtype(elements, dtype, source=None):
f"{prefix}dtype is categorical, which is currently unsupported"
)

if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
raise InvalidArgument(
f"Passed dtype={dtype!r} is a dtype class, please pass in an instance of this class."
"Otherwise it would be treated as dtype=object"
)

if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
note_deprecation(
f"Passed dtype={dtype!r} is not a valid Pandas dtype. We'll treat it as "
Expand All @@ -92,25 +104,36 @@ def elements_and_dtype(elements, dtype, source=None):
f"Passed dtype={dtype!r} is a strategy, but we require a concrete dtype "
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
)
dtype = try_convert(np.dtype, dtype, "dtype")

_get_subclasses = getattr(IntegerDtype, "__subclasses__", list)
dtype = {t.name: t() for t in _get_subclasses()}.get(dtype, dtype)

if isinstance(dtype, IntegerDtype):
is_na_dtype = True
dtype = np.dtype(dtype.name.lower())
elif dtype is not None:
is_na_dtype = False
dtype = try_convert(np.dtype, dtype, "dtype")
else:
is_na_dtype = False

if elements is None:
elements = npst.from_dtype(dtype)
if is_na_dtype:
elements = st.none() | elements
elif dtype is not None:

def convert_element(value):
if is_na_dtype and value is None:
return None
name = f"draw({prefix}elements)"
try:
return np.array([value], dtype=dtype)[0]
except TypeError:
except (TypeError, ValueError):
raise InvalidArgument(
"Cannot convert %s=%r of type %s to dtype %s"
% (name, value, type(value).__name__, dtype.str)
) from None
except ValueError:
raise InvalidArgument(
f"Cannot convert {name}={value!r} to type {dtype.str}"
) from None

elements = elements.map(convert_element)
assert elements is not None
Expand Down Expand Up @@ -282,9 +305,17 @@ def series(
else:
check_strategy(index, "index")

elements, dtype = elements_and_dtype(elements, dtype)
elements, np_dtype = elements_and_dtype(elements, dtype)
index_strategy = index

# if it is converted to an object, use object for series type
if (
np_dtype is not None
and np_dtype.kind == "O"
and not isinstance(dtype, IntegerDtype)
):
dtype = np_dtype

@st.composite
def result(draw):
index = draw(index_strategy)
Expand All @@ -293,13 +324,13 @@ def result(draw):
if dtype is not None:
result_data = draw(
npst.arrays(
dtype=dtype,
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
)
).tolist()
else:
result_data = list(
draw(
Expand All @@ -310,9 +341,8 @@ def result(draw):
fill=fill,
unique=unique,
)
)
).tolist()
)

return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
else:
return pandas.Series(
Expand Down Expand Up @@ -549,7 +579,7 @@ def row():

column_names.add(c.name)

c.elements, c.dtype = elements_and_dtype(c.elements, c.dtype, label)
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)

if c.dtype is None and rows is not None:
raise InvalidArgument(
Expand Down Expand Up @@ -589,7 +619,9 @@ def just_draw_columns(draw):
if columns_without_fill:
for c in columns_without_fill:
data[c.name] = pandas.Series(
np.zeros(shape=len(index), dtype=c.dtype), index=index
np.zeros(shape=len(index), dtype=object),
index=index,
dtype=c.dtype,
)
seen = {c.name: set() for c in columns_without_fill if c.unique}

Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/tests/common/arguments.py
Expand Up @@ -19,7 +19,7 @@ def e(a, *args, **kwargs):


def e_to_str(elt):
f, args, kwargs = elt
f, args, kwargs = getattr(elt, "values", elt)
bits = list(map(repr, args))
bits.extend(sorted(f"{k}={v!r}" for k, v in kwargs.items()))
return "{}({})".format(f.__name__, ", ".join(bits))
Expand Down
33 changes: 30 additions & 3 deletions hypothesis-python/tests/pandas/test_argument_validation.py
Expand Up @@ -11,11 +11,15 @@
from datetime import datetime

import pandas as pd
import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra import pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.arguments import argument_validation_test, e
from tests.common.debug import find_any
from tests.common.utils import checks_deprecated_behaviour

BAD_ARGS = [
Expand All @@ -30,7 +34,11 @@
e(pdst.data_frames, pdst.columns(1, dtype=float, elements=1)),
e(pdst.data_frames, pdst.columns(1, fill=1, dtype=float)),
e(pdst.data_frames, pdst.columns(["A", "A"], dtype=float)),
e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
pytest.param(
*e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.data_frames, pdst.columns(1, elements=st.text(), dtype=int)),
e(pdst.data_frames, 1),
e(pdst.data_frames, [1]),
e(pdst.data_frames, pdst.columns(1, dtype="category")),
Expand Down Expand Up @@ -64,7 +72,11 @@
e(pdst.indexes, dtype="not a dtype"),
e(pdst.indexes, elements="not a strategy"),
e(pdst.indexes, elements=st.text(), dtype=float),
e(pdst.indexes, elements=st.none(), dtype=int),
pytest.param(
*e(pdst.indexes, elements=st.none(), dtype=int),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.indexes, elements=st.text(), dtype=int),
e(pdst.indexes, elements=st.integers(0, 10), dtype=st.sampled_from([int, float])),
e(pdst.indexes, dtype=int, max_size=0, min_size=1),
e(pdst.indexes, dtype=int, unique="true"),
Expand All @@ -77,7 +89,11 @@
e(pdst.series),
e(pdst.series, dtype="not a dtype"),
e(pdst.series, elements="not a strategy"),
e(pdst.series, elements=st.none(), dtype=int),
pytest.param(
*e(pdst.series, elements=st.none(), dtype=int),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.series, elements=st.text(), dtype=int),
e(pdst.series, dtype="category"),
e(pdst.series, index="not a strategy"),
]
Expand All @@ -99,3 +115,14 @@ def test_timestamp_as_datetime_bounds(dt):
@checks_deprecated_behaviour
def test_confusing_object_dtype_aliases():
pdst.series(elements=st.tuples(st.integers()), dtype=tuple).example()


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
def test_pandas_nullable_types_class():
with pytest.raises(
InvalidArgument, match="Otherwise it would be treated as dtype=object"
):
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype)
find_any(st, lambda s: s.isna().any())
12 changes: 12 additions & 0 deletions hypothesis-python/tests/pandas/test_data_frame.py
Expand Up @@ -9,10 +9,12 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas as pd
import pytest

from hypothesis import HealthCheck, given, reject, settings, strategies as st
from hypothesis.extra import numpy as npst, pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.debug import find_any
from tests.pandas.helpers import supported_by_pandas
Expand Down Expand Up @@ -267,3 +269,13 @@ def works_with_object_dtype(df):
assert dtype is None
with pytest.raises(ValueError, match="Maybe passing dtype=object would help"):
works_with_object_dtype()


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
def test_pandas_nullable_types():
st = pdst.data_frames(pdst.columns(2, dtype=pd.core.arrays.integer.Int8Dtype()))
df = find_any(st, lambda s: s.isna().any().any())
for s in df.columns:
assert type(df[s].dtype) == pd.core.arrays.integer.Int8Dtype
28 changes: 25 additions & 3 deletions hypothesis-python/tests/pandas/test_series.py
Expand Up @@ -9,12 +9,14 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas
import pandas as pd
import pytest

from hypothesis import assume, given, strategies as st
from hypothesis.extra import numpy as npst, pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.debug import find_any
from tests.common.debug import assert_all_examples, assert_no_examples, find_any
from tests.pandas.helpers import supported_by_pandas


Expand All @@ -25,7 +27,7 @@ def test_can_create_a_series_of_any_dtype(data):
# Use raw data to work around pandas bug in repr. See
# https://github.com/pandas-dev/pandas/issues/27484
series = data.conjecture_data.draw(pdst.series(dtype=dtype))
assert series.dtype == pandas.Series([], dtype=dtype).dtype
assert series.dtype == pd.Series([], dtype=dtype).dtype


@given(pdst.series(dtype=float, index=pdst.range_indexes(min_size=2, max_size=5)))
Expand Down Expand Up @@ -61,3 +63,23 @@ def test_unique_series_are_unique(s):
@given(pdst.series(dtype="int8", name=st.just("test_name")))
def test_name_passed_on(s):
assert s.name == "test_name"


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
@pytest.mark.parametrize(
"dtype", ["Int8", pd.core.arrays.integer.Int8Dtype() if IntegerDtype else None]
)
def test_pandas_nullable_types(dtype):
assert_no_examples(
pdst.series(dtype=dtype, elements=st.just(0)),
lambda s: s.isna().any(),
)
assert_all_examples(
pdst.series(dtype=dtype, elements=st.none()),
lambda s: s.isna().all(),
)
find_any(pdst.series(dtype=dtype), lambda s: not s.isna().any())
e = find_any(pdst.series(dtype=dtype), lambda s: s.isna().any())
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype

0 comments on commit 6ed5054

Please sign in to comment.