Skip to content

Commit

Permalink
Add nx.config dict for configuring dispatching and backends (#7225)
Browse files Browse the repository at this point in the history
* Add `nx.backend_config` dict for configuring dispatching and backends

* Rename `nx.backend_config` to `nx.config`

* "fallback_to_nx" is for testing, not for user config

* Move config of backends to e.g. `nx.config["backends"]["cugraph"]`

* How do you like this mypy?!

* Rename `automatic_backends` to `backend_priority` (and env variables)

* Create a class to handle configuration

* Oops thanks mypy

* Fix to work with more strict config

* Support (and test) default values

* Remove `__class_getitem__` and add docstring

* Allow `strict=False` when defining subclasses.

This allows configs to be added and deleted.

* Move `__init_subclass__`
  • Loading branch information
eriknw committed Mar 15, 2024
1 parent 929b5ad commit f336cf2
Show file tree
Hide file tree
Showing 7 changed files with 452 additions and 23 deletions.
2 changes: 1 addition & 1 deletion networkx/__init__.py
Expand Up @@ -17,7 +17,7 @@
from networkx.exception import *

from networkx import utils
from networkx.utils.backends import _dispatchable
from networkx.utils.backends import _dispatchable, config

from networkx import classes
from networkx.classes import filters
Expand Down
2 changes: 1 addition & 1 deletion networkx/algorithms/operators/tests/test_binary.py
Expand Up @@ -53,7 +53,7 @@ def test_intersection():
assert set(I2.nodes()) == {1, 2, 3, 4}
assert sorted(I2.edges()) == [(2, 3)]
# Only test if not performing auto convert testing of backend implementations
if not nx.utils.backends._dispatchable._automatic_backends:
if not nx.config["backend_priority"]:
with pytest.raises(TypeError):
nx.intersection(G2, H)
with pytest.raises(TypeError):
Expand Down
4 changes: 2 additions & 2 deletions networkx/classes/tests/test_backends.py
Expand Up @@ -31,8 +31,8 @@ def test_pickle():


@pytest.mark.skipif(
"not nx._dispatchable._automatic_backends "
"or nx._dispatchable._automatic_backends[0] != 'nx-loopback'"
"not nx.config['backend_priority'] "
"or nx.config['backend_priority'][0] != 'nx-loopback'"
)
def test_graph_converter_needs_backend():
# When testing, `nx.from_scipy_sparse_array` will *always* call the backend
Expand Down
16 changes: 8 additions & 8 deletions networkx/conftest.py
Expand Up @@ -45,12 +45,6 @@ def pytest_configure(config):
backend = config.getoption("--backend")
if backend is None:
backend = os.environ.get("NETWORKX_TEST_BACKEND")
if backend:
networkx.utils.backends._dispatchable._automatic_backends = [backend]
fallback_to_nx = config.getoption("--fallback-to-nx")
if not fallback_to_nx:
fallback_to_nx = os.environ.get("NETWORKX_FALLBACK_TO_NX")
networkx.utils.backends._dispatchable._fallback_to_nx = bool(fallback_to_nx)
# nx-loopback backend is only available when testing
backends = entry_points(name="nx-loopback", group="networkx.backends")
if backends:
Expand All @@ -64,16 +58,22 @@ def pytest_configure(config):
" Try `pip install -e .`, or change your PYTHONPATH\n"
" Make sure python finds the networkx repo you are testing\n\n"
)
if backend:
networkx.config["backend_priority"] = [backend]
fallback_to_nx = config.getoption("--fallback-to-nx")
if not fallback_to_nx:
fallback_to_nx = os.environ.get("NETWORKX_FALLBACK_TO_NX")
networkx.utils.backends._dispatchable._fallback_to_nx = bool(fallback_to_nx)


def pytest_collection_modifyitems(config, items):
# Setting this to True here allows tests to be set up before dispatching
# any function call to a backend.
networkx.utils.backends._dispatchable._is_testing = True
if automatic_backends := networkx.utils.backends._dispatchable._automatic_backends:
if backend_priority := networkx.config["backend_priority"]:
# Allow pluggable backends to add markers to tests (such as skip or xfail)
# when running in auto-conversion test mode
backend = networkx.utils.backends.backends[automatic_backends[0]].load()
backend = networkx.utils.backends.backends[backend_priority[0]].load()
if hasattr(backend, "on_start_tests"):
getattr(backend, "on_start_tests")(items)

Expand Down
43 changes: 32 additions & 11 deletions networkx/utils/backends.py
Expand Up @@ -106,7 +106,7 @@ class WrappedSparse:
from ..exception import NetworkXNotImplemented
from .decorators import argmap

__all__ = ["_dispatchable"]
__all__ = ["_dispatchable", "config"]


