Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NOGIL: Make loop data cache and dispatch cache thread-safe in nogil build #26348

Merged
merged 3 commits into from Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 41 additions & 3 deletions numpy/_core/src/common/npy_hashtable.c
Expand Up @@ -29,6 +29,33 @@
#define _NpyHASH_XXROTATE(x) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */
#endif

#ifdef Py_GIL_DISABLED
// TODO: replace with PyMutex when it is public
#define LOCK_TABLE(tb) \
if (!PyThread_acquire_lock(tb->mutex, NOWAIT_LOCK)) { \
PyThread_acquire_lock(tb->mutex, WAIT_LOCK); \
}
#define UNLOCK_TABLE(tb) PyThread_release_lock(tb->mutex);
#define INITIALIZE_LOCK(tb) \
tb->mutex = PyThread_allocate_lock(); \
if (tb->mutex == NULL) { \
PyErr_NoMemory(); \
PyMem_Free(res); \
return NULL; \
}
#define FREE_LOCK(tb) \
if (tb->mutex != NULL) { \
PyThread_free_lock(tb->mutex); \
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These last two don't need to be macros but writing it this way lets me avoid adding more Py_GIL_DISABLED checks below.

#else
// the GIL serializes access to the table so no need
// for locking if it is enabled
#define LOCK_TABLE(tb)
#define UNLOCK_TABLE(tb)
#define INITIALIZE_LOCK(tb)
#define FREE_LOCK(tb)
#endif

/*
* This hashing function is basically the Python tuple hash with the type
* identity hash inlined. The tuple hash itself is a reduced version of xxHash.
Expand Down Expand Up @@ -100,6 +127,8 @@ PyArrayIdentityHash_New(int key_len)
res->size = 4; /* Start with a size of 4 */
res->nelem = 0;

INITIALIZE_LOCK(res);

res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
if (res->buckets == NULL) {
PyErr_NoMemory();
Expand All @@ -114,6 +143,7 @@ NPY_NO_EXPORT void
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
{
PyMem_Free(tb->buckets);
FREE_LOCK(tb);
PyMem_Free(tb);
}

Expand Down Expand Up @@ -160,8 +190,9 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
for (npy_intp i = 0; i < prev_size; i++) {
PyObject **item = &old_table[i * (tb->key_len + 1)];
if (item[0] != NULL) {
tb->nelem -= 1; /* Decrement, setitem will increment again */
PyArrayIdentityHash_SetItem(tb, item+1, item[0], 1);
PyObject **tb_item = find_item(tb, item + 1);
tb_item[0] = item[0];
memcpy(tb_item+1, item+1, tb->key_len * sizeof(PyObject *));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids a recursive call to _resize_if_necessary and then acquiring the lock twice. It might also be a teeny bit faster.

}
}
PyMem_Free(old_table);
Expand All @@ -188,14 +219,17 @@ NPY_NO_EXPORT int
PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
PyObject *const *key, PyObject *value, int replace)
{
LOCK_TABLE(tb);
if (value != NULL && _resize_if_necessary(tb) < 0) {
/* Shrink, only if a new value is added. */
UNLOCK_TABLE(tb);
return -1;
}

PyObject **tb_item = find_item(tb, key);
if (value != NULL) {
if (tb_item[0] != NULL && !replace) {
UNLOCK_TABLE(tb);
PyErr_SetString(PyExc_RuntimeError,
"Identity cache already includes the item.");
return -1;
Expand All @@ -209,12 +243,16 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
memset(tb_item, 0, (tb->key_len + 1) * sizeof(PyObject *));
}

UNLOCK_TABLE(tb);
return 0;
}


NPY_NO_EXPORT PyObject *
PyArrayIdentityHash_GetItem(PyArrayIdentityHash const *tb, PyObject *const *key)
{
return find_item(tb, key)[0];
LOCK_TABLE(tb);
PyObject *res = find_item(tb, key)[0];
UNLOCK_TABLE(tb);
return res;
}
3 changes: 3 additions & 0 deletions numpy/_core/src/common/npy_hashtable.h
Expand Up @@ -13,6 +13,9 @@ typedef struct {
PyObject **buckets;
npy_intp size; /* current size */
npy_intp nelem; /* number of elements */
#ifdef Py_GIL_DISABLED
PyThread_type_lock *mutex;
#endif
} PyArrayIdentityHash;


Expand Down
19 changes: 13 additions & 6 deletions numpy/_core/src/umath/legacy_array_method.c
Expand Up @@ -33,37 +33,43 @@ typedef struct {


/* Use a free list, since we should normally only need one at a time */
#ifndef Py_GIL_DISABLED
#define NPY_LOOP_DATA_CACHE_SIZE 5
static int loop_data_num_cached = 0;
static legacy_array_method_auxdata *loop_data_cache[NPY_LOOP_DATA_CACHE_SIZE];

#else
#define NPY_LOOP_DATA_CACHE_SIZE 0
#endif

static void
legacy_array_method_auxdata_free(NpyAuxData *data)
{
#if NPY_LOOP_DATA_CACHE_SIZE > 0
if (loop_data_num_cached < NPY_LOOP_DATA_CACHE_SIZE) {
loop_data_cache[loop_data_num_cached] = (
(legacy_array_method_auxdata *)data);
loop_data_num_cached++;
}
else {
else
#endif
{
PyMem_Free(data);
}
}

#undef NPY_LOOP_DATA_CACHE_SIZE


NpyAuxData *
get_new_loop_data(
PyUFuncGenericFunction loop, void *user_data, int pyerr_check)
{
legacy_array_method_auxdata *data;
#if NPY_LOOP_DATA_CACHE_SIZE > 0
if (NPY_LIKELY(loop_data_num_cached > 0)) {
loop_data_num_cached--;
data = loop_data_cache[loop_data_num_cached];
}
else {
else
#endif
{
data = PyMem_Malloc(sizeof(legacy_array_method_auxdata));
if (data == NULL) {
return NULL;
Expand All @@ -77,6 +83,7 @@ get_new_loop_data(
return (NpyAuxData *)data;
}

#undef NPY_LOOP_DATA_CACHE_SIZE

/*
* This is a thin wrapper around the legacy loop signature.
Expand Down
27 changes: 22 additions & 5 deletions numpy/_core/tests/test_multithreading.py
Expand Up @@ -9,13 +9,30 @@
pytest.skip(allow_module_level=True, reason="no threading support in wasm")


def test_parallel_errstate_creation():
def run_threaded(func, iters, pass_count=False):
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
if pass_count:
futures = [tpe.submit(func, i) for i in range(iters)]
else:
futures = [tpe.submit(func) for _ in range(iters)]
for f in futures:
f.result()


def test_parallel_randomstate_creation():
# if the coercion cache is enabled and not thread-safe, creating
# RandomState instances simultaneously leads to a data race
def func(seed):
np.random.RandomState(seed)

with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
futures = [tpe.submit(func, i) for i in range(500)]
for f in futures:
f.result()
run_threaded(func, 500, pass_count=True)

def test_parallel_ufunc_execution():
# if the loop data cache or dispatch cache are not thread-safe
# computing ufuncs simultaneously in multiple threads leads
# to a data race
def func():
arr = np.random.random((25,))
np.isnan(arr)

run_threaded(func, 500)