Skip to content

Commit

Permalink
pythongh-116664: Make module state Py_SETREF's in _warnings thread-sa…
Browse files Browse the repository at this point in the history
…fe (python#116959)

Mark the swap operations as critical sections.

Add an internal Py_BEGIN_CRITICAL_SECTION_MUT API that takes a PyMutex
pointer instead of a PyObject pointer.
  • Loading branch information
erlend-aasland authored and diegorusso committed Apr 17, 2024
1 parent 3ca45ff commit 8250cb6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
8 changes: 6 additions & 2 deletions Include/internal/pycore_critical_section.h
Expand Up @@ -87,10 +87,13 @@ extern "C" {
#define _Py_CRITICAL_SECTION_MASK 0x3

#ifdef Py_GIL_DISABLED
# define Py_BEGIN_CRITICAL_SECTION(op) \
# define Py_BEGIN_CRITICAL_SECTION_MUT(mutex) \
{ \
_PyCriticalSection _cs; \
_PyCriticalSection_Begin(&_cs, &_PyObject_CAST(op)->ob_mutex)
_PyCriticalSection_Begin(&_cs, mutex)

# define Py_BEGIN_CRITICAL_SECTION(op) \
Py_BEGIN_CRITICAL_SECTION_MUT(&_PyObject_CAST(op)->ob_mutex)

# define Py_END_CRITICAL_SECTION() \
_PyCriticalSection_End(&_cs); \
Expand Down Expand Up @@ -138,6 +141,7 @@ extern "C" {

#else /* !Py_GIL_DISABLED */
// The critical section APIs are no-ops with the GIL.
# define Py_BEGIN_CRITICAL_SECTION_MUT(mut)
# define Py_BEGIN_CRITICAL_SECTION(op)
# define Py_END_CRITICAL_SECTION()
# define Py_XBEGIN_CRITICAL_SECTION(op)
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_warnings.h
Expand Up @@ -14,6 +14,7 @@ struct _warnings_runtime_state {
PyObject *filters; /* List */
PyObject *once_registry; /* Dict */
PyObject *default_action; /* String */
struct _PyMutex mutex;
long filters_version;
};

Expand Down
58 changes: 37 additions & 21 deletions Python/_warnings.c
@@ -1,4 +1,5 @@
#include "Python.h"
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION_MUT()
#include "pycore_interp.h" // PyInterpreterState.warnings
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_pyerrors.h" // _PyErr_Occurred()
Expand Down Expand Up @@ -235,14 +236,12 @@ get_warnings_attr(PyInterpreterState *interp, PyObject *attr, int try_import)
static PyObject *
get_once_registry(PyInterpreterState *interp)
{
PyObject *registry;

WarningsState *st = warnings_get_state(interp);
if (st == NULL) {
return NULL;
}
assert(st != NULL);

_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);

registry = GET_WARNINGS_ATTR(interp, onceregistry, 0);
PyObject *registry = GET_WARNINGS_ATTR(interp, onceregistry, 0);
if (registry == NULL) {
if (PyErr_Occurred())
return NULL;
Expand All @@ -265,14 +264,12 @@ get_once_registry(PyInterpreterState *interp)
static PyObject *
get_default_action(PyInterpreterState *interp)
{
PyObject *default_action;

WarningsState *st = warnings_get_state(interp);
if (st == NULL) {
return NULL;
}
assert(st != NULL);

default_action = GET_WARNINGS_ATTR(interp, defaultaction, 0);
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);

PyObject *default_action = GET_WARNINGS_ATTR(interp, defaultaction, 0);
if (default_action == NULL) {
if (PyErr_Occurred()) {
return NULL;
Expand All @@ -299,15 +296,12 @@ get_filter(PyInterpreterState *interp, PyObject *category,
PyObject *text, Py_ssize_t lineno,
PyObject *module, PyObject **item)
{
PyObject *action;
Py_ssize_t i;
PyObject *warnings_filters;
WarningsState *st = warnings_get_state(interp);
if (st == NULL) {
return NULL;
}
assert(st != NULL);

warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0);
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);

PyObject *warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0);
if (warnings_filters == NULL) {
if (PyErr_Occurred())
return NULL;
Expand All @@ -324,7 +318,7 @@ get_filter(PyInterpreterState *interp, PyObject *category,
}

/* WarningsState.filters could change while we are iterating over it. */
for (i = 0; i < PyList_GET_SIZE(filters); i++) {
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(filters); i++) {
PyObject *tmp_item, *action, *msg, *cat, *mod, *ln_obj;
Py_ssize_t ln;
int is_subclass, good_msg, good_mod;
Expand Down Expand Up @@ -384,7 +378,7 @@ get_filter(PyInterpreterState *interp, PyObject *category,
Py_DECREF(tmp_item);
}

action = get_default_action(interp);
PyObject *action = get_default_action(interp);
if (action != NULL) {
*item = Py_NewRef(Py_None);
return action;
Expand Down Expand Up @@ -1000,8 +994,13 @@ do_warn(PyObject *message, PyObject *category, Py_ssize_t stack_level,
&filename, &lineno, &module, &registry))
return NULL;

WarningsState *st = warnings_get_state(tstate->interp);
assert(st != NULL);

Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
res = warn_explicit(tstate, category, message, filename, lineno, module, registry,
NULL, source);
Py_END_CRITICAL_SECTION();
Py_DECREF(filename);
Py_DECREF(registry);
Py_DECREF(module);
Expand Down Expand Up @@ -1149,8 +1148,14 @@ warnings_warn_explicit_impl(PyObject *module, PyObject *message,
return NULL;
}
}

WarningsState *st = warnings_get_state(tstate->interp);
assert(st != NULL);

Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
returned = warn_explicit(tstate, category, message, filename, lineno,
mod, registry, source_line, sourceobj);
Py_END_CRITICAL_SECTION();
Py_XDECREF(source_line);
return returned;
}
Expand Down Expand Up @@ -1290,8 +1295,14 @@ PyErr_WarnExplicitObject(PyObject *category, PyObject *message,
if (tstate == NULL) {
return -1;
}

WarningsState *st = warnings_get_state(tstate->interp);
assert(st != NULL);

Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
res = warn_explicit(tstate, category, message, filename, lineno,
module, registry, NULL, NULL);
Py_END_CRITICAL_SECTION();
if (res == NULL)
return -1;
Py_DECREF(res);
Expand Down Expand Up @@ -1356,8 +1367,13 @@ PyErr_WarnExplicitFormat(PyObject *category,
PyObject *res;
PyThreadState *tstate = get_current_tstate();
if (tstate != NULL) {
WarningsState *st = warnings_get_state(tstate->interp);
assert(st != NULL);

Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
res = warn_explicit(tstate, category, message, filename, lineno,
module, registry, NULL, NULL);
Py_END_CRITICAL_SECTION();
Py_DECREF(message);
if (res != NULL) {
Py_DECREF(res);
Expand Down

0 comments on commit 8250cb6

Please sign in to comment.