Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Engine plugin API and engine entry point for Lloyd's KMeans #24497

Closed
wants to merge 79 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
e209cb9
wip engines
ogrisel Mar 29, 2022
1647eed
wip
ogrisel Apr 1, 2022
87654c1
fixes
ogrisel Apr 4, 2022
b05d4ab
wip
ogrisel Apr 4, 2022
828f797
more specific assertion
ogrisel Apr 4, 2022
5819524
Add docstring to the config context
ogrisel Apr 5, 2022
5df598c
add default kwarg
ogrisel Apr 6, 2022
44cbd6c
Various fixes
ogrisel Apr 15, 2022
a693d2e
empty doc
ogrisel Apr 19, 2022
e57eae7
Merge branch 'main' into wip-engines
ogrisel May 19, 2022
1dbae5b
WIP
ogrisel May 21, 2022
ec6baa0
Merge branch 'main' into wip-engines
ogrisel Jun 2, 2022
dd586f1
wip
ogrisel Jun 3, 2022
e3c1056
Move tolerance computation to the engine
ogrisel Jun 8, 2022
bd280ef
wip
ogrisel Jun 8, 2022
4913f9b
Merge branch 'main' into wip-engines
ogrisel Sep 22, 2022
3352c58
wip
ogrisel Sep 22, 2022
4753143
wip
ogrisel Sep 22, 2022
2794d26
linting
fcharras Sep 23, 2022
cc36c6e
fix MBKMeans and linting
fcharras Sep 23, 2022
80d4ba4
Merge branch 'main' into wip-engines
ogrisel Sep 23, 2022
f92d63f
Draft changelog entry
ogrisel Sep 23, 2022
dbf607b
doc reorg
ogrisel Sep 23, 2022
abb278a
fix changelog entry to add the pr number
ogrisel Sep 23, 2022
c1e8510
fix test name
ogrisel Sep 23, 2022
9408280
attempt at sphinx the sphinx warning
ogrisel Sep 23, 2022
1fc4d79
fix MBKMeans test
fcharras Sep 26, 2022
acc4b47
fix: skip entry points that do not match the requested provider names…
fcharras Sep 26, 2022
d0d7f95
update config_context unit test to account for new engine_provider ke…
fcharras Sep 26, 2022
dcb3140
for python < 3.10 returns an empty list when the slearn engines entry…
fcharras Sep 26, 2022
8df26c7
Link to user guide from docstring
ogrisel Sep 26, 2022
db7b98c
add a verbosity parameter to get_engine_class
fcharras Sep 26, 2022
4a379d2
add pytest plugin that can be used by engine providers to run sklearn…
fcharras Sep 28, 2022
12fc503
Add plugin methods for predict, transform and score methods for KMeans
fcharras Sep 28, 2022
da5fe85
ad _engine.base module
fcharras Sep 28, 2022
53cb8f1
linting
fcharras Sep 28, 2022
7cb4712
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Sep 28, 2022
d72050f
fix a bug that caused by the pytest plugin where sklearn tests would …
fcharras Sep 28, 2022
a299bd8
__all__fixup and renaming FeatureNotCoveredByPluginError -> NotSuppor…
fcharras Sep 28, 2022
2d08a91
exception __all__ fixup
fcharras Sep 28, 2022
70e4ecf
register the _engine subpackage in setup.py
fcharras Sep 29, 2022
54d1924
add __init__.py
fcharras Sep 29, 2022
1a805d7
fix test_engines test
fcharras Sep 29, 2022
e8c5193
fix _config.py conflicts
fcharras Nov 2, 2022
44cc071
Merge branch 'main' into wip-engines
jjerphan Nov 16, 2022
3b99c29
fixup! Merge branch 'main' into wip-engines
jjerphan Nov 16, 2022
39d10cd
add engine method to get nb of distinct clusters
fcharras Nov 25, 2022
a13efc9
Merge branch 'wip-engines' of https://github.com/ogrisel/scikit-learn…
fcharras Nov 25, 2022
57f095f
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Nov 25, 2022
8b974c7
Apply suggestions from code review
fcharras Nov 25, 2022
5550f40
linting
fcharras Nov 25, 2022
f570f45
fix
fcharras Nov 25, 2022
fdaf97b
fix
fcharras Nov 25, 2022
ab00912
Merge branch 'main' into wip-engines
betatim Dec 6, 2022
b770bea
Add support for trying engines in turn
betatim Dec 6, 2022
76fd41b
Fix formatting
betatim Dec 6, 2022
dd0d0ff
Update mean_variance computation
betatim Dec 6, 2022
10c599c
Move changelog entry to v1.3.rst
ogrisel Dec 13, 2022
ab1e34d
Merge main
ogrisel Dec 13, 2022
ff191e2
Pass sample_weight parameter
betatim Dec 14, 2022
d74d652
Add engine aware mixin to factor out engine stuff
betatim Dec 20, 2022
dfb3fe0
Add attribute conversion decorator
betatim Dec 21, 2022
e531a6d
Tweak attribute conversion
betatim Jan 6, 2023
737c74f
Use Array API to get unique cluster count
betatim Jan 11, 2023
a1ff1ec
Rename cluster counting method
betatim Jan 12, 2023
519ae67
Update conversion related bits
betatim Jan 13, 2023
1a19de0
Fix engine provider at fit time
betatim Jan 16, 2023
96c5f9b
Combine engine selection and validation
betatim Jan 17, 2023
0c29cc2
Rename argument
betatim Jan 17, 2023
c7be19a
Merge pull request #14 from betatim/engines-cluster-count
ogrisel Jan 17, 2023
56bf7b1
Merge branch 'main' into wip-engines
jjerphan Jan 18, 2023
48803d9
Rename loop variable
betatim Jan 18, 2023
6d1b390
Switch back to using `accepts` for engine selection
betatim Jan 20, 2023
157a9c6
Allow "runtime" engines to be passed as well as provider names
betatim Jan 20, 2023
1441088
Update tests
betatim Jan 25, 2023
68a64a8
Update docstring for `accepts`
betatim Jan 25, 2023
d9020e7
Update comment for engine class config
betatim Jan 25, 2023
183068d
Add `engine_name` attribute to ad-hoc engine classes
betatim Jan 26, 2023
39a39ad
Merge pull request #13 from betatim/engine-mixin
betatim Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/modules/engine.rst
@@ -0,0 +1,11 @@
.. Places parent toc into the sidebar
:parenttoc: True

