Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-116664: Ensure thread-safe dict access in _warnings #116768

Merged
53 changes: 27 additions & 26 deletions Python/_warnings.c
@@ -1,5 +1,4 @@
#include "Python.h"
#include "pycore_dict.h" // _PyDict_GetItemWithError()
#include "pycore_interp.h" // PyInterpreterState.warnings
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_pyerrors.h" // _PyErr_Occurred()
Expand All @@ -8,6 +7,8 @@
#include "pycore_sysmodule.h" // _PySys_GetAttr()
#include "pycore_traceback.h" // _Py_DisplaySourceLine()

#include <stdbool.h>

#include "clinic/_warnings.c.h"

#define MODULE_NAME "_warnings"
Expand Down Expand Up @@ -397,7 +398,7 @@ static int
already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
int should_set)
{
PyObject *version_obj, *already_warned;
PyObject *already_warned;

if (key == NULL)
return -1;
Expand All @@ -406,14 +407,17 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
if (st == NULL) {
return -1;
}
version_obj = _PyDict_GetItemWithError(registry, &_Py_ID(version));
if (version_obj == NULL
PyObject *version_obj;
if (PyDict_GetItemRef(registry, &_Py_ID(version), &version_obj) < 0) {
return -1;
}
bool should_update_version = (
version_obj == NULL
|| !PyLong_CheckExact(version_obj)
|| PyLong_AsLong(version_obj) != st->filters_version)
{
if (PyErr_Occurred()) {
return -1;
}
|| PyLong_AsLong(version_obj) != st->filters_version
);
Py_XDECREF(version_obj);
if (should_update_version) {
PyDict_Clear(registry);
version_obj = PyLong_FromLong(st->filters_version);
if (version_obj == NULL)
Expand Down Expand Up @@ -911,13 +915,12 @@ setup_context(Py_ssize_t stack_level,
/* Setup registry. */
assert(globals != NULL);
assert(PyDict_Check(globals));
*registry = _PyDict_GetItemWithError(globals, &_Py_ID(__warningregistry__));
int rc = PyDict_GetItemRef(globals, &_Py_ID(__warningregistry__),
registry);
if (rc < 0) {
goto handle_error;
}
if (*registry == NULL) {
int rc;

if (_PyErr_Occurred(tstate)) {
goto handle_error;
}
*registry = PyDict_New();
if (*registry == NULL)
goto handle_error;
Expand All @@ -926,22 +929,20 @@ setup_context(Py_ssize_t stack_level,
if (rc < 0)
goto handle_error;
}
else
Py_INCREF(*registry);

/* Setup module. */
*module = _PyDict_GetItemWithError(globals, &_Py_ID(__name__));
if (*module == Py_None || (*module != NULL && PyUnicode_Check(*module))) {
Py_INCREF(*module);
}
else if (_PyErr_Occurred(tstate)) {
rc = PyDict_GetItemRef(globals, &_Py_ID(__name__), module);
if (rc < 0) {
goto handle_error;
}
else {
if (rc == 0) {
*module = PyUnicode_FromString("<string>");
if (*module == NULL)
goto handle_error;
}
else {
assert(Py_IsNone(*module) || PyUnicode_Check(*module));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to handle the case where *module (i.e., __name__) is not None and not a unicode object because Python code can set __name__ arbitrarily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4f3cd02 should fix that case:

If None or a unicode object, return successfully; else decref and fall through to the no-key case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, there is no test case for this in the test suite.

}

return 1;

Expand Down Expand Up @@ -1063,12 +1064,12 @@ get_source_line(PyInterpreterState *interp, PyObject *module_globals, int lineno
return NULL;
}

module_name = _PyDict_GetItemWithError(module_globals, &_Py_ID(__name__));
if (!module_name) {
int rc = PyDict_GetItemRef(module_globals, &_Py_ID(__name__),
&module_name);
if (rc < 0 || rc == 0) {
Py_DECREF(loader);
return NULL;
}
Py_INCREF(module_name);

/* Make sure the loader implements the optional get_source() method. */
(void)PyObject_GetOptionalAttr(loader, &_Py_ID(get_source), &get_source);
Expand Down