Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add support for nan-like null strings in string replace #26355

Merged
merged 4 commits into from Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 67 additions & 28 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Expand Up @@ -1300,7 +1300,9 @@ string_replace_strided_loop(

PyArray_StringDTypeObject *descr0 =
(PyArray_StringDTypeObject *)context->descriptors[0];
int has_null = descr0->na_object != NULL;
int has_string_na = descr0->has_string_na;
int has_nan_na = descr0->has_nan_na;
const npy_static_string *default_string = &descr0->default_string;


Expand Down Expand Up @@ -1330,11 +1332,29 @@ string_replace_strided_loop(
goto fail;
}
else if (i1_isnull || i2_isnull || i3_isnull) {
if (!has_string_na) {
npy_gil_error(PyExc_ValueError,
"Null values are not supported as replacement arguments "
"for replace");
goto fail;
if (has_null && !has_string_na) {
if (i2_isnull || i3_isnull) {
npy_gil_error(PyExc_ValueError,
"Null values are not supported as search "
"patterns or replacement strings for "
"replace");
goto fail;
}
else if (i1_isnull) {
if (has_nan_na) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in replace");
goto fail;
}
goto next_step;
}
else {
npy_gil_error(PyExc_ValueError,
"Only string or NaN-like null strings can "
"be used as search strings for replace");
}
}
}
else {
if (i1_isnull) {
Expand All @@ -1349,32 +1369,51 @@ string_replace_strided_loop(
}
}

// conservatively overallocate
// TODO check overflow
size_t max_size;
if (i2s.size == 0) {
// interleaving
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
}
else {
// replace i2 with i3
max_size = i1s.size * (i3s.size/i2s.size + 1);
}
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);
Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the new indentation? It already is in the loop.

(And it makes reviewing harder...)

Copy link
Member Author

@ngoldbaum ngoldbaum Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because of the new use of goto next_step, I need to define a new lexical scope or define a bunch of variables at the top of the for loop that are only used at the bottom of it, otherwise the compiler complains about jumping over variable declarations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably have gone for top of the for-loop myself, but no big deal...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I don't hate the while (N--) loop in general, I do think a goto for loop control flow isn't nice and I much prefer a long for instead.
But this file has this pattern in a few places right now so it doesn't matter since the other places use this pattern also.

Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);

size_t new_buf_size = string_replace(
buf1, buf2, buf3, *(npy_int64 *)in4, outbuf);
npy_int64 in_count = *(npy_int64*)in4;
if (in_count == -1) {
in_count = NPY_MAX_INT64;
}

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
goto fail;
}
npy_int64 found_count = string_count<ENCODING::UTF8>(
buf1, buf2, 0, NPY_MAX_INT64);
if (found_count < 0) {
goto fail;
}

PyMem_RawFree(new_buf);
npy_intp count = Py_MIN(in_count, found_count);

Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);

// conservatively overallocate
// TODO check overflow
size_t max_size;
if (i2s.size == 0) {
// interleaving
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
}
else {
// replace i2 with i3
size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size;
max_size = i1s.size + count * change;
}
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);

size_t new_buf_size = string_replace(
buf1, buf2, buf3, count, outbuf);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
goto fail;
}

PyMem_RawFree(new_buf);
}
next_step:

in1 += strides[0];
in2 += strides[1];
Expand Down
8 changes: 4 additions & 4 deletions numpy/_core/strings.py
Expand Up @@ -1153,15 +1153,15 @@ def replace(a, old, new, count=-1):
a_dt = arr.dtype
old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt))
new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt))
count = np.asanyarray(count)

if arr.dtype.char == "T":
return _replace(arr, old, new, count)

max_int64 = np.iinfo(np.int64).max
counts = _count_ufunc(arr, old, 0, max_int64)
count = np.asanyarray(count)
counts = np.where(count < 0, counts, np.minimum(counts, count))

if arr.dtype.char == "T":
return _replace(arr, old, new, counts)

buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old))
out_dtype = f"{arr.dtype.char}{buffersizes.max()}"
out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype)
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/tests/test_stringdtype.py
Expand Up @@ -1218,6 +1218,7 @@ def test_unary(string_array, unicode_array, function_name):
"strip",
"lstrip",
"rstrip",
"replace"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seberg this change makes sure the error paths are tested.

"zfill",
]

Expand All @@ -1230,7 +1231,6 @@ def test_unary(string_array, unicode_array, function_name):
"count",
"find",
"rfind",
"replace",
]

SUPPORTS_NULLS = (
Expand Down