Skip to content

Commit

Permalink
Merge pull request #26392 from ngoldbaum/strip-null-support
Browse files Browse the repository at this point in the history
BUG: support nan-like null strings in [l,r]strip
  • Loading branch information
ngoldbaum committed May 10, 2024
2 parents 2a9b913 + e438a86 commit e86c581
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 30 deletions.
86 changes: 60 additions & 26 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Expand Up @@ -1046,6 +1046,7 @@ string_lrstrip_chars_strided_loop(
PyArray_StringDTypeObject *s1descr = (PyArray_StringDTypeObject *)context->descriptors[0];
int has_null = s1descr->na_object != NULL;
int has_string_na = s1descr->has_string_na;
int has_nan_na = s1descr->has_nan_na;

const npy_static_string *default_string = &s1descr->default_string;
npy_intp N = dimensions[0];
Expand All @@ -1072,28 +1073,47 @@ string_lrstrip_chars_strided_loop(
s2 = *default_string;
}
}
else if (has_nan_na) {
if (s2_isnull) {
npy_gil_error(PyExc_ValueError,
"Cannot use a null string that is not a "
"string as the %s delimiter", ufunc_name);
}
if (s1_isnull) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in %s",
ufunc_name);
goto fail;
}
goto next_step;
}
}
else {
npy_gil_error(PyExc_ValueError,
"Cannot strip null values that are not strings");
"Can only strip null values that are strings "
"or NaN-like values");
goto fail;
}
}
{
char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
size_t new_buf_size = string_lrstrip_chars
(buf1, buf2, outbuf, striptype);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
PyMem_RawFree(new_buf);
goto fail;
}

char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
size_t new_buf_size = string_lrstrip_chars
(buf1, buf2, outbuf, striptype);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
PyMem_RawFree(new_buf);
}

PyMem_RawFree(new_buf);
next_step:

in1 += strides[0];
in2 += strides[1];
Expand Down Expand Up @@ -1150,8 +1170,9 @@ string_lrstrip_whitespace_strided_loop(
const char *ufunc_name = ((PyUFuncObject *)context->caller)->name;
STRIPTYPE striptype = *(STRIPTYPE *)context->method->static_data;
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptors[0];
int has_string_na = descr->has_string_na;
int has_null = descr->na_object != NULL;
int has_string_na = descr->has_string_na;
int has_nan_na = descr->has_nan_na;
const npy_static_string *default_string = &descr->default_string;

npy_string_allocator *allocators[2] = {};
Expand Down Expand Up @@ -1181,26 +1202,39 @@ string_lrstrip_whitespace_strided_loop(
if (has_string_na || !has_null) {
s = *default_string;
}
else if (has_nan_na) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in %s",
ufunc_name);
goto fail;
}
goto next_step;
}
else {
npy_gil_error(PyExc_ValueError,
"Cannot strip null values that are not strings");
"Can only strip null values that are strings or "
"NaN-like values");
goto fail;
}
}
{
char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
size_t new_buf_size = string_lrstrip_whitespace(
buf, outbuf, striptype);

char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
size_t new_buf_size = string_lrstrip_whitespace(
buf, outbuf, striptype);
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
}

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
PyMem_RawFree(new_buf);
}

PyMem_RawFree(new_buf);
next_step:

in += strides[0];
out += strides[1];
Expand Down
38 changes: 34 additions & 4 deletions numpy/_core/tests/test_stringdtype.py
Expand Up @@ -1080,7 +1080,13 @@ def unicode_array():
"capitalize",
"expandtabs",
"lower",
"splitlines" "swapcase" "title" "upper",
"lstrip",
"rstrip",
"splitlines",
"strip",
"swapcase",
"title",
"upper",
]

BOOL_OUTPUT_FUNCTIONS = [
Expand All @@ -1107,7 +1113,10 @@ def unicode_array():
"istitle",
"isupper",
"lower",
"lstrip",
"rstrip",
"splitlines",
"strip",
"swapcase",
"title",
"upper",
Expand All @@ -1129,10 +1138,20 @@ def unicode_array():
"upper",
]

ONLY_IN_NP_CHAR = [
"join",
"split",
"rsplit",
"splitlines"
]


@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS)
def test_unary(string_array, unicode_array, function_name):
func = getattr(np.char, function_name)
if function_name in ONLY_IN_NP_CHAR:
func = getattr(np.char, function_name)
else:
func = getattr(np.strings, function_name)
dtype = string_array.dtype
sres = func(string_array)
ures = func(unicode_array)
Expand Down Expand Up @@ -1173,6 +1192,10 @@ def test_unary(string_array, unicode_array, function_name):
with pytest.raises(ValueError):
func(na_arr)
return
if not (is_nan or is_str):
with pytest.raises(ValueError):
func(na_arr)
return
res = func(na_arr)
if is_nan and function_name in NAN_PRESERVING_FUNCTIONS:
assert res[0] is dtype.na_object
Expand All @@ -1197,13 +1220,17 @@ def test_unary(string_array, unicode_array, function_name):
("index", (None, "e")),
("join", ("-", None)),
("ljust", (None, 12)),
("lstrip", (None, "A")),
("partition", (None, "A")),
("replace", (None, "A", "B")),
("rfind", (None, "A")),
("rindex", (None, "e")),
("rjust", (None, 12)),
("rsplit", (None, "A")),
("rstrip", (None, "A")),
("rpartition", (None, "A")),
("split", (None, "A")),
("strip", (None, "A")),
("startswith", (None, "A")),
("zfill", (None, 12)),
]
Expand Down Expand Up @@ -1260,10 +1287,13 @@ def call_func(func, args, array, sanitize=True):

@pytest.mark.parametrize("function_name, args", BINARY_FUNCTIONS)
def test_binary(string_array, unicode_array, function_name, args):
func = getattr(np.char, function_name)
if function_name in ONLY_IN_NP_CHAR:
func = getattr(np.char, function_name)
else:
func = getattr(np.strings, function_name)
sres = call_func(func, args, string_array)
ures = call_func(func, args, unicode_array, sanitize=False)
if sres.dtype == StringDType():
if not isinstance(sres, tuple) and sres.dtype == StringDType():
ures = ures.astype(StringDType())
assert_array_equal(sres, ures)

Expand Down

0 comments on commit e86c581

Please sign in to comment.