-
-
Notifications
You must be signed in to change notification settings - Fork 960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a new class called StudyJoblib
#5301
Labels
feature
Change that does not break compatibility, but affects the public interfaces.
Comments
cheginit
added
the
feature
Change that does not break compatibility, but affects the public interfaces.
label
Mar 10, 2024
Never mind, I came up with a more elegant solution for using import contextlib
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, cast
import optuna
from optuna.exceptions import DuplicatedStudyError, ExperimentalWarning
from optuna.pruners import BasePruner, HyperbandPruner
from optuna.samplers import BaseSampler, TPESampler
from optuna.storages import JournalFileStorage, JournalStorage
from optuna.study import MaxTrialsCallback, Study
from optuna.trial import Trial, TrialState
@dataclass
class StudyConfig:
study_name: str
sampler: BaseSampler
pruner: BasePruner
directions: list[Literal["minimize", "maximize"]]
storage: JournalStorage
n_trials: int
n_cores: int = 1
log_path: Path = Path("optuna_journal.log")
study_path: Path = Path("optuna_study.pkl")
@property
def study_args(self) -> dict[str, Any]:
return {
"study_name": self.study_name,
"sampler": self.sampler,
"pruner": self.pruner,
"directions": self.directions,
"storage": self.storage,
}
def objective(trial: Trial) -> float:
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
return x**2 + y
def optimize(study_cfg: StudyConfig, worker_id: int) -> None:
study = optuna.create_study(**study_cfg.study_args, load_if_exists=True)
n_trials = study_cfg.n_trials // study_cfg.n_cores
n_trials += study_cfg.n_cores - (study_cfg.n_trials % study_cfg.n_cores)
study.optimize(
objective,
n_trials=n_trials,
callbacks=[MaxTrialsCallback(study_cfg.n_trials, states=(TrialState.COMPLETE,))],
)
if worker_id == 0:
with study_cfg.study_path.open("wb") as f:
pickle.dump(study, f)
n_trials = 6000
n_cores = 12
log_path = Path("optuna_journal.log")
log_path.unlink(missing_ok=True)
Path(f"{log_path}.lock").unlink(missing_ok=True)
study_path = Path("optuna_study.pkl")
study_path.unlink(missing_ok=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore", ExperimentalWarning)
study_cfg = StudyConfig(
"test",
TPESampler(seed=42),
HyperbandPruner(),
["minimize"],
JournalStorage(JournalFileStorage(str(log_path))),
n_trials,
n_cores,
log_path,
study_path,
)
with contextlib.suppress(DuplicatedStudyError):
_ = optuna.create_study(**study_cfg.study_args)
while study_cfg.n_trials >= min(100, study_cfg.n_trials):
try:
_ = joblib.Parallel(n_jobs=n_cores)(
joblib.delayed(optimize)(study_cfg, i) for i in range(n_cores)
)
except Exception:
Path(f"{log_path}.lock").unlink(missing_ok=True)
study_cfg.n_trials //= 2
else:
break
with study_cfg.study_path.open("rb") as f:
study = cast("Study", pickle.load(f))
best_params = study.best_trial.params
best_params Note that, sometimes, depending on the number of trials, it fails with |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Motivation
The current parallelization with Python's built-in
ThreadPool
is very limited. By adding a new class calledStudyJoblib
and usingStudy
as the base-class, we can just replace the.optimize
and usejoblib
instead. We can then add a new arg tocreate_study
called, for example,use_joblib
that will create aStudyJoblib
instead ofStudy
. My implementation does not addjoblib
as a new dep, rather it's a soft dep that ifuse_joblib=True
, will check if it's installed and throws an exception if not installed.I can submit a PR if interested.
Description
I have already implemented this and tested it for my use-case. It works nicely and runs much faster than the default method. This is what I have:
Alternatives (optional)
No response
Additional context (optional)
No response
The text was updated successfully, but these errors were encountered: