Skip to content

Commit

Permalink
Move test
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanx749 committed Apr 30, 2024
1 parent f0c2097 commit 5031be7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
19 changes: 19 additions & 0 deletions pandas/conftest.py
Expand Up @@ -1367,6 +1367,25 @@ def any_string_dtype(request):
return request.param


@pytest.fixture(
params=[
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
pytest.param(pd.ArrowDtype(pa.string()), marks=td.skip_if_no("pyarrow")), # noqa: F821
]
)
def any_string_dtype_2(request):
"""
Parametrized fixture for string dtypes.
* 'object'
* 'string[python]'
* 'string[pyarrow]'
"""
return request.param


@pytest.fixture(params=tm.DATETIME64_DTYPES)
def datetime64_dtype(request):
"""
Expand Down
10 changes: 0 additions & 10 deletions pandas/tests/extension/test_arrow.py
Expand Up @@ -2296,16 +2296,6 @@ def test_str_split_pat_none(method):
tm.assert_series_equal(result, expected)


def test_str_split_regex_none():
# GH 58321
ser = pd.Series(["230/270/270", "240-290-290"], dtype=ArrowDtype(pa.string()))
result = ser.str.split(r"/|-", regex=None)
expected = pd.Series(
ArrowExtensionArray(pa.array([["230", "270", "270"], ["240", "290", "290"]]))
)
tm.assert_series_equal(result, expected)


def test_str_split():
# GH 52401
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
Expand Down
32 changes: 24 additions & 8 deletions pandas/tests/strings/test_split_partition.py
Expand Up @@ -17,6 +17,10 @@
object_pyarrow_numpy,
)

pa = pytest.importorskip("pyarrow")

from pandas.core.arrays.arrow.array import ArrowExtensionArray


@pytest.mark.parametrize("method", ["split", "rsplit"])
def test_split(any_string_dtype, method):
Expand Down Expand Up @@ -59,27 +63,39 @@ def test_split_regex(any_string_dtype):
tm.assert_series_equal(result, exp)


def test_split_regex_explicit(any_string_dtype):
def test_split_regex_explicit(any_string_dtype_2):
# explicit regex = True split with compiled regex
regex_pat = re.compile(r".jpg")
values = Series("xxxjpgzzz.jpg", dtype=any_string_dtype)
result = values.str.split(regex_pat)
exp = Series([["xx", "zzz", ""]])
tm.assert_series_equal(result, exp)
values = Series("xxxjpgzzz.jpg", dtype=any_string_dtype_2)

if not isinstance(any_string_dtype_2, pd.ArrowDtype):
# ArrowDtype does not support compiled regex
result = values.str.split(regex_pat)
exp = Series([["xx", "zzz", ""]])
tm.assert_series_equal(result, exp)

# explicit regex = False split
result = values.str.split(r"\.jpg", regex=False)
exp = Series([["xxxjpgzzz.jpg"]])
if not isinstance(any_string_dtype_2, pd.ArrowDtype):
exp = Series([["xxxjpgzzz.jpg"]])
else:
exp = Series(ArrowExtensionArray(pa.array([["xxxjpgzzz.jpg"]])))
tm.assert_series_equal(result, exp)

# non explicit regex split, pattern length == 1
result = values.str.split(r".")
exp = Series([["xxxjpgzzz", "jpg"]])
if not isinstance(any_string_dtype_2, pd.ArrowDtype):
exp = Series([["xxxjpgzzz", "jpg"]])
else:
exp = Series(ArrowExtensionArray(pa.array([["xxxjpgzzz", "jpg"]])))
tm.assert_series_equal(result, exp)

# non explicit regex split, pattern length != 1
result = values.str.split(r".jpg")
exp = Series([["xx", "zzz", ""]])
if not isinstance(any_string_dtype_2, pd.ArrowDtype):
exp = Series([["xx", "zzz", ""]])
else:
exp = Series(ArrowExtensionArray(pa.array([["xx", "zzz", ""]])))
tm.assert_series_equal(result, exp)

# regex=False with pattern compiled regex raises error
Expand Down

0 comments on commit 5031be7

Please sign in to comment.