def _do_nothing():
Expand Down Expand Up @@ -142,6 +142,31 @@ def _get_backends(group, *, load_and_call=False):
backends = _get_backends("networkx.backends")
backend_info = _get_backends("networkx.backend_info", load_and_call=True)

# We must import from config after defining `backends` above
from .configs import Config, config

# Get default configuration from environment variables at import time
config.backend_priority = [
x.strip()
for x in os.environ.get(
"NETWORKX_BACKEND_PRIORITY",
os.environ.get("NETWORKX_AUTOMATIC_BACKENDS", ""),
).split(",")
if x.strip()
]
# Initialize default configuration for backends
config.backends = Config(
**{
backend: (
cfg if isinstance(cfg := info["default_config"], Config) else Config(**cfg)
)
if "default_config" in info
else Config()
for backend, info in backend_info.items()
}
)
type(config.backends).__doc__ = "All installed NetworkX backends and their configs."

# Load and cache backends on-demand
_loaded_backends = {} # type: ignore[var-annotated]

Expand Down Expand Up @@ -180,11 +205,6 @@ class _dispatchable:
_fallback_to_nx = (
os.environ.get("NETWORKX_FALLBACK_TO_NX", "true").strip().lower() == "true"
)
_automatic_backends = [
x.strip()
for x in os.environ.get("NETWORKX_AUTOMATIC_BACKENDS", "").split(",")
if x.strip()
]

def __new__(
cls,
Expand Down Expand Up @@ -532,11 +552,12 @@ def __call__(self, /, *args, backend=None, **kwargs):
for g in graphs_resolved.values()
}