.. _engine:


==================================
Computation Engines (experimental)
==================================


24 changes: 24 additions & 0 deletions sklearn/_config.py
Expand Up @@ -14,6 +14,7 @@
),
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"engine_provider": (),
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -52,6 +53,7 @@ def set_config(
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
engine_provider=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -120,6 +122,15 @@ def set_config(

.. versionadded:: 1.2

engine_provider : str or sequence of str, default=None
Enable computational engine implementation provided by third party
packages to leverage specific hardware platforms using frameworks or
libraries outside of the usual scikit-learn project dependencies.

TODO: add link to doc

.. versionadded:: 1.3

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand All @@ -141,6 +152,8 @@ def set_config(
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
if array_api_dispatch is not None:
local_config["array_api_dispatch"] = array_api_dispatch
if engine_provider is not None:
local_config["engine_provider"] = engine_provider


@contextmanager
Expand All @@ -153,6 +166,7 @@ def config_context(
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
engine_provider=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -220,6 +234,15 @@ def config_context(

.. versionadded:: 1.2

engine_provider : str or sequence of str, default=None
Enable computational engine implementation provided by third party
packages to leverage specific hardware platforms using frameworks or
libraries outside of the usual scikit-learn project dependencies.

TODO: add link to doc

.. versionadded:: 1.3

Yields
------
None.
Expand Down Expand Up @@ -256,6 +279,7 @@ def config_context(
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
engine_provider=engine_provider,
)

try:
Expand Down
111 changes: 111 additions & 0 deletions sklearn/_engine.py
@@ -0,0 +1,111 @@
from importlib.metadata import entry_points
from importlib import import_module
from contextlib import contextmanager
from functools import lru_cache
from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
import warnings

from sklearn._config import get_config

SKLEARN_ENGINES_ENTRY_POINT = "sklearn_engines"


class EngineSpec:

__slots__ = ["name", "provider_name", "module_name", "engine_qualname"]

def __init__(self, name, provider_name, module_name, engine_qualname):
self.name = name
self.provider_name = provider_name
self.module_name = module_name
self.engine_qualname = engine_qualname

def get_engine_class(self):
engine = import_module(self.module_name)
for attr in self.engine_qualname.split("."):
engine = getattr(engine, attr)
return engine


def _parse_entry_point(entry_point):
module_name, engine_qualname = entry_point.value.split(":")
provider_name = next(iter(module_name.split(".", 1)))
return EngineSpec(entry_point.name, provider_name, module_name, engine_qualname)


@lru_cache
def _parse_entry_points(provider_names=None):
specs = []
all_entry_points = entry_points()
if hasattr(all_entry_points, "select"):
engine_entry_points = all_entry_points.select(group=SKLEARN_ENGINES_ENTRY_POINT)
else:
engine_entry_points = all_entry_points[SKLEARN_ENGINES_ENTRY_POINT]
for entry_point in engine_entry_points:
try:
spec = _parse_entry_point(entry_point)
if provider_names is not None and spec.provider_name in provider_names:
# Skip entry points that do not match the requested provider names.
continue
specs.append(spec)
except Exception as e:
# Do not raise an exception in case an invalid package has been
# installed in the same Python env as scikit-learn: just warn and
# skip.
warnings.warn(
f"Invalid {SKLEARN_ENGINES_ENTRY_POINT} entry point"
f" {entry_point.name} with value {entry_point.value}: {e}"
)
if provider_names is not None:
observed_provider_names = {spec.provider_name for spec in specs}
missing_providers = set(provider_names) - observed_provider_names
if missing_providers:
raise RuntimeError(
"Could not find any provider for the"
f" {SKLEARN_ENGINES_ENTRY_POINT} entry point with name(s):"
f" {', '.join(repr(p) for p in sorted(missing_providers))}"
)
return specs


def list_engine_provider_names():
"""Find the list of sklearn_engine provider names

This function only inspects the metadata and should trigger any module import.
"""
return sorted({spec.provider_name for spec in _parse_entry_points()})


def _get_engine_class(engine_name, provider_names, engine_specs, default=None):
specs_by_provider = {}
for spec in engine_specs:
if spec.name != engine_name:
continue
specs_by_provider.setdefault(spec.provider_name, spec)

for provider_name in provider_names:
spec = specs_by_provider.get(provider_name)
if spec is not None:
# XXX: should we return an instance or the class itself?
return spec.get_engine_class()

return default


def get_engine_class(engine_name, default=None):
provider_names = get_config()["engine_provider"]
if isinstance(provider_names, str):
provider_names = (provider_names,)
elif not isinstance(provider_names, tuple):
# Make sure the provider names are a tuple to make it possible for the
# lru cache to hash them.
provider_names = tuple(provider_names)
if not provider_names:
return default
engine_specs = _parse_entry_points(provider_names=provider_names)
return _get_engine_class(
engine_name=engine_name,
provider_names=provider_names,
engine_specs=engine_specs,
default=default,
)