Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-116664: Ensure thread-safe dict access in _warnings #116768

Merged
61 changes: 32 additions & 29 deletions Python/_warnings.c
@@ -1,5 +1,4 @@
#include "Python.h"
#include "pycore_dict.h" // _PyDict_GetItemWithError()
#include "pycore_interp.h" // PyInterpreterState.warnings
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_pyerrors.h" // _PyErr_Occurred()
Expand All @@ -8,6 +7,8 @@
#include "pycore_sysmodule.h" // _PySys_GetAttr()
#include "pycore_traceback.h" // _Py_DisplaySourceLine()

#include <stdbool.h>

#include "clinic/_warnings.c.h"

#define MODULE_NAME "_warnings"
Expand Down Expand Up @@ -397,7 +398,7 @@ static int
already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
int should_set)
{
PyObject *version_obj, *already_warned;
PyObject *already_warned;

if (key == NULL)
return -1;
Expand All @@ -406,14 +407,17 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
if (st == NULL) {
return -1;
}
version_obj = _PyDict_GetItemWithError(registry, &_Py_ID(version));
if (version_obj == NULL
PyObject *version_obj;
if (PyDict_GetItemRef(registry, &_Py_ID(version), &version_obj) < 0) {
return -1;
}
bool should_update_version = (
version_obj == NULL
|| !PyLong_CheckExact(version_obj)
|| PyLong_AsLong(version_obj) != st->filters_version)
{
if (PyErr_Occurred()) {
return -1;
}
|| PyLong_AsLong(version_obj) != st->filters_version
);
Py_XDECREF(version_obj);
if (should_update_version) {
PyDict_Clear(registry);
version_obj = PyLong_FromLong(st->filters_version);
if (version_obj == NULL)
Expand Down Expand Up @@ -911,13 +915,12 @@ setup_context(Py_ssize_t stack_level,
/* Setup registry. */
assert(globals != NULL);
assert(PyDict_Check(globals));
*registry = _PyDict_GetItemWithError(globals, &_Py_ID(__warningregistry__));
int rc = PyDict_GetItemRef(globals, &_Py_ID(__warningregistry__),
registry);
if (rc < 0) {
goto handle_error;
}
if (*registry == NULL) {
int rc;

if (_PyErr_Occurred(tstate)) {
goto handle_error;
}
*registry = PyDict_New();
if (*registry == NULL)
goto handle_error;
Expand All @@ -926,21 +929,21 @@ setup_context(Py_ssize_t stack_level,
if (rc < 0)
goto handle_error;
}
else
Py_INCREF(*registry);

/* Setup module. */
*module = _PyDict_GetItemWithError(globals, &_Py_ID(__name__));
if (*module == Py_None || (*module != NULL && PyUnicode_Check(*module))) {
Py_INCREF(*module);
}
else if (_PyErr_Occurred(tstate)) {
rc = PyDict_GetItemRef(globals, &_Py_ID(__name__), module);
if (rc < 0) {
goto handle_error;
}
else {
*module = PyUnicode_FromString("<string>");
if (*module == NULL)
goto handle_error;
if (rc > 0) {
if (Py_IsNone(*module) || PyUnicode_Check(*module)) {
return 1;
}
Py_DECREF(*module);
}
*module = PyUnicode_FromString("<string>");
if (*module == NULL) {
goto handle_error;
}

return 1;
Expand Down Expand Up @@ -1063,12 +1066,12 @@ get_source_line(PyInterpreterState *interp, PyObject *module_globals, int lineno
return NULL;
}

module_name = _PyDict_GetItemWithError(module_globals, &_Py_ID(__name__));
if (!module_name) {
int rc = PyDict_GetItemRef(module_globals, &_Py_ID(__name__),
&module_name);
if (rc < 0 || rc == 0) {
Py_DECREF(loader);
return NULL;
}
Py_INCREF(module_name);

/* Make sure the loader implements the optional get_source() method. */
(void)PyObject_GetOptionalAttr(loader, &_Py_ID(get_source), &get_source);
Expand Down