if self._is_testing and self._automatic_backends and backend_name is None:
backend_priority = config.backend_priority
if self._is_testing and backend_priority and backend_name is None:
# Special path if we are running networkx tests with a backend.
# This even runs for (and handles) functions that mutate input graphs.
return self._convert_and_call_for_tests(
self._automatic_backends[0],
backend_priority[0],
args,
kwargs,
fallback_to_nx=self._fallback_to_nx,
Expand All @@ -563,7 +584,7 @@ def __call__(self, /, *args, backend=None, **kwargs):
raise ImportError(f"Unable to load backend: {graph_backend_name}")
if (
"networkx" in graph_backend_names
and graph_backend_name not in self._automatic_backends
and graph_backend_name not in backend_priority
):
# Not configured to convert networkx graphs to this backend
raise TypeError(
Expand All @@ -584,7 +605,7 @@ def __call__(self, /, *args, backend=None, **kwargs):
)
# All graphs are backend graphs--no need to convert!
return getattr(backend, self.name)(*args, **kwargs)
# Future work: try to convert and run with other backends in self._automatic_backends
# Future work: try to convert and run with other backends in backend_priority
raise NetworkXNotImplemented(
f"'{self.name}' not implemented by {graph_backend_name}"
)
Expand Down Expand Up @@ -622,7 +643,7 @@ def __call__(self, /, *args, backend=None, **kwargs):
)
):
# Should we warn or log if we don't convert b/c the input will be mutated?
for backend_name in self._automatic_backends:
for backend_name in backend_priority:
if self._should_backend_run(backend_name, *args, **kwargs):
return self._convert_and_call(
backend_name,
Expand Down
228 changes: 228 additions & 0 deletions networkx/utils/configs.py
@@ -0,0 +1,228 @@
import collections
import typing
from dataclasses import dataclass

__all__ = ["Config", "config"]


@dataclass(init=False, eq=False, slots=True, kw_only=True, match_args=False)
class Config:
"""The base class for NetworkX configuration.
There are two ways to use this to create configurations. The first is to
simply pass the initial configuration as keyword arguments to ``Config``:
>>> cfg = Config(eggs=1, spam=5)
>>> cfg
Config(eggs=1, spam=5)
The second--and preferred--way is to subclass ``Config`` with docs and annotations.
>>> class MyConfig(Config):
... '''Breakfast!'''
...
... eggs: int
... spam: int
...
... def _check_config(self, key, value):
... assert isinstance(value, int) and value >= 0
>>> cfg = MyConfig(eggs=1, spam=5)
Once defined, config items may be modified, but can't be added or deleted by default.
``Config`` is a ``Mapping``, and can get and set configs via attributes or brackets:
>>> cfg.eggs = 2
>>> cfg.eggs
2
>>> cfg["spam"] = 42
>>> cfg["spam"]
42
Subclasses may also define ``_check_config`` (as done in the example above)
to ensure the value being assigned is valid:
>>> cfg.spam = -1
Traceback (most recent call last):
...
AssertionError
If a more flexible configuration object is needed that allows adding and deleting
configurations, then pass ``strict=False`` when defining the subclass:
>>> class FlexibleConfig(Config, strict=False):
... default_greeting: str = "Hello"
>>> flexcfg = FlexibleConfig()
>>> flexcfg.name = "Mr. Anderson"
>>> flexcfg
FlexibleConfig(default_greeting='Hello', name='Mr. Anderson')
"""

def __init_subclass__(cls, strict=True):
cls._strict = strict

def __new__(cls, **kwargs):
orig_class = cls
if cls is Config:
# Enable the "simple" case of accepting config definition as keywords
cls = type(
cls.__name__,
(cls,),
{"__annotations__": {key: typing.Any for key in kwargs}},
)
cls = dataclass(
eq=False,
repr=cls._strict,
slots=cls._strict,
kw_only=True,
match_args=False,
)(cls)
if not cls._strict:
cls.__repr__ = _flexible_repr
cls._orig_class = orig_class # Save original class so we can pickle
instance = object.__new__(cls)
instance.__init__(**kwargs)
return instance

def _check_config(self, key, value):
"""Check whether config value is valid. This is useful for subclasses."""

# Control behavior of attributes
def __dir__(self):
return self.__dataclass_fields__.keys()

def __setattr__(self, key, value):
if self._strict and key not in self.__dataclass_fields__:
raise AttributeError(f"Invalid config name: {key!r}")
self._check_config(key, value)
object.__setattr__(self, key, value)

def __delattr__(self, key):
if self._strict:
raise TypeError(
f"Configuration items can't be deleted (can't delete {key!r})."
)
object.__delattr__(self, key)

# Be a `collection.abc.Collection`
def __contains__(self, key):
return (
key in self.__dataclass_fields__ if self._strict else key in self.__dict__
)

def __iter__(self):
return iter(self.__dataclass_fields__ if self._strict else self.__dict__)

def __len__(self):
return len(self.__dataclass_fields__ if self._strict else self.__dict__)

def __reversed__(self):
return reversed(self.__dataclass_fields__ if self._strict else self.__dict__)

# Add dunder methods for `collections.abc.Mapping`
def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError as err:
raise KeyError(*err.args) from None

def __setitem__(self, key, value):
try:
setattr(self, key, value)
except AttributeError as err:
raise KeyError(*err.args) from None

__delitem__ = __delattr__
_ipython_key_completions_ = __dir__ # config["<TAB>

# Go ahead and make it a `collections.abc.Mapping`
def get(self, key, default=None):
return getattr(self, key, default)

def items(self):
return collections.abc.ItemsView(self)

def keys(self):
return collections.abc.KeysView(self)

def values(self):
return collections.abc.ValuesView(self)

# dataclass can define __eq__ for us, but do it here so it works after pickling
def __eq__(self, other):
if not isinstance(other, Config):
return NotImplemented
return self._orig_class == other._orig_class and self.items() == other.items()

# Make pickle work
def __reduce__(self):
return self._deserialize, (self._orig_class, dict(self))

@staticmethod
def _deserialize(cls, kwargs):
return cls(**kwargs)


def _flexible_repr(self):
return (
f"{self.__class__.__qualname__}("
+ ", ".join(f"{key}={val!r}" for key, val in self.__dict__.items())
+ ")"
)


# Register, b/c `Mapping.__subclasshook__` returns `NotImplemented`
collections.abc.Mapping.register(Config)


class NetworkXConfig(Config):
"""Configuration for NetworkX that controls behaviors such as how to use backends.
Attribute and bracket notation are supported for getting and setting configurations:
>>> nx.config.backend_priority == nx.config["backend_priority"]
True
Config Parameters
-----------------
backend_priority : list of backend names
Enable automatic conversion of graphs to backend graphs for algorithms
implemented by the backend. Priority is given to backends listed earlier.
backends : Config mapping of backend names to backend Config
The keys of the Config mapping are names of all installed NetworkX backends,
and the values are their configurations as Config mappings.
"""

backend_priority: list[str]
backends: Config

def _check_config(self, key, value):
from .backends import backends

if key == "backend_priority":
if not (isinstance(value, list) and all(isinstance(x, str) for x in value)):
raise TypeError(
f"{key!r} config must be a list of backend names; got {value!r}"
)
if missing := {x for x in value if x not in backends}:
missing = ", ".join(map(repr, sorted(missing)))
raise ValueError(f"Unknown backend when setting {key!r}: {missing}")
elif key == "backends":
if not (
isinstance(value, Config)
and all(isinstance(key, str) for key in value)
and all(isinstance(val, Config) for val in value.values())
):
raise TypeError(
f"{key!r} config must be a Config of backend configs; got {value!r}"
)
if missing := {x for x in value if x not in backends}:
missing = ", ".join(map(repr, sorted(missing)))
raise ValueError(f"Unknown backend when setting {key!r}: {missing}")


# Backend configuration will be updated in backends.py
config = NetworkXConfig(
backend_priority=[],
backends=Config(),
)

0 comments on commit f336cf2

Please sign in to comment.