Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable config to be used as context manager #7363

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 36 additions & 0 deletions networkx/utils/configs.py
Expand Up @@ -38,6 +38,14 @@ class Config:
>>> cfg["spam"]
42

For convenience, it can also set configs within a context with the "with" statement:

>>> with cfg(spam=3):
... print("spam (in context):", cfg.spam)
spam (in context): 3
>>> print("spam (after context):", cfg.spam)
spam (after context): 42

Subclasses may also define ``_check_config`` (as done in the example above)
to ensure the value being assigned is valid:

Expand Down Expand Up @@ -79,6 +87,8 @@ def __new__(cls, **kwargs):
if not cls._strict:
cls.__repr__ = _flexible_repr
cls._orig_class = orig_class # Save original class so we can pickle
cls._prev = None # Stage previous configs to enable use as context manager
cls._context_stack = [] # Stack of previous configs when used as context
instance = object.__new__(cls)
instance.__init__(**kwargs)
return instance
Expand All @@ -95,13 +105,15 @@ def __setattr__(self, key, value):
raise AttributeError(f"Invalid config name: {key!r}")
self._check_config(key, value)
object.__setattr__(self, key, value)
self.__class__._prev = None

def __delattr__(self, key):
if self._strict:
raise TypeError(
f"Configuration items can't be deleted (can't delete {key!r})."
)
object.__delattr__(self, key)
self.__class__._prev = None

# Be a `collection.abc.Collection`
def __contains__(self, key):
Expand Down Expand Up @@ -161,6 +173,30 @@ def __reduce__(self):
def _deserialize(cls, kwargs):
return cls(**kwargs)

# Allow to be used as context manager
def __call__(self, **kwargs):
for key, val in kwargs.items():
self._check_config(key, val)
prev = dict(self)
for key, val in kwargs.items():
setattr(self, key, val)
self.__class__._prev = prev
return self

def __enter__(self):
self.__class__._context_stack.append(self.__class__._prev)
self.__class__._prev = None
return self

def __exit__(self, exc_type, exc_value, traceback):
prev = self.__class__._context_stack.pop()
if not prev:
# Be defensive. This branch may occur from `with cfg:` (forgot to call)
self.__class__._prev = None
return
for key, val in prev.items():
setattr(self, key, val)


def _flexible_repr(self):
return (
Expand Down
38 changes: 38 additions & 0 deletions networkx/utils/tests/test_config.py
Expand Up @@ -178,3 +178,41 @@ class FlexibleConfigWithDefault(Config, strict=False):

assert FlexibleConfigWithDefault().x == 0
assert FlexibleConfigWithDefault(x=1)["x"] == 1


def test_context():
cfg = Config(x=1)
with cfg(x=2) as c:
assert c.x == 2
c.x = 3
assert cfg.x == 3
assert cfg.x == 1

with cfg(x=2) as c:
assert c == cfg
assert cfg.x == 2
with cfg(x=3) as c2:
assert c2 == cfg
assert cfg.x == 3
with cfg as c3: # Forgot to call `cfg(...)`
assert c3 == cfg
assert cfg.x == 3
assert cfg.x == 3
assert cfg.x == 2
assert cfg.x == 1

c = cfg(x=4) # Not yet as context (not recommended, but possible)
assert c == cfg
assert cfg.x == 4
# Cheat by looking at internal data; context stack should only grow with __enter__
assert cfg._prev is not None
assert cfg._context_stack == []
with c:
assert c == cfg
assert cfg.x == 4
assert cfg.x == 1
# Cheat again; there was no preceding `cfg(...)` call this time
assert cfg._prev is None
with cfg:
assert cfg.x == 1
assert cfg.x == 1