Skip to content

Commit

Permalink
MNT: add locking for PyArrayIdentityHash
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Apr 30, 2024
1 parent 6261524 commit 86e39a0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
44 changes: 41 additions & 3 deletions numpy/_core/src/common/npy_hashtable.c
Expand Up @@ -29,6 +29,33 @@
#define _NpyHASH_XXROTATE(x) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */
#endif

#ifdef Py_GIL_DISABLED
// TODO: replace with PyMutex when it is public
#define LOCK_TABLE(tb) \
if (!PyThread_acquire_lock(tb->mutex, NOWAIT_LOCK)) { \
PyThread_acquire_lock(tb->mutex, WAIT_LOCK); \
}
#define UNLOCK_TABLE(tb) PyThread_release_lock(tb->mutex);
#define INITIALIZE_LOCK(tb) \
tb->mutex = PyThread_allocate_lock(); \
if (tb->mutex == NULL) { \
PyErr_NoMemory(); \
PyMem_Free(res); \
return NULL; \
}
#define FREE_LOCK(tb) \
if (tb->mutex != NULL) { \
PyThread_free_lock(tb->mutex); \
}
#else
// the GIL serializes access to the table so no need
// for locking if it is enabled
#define LOCK_TABLE(tb)
#define UNLOCK_TABLE(tb)
#define INITIALIZE_LOCK(tb)
#define FREE_LOCK(tb)
#endif

/*
* This hashing function is basically the Python tuple hash with the type
* identity hash inlined. The tuple hash itself is a reduced version of xxHash.
Expand Down Expand Up @@ -100,6 +127,8 @@ PyArrayIdentityHash_New(int key_len)
res->size = 4; /* Start with a size of 4 */
res->nelem = 0;

INITIALIZE_LOCK(res);

res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
if (res->buckets == NULL) {
PyErr_NoMemory();
Expand All @@ -114,6 +143,7 @@ NPY_NO_EXPORT void
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
{
PyMem_Free(tb->buckets);
FREE_LOCK(tb);
PyMem_Free(tb);
}

Expand Down Expand Up @@ -160,8 +190,9 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
for (npy_intp i = 0; i < prev_size; i++) {
PyObject **item = &old_table[i * (tb->key_len + 1)];
if (item[0] != NULL) {
tb->nelem -= 1; /* Decrement, setitem will increment again */
PyArrayIdentityHash_SetItem(tb, item+1, item[0], 1);
PyObject **tb_item = find_item(tb, item + 1);
tb_item[0] = item[0];
memcpy(tb_item+1, item+1, tb->key_len * sizeof(PyObject *));
}
}
PyMem_Free(old_table);
Expand All @@ -188,14 +219,17 @@ NPY_NO_EXPORT int
PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
PyObject *const *key, PyObject *value, int replace)
{
LOCK_TABLE(tb);
if (value != NULL && _resize_if_necessary(tb) < 0) {
/* Shrink, only if a new value is added. */
UNLOCK_TABLE(tb);
return -1;
}

PyObject **tb_item = find_item(tb, key);
if (value != NULL) {
if (tb_item[0] != NULL && !replace) {
UNLOCK_TABLE(tb);
PyErr_SetString(PyExc_RuntimeError,
"Identity cache already includes the item.");
return -1;
Expand All @@ -209,12 +243,16 @@ PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb,
memset(tb_item, 0, (tb->key_len + 1) * sizeof(PyObject *));
}

UNLOCK_TABLE(tb);
return 0;
}


NPY_NO_EXPORT PyObject *
PyArrayIdentityHash_GetItem(PyArrayIdentityHash const *tb, PyObject *const *key)
{
return find_item(tb, key)[0];
LOCK_TABLE(tb);
PyObject *res = find_item(tb, key)[0];
UNLOCK_TABLE(tb);
return res;
}
3 changes: 3 additions & 0 deletions numpy/_core/src/common/npy_hashtable.h
Expand Up @@ -13,6 +13,9 @@ typedef struct {
PyObject **buckets;
npy_intp size; /* current size */
npy_intp nelem; /* number of elements */
#ifdef Py_GIL_DISABLED
PyThread_type_lock *mutex;
#endif
} PyArrayIdentityHash;


Expand Down

0 comments on commit 86e39a0

Please sign in to comment.