From 56a2ac1d69c7c61a847c678879a67f5d3672b3e8 Mon Sep 17 00:00:00 2001 From: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Date: Tue, 23 Aug 2022 18:43:15 +0200 Subject: [PATCH] [RELEASE] v0.2.1 (#475) * [FIX] Documentation and docker workflow file (#449) * fixes to documentation and docker * fix to docker * Apply suggestions from code review * add change log for release (#450) * [FIX] release docs (#452) * Release 0.2 * Release 0.2.0 * fix docs new line * [FIX] ADD forecasting init design to pip data files (#459) * add forecasting_init.json to data files under setup * avoid undefined reference in scale_value * checks for time series dataset split (#464) * checks for time series dataset split * maint * Update autoPyTorch/datasets/time_series_dataset.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * [FIX] Numerical stability scaling for timeseries forecasting tasks (#467) * resolve rebase conflict * add checks for scaling factors * flake8 fix * resolve conflict * [FIX] pipeline options in `fit_pipeline` (#466) * fix update of pipeline config options in fit pipeline * fix flake and test * suggestions from review * [FIX] results management and visualisation with missing test data (#465) * add flexibility to avoid checking for test scores * fix flake and test * fix bug in tests * suggestions from review * [ADD] Robustly refit models in final ensemble in parallel (#471) * add parallel model runner and update running traditional classifiers * update pipeline config to pipeline options * working refit function * fix mypy and flake * suggestions from review * fix mypy and flake * suggestions from review * finish documentation * fix tests * add test for parallel model runner * fix flake * fix tests * fix traditional prediction for refit * suggestions from review * add warning for failed processing of results * remove unnecessary change * update autopytorch version number * update autopytorch version number and the example file * [DOCS] Release notes v0.2.1 (#476) * Release 0.2.1 * add release docs * Update docs/releases.rst Co-authored-by: Difan Deng <33290713+dengdifan@users.noreply.github.com> --- autoPyTorch/__version__.py | 2 +- autoPyTorch/api/base_task.py | 421 +++-- autoPyTorch/constants.py | 7 + autoPyTorch/datasets/time_series_dataset.py | 11 + autoPyTorch/ensemble/abstract_ensemble.py | 12 + autoPyTorch/evaluation/abstract_evaluator.py | 28 +- autoPyTorch/evaluation/tae.py | 18 +- autoPyTorch/evaluation/test_evaluator.py | 10 +- ...time_series_forecasting_train_evaluator.py | 12 +- autoPyTorch/evaluation/train_evaluator.py | 12 +- autoPyTorch/optimizer/smbo.py | 8 +- .../scaling/utils.py | 18 +- .../setup/forecasting_target_scaling/utils.py | 20 +- .../setup/network/forecasting_architecture.py | 4 +- .../setup/traditional_ml/base_model.py | 11 +- .../base_traditional_learner.py | 27 +- .../traditional_learner/learners.py | 46 +- autoPyTorch/utils/parallel_model_runner.py | 300 ++++ autoPyTorch/utils/results_manager.py | 61 +- autoPyTorch/utils/results_visualizer.py | 9 + docs/releases.rst | 16 + .../example_tabular_classification.py | 40 +- .../20_basics/example_tabular_regression.py | 39 +- setup.py | 5 +- test/test_api/test_api.py | 128 +- test/test_api/test_base_api.py | 18 +- test/test_api/utils.py | 13 +- .../test_time_series_datasets.py | 18 +- .../test_abstract_evaluator.py | 2 +- test/test_evaluation/test_evaluators.py | 12 +- .../test_forecasting_evaluators.py | 15 +- .../preprocessing/forecasting/test_scaling.py | 20 + .../test_forecasting_target_scaling.py | 20 + test/test_utils/runhistory_no_test.json | 1582 +++++++++++++++++ test/test_utils/test_parallel_model_runner.py | 66 + test/test_utils/test_results_manager.py | 53 + test/test_utils/test_results_visualizer.py | 36 + 37 files changed, 2814 insertions(+), 306 deletions(-) create mode 100644 autoPyTorch/utils/parallel_model_runner.py create mode 100644 test/test_utils/runhistory_no_test.json create mode 100644 test/test_utils/test_parallel_model_runner.py diff --git a/autoPyTorch/__version__.py b/autoPyTorch/__version__.py index 94b9a71f5..36509b4a7 100644 --- a/autoPyTorch/__version__.py +++ b/autoPyTorch/__version__.py @@ -1,4 +1,4 @@ """Version information.""" # The following line *must* be the last in the module, exactly as formatted: -__version__ = "0.2" +__version__ = "0.2.1" diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c5468eae7..167264306 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -52,7 +52,6 @@ ) from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager from autoPyTorch.ensemble.singlebest_ensemble import SingleBest -from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash from autoPyTorch.evaluation.utils import DisableFileOutputParameters from autoPyTorch.optimizer.smbo import AutoMLSMBO @@ -69,6 +68,7 @@ start_log_server, ) from autoPyTorch.utils.parallel import preload_modules +from autoPyTorch.utils.parallel_model_runner import run_models_on_dataset from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements from autoPyTorch.utils.results_manager import MetricResults, ResultsManager, SearchResults from autoPyTorch.utils.results_visualizer import ColorLabelSettings, PlotSettingParams, ResultsVisualizer @@ -443,14 +443,14 @@ def ensemble_performance_history(self) -> List[Dict[str, Any]]: def trajectory(self) -> Optional[List]: return self._results_manager.trajectory - def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None: + def set_pipeline_options(self, **pipeline_options_kwargs: Any) -> None: """ Check whether arguments are valid and then sets them to the current pipeline configuration. Args: - **pipeline_config_kwargs: Valid config options include "num_run", + **pipeline_options_kwargs: Valid config options include "num_run", "device", "budget_type", "epochs", "runtime", "torch_num_threads", "early_stopping", "use_tensorboard_logger", "metrics_during_training" @@ -459,7 +459,7 @@ def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None: None """ unknown_keys = [] - for option, value in pipeline_config_kwargs.items(): + for option, value in pipeline_options_kwargs.items(): if option in self.pipeline_options.keys(): pass else: @@ -470,7 +470,7 @@ def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None: " expected arguments to be in {}". format(unknown_keys, self.pipeline_options.keys())) - self.pipeline_options.update(pipeline_config_kwargs) + self.pipeline_options.update(pipeline_options_kwargs) def get_pipeline_options(self) -> dict: """ @@ -634,7 +634,9 @@ def _close_dask_client(self) -> None: self._is_dask_client_internally_created = False del self._is_dask_client_internally_created - def _load_models(self) -> bool: + def _load_models( + self, + ) -> bool: """ Loads the models saved in the temporary directory @@ -645,6 +647,7 @@ def _load_models(self) -> bool: """ if self.resampling_strategy is None: raise ValueError("Resampling strategy is needed to determine what models to load") + self.ensemble_ = self._backend.load_ensemble(self.seed) # If no ensemble is loaded, try to get the best performing model @@ -799,113 +802,37 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: assert self._dask_client is not None self._logger.info("Starting to create traditional classifier predictions.") - starttime = time.time() # Initialise run history for the traditional classifiers - run_history = RunHistory() memory_limit = self._memory_limit if memory_limit is not None: memory_limit = int(math.ceil(memory_limit)) available_classifiers = get_available_traditional_learners() - dask_futures = [] - - total_number_classifiers = len(available_classifiers) - for n_r, classifier in enumerate(available_classifiers): - - # Only launch a task if there is time - start_time = time.time() - if time_left >= func_eval_time_limit_secs: - self._logger.info(f"{n_r}: Started fitting {classifier} with cutoff={func_eval_time_limit_secs}") - scenario_mock = unittest.mock.Mock() - scenario_mock.wallclock_limit = time_left - # This stats object is a hack - maybe the SMAC stats object should - # already be generated here! - stats = Stats(scenario_mock) - stats.start_timing() - ta = ExecuteTaFuncWithQueue( - pynisher_context=self._multiprocessing_context, - backend=self._backend, - seed=self.seed, - multi_objectives=["cost"], - metric=self._metric, - logger_port=self._logger_port, - cost_for_crash=get_cost_of_crash(self._metric), - abort_on_first_run_crash=False, - initial_num_run=self._backend.get_next_num_run(), - stats=stats, - memory_limit=memory_limit, - disable_file_output=self._disable_file_output, - all_supported_metrics=self._all_supported_metrics, - ) - dask_futures.append([ - classifier, - self._dask_client.submit( - ta.run, config=classifier, - cutoff=func_eval_time_limit_secs, - ) - ]) - - # When managing time, we need to take into account the allocated time resources, - # which are dependent on the number of cores. 'dask_futures' is a proxy to the number - # of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most - # 4 task in parallel. Every 'cutoff' seconds, we generate up to 4 tasks. - # If we only have 4 workers and there are 4 futures in dask_futures, it means that every - # worker has a task. We would not like to launch another job until a worker is available. To this - # end, the following if-statement queries the number of active jobs, and forces to wait for a job - # completion via future.result(), so that a new worker is available for the next iteration. - if len(dask_futures) >= self.n_jobs: - - # How many workers to wait before starting fitting the next iteration - workers_to_wait = 1 - if n_r >= total_number_classifiers - 1 or time_left <= func_eval_time_limit_secs: - # If on the last iteration, flush out all tasks - workers_to_wait = len(dask_futures) - - while workers_to_wait >= 1: - workers_to_wait -= 1 - # We launch dask jobs only when there are resources available. - # This allow us to control time allocation properly, and early terminate - # the traditional machine learning pipeline - cls, future = dask_futures.pop(0) - status, cost, runtime, additional_info = future.result() - if status == StatusType.SUCCESS: - self._logger.info( - "Fitting {} took {} [sec] and got performance: {}.\n" - "additional info:\n{}".format(cls, runtime, cost, dict_repr(additional_info)) - ) - configuration = additional_info['pipeline_configuration'] - origin = additional_info['configuration_origin'] - additional_info.pop('pipeline_configuration') - run_history.add(config=configuration, cost=cost, - time=runtime, status=status, seed=self.seed, - starttime=starttime, endtime=starttime + runtime, - origin=origin, additional_info=additional_info) - else: - if additional_info.get('exitcode') == -6: - self._logger.error( - "Traditional prediction for {} failed with run state {},\n" - "because the provided memory limits were too tight.\n" - "Please increase the 'ml_memory_limit' and try again.\n" - "If you still get the problem, please open an issue\n" - "and paste the additional info.\n" - "Additional info:\n{}".format(cls, str(status), dict_repr(additional_info)) - ) - else: - self._logger.error( - "Traditional prediction for {} failed with run state {}.\nAdditional info:\n{}".format( - cls, str(status), dict_repr(additional_info) - ) - ) - - # In the case of a serial execution, calling submit halts the run for a resource - # dynamically adjust time in this case - time_left -= int(time.time() - start_time) - - # Exit if no more time is available for a new classifier - if time_left < func_eval_time_limit_secs: - self._logger.warning("Not enough time to fit all traditional machine learning models." - "Please consider increasing the run time to further improve performance.") - break + model_configs = [(classifier, self.pipeline_options[self.pipeline_options['budget_type']]) + for classifier in available_classifiers] + + run_history = run_models_on_dataset( + time_left=time_left, + func_eval_time_limit_secs=func_eval_time_limit_secs, + model_configs=model_configs, + logger=self._logger, + logger_port=self._logger_port, + metric=self._metric, + dask_client=self._dask_client, + backend=self._backend, + memory_limit=memory_limit, + disable_file_output=self._disable_file_output, + all_supported_metrics=self._all_supported_metrics, + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates, + pipeline_options=self.pipeline_options, + seed=self.seed, + multiprocessing_context=self._multiprocessing_context, + n_jobs=self.n_jobs, + current_search_space=self.search_space, + initial_num_run=self._backend.get_next_num_run() + ) self._logger.debug("Run history traditional: {}".format(run_history)) # add run history of traditional to api run history @@ -1272,7 +1199,7 @@ def _search( all_supported_metrics=self._all_supported_metrics, smac_scenario_args=smac_scenario_args, get_smac_object_callback=get_smac_object_callback, - pipeline_config=self.pipeline_options, + pipeline_options=self.pipeline_options, min_budget=min_budget, max_budget=max_budget, ensemble_callback=proc_ensemble, @@ -1378,38 +1305,144 @@ def _get_fit_dictionary( def refit( self, - dataset: BaseDataset, - split_id: int = 0 + dataset: Optional[BaseDataset] = None, + X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + dataset_name: Optional[str] = None, + resampling_strategy: ResamplingStrategies = NoResamplingStrategyTypes.no_resampling, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + total_walltime_limit: int = 120, + run_time_limit_secs: int = 60, + memory_limit: Optional[int] = None, + eval_metric: Optional[str] = None, + all_supported_metrics: bool = False, + budget_type: Optional[str] = None, + budget: Optional[float] = None, + pipeline_options: Optional[Dict] = None, + disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, ) -> "BaseTask": """ - Refit all models found with fit to new data. + Fit all the models found in the ensemble on the whole training set X_train. + Therefore, we recommend using `NoResamplingStrategy` to be able to do that. Nevertheless, it + is still able to fit using other splitting techniques such as hold out or cross validation. - Necessary when using cross-validation. During training, autoPyTorch - fits each model k times on the dataset, but does not keep any trained - model and can therefore not be used to predict for new data points. - This methods fits all models found during a call to fit on the data - given. This method may also be used together with holdout to avoid - only using 66% of the training data to fit the final model. - - Refit uses the estimator pipeline_config attribute, which the user - can interact via the get_pipeline_config()/set_pipeline_config() + Refit uses the estimator pipeline_options attribute, which the user + can interact via the get_pipeline_options()/set_pipeline_options() methods. Args: - dataset (Dataset): - The argument that will provide the dataset splits. It can either - be a dictionary with the splits, or the dataset object which can - generate the splits based on different restrictions. - split_id (int): - split id to fit on. + dataset (BaseDataset): + An object of the appropriate child class of `BaseDataset`, + that will be used to fit the pipeline + X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame] + A pair of features (X_train) and targets (y_train) used to fit a + pipeline. Additionally, a holdout of this pairs (X_test, y_test) can + be provided to track the generalization performance of each stage. + dataset_name (Optional[str]): + Name of the dataset, if None, random value is used. + resampling_strategy (ResamplingStrategies): + Strategy to split the training data. Defaults to + NoResamplingStrategyTypes.no_resampling. + resampling_strategy_args (Optional[Dict[str, Any]]): + Arguments required for the chosen resampling strategy. If None, uses + the default values provided in DEFAULT_RESAMPLING_PARAMETERS + in ```datasets/resampling_strategy.py```. + dataset_name (Optional[str]): + name of the dataset, used as experiment name. + total_walltime_limit (int): + Total time that can be used by all the models to be refitted. Defaults to 120. + run_time_limit_secs (int: default=60): + Time limit for a single call to the machine learning model. + Model fitting will be terminated if the machine learning algorithm + runs over the time limit. Set this value high enough so that + typical machine learning algorithms can be fit on the training + data. + memory_limit (Optional[int]): + Memory limit in MB for the machine learning algorithm. autopytorch + will stop fitting the machine learning algorithm if it tries + to allocate more than memory_limit MB. If None is provided, + no memory limit is set. In case of multi-processing, memory_limit + will be per job. This memory limit also applies to the ensemble + creation process. + eval_metric (Optional[str]): + Name of the metric that is used to evaluate a pipeline. + all_supported_metrics (bool: default=True): + if True, all metrics supporting current task will be calculated + for each pipeline and results will be available via cv_results + budget_type (str): + Type of budget to be used when fitting the pipeline. + It can be one of: + + + `epochs`: The training of each pipeline will be terminated after + a number of epochs have passed. This number of epochs is determined by the + budget argument of this method. + + `runtime`: The training of each pipeline will be terminated after + a number of seconds have passed. This number of seconds is determined by the + budget argument of this method. The overall fitting time of a pipeline is + controlled by func_eval_time_limit_secs. 'runtime' only controls the allocated + time to train a pipeline, but it does not consider the overall time it takes + to create a pipeline (data loading and preprocessing, other i/o operations, etc.). + budget (Optional[float]): + Budget to fit a single run of the pipeline. If not + provided, uses the default in the pipeline config + pipeline_options (Optional[Dict]): + Valid config options include "device", + "torch_num_threads", "early_stopping", "use_tensorboard_logger", + "metrics_during_training" + disable_file_output (Optional[List[Union[str, DisableFileOutputParameters]]]): + Used as a list to pass more fine-grained + information on what to save. Must be a member of `DisableFileOutputParameters`. + Allowed elements in the list are: + + + `y_optimization`: + do not save the predictions for the optimization set, + which would later on be used to build an ensemble. Note that SMAC + optimizes a metric evaluated on the optimization set. + + `pipeline`: + do not save any individual pipeline files + + `pipelines`: + In case of cross validation, disables saving the joint model of the + pipelines fit on each fold. + + `y_test`: + do not save the predictions for the test set. + + `all`: + do not save any of the above. + For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`. + Returns: self """ + if dataset is None: + if ( + X_train is None + and y_train is None + ): + raise ValueError("No dataset provided, must provide X_train, y_train tensors") + dataset = self.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=resampling_strategy, + resampling_strategy_args=resampling_strategy_args, + dataset_name=dataset_name + ) + self.dataset_name = dataset.dataset_name - if self._logger is None: - self._logger = self._get_logger(str(self.dataset_name)) + # Used when loading models + self.resampling_strategy = resampling_strategy + + self._logger = self._get_logger("RefitLogger") + + self._logger.debug("Starting refit") + + if self.n_jobs == 1: + self._dask_client = SingleThreadedClient() + else: + self._create_dask_client() dataset_requirements = get_dataset_requirements( info=dataset.get_required_dataset_info(), @@ -1417,29 +1450,102 @@ def refit( exclude=self.exclude_components, search_space_updates=self.search_space_updates) dataset_properties = dataset.get_dataset_properties(dataset_requirements) + self._backend.save_datamanager(dataset) + scenario_mock = unittest.mock.Mock() + scenario_mock.wallclock_limit = run_time_limit_secs + # This stats object is a hack - maybe the SMAC stats object should + # already be generated here! + stats = Stats(scenario_mock) + + if memory_limit is None and getattr(self, '_memory_limit', None) is not None: + memory_limit = self._memory_limit + + metric = get_metrics(dataset_properties=dataset_properties, + names=[eval_metric] if eval_metric is not None else None, + all_supported_metrics=False).pop() + + pipeline_options = self.pipeline_options.copy().update(pipeline_options) if pipeline_options is not None \ + else self.pipeline_options.copy() + + assert pipeline_options is not None + + if budget_type is not None: + pipeline_options.update({'budget_type': budget_type}) + else: + budget_type = pipeline_options['budget_type'] + + budget = budget if budget is not None else pipeline_options[budget_type] + + if disable_file_output is None: + disable_file_output = getattr(self, '_disable_file_output', []) + + stats.start_timing() + if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: self._load_models() - # Refit is not applicable when ensemble_size is set to zero. - if self.ensemble_ is None: - raise ValueError("Refit can only be called if 'ensemble_size != 0'") - + model_configs = [] for identifier in self.models_: model = self.models_[identifier] - # this updates the model inplace, it can then later be used in - # predict method - - # try to fit the model. If it fails, shuffle the data. This - # could alleviate the problem in algorithms that depend on - # the ordering of the data. - X = self._get_fit_dictionary( - dataset_properties=copy.copy(dataset_properties), - dataset=dataset, - split_id=split_id) - fit_and_suppress_warnings(self._logger, model, X, y=None) + budget = identifier[-1] # identifier is seed, num_run, budget + model_configs.append((model.config, budget)) + + self._logger.debug(f"Refitting {model_configs}") + run_history = run_models_on_dataset( + time_left=total_walltime_limit, + func_eval_time_limit_secs=run_time_limit_secs, + model_configs=model_configs, + logger=self._logger, + logger_port=self._logger_port, + metric=metric, + dask_client=self._dask_client, + backend=self._backend, + memory_limit=memory_limit, + disable_file_output=disable_file_output, + all_supported_metrics=all_supported_metrics, + include=self.include_components, + exclude=self.exclude_components, + search_space_updates=self.search_space_updates, + pipeline_options=pipeline_options, + seed=self.seed, + multiprocessing_context=self._multiprocessing_context, + n_jobs=self.n_jobs, + current_search_space=self.search_space, + initial_num_run=self._backend.get_next_num_run() + ) + + replace_old_identifiers_to_refit_identifiers = {} + + self._logger.debug("Finished refit training") + old_identifier_index = None + for _, run_value in run_history.data.items(): + config = run_value.additional_info['configuration'] + for i, (configuration, _) in enumerate(model_configs): + if isinstance(configuration, Configuration): + configuration = configuration.get_dictionary() + self._logger.debug(f"Matching {config} with {configuration}") + if config == configuration: + old_identifier_index = i + break + if old_identifier_index is not None: + old_identifier = list(self.models_.keys())[old_identifier_index] + refit_identifier = (self.seed, run_value.additional_info['num_run'], old_identifier[2]) + replace_old_identifiers_to_refit_identifiers[old_identifier] = refit_identifier + else: + warnings.warn(f"Refit for {config} failed. Model fitted during search will be used instead.") + old_identifier_index = None + self.ensemble_.update_identifiers(replace_old_identifiers_to_refit_identifiers) + + self.run_history.update(run_history, DataOrigin.EXTERNAL_SAME_INSTANCES) + run_history.save_json(os.path.join(self._backend.internals_directory, 'refit_run_history.json'), + save_external=True) + + # store ensemble for later use, with large iteration + self._backend.save_ensemble(self.ensemble_, 10**8, self.seed) + self._load_models() self._clean_logger() return self @@ -1473,8 +1579,8 @@ def fit_pipeline( A pipeline configuration can be specified if None, uses default - Fit uses the estimator pipeline_config attribute, which the user - can interact via the get_pipeline_config()/set_pipeline_config() + Fit uses the estimator pipeline_options attribute, which the user + can interact via the get_pipeline_options()/set_pipeline_options() methods. Args: @@ -1581,8 +1687,8 @@ def fit_pipeline( if dataset is None: if ( - X_train is not None - and y_train is not None + X_train is None + and y_train is None ): raise ValueError("No dataset provided, must provide X_train, y_train tensors") dataset = self.get_dataset(X_train=X_train, @@ -1603,22 +1709,22 @@ def fit_pipeline( # search process, it makes sense to set it to 0 configuration.__setattr__('config_id', 0) + include_components = self.include_components if include_components is None else include_components + exclude_components = self.exclude_components if exclude_components is None else exclude_components + search_space_updates = self.search_space_updates if search_space_updates is None else search_space_updates + # get dataset properties dataset_requirements = get_dataset_requirements( info=dataset.get_required_dataset_info(), - include=self.include_components, - exclude=self.exclude_components, - search_space_updates=self.search_space_updates) + include=include_components, + exclude=exclude_components, + search_space_updates=search_space_updates) dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._backend.save_datamanager(dataset) if self._logger is None: self._logger = self._get_logger(dataset.dataset_name) - include_components = self.include_components if include_components is None else include_components - exclude_components = self.exclude_components if exclude_components is None else exclude_components - search_space_updates = self.search_space_updates if search_space_updates is None else search_space_updates - scenario_mock = unittest.mock.Mock() scenario_mock.wallclock_limit = run_time_limit_secs # This stats object is a hack - maybe the SMAC stats object should @@ -1632,7 +1738,7 @@ def fit_pipeline( names=[eval_metric] if eval_metric is not None else None, all_supported_metrics=False).pop() - pipeline_options = self.pipeline_options.copy().update(pipeline_options) if pipeline_options is not None \ + pipeline_options = {**self.pipeline_options, **pipeline_options} if pipeline_options is not None \ else self.pipeline_options.copy() assert pipeline_options is not None @@ -1666,7 +1772,7 @@ def fit_pipeline( include=include_components, exclude=exclude_components, search_space_updates=search_space_updates, - pipeline_config=pipeline_options, + pipeline_options=pipeline_options, pynisher_context=self._multiprocessing_context, ) @@ -1742,8 +1848,7 @@ def predict( # Parallelize predictions across models with n_jobs processes. # Each process computes predictions in chunks of batch_size rows. - if self._logger is None: - self._logger = self._get_logger("Predict-Logger") + self._logger = self._get_logger("Predict-Logger") if self.ensemble_ is None and not self._load_models(): raise ValueError("No ensemble found. Either fit has not yet " diff --git a/autoPyTorch/constants.py b/autoPyTorch/constants.py index bfd56d27f..3d77f77bc 100644 --- a/autoPyTorch/constants.py +++ b/autoPyTorch/constants.py @@ -58,6 +58,10 @@ "forecasting tasks! Please run \n pip install autoPyTorch[forecasting] \n to "\ "install the corresponding dependencies!" +# This value is applied to ensure numerical stability: Sometimes we want to rescale some values: value / scale. +# We make the scale value to be 1 if it is smaller than this value to ensure that the scaled value will not resutl in +# overflow +VERY_SMALL_VALUE = 1e-12 # The constant values for time series forecasting comes from # https://github.com/rakshitha123/TSForecasting/blob/master/experiments/deep_learning_experiments.py @@ -78,3 +82,6 @@ # To avoid that we get a sequence that is too long to be fed to a network MAX_WINDOW_SIZE_BASE = 500 + +# AutoPyTorch optionally allows network inference or metrics calculation for the following datasets +OPTIONAL_INFERENCE_CHOICES = ('test',) diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 670eb44c9..4c3565172 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -693,6 +693,17 @@ def __init__(self, self.splits = self.get_splits_from_resampling_strategy() # type: ignore[assignment] + valid_splits = [] + for i, split in enumerate(self.splits): + if len(split[0]) > 0: + valid_splits.append(split) + + if len(valid_splits) == 0: + raise ValueError(f'The passed value for {n_prediction_steps} is unsuited for the current dataset, please ' + 'consider reducing n_prediction_steps') + + self.splits = valid_splits + # TODO doing experiments to give the most proper way of defining these two values if lagged_value is None: try: diff --git a/autoPyTorch/ensemble/abstract_ensemble.py b/autoPyTorch/ensemble/abstract_ensemble.py index 072b6d260..0f04fe38a 100644 --- a/autoPyTorch/ensemble/abstract_ensemble.py +++ b/autoPyTorch/ensemble/abstract_ensemble.py @@ -9,6 +9,9 @@ class AbstractEnsemble(object): __metaclass__ = ABCMeta + def __init__(self): + self.identifiers_: List[Tuple[int, int, float]] = [] + @abstractmethod def fit( self, @@ -76,3 +79,12 @@ def get_validation_performance(self) -> float: Returns: Score """ + + def update_identifiers( + self, + replace_identifiers_mapping: Dict[Tuple[int, int, float], Tuple[int, int, float]] + ) -> None: + identifiers = self.identifiers_.copy() + for i, identifier in enumerate(self.identifiers_): + identifiers[i] = replace_identifiers_mapping.get(identifier, identifier) + self.identifiers_ = identifiers diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index d20a96b75..069228726 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -195,7 +195,8 @@ def get_additional_run_info(self) -> Dict[str, Any]: Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs """ return {'pipeline_configuration': self.configuration, - 'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()} + 'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config(), + 'configuration_origin': 'traditional'} def get_pipeline_representation(self) -> Dict[str, str]: return self.pipeline.get_pipeline_representation() @@ -347,7 +348,7 @@ class AbstractEvaluator(object): An evaluator is an object that: + constructs a pipeline (i.e. a classification or regression estimator) for a given - pipeline_config and run settings (budget, seed) + pipeline_options and run settings (budget, seed) + Fits and trains this pipeline (TrainEvaluator) or tests a given configuration (TestEvaluator) @@ -369,7 +370,7 @@ class AbstractEvaluator(object): The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type. Currently, only epoch and time are allowed. - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -430,7 +431,7 @@ def __init__(self, backend: Backend, budget: float, configuration: Union[int, str, Configuration], budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, seed: int = 1, output_y_hat_optimization: bool = True, num_run: Optional[int] = None, @@ -523,10 +524,10 @@ def __init__(self, backend: Backend, self._init_params = init_params assert self.pipeline_class is not None, "Could not infer pipeline class" - pipeline_config = pipeline_config if pipeline_config is not None \ + pipeline_options = pipeline_options if pipeline_options is not None \ else self.pipeline_class.get_default_pipeline_options() - self.budget_type = pipeline_config['budget_type'] if budget_type is None else budget_type - self.budget = pipeline_config[self.budget_type] if budget == 0 else budget + self.budget_type = pipeline_options['budget_type'] if budget_type is None else budget_type + self.budget = pipeline_options[self.budget_type] if budget == 0 else budget self.num_run = 0 if num_run is None else num_run @@ -539,7 +540,7 @@ def __init__(self, backend: Backend, port=logger_port, ) - self._init_fit_dictionary(logger_port=logger_port, pipeline_config=pipeline_config, metrics_dict=metrics_dict) + self._init_fit_dictionary(logger_port=logger_port, pipeline_options=pipeline_options, metrics_dict=metrics_dict) self.Y_optimization: Optional[np.ndarray] = None self.Y_actual_train: Optional[np.ndarray] = None self.pipelines: Optional[List[BaseEstimator]] = None @@ -597,7 +598,7 @@ def _init_datamanager_info( def _init_fit_dictionary( self, logger_port: int, - pipeline_config: Dict[str, Any], + pipeline_options: Dict[str, Any], metrics_dict: Optional[Dict[str, List[str]]] = None, ) -> None: """ @@ -608,7 +609,7 @@ def _init_fit_dictionary( Logging is performed using a socket-server scheme to be robust against many parallel entities that want to write to the same file. This integer states the socket port for the communication channel. - pipeline_config (Dict[str, Any]): + pipeline_options (Dict[str, Any]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -634,7 +635,7 @@ def _init_fit_dictionary( 'optimize_metric': self.metric.name }) - self.fit_dictionary.update(pipeline_config) + self.fit_dictionary.update(pipeline_options) # If the budget is epochs, we want to limit that in the fit dictionary if self.budget_type == 'epochs': self.fit_dictionary['epochs'] = self.budget @@ -805,6 +806,11 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float], if test_loss is not None: additional_run_info['test_loss'] = test_loss + # Add information to additional info that can be useful for other functionalities + additional_run_info['configuration'] = self.configuration \ + if not isinstance(self.configuration, Configuration) else self.configuration.get_dictionary() + additional_run_info['budget'] = self.budget + rval_dict = {'loss': cost, 'additional_run_info': additional_run_info, 'status': status} diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index b1650113c..3eaea6720 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -123,7 +123,7 @@ def __init__( abort_on_first_run_crash: bool, pynisher_context: str, multi_objectives: List[str], - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, initial_num_run: int = 1, stats: Optional[Stats] = None, run_obj: str = 'quality', @@ -198,13 +198,13 @@ def __init__( self.disable_file_output = disable_file_output self.init_params = init_params - self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type + self.budget_type = pipeline_options['budget_type'] if pipeline_options is not None else budget_type - self.pipeline_config: Dict[str, Union[int, str, float]] = dict() - if pipeline_config is None: - pipeline_config = replace_string_bool_to_bool(json.load(open( + self.pipeline_options: Dict[str, Union[int, str, float]] = dict() + if pipeline_options is None: + pipeline_options = replace_string_bool_to_bool(json.load(open( os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json')))) - self.pipeline_config.update(pipeline_config) + self.pipeline_options.update(pipeline_options) self.logger_port = logger_port if self.logger_port is None: @@ -225,7 +225,7 @@ def __init__( def _check_and_get_default_budget(self) -> float: budget_type_choices_tabular = ('epochs', 'runtime') budget_choices = { - budget_type: float(self.pipeline_config.get(budget_type, np.inf)) + budget_type: float(self.pipeline_options.get(budget_type, np.inf)) for budget_type in budget_type_choices_tabular } @@ -234,7 +234,7 @@ def _check_and_get_default_budget(self) -> float: budget_type_choices = budget_type_choices_tabular + FORECASTING_BUDGET_TYPE # budget is defined by epochs by default - budget_type = str(self.pipeline_config.get('budget_type', 'epochs')) + budget_type = str(self.pipeline_options.get('budget_type', 'epochs')) if self.budget_type is not None: budget_type = self.budget_type @@ -361,7 +361,7 @@ def run( init_params=init_params, budget=budget, budget_type=self.budget_type, - pipeline_config=self.pipeline_config, + pipeline_options=self.pipeline_options, logger_port=self.logger_port, all_supported_metrics=self.all_supported_metrics, search_space_updates=self.search_space_updates diff --git a/autoPyTorch/evaluation/test_evaluator.py b/autoPyTorch/evaluation/test_evaluator.py index 4d5b0ae91..12b7bc31d 100644 --- a/autoPyTorch/evaluation/test_evaluator.py +++ b/autoPyTorch/evaluation/test_evaluator.py @@ -51,7 +51,7 @@ class TestEvaluator(AbstractEvaluator): The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -113,7 +113,7 @@ def __init__( budget: float, configuration: Union[int, str, Configuration], budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, seed: int = 1, output_y_hat_optimization: bool = False, num_run: Optional[int] = None, @@ -141,7 +141,7 @@ def __init__( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates ) @@ -206,7 +206,7 @@ def eval_test_function( include: Optional[Dict[str, Any]], exclude: Optional[Dict[str, Any]], disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, budget_type: str = None, init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, @@ -230,7 +230,7 @@ def eval_test_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates) evaluator.fit_predict_and_loss() diff --git a/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py b/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py index 0940d1e9a..07a87ede4 100644 --- a/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py +++ b/autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py @@ -40,7 +40,7 @@ class TimeSeriesForecastingTrainEvaluator(TrainEvaluator): The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -106,7 +106,7 @@ def __init__(self, backend: Backend, queue: Queue, metric: autoPyTorchMetric, budget: float, budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, configuration: Optional[Configuration] = None, seed: int = 1, output_y_hat_optimization: bool = True, @@ -138,7 +138,7 @@ def __init__(self, backend: Backend, queue: Queue, logger_port=logger_port, keep_models=keep_models, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates ) self.datamanager = backend.load_datamanager() @@ -456,7 +456,7 @@ def forecasting_eval_train_function( include: Optional[Dict[str, Any]], exclude: Optional[Dict[str, Any]], disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, budget_type: str = None, init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, @@ -490,7 +490,7 @@ def forecasting_eval_train_function( The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -550,7 +550,7 @@ def forecasting_eval_train_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates, max_budget=max_budget, min_num_test_instances=min_num_test_instances, diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index 142af6bcc..e88c8eaca 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -60,7 +60,7 @@ class TrainEvaluator(AbstractEvaluator): The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -121,7 +121,7 @@ def __init__(self, backend: Backend, queue: Queue, budget: float, configuration: Union[int, str, Configuration], budget_type: str = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, seed: int = 1, output_y_hat_optimization: bool = True, num_run: Optional[int] = None, @@ -149,7 +149,7 @@ def __init__(self, backend: Backend, queue: Queue, budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates ) @@ -420,7 +420,7 @@ def eval_train_function( include: Optional[Dict[str, Any]], exclude: Optional[Dict[str, Any]], disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, - pipeline_config: Optional[Dict[str, Any]] = None, + pipeline_options: Optional[Dict[str, Any]] = None, budget_type: str = None, init_params: Optional[Dict[str, Any]] = None, logger_port: Optional[int] = None, @@ -452,7 +452,7 @@ def eval_train_function( The amount of epochs/time a configuration is allowed to run. budget_type (str): The budget type, which can be epochs or time - pipeline_config (Optional[Dict[str, Any]]): + pipeline_options (Optional[Dict[str, Any]]): Defines the content of the pipeline being evaluated. For example, it contains pipeline specific settings like logging name, or whether or not to use tensorboard. @@ -506,7 +506,7 @@ def eval_train_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates, ) evaluator.fit_predict_and_loss() diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index 53eae4696..92bf7bb87 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -111,7 +111,7 @@ def __init__(self, watcher: StopWatch, n_jobs: int, dask_client: Optional[dask.distributed.Client], - pipeline_config: Dict[str, Any], + pipeline_options: Dict[str, Any], start_num_run: int = 1, seed: int = 1, resampling_strategy: Union[HoldoutValTypes, @@ -227,7 +227,7 @@ def __init__(self, self.backend = backend self.all_supported_metrics = all_supported_metrics - self.pipeline_config = pipeline_config + self.pipeline_options = pipeline_options # the configuration space self.config_space = config_space @@ -326,7 +326,7 @@ def run_smbo(self, func: Optional[Callable] = None ta=func, logger_port=self.logger_port, all_supported_metrics=self.all_supported_metrics, - pipeline_config=self.pipeline_config, + pipeline_options=self.pipeline_options, search_space_updates=self.search_space_updates, pynisher_context=self.pynisher_context, ) @@ -376,7 +376,7 @@ def run_smbo(self, func: Optional[Callable] = None ) scenario_dict.update(self.smac_scenario_args) - budget_type = self.pipeline_config['budget_type'] + budget_type = self.pipeline_options['budget_type'] if budget_type in FORECASTING_BUDGET_TYPE: if STRING_TO_TASK_TYPES.get(self.task_type, -1) != TIMESERIES_FORECASTING: raise ValueError('Forecasting Budget type is only available for forecasting task!') diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/scaling/utils.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/scaling/utils.py index abd246072..cf314dc7e 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/scaling/utils.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/scaling/utils.py @@ -6,6 +6,8 @@ from sklearn.base import BaseEstimator +from autoPyTorch.constants import VERY_SMALL_VALUE + # Similar to / inspired by # https://github.com/tslearn-team/tslearn/blob/a3cf3bf/tslearn/preprocessing/preprocessing.py @@ -41,7 +43,7 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS self.loc[self.static_features] = X[self.static_features].mean() # ensure that if all the values are the same in a group, we could still normalize them correctly - self.scale[self.scale == 0] = 1. + self.scale[self.scale < VERY_SMALL_VALUE] = 1. elif self.mode == "min_max": X_grouped = X.groupby(X.index) @@ -55,14 +57,14 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS self.loc = min_ self.scale = diff_ self.scale.mask(self.scale == 0.0, self.loc) - self.scale[self.scale == 0.0] = 1.0 + self.scale[self.scale < VERY_SMALL_VALUE] = 1.0 elif self.mode == "max_abs": X_abs = X.transform("abs") max_abs_ = X_abs.groupby(X_abs.index).agg("max") max_abs_[self.static_features] = max_abs_[self.static_features].max() - max_abs_[max_abs_ == 0.0] = 1.0 + max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0 self.loc = None self.scale = max_abs_ @@ -73,7 +75,7 @@ def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Any = None) -> "TimeSeriesS mean_abs_[self.static_features] = mean_abs_[self.static_features].mean() self.scale = mean_abs_.mask(mean_abs_ == 0.0, X_abs.agg("max")) - self.scale[self.scale == 0] = 1 + self.scale[self.scale < VERY_SMALL_VALUE] = 1 self.loc = None elif self.mode == "none": @@ -108,7 +110,7 @@ def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, n loc = X.mean(axis=0, keepdims=True) scale = np.nan_to_num(X.std(axis=0, ddof=1, keepdims=True)) scale = np.where(scale == 0, loc, scale) - scale[scale == 0] = 1. + scale[scale < VERY_SMALL_VALUE] = 1. return (X - loc) / scale elif self.mode == 'min_max': @@ -119,13 +121,13 @@ def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, n loc = min_ scale = diff_ scale = np.where(scale == 0., loc, scale) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 return (X - loc) / scale elif self.mode == "max_abs": X_abs = np.abs(X) max_abs_ = X_abs.max(0, keepdims=True) - max_abs_[max_abs_ == 0.0] = 1.0 + max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0 scale = max_abs_ return X / scale @@ -133,7 +135,7 @@ def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, n X_abs = np.abs(X) mean_abs_ = X_abs.mean(0, keepdims=True) scale = np.where(mean_abs_ == 0.0, np.max(X_abs), mean_abs_) - scale[scale == 0] = 1 + scale[scale < VERY_SMALL_VALUE] = 1 return X / scale elif self.mode == "none": diff --git a/autoPyTorch/pipeline/components/setup/forecasting_target_scaling/utils.py b/autoPyTorch/pipeline/components/setup/forecasting_target_scaling/utils.py index 7b4782206..ed4da6d2a 100644 --- a/autoPyTorch/pipeline/components/setup/forecasting_target_scaling/utils.py +++ b/autoPyTorch/pipeline/components/setup/forecasting_target_scaling/utils.py @@ -4,6 +4,8 @@ import torch +from autoPyTorch.constants import VERY_SMALL_VALUE + # Similar to / inspired by # https://github.com/tslearn-team/tslearn/blob/a3cf3bf/tslearn/preprocessing/preprocessing.py @@ -30,7 +32,7 @@ def transform(self, offset_targets = past_targets - loc scale = torch.where(torch.logical_or(scale == 0.0, scale == torch.nan), offset_targets[:, [-1]], scale) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = (future_targets - loc) / scale return (past_targets - loc) / scale, future_targets, loc, scale @@ -42,14 +44,14 @@ def transform(self, diff_ = max_ - min_ loc = min_ scale = torch.where(diff_ == 0, past_targets[:, [-1]], diff_) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = (future_targets - loc) / scale return (past_targets - loc) / scale, future_targets, loc, scale elif self.mode == "max_abs": max_abs_ = torch.max(torch.abs(past_targets), dim=1, keepdim=True)[0] - max_abs_[max_abs_ == 0.0] = 1.0 + max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0 scale = max_abs_ if future_targets is not None: future_targets = future_targets / scale @@ -58,7 +60,7 @@ def transform(self, elif self.mode == 'mean_abs': mean_abs = torch.mean(torch.abs(past_targets), dim=1, keepdim=True) scale = torch.where(mean_abs == 0.0, past_targets[:, [-1]], mean_abs) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = future_targets / scale return past_targets / scale, future_targets, None, scale @@ -82,7 +84,7 @@ def transform(self, offset_targets = past_targets - loc # ensure that all the targets are scaled properly scale = torch.where(torch.logical_or(scale == 0.0, scale == torch.nan), offset_targets[:, [-1]], scale) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = (future_targets - loc) / scale @@ -100,7 +102,7 @@ def transform(self, diff_ = max_ - min_ loc = min_ scale = torch.where(diff_ == 0, past_targets[:, [-1]], diff_) - scale[scale == 0.0] = 1.0 + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = (future_targets - loc) / scale @@ -110,7 +112,7 @@ def transform(self, elif self.mode == "max_abs": max_abs_ = torch.max(torch.abs(valid_past_targets), dim=1, keepdim=True)[0] - max_abs_[max_abs_ == 0.0] = 1.0 + max_abs_[max_abs_ < VERY_SMALL_VALUE] = 1.0 scale = max_abs_ if future_targets is not None: future_targets = future_targets / scale @@ -122,8 +124,8 @@ def transform(self, elif self.mode == 'mean_abs': mean_abs = torch.sum(torch.abs(valid_past_targets), dim=1, keepdim=True) / valid_past_obs scale = torch.where(mean_abs == 0.0, valid_past_targets[:, [-1]], mean_abs) - # in case that all values in the tensor is 0 - scale[scale == 0.0] = 1.0 + # in case that all values in the tensor is too small + scale[scale < VERY_SMALL_VALUE] = 1.0 if future_targets is not None: future_targets = future_targets / scale diff --git a/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py b/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py index fc7ac3ae1..3912cadf9 100644 --- a/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py +++ b/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py @@ -368,7 +368,9 @@ def scale_value(self, outputs = raw_value - loc.to(device) else: outputs = (raw_value - loc.to(device)) / scale.to(device) - return outputs + return outputs + else: + return raw_value @abstractmethod def forward(self, diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py index 7d26c5481..8b4723066 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py @@ -52,8 +52,7 @@ def __init__( self.add_fit_requirements([ FitRequirement('X_train', (np.ndarray, list, pd.DataFrame), user_defined=False, dataset_property=False), FitRequirement('y_train', (np.ndarray, list, pd.Series,), user_defined=False, dataset_property=False), - FitRequirement('train_indices', (np.ndarray, list), user_defined=False, dataset_property=False), - FitRequirement('val_indices', (np.ndarray, list), user_defined=False, dataset_property=False)]) + FitRequirement('train_indices', (np.ndarray, list), user_defined=False, dataset_property=False)]) def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: """ @@ -90,8 +89,14 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: # train model blockPrint() + val_indices = X.get('val_indices', None) + X_val = None + y_val = None + if val_indices is not None: + X_val = X['X_train'][val_indices] + y_val = X['y_train'][val_indices] self.fit_output = self.model.fit(X['X_train'][X['train_indices']], X['y_train'][X['train_indices']], - X['X_train'][X['val_indices']], X['y_train'][X['val_indices']]) + X_val, y_val) enablePrint() # infer diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/base_traditional_learner.py b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/base_traditional_learner.py index 9c0166a9f..eaf40feb3 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/base_traditional_learner.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/base_traditional_learner.py @@ -68,6 +68,8 @@ def __init__(self, self.is_classification = STRING_TO_TASK_TYPES[task_type] not in REGRESSION_TASKS + self.has_val_set = False + self.metric = get_metrics(dataset_properties={'task_type': task_type, 'output_type': output_type}, names=[optimize_metric] if optimize_metric is not None else None)[0] @@ -132,8 +134,8 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None) -> None: """ Method that fits the underlying estimator Args: @@ -152,8 +154,8 @@ def _fit(self, def fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> Dict[str, Any]: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None) -> Dict[str, Any]: """ Fit the model (possible using the validation set for early stopping) and return the results on the training and validation set. @@ -172,7 +174,10 @@ def fit(self, X_train: np.ndarray, Dictionary containing the results. see _get_results() """ X_train = self._preprocess(X_train) - X_val = self._preprocess(X_val) + + if X_val is not None: + self.has_val_set = True + X_val = self._preprocess(X_val) self._prepare_model(X_train, y_train) @@ -253,14 +258,14 @@ def _get_results(self, Dictionary containing the results """ pred_train = self.predict(X_train, predict_proba=self.is_classification, preprocess=False) - pred_val = self.predict(X_val, predict_proba=self.is_classification, preprocess=False) results = dict() - - results["val_preds"] = pred_val.tolist() - results["labels"] = y_val.tolist() - results["train_score"] = self.metric(y_train, pred_train) - results["val_score"] = self.metric(y_val, pred_val) + + if self.has_val_set: + pred_val = self.predict(X_val, predict_proba=self.is_classification, preprocess=False) + results["labels"] = y_val.tolist() + results["val_preds"] = pred_val.tolist() + results["val_score"] = self.metric(y_val, pred_val) return results diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py index 220c52dcd..fca02aa32 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/traditional_learner/learners.py @@ -45,8 +45,10 @@ def _prepare_model(self, X_train: np.ndarray, y_train: np.ndarray ) -> None: - early_stopping = 150 if X_train.shape[0] > 10000 else max(round(150 * 10000 / X_train.shape[0]), 10) - self.config["early_stopping_rounds"] = early_stopping + + if self.has_val_set: + early_stopping = 150 if X_train.shape[0] > 10000 else max(round(150 * 10000 / X_train.shape[0]), 10) + self.config["early_stopping_rounds"] = early_stopping if not self.is_classification: self.model = LGBMRegressor(**self.config, random_state=self.random_state) else: @@ -57,11 +59,14 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None ) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" - self.model.fit(X_train, y_train, eval_set=[(X_val, y_val)]) + eval_set = None + if self.has_val_set: + eval_set = [(X_val, y_val)] + self.model.fit(X_train, y_train, eval_set=eval_set) def predict(self, X_test: np.ndarray, predict_proba: bool = False, @@ -125,15 +130,21 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None + ) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" - early_stopping = 150 if X_train.shape[0] > 10000 else max(round(150 * 10000 / X_train.shape[0]), 10) categoricals = [ind for ind in range(X_train.shape[1]) if isinstance(X_train[0, ind], str)] X_train_pooled = Pool(data=X_train, label=y_train, cat_features=categoricals) - X_val_pooled = Pool(data=X_val, label=y_val, cat_features=categoricals) + X_val_pooled = None + if self.has_val_set: + X_val_pooled = Pool(data=X_val, label=y_val, cat_features=categoricals) + early_stopping: Optional[int] = 150 if X_train.shape[0] > 10000 else max( + round(150 * 10000 / X_train.shape[0]), 10) + else: + early_stopping = None self.model.fit(X_train_pooled, eval_set=X_val_pooled, @@ -189,8 +200,9 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None + ) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" self.model.fit(X_train, y_train) @@ -244,8 +256,8 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" self.model.fit(X_train, y_train) if self.config["warm_start"]: @@ -303,8 +315,8 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" self.model.fit(X_train, y_train) @@ -346,8 +358,8 @@ def _prepare_model(self, def _fit(self, X_train: np.ndarray, y_train: np.ndarray, - X_val: np.ndarray, - y_val: np.ndarray) -> None: + X_val: Optional[np.ndarray] = None, + y_val: Optional[np.ndarray] = None) -> None: assert self.model is not None, "No model found. Can't fit without preparing the model" self.model.fit(X_train, y_train) diff --git a/autoPyTorch/utils/parallel_model_runner.py b/autoPyTorch/utils/parallel_model_runner.py new file mode 100644 index 000000000..d4237f683 --- /dev/null +++ b/autoPyTorch/utils/parallel_model_runner.py @@ -0,0 +1,300 @@ +import logging +import math +import time +import unittest +from typing import Dict, List, Optional, Tuple, Union + +from ConfigSpace.configuration_space import Configuration, ConfigurationSpace + +import dask.distributed + +from smac.runhistory.runhistory import RunHistory +from smac.stats.stats import Stats +from smac.tae import StatusType + +from autoPyTorch.automl_common.common.utils.backend import Backend +from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash +from autoPyTorch.evaluation.utils import DisableFileOutputParameters +from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric +from autoPyTorch.utils.common import dict_repr +from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates +from autoPyTorch.utils.logging_ import PicklableClientLogger + + +def run_models_on_dataset( + time_left: int, + func_eval_time_limit_secs: int, + model_configs: List[Tuple[str, Configuration]], + logger: PicklableClientLogger, + metric: autoPyTorchMetric, + dask_client: dask.distributed.Client, + backend: Backend, + seed: int, + multiprocessing_context: str, + current_search_space: ConfigurationSpace, + n_jobs: int = 1, + initial_num_run: int = 1, + all_supported_metrics: bool = False, + include: Optional[Dict[str, List[str]]] = None, + exclude: Optional[Dict[str, List[str]]] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, + logger_port: Optional[int] = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + memory_limit: Optional[int] = None, + disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None, + pipeline_options: Optional[Dict] = None, +) -> RunHistory: + """ + Runs models specified by `model_configs` on dask parallel infrastructure. + + Args: + time_left (int): + Time limit in seconds for the search of appropriate models. + By increasing this value, autopytorch has a higher + chance of finding better models. + func_eval_time_limit_secs (int): + Time limit for a single call to the machine learning model. + Model fitting will be terminated if the machine + learning algorithm runs over the time limit. Set + this value high enough so that typical machine + learning algorithms can be fit on the training + data. + Set to np.inf in case no time limit is desired. + model_configs (List[Tuple[str, Configuration]]): + List containing the configuration and the budget for the model to be evaluated. + logger (PicklableClientLogger): + Logger + metric (autoPyTorchMetric): + autoPyTorchMetric to be used for evaluation. + dask_client (dask.distributed.Client): + dask client where the function evaluation jobs are submitted. + backend (Backend): + Current backend object where the data is stored. The backend + is used to interact with the disk. + all_supported_metrics (bool): + If True, all metrics supporting current task will be calculated + for each pipeline. + seed (int): + Seed to be used for reproducibility. + multiprocessing_context (str): + context used for spawning child processes. + n_jobs (int): + Number of consecutive processes to spawn. + current_search_space (ConfigurationSpace): + The search space of the neural networks which will be used to instantiate Configuration objects. + initial_num_run (int): + Initial num run for running the models. + include (Optional[Dict[str, List[str]]]): + Dictionary containing components to include. Key is the node + name and Value is an Iterable of the names of the components + to include. Only these components will be present in the + search space. Defaults to None. + exclude (Optional[Dict[str, List[str]]]): + Dictionary containing components to exclude. Key is the node + name and Value is an Iterable of the names of the components + to exclude. All except these components will be present in + the search space. Defaults to None. + search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): + Search space updates that can be used to modify the search + space of particular components or choice modules of the pipeline. + Defaults to None. + logger_port (Optional[int]): + Port used to create the logging server. Defaults to logging.handlers.DEFAULT_TCP_LOGGING_PORT. + memory_limit (Optional[int]): + Memory limit in MB for the machine learning algorithm. + Autopytorch will stop fitting the machine learning algorithm + if it tries to allocate more than memory_limit MB. If None + is provided, no memory limit is set. In case of multi-processing, + memory_limit will be per job. This memory limit also applies to + the ensemble creation process. Defaults to None. + disable_file_output (Optional[List[Union[str, DisableFileOutputParameters]]]): + Used as a list to pass more fine-grained + information on what to save. Must be a member of `DisableFileOutputParameters`. + Allowed elements in the list are: + + + `y_optimization`: + do not save the predictions for the optimization set, + which would later on be used to build an ensemble. Note that SMAC + optimizes a metric evaluated on the optimization set. + + `pipeline`: + do not save any individual pipeline files + + `pipelines`: + In case of cross validation, disables saving the joint model of the + pipelines fit on each fold. + + `y_test`: + do not save the predictions for the test set. + + `all`: + do not save any of the above. + For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`. + Defaults to None. + pipeline_options (Optional[Dict]): + Valid config options include "device", + "torch_num_threads", "early_stopping", "use_tensorboard_logger", + "metrics_during_training". + + Returns: + RunHistory: + run_history: + Run History of training all the models in model_configs + """ + starttime = time.time() + run_history = RunHistory() + if memory_limit is not None: + memory_limit = int(math.ceil(memory_limit)) + total_models = len(model_configs) + dask_futures: List[dask.distributed.Future] = [] + for n_r, (config, budget) in enumerate(model_configs): + + # Only launch a task if there is time + start_time = time.time() + if time_left >= func_eval_time_limit_secs: + logger.info(f"{n_r}: Started fitting {config} with cutoff={func_eval_time_limit_secs}") + scenario_mock = unittest.mock.Mock() + scenario_mock.wallclock_limit = time_left + # This stats object is a hack - maybe the SMAC stats object should + # already be generated here! + stats = Stats(scenario_mock) + stats.start_timing() + + if isinstance(config, Configuration): + config.config_id = n_r + init_num_run = initial_num_run + else: + init_num_run = initial_num_run + n_r + + ta = ExecuteTaFuncWithQueue( + pynisher_context=multiprocessing_context, + backend=backend, + seed=seed, + metric=metric, + multi_objectives=["cost"], + logger_port=logger_port, + pipeline_options=pipeline_options, + cost_for_crash=get_cost_of_crash(metric), + abort_on_first_run_crash=False, + initial_num_run=init_num_run, + stats=stats, + memory_limit=memory_limit, + disable_file_output=disable_file_output, + all_supported_metrics=all_supported_metrics, + include=include, + exclude=exclude, + search_space_updates=search_space_updates + ) + dask_futures.append([ + config, + dask_client.submit( + ta.run, config=config, + cutoff=func_eval_time_limit_secs, + budget=budget + ) + ]) + + # When managing time, we need to take into account the allocated time resources, + # which are dependent on the number of cores. 'dask_futures' is a proxy to the number + # of workers /n_jobs that we have, in that if there are 4 cores allocated, we can run at most + # 4 task in parallel. Every 'cutoff' seconds, we generate up to 4 tasks. + # If we only have 4 workers and there are 4 futures in dask_futures, it means that every + # worker has a task. We would not like to launch another job until a worker is available. To this + # end, the following if-statement queries the number of active jobs, and forces to wait for a job + # completion via future.result(), so that a new worker is available for the next iteration. + if len(dask_futures) >= n_jobs: + + # How many workers to wait before starting fitting the next iteration + workers_to_wait = 1 + if n_r >= total_models - 1 or time_left <= func_eval_time_limit_secs: + # If on the last iteration, flush out all tasks + workers_to_wait = len(dask_futures) + + while workers_to_wait >= 1: + workers_to_wait -= 1 + # We launch dask jobs only when there are resources available. + # This allow us to control time allocation properly, and early terminate + # the traditional machine learning pipeline + _process_result( + current_search_space=current_search_space, + dask_futures=dask_futures, + run_history=run_history, + seed=seed, + starttime=starttime, + logger=logger) + + # In the case of a serial execution, calling submit halts the run for a resource + # dynamically adjust time in this case + time_left -= int(time.time() - start_time) + + # Exit if no more time is available for a new classifier + if time_left < func_eval_time_limit_secs: + logger.warning("Not enough time to fit all machine learning models." + "Please consider increasing the run time to further improve performance.") + break + + return run_history + + +def _process_result( + dask_futures: List[dask.distributed.Future], + current_search_space: ConfigurationSpace, + run_history: RunHistory, + seed: int, + starttime: float, + logger: PicklableClientLogger +) -> None: + """ + Update run_history in-place using results of the + latest finishing model. + + Args: + dask_futures (List[dask.distributed.Future]): + List of dask futures which are used to get the results of a finished run. + run_history (RunHistory): + RunHistory object to be appended with the finished run + seed (int): + Seed used for reproducibility. + starttime (float): + starttime of the runs. + logger (PicklableClientLogger): + Logger. + """ + cls, future = dask_futures.pop(0) + status, cost, runtime, additional_info = future.result() + if status == StatusType.SUCCESS: + logger.info( + "Fitting {} took {} [sec] and got performance: {}.\n" + "additional info:\n{}".format(cls, runtime, cost, dict_repr(additional_info)) + ) + origin: str = additional_info['configuration_origin'] + current_config: Union[str, dict] = additional_info['configuration'] + + # indicates the finished model is part of autopytorch search space + if isinstance(current_config, dict): + configuration = Configuration(current_search_space, current_config) # type: ignore[misc] + else: + # we assume that it is a traditional model and `pipeline_configuration` + # specifies the configuration. + configuration = additional_info.pop('pipeline_configuration', None) + + if configuration is not None: + run_history.add(config=configuration, cost=cost, + time=runtime, status=status, seed=seed, + starttime=starttime, endtime=starttime + runtime, + origin=origin, additional_info=additional_info) + else: + logger.warning(f"Something went wrong while processing the results of {current_config}." + f"with additional_info: {additional_info} and status_type: {status}. " + f"Refer to the log file for more information.\nSkipping for now.") + else: + if additional_info.get('exitcode') == -6: + logger.error( + "Prediction for {} failed with run state {},\n" + "because the provided memory limits were too tight.\n" + "Please increase the 'ml_memory_limit' and try again.\n" + "If you still get the problem, please open an issue\n" + "and paste the additional info.\n" + "Additional info:\n{}".format(cls, str(status), dict_repr(additional_info)) + ) + else: + logger.error( + "Prediction for {} failed with run state {}.\nAdditional info:\n{}".format( + cls, str(status), dict_repr(additional_info) + ) + ) diff --git a/autoPyTorch/utils/results_manager.py b/autoPyTorch/utils/results_manager.py index c1860b0f6..6547aaef2 100644 --- a/autoPyTorch/utils/results_manager.py +++ b/autoPyTorch/utils/results_manager.py @@ -1,6 +1,6 @@ import io from datetime import datetime -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import Configuration @@ -12,6 +12,7 @@ from smac.tae import StatusType from smac.utils.io.traj_logging import TrajEntry +from autoPyTorch.constants import OPTIONAL_INFERENCE_CHOICES from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric @@ -69,7 +70,7 @@ def _extract_metrics_info( run_value: RunValue, scoring_functions: List[autoPyTorchMetric], inference_name: str -) -> Dict[str, float]: +) -> Dict[str, Optional[float]]: """ Extract the metric information given a run_value and a list of metrics of interest. @@ -97,7 +98,14 @@ def _extract_metrics_info( if inference_name not in inference_choices: raise ValueError(f'inference_name must be in {inference_choices}, but got {inference_choices}') - cost_info = run_value.additional_info[f'{inference_name}_loss'] + cost_info = run_value.additional_info.get(f'{inference_name}_loss', None) + if cost_info is None: + if inference_name not in OPTIONAL_INFERENCE_CHOICES: + raise ValueError(f"Expected loss for {inference_name} set to not be None, but got {cost_info}") + else: + # Additional info for metrics is not available in this case. + return {metric.name: None for metric in scoring_functions} + avail_metrics = cost_info.keys() return { @@ -175,7 +183,7 @@ def _update(self, data: Dict[str, Any]) -> None: ) self._train_scores.append(data[f'train_{self.metric_name}']) - self._test_scores.append(data[f'test_{self.metric_name}']) + self._test_scores.append(data.get(f'test_{self.metric_name}', None)) self._end_times.append(datetime.timestamp(data['Timestamp'])) def _sort_by_endtime(self) -> None: @@ -413,11 +421,31 @@ def _extract_results_from_run_history(self, run_history: RunHistory) -> None: config = run_history.ids_config[run_key.config_id] self._update(config=config, run_key=run_key, run_value=run_value) + self._check_null_in_optional_inference_choices() + self.rank_opt_scores = scipy.stats.rankdata( -1 * self._metric._sign * self.opt_scores, # rank order method='min' ) + def _check_null_in_optional_inference_choices( + self + ) -> None: + """ + Checks if the data is missing or if all the runs failed for each optional inference choice and + sets the scores for that inference choice to all None. + """ + for inference_choice in OPTIONAL_INFERENCE_CHOICES: + metrics_dict = getattr(self, f'{inference_choice}_metric_dict') + new_metric_dict = {} + + for metric in self._scoring_functions: + scores = metrics_dict[metric.name] + if all([score is None or score == metric._worst_possible_result for score in scores]): + scores = [None] * len(self.status_types) + new_metric_dict[metric.name] = scores + setattr(self, f'{inference_choice}_metric_dict', new_metric_dict) + class MetricResults: def __init__( @@ -486,12 +514,24 @@ def _extract_results(self) -> None: for inference_name in ['train', 'test', 'opt']: # TODO: Extract information from self.search_results data = getattr(self.search_results, f'{inference_name}_metric_dict')[metric_name] + if all([d is None for d in data]): + if inference_name not in OPTIONAL_INFERENCE_CHOICES: + raise ValueError(f"Expected {metric_name} score for {inference_name} set" + f" to not be None, but got {data}") + else: + continue self.data[f'single::{inference_name}::{metric_name}'] = np.array(data) if self.ensemble_results.empty() or inference_name == 'opt': continue data = getattr(self.ensemble_results, f'{inference_name}_scores') + if all([d is None for d in data]): + if inference_name not in OPTIONAL_INFERENCE_CHOICES: + raise ValueError(f"Expected {metric_name} score for {inference_name} set" + f" to not be None, but got {data}") + else: + continue self.data[f'ensemble::{inference_name}::{metric_name}'] = np.array(data) def get_ensemble_merged_data(self) -> Dict[str, np.ndarray]: @@ -516,6 +556,8 @@ def get_ensemble_merged_data(self) -> Dict[str, np.ndarray]: cur, timestep_size, sign = 0, self.cum_times.size, self.metric._sign key_train, key_test = f'ensemble::train::{self.metric.name}', f'ensemble::test::{self.metric.name}' + all_test_perfs_null = all([perf is None for perf in test_scores]) + train_perfs = np.full_like(self.cum_times, self.metric._worst_possible_result) test_perfs = np.full_like(self.cum_times, self.metric._worst_possible_result) @@ -530,9 +572,16 @@ def get_ensemble_merged_data(self) -> Dict[str, np.ndarray]: time_index = min(cur, timestep_size - 1) # If there already exists a previous allocated value, update by a better value train_perfs[time_index] = sign * max(sign * train_perfs[time_index], sign * train_score) - test_perfs[time_index] = sign * max(sign * test_perfs[time_index], sign * test_score) + # test_perfs can be none when X_test is not passed + if not all_test_perfs_null: + test_perfs[time_index] = sign * max(sign * test_perfs[time_index], sign * test_score) + + update_dict = {key_train: train_perfs} + if not all_test_perfs_null: + update_dict[key_test] = test_perfs + + data.update(update_dict) - data.update({key_train: train_perfs, key_test: test_perfs}) return data diff --git a/autoPyTorch/utils/results_visualizer.py b/autoPyTorch/utils/results_visualizer.py index e1debe29c..44f931285 100644 --- a/autoPyTorch/utils/results_visualizer.py +++ b/autoPyTorch/utils/results_visualizer.py @@ -6,6 +6,7 @@ import numpy as np +from autoPyTorch.constants import OPTIONAL_INFERENCE_CHOICES from autoPyTorch.utils.results_manager import MetricResults @@ -318,7 +319,15 @@ def plot_perf_over_time( minimize = (results.metric._sign == -1) for key in data.keys(): + inference_name = key.split('::')[1] _label, _color, _perfs = labels[key], colors[key], data[key] + all_null_perfs = all([perf is None for perf in _perfs]) + + if all_null_perfs: + if inference_name not in OPTIONAL_INFERENCE_CHOICES: + raise ValueError(f"Expected loss for {inference_name} set to not be None") + else: + continue # Take the best results over time _cum_perfs = np.minimum.accumulate(_perfs) if minimize else np.maximum.accumulate(_perfs) diff --git a/docs/releases.rst b/docs/releases.rst index ef6d65717..1dcb742b9 100644 --- a/docs/releases.rst +++ b/docs/releases.rst @@ -12,6 +12,22 @@ Releases ======== +Version 0.2.1 +============= +| [FIX] ADD forecasting init design to pip data files by @dengdifan in https://github.com/automl/Auto-PyTorch/pull/459 +| checks for time series dataset split by @dengdifan in https://github.com/automl/Auto-PyTorch/pull/464 +| [FIX] Numerical stability scaling for timeseries forecasting tasks by @dengdifan in https://github.com/automl/Auto-PyTorch/pull/467 +| [FIX] pipeline options in `fit_pipeline` by @ravinkohli in https://github.com/automl/Auto-PyTorch/pull/466 +| [FIX] results management and visualisation with missing test data by @ravinkohli in https://github.com/automl/Auto-PyTorch/pull/465 +| [ADD] Robustly refit models in final ensemble in parallel by @ravinkohli in https://github.com/automl/Auto-PyTorch/pull/471 + +Contributors v0.2 +***************** +* Difan Deng +* Ravin Kohli +* Shuhei Watanabe +* Theodoros Athanasiadis + Version 0.2 =========== | [FIX] Documentation and docker workflow file (#449) diff --git a/examples/20_basics/example_tabular_classification.py b/examples/20_basics/example_tabular_classification.py index 636281eff..291e017ac 100644 --- a/examples/20_basics/example_tabular_classification.py +++ b/examples/20_basics/example_tabular_classification.py @@ -3,13 +3,15 @@ Tabular Classification ====================== -The following example shows how to fit a sample classification model -with AutoPyTorch +The following example shows how to fit a simple classification ensemble +with AutoPyTorch and refit the found ensemble. """ import os import tempfile as tmp import warnings +from autoPyTorch.datasets.resampling_strategy import CrossValTypes + os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() os.environ['OMP_NUM_THREADS'] = '1' os.environ['OPENBLAS_NUM_THREADS'] = '1' @@ -62,13 +64,39 @@ ) ############################################################################ -# Print the final ensemble performance -# ==================================== +# Print the final ensemble performance before refit +# ================================================= + y_pred = api.predict(X_test) score = api.score(y_pred, y_test) print(score) -# Print the final ensemble built by AutoPyTorch -print(api.show_models()) # Print statistics from search print(api.sprint_statistics()) + +########################################################################### +# Refit the models on the full dataset. +# ===================================== + +api.refit( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + dataset_name="Australian", + # you can change the resampling strategy to + # for example, CrossValTypes.k_fold_cross_validation + # to fit k fold models and have a voting classifier + # resampling_strategy=CrossValTypes.k_fold_cross_validation +) + +############################################################################ +# Print the final ensemble performance after refit +# ================================================ + +y_pred = api.predict(X_test) +score = api.score(y_pred, y_test) +print(score) + +# Print the final ensemble built by AutoPyTorch +print(api.show_models()) diff --git a/examples/20_basics/example_tabular_regression.py b/examples/20_basics/example_tabular_regression.py index 127f26829..6357d23e1 100644 --- a/examples/20_basics/example_tabular_regression.py +++ b/examples/20_basics/example_tabular_regression.py @@ -50,19 +50,44 @@ optimize_metric='r2', total_walltime_limit=300, func_eval_time_limit_secs=50, + dataset_name="Boston" ) ############################################################################ -# Print the final ensemble performance -# ==================================== +# Print the final ensemble performance before refit +# ================================================= y_pred = api.predict(X_test) - -# Rescale the Neural Network predictions into the original target range score = api.score(y_pred, y_test) - print(score) -# Print the final ensemble built by AutoPyTorch -print(api.show_models()) # Print statistics from search print(api.sprint_statistics()) + +########################################################################### +# Refit the models on the full dataset. +# ===================================== + +api.refit( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + dataset_name="Boston", + total_walltime_limit=500, + run_time_limit_secs=50 + # you can change the resampling strategy to + # for example, CrossValTypes.k_fold_cross_validation + # to fit k fold models and have a voting classifier + # resampling_strategy=CrossValTypes.k_fold_cross_validation +) + +############################################################################ +# Print the final ensemble performance after refit +# ================================================ + +y_pred = api.predict(X_test) +score = api.score(y_pred, y_test) +print(score) + +# Print the final ensemble built by AutoPyTorch +print(api.show_models()) diff --git a/setup.py b/setup.py index bd524276d..422c6f24d 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ # noinspection PyInterpreter setuptools.setup( name="autoPyTorch", - version="0.2", + version="0.2.1", author="AutoML Freiburg", author_email="eddiebergmanhs@gmail.com", description=("Auto-PyTorch searches neural architectures using smac"), @@ -83,6 +83,7 @@ }, test_suite="pytest", data_files=[('configs', ['autoPyTorch/configs/default_pipeline_options.json']), - ('portfolio', ['autoPyTorch/configs/greedy_portfolio.json'])], + ('portfolio', ['autoPyTorch/configs/greedy_portfolio.json']), + ('forecasting_init', ['autoPyTorch/configs/forecasting_init_cfgs.json'])], dependency_links=['https://github.com/automl/automl_common.git/tarball/autoPyTorch#egg=package-0.0.1'] ) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 465d74c6b..f5c99b31d 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -46,6 +46,43 @@ HOLDOUT_NUM_SPLITS = 1 +def refit_test_estimator( + estimator, + X_train, + y_train, + X_test, + y_test, +): + estimator.refit( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test + ) + + # Check if the refit models are actually used in the new ensemble. + refit_ensemble_identifiers = estimator.ensemble_.get_selected_model_identifiers() + refit_run_history_path = os.path.join(estimator._backend.internals_directory, 'refit_run_history.json') + assert os.path.exists(refit_run_history_path) + + refit_run_history: RunHistory = RunHistory() + refit_run_history.update_from_json(refit_run_history_path, estimator.search_space) + + all_refit_runs_in_new_ensemble = [] + model_num_runs = [] + for run_key, run_value in refit_run_history.data.items(): + any_refit_run_in_new_ensemble = False + num_run = run_value.additional_info["num_run"] + model_num_runs.append(num_run) + for identifier in refit_ensemble_identifiers: + if num_run == identifier[1]: + any_refit_run_in_new_ensemble = True + break + all_refit_runs_in_new_ensemble.append(any_refit_run_in_new_ensemble) + + assert all(all_refit_runs_in_new_ensemble), "All successful runs in the refit should be a part of the new ensemble" + + # Test # ==== @unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', @@ -186,8 +223,9 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl # Ensemble Builder produced an ensemble estimator.ensemble_ is not None + ensemble_identifiers = estimator.ensemble_.identifiers_ # There should be a weight for each element of the ensemble - assert len(estimator.ensemble_.identifiers_) == len(estimator.ensemble_.weights_) + assert len(ensemble_identifiers) == len(estimator.ensemble_.weights_) y_pred = estimator.predict(X_test) assert np.shape(y_pred)[0] == np.shape(X_test)[0] @@ -207,6 +245,15 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl successful_num_run) assert 'train_loss' in incumbent_results + # Test refit on dummy data + refit_test_estimator( + estimator=estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test + ) + # Check that we can pickle dump_file = os.path.join(estimator._backend.temporary_directory, 'dump.pkl') @@ -217,9 +264,6 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl restored_estimator = pickle.load(f) restored_estimator.predict(X_test) - # Test refit on dummy data - estimator.refit(dataset=backend.load_datamanager()) - # Make sure that a configuration space is stored in the estimator assert isinstance(estimator.get_search_space(), CS.ConfigurationSpace) @@ -387,6 +431,15 @@ def test_tabular_regression(openml_name, resampling_strategy, backend, resamplin successful_num_run) assert 'train_loss' in incumbent_results, estimator.run_history.data + # Test refit on dummy data + refit_test_estimator( + estimator=estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test + ) + # Check that we can pickle dump_file = os.path.join(estimator._backend.temporary_directory, 'dump.pkl') @@ -397,9 +450,6 @@ def test_tabular_regression(openml_name, resampling_strategy, backend, resamplin restored_estimator = pickle.load(f) restored_estimator.predict(X_test) - # Test refit on dummy data - estimator.refit(dataset=backend.load_datamanager()) - # Make sure that a configuration space is stored in the estimator assert isinstance(estimator.get_search_space(), CS.ConfigurationSpace) @@ -580,7 +630,14 @@ def test_time_series_forecasting(forecasting_toy_dataset, resampling_strategy, b assert np.shape(y_pred) == np.shape(y_test) # Test refit on dummy data - estimator.refit(dataset=backend.load_datamanager()) + refit_test_estimator( + estimator=estimator, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test + ) + # Make sure that a configuration space is stored in the estimator assert isinstance(estimator.get_search_space(), CS.ConfigurationSpace) @@ -932,6 +989,61 @@ def test_pipeline_fit(openml_id, assert not os.path.exists(cv_model_path) +@pytest.mark.parametrize('openml_id,budget', [(40984, 1)]) +def test_pipeline_fit_pass_pipeline_options( + openml_id, + backend, + budget, + n_samples +): + # Get the data and check that contents of data-manager make sense + X, y = sklearn.datasets.fetch_openml( + data_id=int(openml_id), + return_X_y=True, as_frame=True + ) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X[:n_samples], y[:n_samples], random_state=1) + + # Search for a good configuration + estimator = TabularClassificationTask( + backend=backend, + ensemble_size=0 + ) + + dataset = estimator.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test) + + configuration = estimator.get_search_space(dataset).get_default_configuration() + pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset, + configuration=configuration, + run_time_limit_secs=50, + budget_type='epochs', + budget=budget, + pipeline_options={'early_stopping': 100} + ) + assert isinstance(dataset, BaseDataset) + assert isinstance(run_info, RunInfo) + assert isinstance(run_info.config, Configuration) + + assert isinstance(run_value, RunValue) + assert 'SUCCESS' in str(run_value.status) + + # Make sure that the pipeline can be pickled + dump_file = os.path.join(tempfile.gettempdir(), 'automl.dump.pkl') + with open(dump_file, 'wb') as f: + pickle.dump(pipeline, f) + + num_run_dir = estimator._backend.get_numrun_directory( + run_info.seed, run_value.additional_info['num_run'], budget=float(budget)) + model_path = os.path.join(num_run_dir, estimator._backend.get_model_filename( + run_info.seed, run_value.additional_info['num_run'], budget=float(budget))) + + # We expect the model path always + assert os.path.exists(model_path) + + @pytest.mark.parametrize('openml_id', (40984,)) @pytest.mark.parametrize('resampling_strategy,resampling_strategy_args', ((HoldoutValTypes.holdout_validation, {'val_share': 0.8}), diff --git a/test/test_api/test_base_api.py b/test/test_api/test_base_api.py index bb8f9c061..ac01af25e 100644 --- a/test/test_api/test_base_api.py +++ b/test/test_api/test_base_api.py @@ -27,7 +27,7 @@ def test_nonsupported_arguments(fit_dictionary_tabular): api = BaseTask() with pytest.raises(ValueError, match=r".*Invalid configuration arguments given.*"): - api.set_pipeline_config(unsupported=True) + api.set_pipeline_options(unsupported=True) with pytest.raises(ValueError, match=r".*No search space initialised and no dataset.*"): api.get_search_space() api.resampling_strategy = None @@ -95,7 +95,7 @@ def test_show_models(fit_dictionary_tabular): assert re.search(expected, api.show_models()) is not None -def test_set_pipeline_config(): +def test_set_pipeline_options(): # checks if we can correctly change the pipeline options BaseTask.__abstractmethods__ = set() estimator = BaseTask() @@ -103,7 +103,7 @@ def test_set_pipeline_config(): "budget_type": "epochs", "epochs": 51, "runtime": 360} - estimator.set_pipeline_config(**pipeline_options) + estimator.set_pipeline_options(**pipeline_options) assert pipeline_options.items() <= estimator.get_pipeline_options().items() @@ -118,12 +118,12 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud estimator = BaseTask(task_type='tabular_classification', ensemble_size=0) # Fixture pipeline config - default_pipeline_config = { + default_pipeline_options = { 'device': 'cpu', 'budget_type': 'epochs', 'epochs': 50, 'runtime': 3600, 'torch_num_threads': 1, 'early_stopping': 20, 'use_tensorboard_logger': False, 'metrics_during_training': True, 'optimize_metric': 'accuracy' } - default_pipeline_config.update(expected) + default_pipeline_options.update(expected) # Create pre-requisites dataset = fit_dictionary_tabular['backend'].load_datamanager() @@ -141,7 +141,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud enable_traditional_pipeline=False, total_walltime_limit=20, func_eval_time_limit_secs=10, load_models=False) - assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config + assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_options'] == default_pipeline_options assert list(smac_mock.call_args)[1]['max_budget'] == max_budget assert list(smac_mock.call_args)[1]['initial_budget'] == min_budget @@ -174,12 +174,12 @@ def test_pipeline_get_budget_forecasting(fit_dictionary_forecasting, min_budget, BaseTask.__abstractmethods__ = set() estimator = BaseTask(task_type='time_series_forecasting', ensemble_size=0) # Fixture pipeline config - default_pipeline_config = { + default_pipeline_options = { 'device': 'cpu', 'budget_type': 'epochs', 'epochs': 50, 'runtime': 3600, 'torch_num_threads': 1, 'early_stopping': 20, 'use_tensorboard_logger': False, 'metrics_during_training': True, 'optimize_metric': 'mean_MASE_forecasting' } - default_pipeline_config.update(expected) + default_pipeline_options.update(expected) # Create pre-requisites dataset = fit_dictionary_forecasting['backend'].load_datamanager() @@ -198,6 +198,6 @@ def test_pipeline_get_budget_forecasting(fit_dictionary_forecasting, min_budget, total_walltime_limit=20, func_eval_time_limit_secs=10, memory_limit=8192, load_models=False) - assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config + assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_options'] == default_pipeline_options assert list(smac_mock.call_args)[1]['max_budget'] == max_budget assert list(smac_mock.call_args)[1]['initial_budget'] == min_budget diff --git a/test/test_api/utils.py b/test/test_api/utils.py index bbee9a3c4..701c455c1 100644 --- a/test/test_api/utils.py +++ b/test/test_api/utils.py @@ -63,6 +63,11 @@ def _fit_and_predict(self, pipeline, fold: int, train_indices, test_indices=test_indices, ) + # the configuration is used in refit where + # pipeline.config is used to retrieve the + # original configuration. + pipeline.config = self.configuration + if add_pipeline_to_self: self.pipeline = pipeline else: @@ -94,7 +99,7 @@ def dummy_eval_train_function( include, exclude, disable_file_output, - pipeline_config=None, + pipeline_options=None, budget_type=None, init_params=None, logger_port=None, @@ -118,7 +123,7 @@ def dummy_eval_train_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates, ) evaluator.fit_predict_and_loss() @@ -137,7 +142,7 @@ def dummy_forecasting_eval_train_function( include, exclude, disable_file_output, - pipeline_config=None, + pipeline_options=None, budget_type=None, init_params=None, logger_port=None, @@ -163,7 +168,7 @@ def dummy_forecasting_eval_train_function( budget_type=budget_type, logger_port=logger_port, all_supported_metrics=all_supported_metrics, - pipeline_config=pipeline_config, + pipeline_options=pipeline_options, search_space_updates=search_space_updates, max_budget=max_budget, min_num_test_instances=min_num_test_instances, diff --git a/test/test_datasets/test_time_series_datasets.py b/test/test_datasets/test_time_series_datasets.py index fa8faa625..68d866e09 100644 --- a/test/test_datasets/test_time_series_datasets.py +++ b/test/test_datasets/test_time_series_datasets.py @@ -13,7 +13,7 @@ import torch -from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes from autoPyTorch.datasets.time_series_dataset import ( TimeSeriesForecastingDataset, TimeSeriesSequence, @@ -297,7 +297,8 @@ def test_freq_valeus(): def test_target_normalization(): Y = [[1, 2], [3, 4, 5]] - dataset = TimeSeriesForecastingDataset(None, Y, normalize_y=True) + dataset = TimeSeriesForecastingDataset(None, Y, normalize_y=True, + resampling_strategy=NoResamplingStrategyTypes.no_resampling) assert np.allclose(dataset.y_mean.values, np.vstack([np.mean(y) for y in Y])) assert np.allclose(dataset.y_std.values, np.vstack([np.std(y, ddof=1) for y in Y])) @@ -356,7 +357,8 @@ def test_test_tensors(backend, fit_dictionary_forecasting): assert test_tensors[0].shape == (n_seq * forecast_horizon, datamanager.num_features) assert test_tensors[1].shape == (n_seq * forecast_horizon, datamanager.num_targets) - datamanager2 = TimeSeriesForecastingDataset(X=None, Y=[[1, 2]]) + datamanager2 = TimeSeriesForecastingDataset(X=None, Y=[[1, 2]], + resampling_strategy=NoResamplingStrategyTypes.no_resampling) assert datamanager2.test_tensors is None @@ -397,7 +399,7 @@ def test_splits(): n_prediction_steps=10, freq='1M') # the length of each sequence does not support 5 splitions - assert len(dataset.splits) == 3 + assert len(dataset.splits) == 2 # datasets with long but little sequence y = [np.arange(4000) for _ in range(2)] @@ -457,6 +459,14 @@ def test_splits(): refit_set = dataset.create_refit_set() assert len(refit_set.splits[0][0]) == len(refit_set) + y = [np.arange(10)] + with pytest.raises(ValueError): + dataset = TimeSeriesForecastingDataset(None, y, + resampling_strategy=CrossValTypes.time_series_cross_validation, + resampling_strategy_args=resampling_strategy_args, + n_prediction_steps=5, + freq='1M') + def test_extract_time_features(): feature_shapes = {'b': 5, 'a': 3, 'c': 7, 'd': 12} diff --git a/test/test_evaluation/test_abstract_evaluator.py b/test/test_evaluation/test_abstract_evaluator.py index a0be2c3f3..bb9df88e7 100644 --- a/test/test_evaluation/test_abstract_evaluator.py +++ b/test/test_evaluation/test_abstract_evaluator.py @@ -307,7 +307,7 @@ def test_error_unsupported_budget_type(self): backend=backend, output_y_hat_optimization=False, queue=queue_mock, - pipeline_config={'budget_type': "error", 'error': 0}, + pipeline_options={'budget_type': "error", 'error': 0}, metric=accuracy, budget=0, configuration=1) diff --git a/test/test_evaluation/test_evaluators.py b/test/test_evaluation/test_evaluators.py index 2ca32af10..0f0f15cdc 100644 --- a/test/test_evaluation/test_evaluators.py +++ b/test/test_evaluation/test_evaluators.py @@ -97,12 +97,13 @@ def test_holdout(self, pipeline_mock): pipeline_mock.get_additional_run_info.return_value = None configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}) + pipeline_options={'budget_type': 'epochs', 'epochs': 50}) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -136,12 +137,13 @@ def test_cv(self, pipeline_mock): pipeline_mock.get_additional_run_info.return_value = None configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}) + pipeline_options={'budget_type': 'epochs', 'epochs': 50}) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -246,7 +248,7 @@ def test_predict_proba_binary_classification(self, mock): queue_ = multiprocessing.Queue() evaluator = TrainEvaluator(self.backend_mock, queue_, configuration=configuration, metric=accuracy, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}) + pipeline_options={'budget_type': 'epochs', 'epochs': 50}) evaluator.fit_predict_and_loss() Y_optimization_pred = self.backend_mock.save_numrun_to_dir.call_args_list[0][1][ @@ -278,12 +280,13 @@ def test_additional_metrics_during_training(self, pipeline_mock): D = get_binary_classification_datamanager() configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() evaluator = TrainEvaluator(backend_api, queue_, configuration=configuration, metric=accuracy, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}, all_supported_metrics=True) + pipeline_options={'budget_type': 'epochs', 'epochs': 50}, all_supported_metrics=True) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -339,6 +342,7 @@ def test_no_resampling(self, pipeline_mock): pipeline_mock.get_default_pipeline_options.return_value = {'budget_type': 'epochs', 'epochs': 10} configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, 'autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() diff --git a/test/test_evaluation/test_forecasting_evaluators.py b/test/test_evaluation/test_forecasting_evaluators.py index 580402d5c..5eea055df 100644 --- a/test/test_evaluation/test_forecasting_evaluators.py +++ b/test/test_evaluation/test_forecasting_evaluators.py @@ -60,8 +60,8 @@ def test_budget_type_choices(self, pipeline_mock): queue_, configuration=configuration, metric=mean_MASE_forecasting, budget=0, - pipeline_config={'budget_type': budget_type, - budget_type: 0.1}, + pipeline_options={'budget_type': budget_type, + budget_type: 0.1}, min_num_test_instances=100) self.assertTrue('epochs' not in evaluator.fit_dictionary) if budget_type == 'resolution': @@ -85,6 +85,7 @@ def test_holdout(self, pipeline_mock): pipeline_mock.get_additional_run_info.return_value = None configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() @@ -93,7 +94,7 @@ def test_holdout(self, pipeline_mock): queue_, configuration=configuration, metric=mean_MASE_forecasting, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}, + pipeline_options={'budget_type': 'epochs', 'epochs': 50}, min_num_test_instances=100) self.assertTrue('epochs' in evaluator.fit_dictionary) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) @@ -140,6 +141,7 @@ def test_cv(self, pipeline_mock): pipeline_mock.get_additional_run_info.return_value = None configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() @@ -148,7 +150,7 @@ def test_cv(self, pipeline_mock): queue_, configuration=configuration, metric=mean_MASE_forecasting, budget=0, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}) + pipeline_options={'budget_type': 'epochs', 'epochs': 50}) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -189,6 +191,7 @@ def test_proxy_val_set(self, pipeline_mock): pipeline_mock.get_additional_run_info.return_value = None configuration = unittest.mock.Mock(spec=Configuration) + configuration.get_dictionary.return_value = {} backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch') backend_api.load_datamanager = lambda: D queue_ = multiprocessing.Queue() @@ -197,7 +200,7 @@ def test_proxy_val_set(self, pipeline_mock): queue_, configuration=configuration, metric=mean_MASE_forecasting, budget=0.3, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}, + pipeline_options={'budget_type': 'epochs', 'epochs': 50}, min_num_test_instances=1) evaluator.file_output = unittest.mock.Mock(spec=evaluator.file_output) evaluator.file_output.return_value = (None, {}) @@ -247,7 +250,7 @@ def test_finish_up(self, pipeline_mock, queue_mock): queue_mock, configuration=configuration, metric=mean_MASE_forecasting, budget=0.3, - pipeline_config={'budget_type': 'epochs', 'epochs': 50}, + pipeline_options={'budget_type': 'epochs', 'epochs': 50}, min_num_test_instances=1) val_splits = D.splits[0][1] diff --git a/test/test_pipeline/components/preprocessing/forecasting/test_scaling.py b/test/test_pipeline/components/preprocessing/forecasting/test_scaling.py index 047806bc5..877cf7afa 100644 --- a/test/test_pipeline/components/preprocessing/forecasting/test_scaling.py +++ b/test/test_pipeline/components/preprocessing/forecasting/test_scaling.py @@ -26,6 +26,7 @@ def setUp(self) -> None: columns = ['f1', 's', 'f2'] self.raw_data = [data_seq_1, data_seq_2] + self.data = pd.DataFrame(np.concatenate([data_seq_1, data_seq_2]), columns=columns, index=[0] * 3 + [1] * 4) self.static_features = ('s',) self.static_features_column = (1, ) @@ -38,6 +39,9 @@ def setUp(self) -> None: 'static_features': self.static_features, 'is_small_preprocess': True} + self.small_data = pd.DataFrame(np.array([[1e-10, 0., 1e-15], + [-1e-10, 0., +1e-15]]), columns=columns, index=[0] * 2) + def test_base_and_standard_scaler(self): scaler_component = BaseScaler(scaling_mode='standard') X = { @@ -82,6 +86,10 @@ def test_base_and_standard_scaler(self): transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data]) self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]])) + scaler.dataset_is_small_preprocess = True + scaler.fit(self.small_data) + self.assertTrue(np.allclose(scaler.scale.values.flatten(), np.array([1.41421356e-10, 1., 1.]))) + def test_min_max(self): scaler = TimeSeriesScaler(mode='min_max', static_features=self.static_features @@ -109,6 +117,10 @@ def test_min_max(self): transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data]) self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]])) + scaler.dataset_is_small_preprocess = True + scaler.fit(self.small_data) + self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([2e-10, 1., 1.]))) + def test_max_abs_scaler(self): scaler = TimeSeriesScaler(mode='max_abs', static_features=self.static_features @@ -136,6 +148,10 @@ def test_max_abs_scaler(self): transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data]) self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]])) + scaler.dataset_is_small_preprocess = True + scaler.fit(self.small_data) + self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([1e-10, 1., 1.]))) + def test_mean_abs_scaler(self): scaler = TimeSeriesScaler(mode='mean_abs', static_features=self.static_features @@ -162,6 +178,10 @@ def test_mean_abs_scaler(self): transformed_test = np.concatenate([scaler.transform(raw_data) for raw_data in self.raw_data]) self.assertTrue(np.allclose(transformed_test[:, [0, -1]], transformed_test[:, [0, -1]])) + scaler.dataset_is_small_preprocess = True + scaler.fit(self.small_data) + self.assertTrue(np.all(scaler.scale.values.flatten() == np.array([1e-10, 1., 1.]))) + def test_no_scaler(self): scaler = TimeSeriesScaler(mode='none', static_features=self.static_features diff --git a/test/test_pipeline/components/setup/forecasting/test_forecasting_target_scaling.py b/test/test_pipeline/components/setup/forecasting/test_forecasting_target_scaling.py index a415e2e22..3cded4954 100644 --- a/test/test_pipeline/components/setup/forecasting/test_forecasting_target_scaling.py +++ b/test/test_pipeline/components/setup/forecasting/test_forecasting_target_scaling.py @@ -95,6 +95,11 @@ def test_target_mean_abs_scalar(self): self.assertIsNone(loc_full) + _, _, _, scale = scalar( + torch.Tensor([[1e-10, 1e-10, 1e-10], [1e-15, 1e-15, 1e-15]]).reshape([2, 3, 1]) + ) + self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.]))) + def test_target_standard_scalar(self): X = {'dataset_properties': {}} scalar = BaseTargetScaler(scaling_mode='standard') @@ -178,6 +183,11 @@ def test_target_standard_scalar(self): self.assertTrue(torch.equal(loc, loc_full)) self.assertTrue(torch.equal(scale, scale_full)) + _, _, _, scale = scalar( + torch.Tensor([[1e-10, -1e-10, 1e-10], [1e-15, -1e-15, 1e-15]]).reshape([2, 3, 1]) + ) + self.assertTrue(torch.all(torch.isclose(scale.flatten(), torch.Tensor([1.1547e-10, 1.])))) + def test_target_min_max_scalar(self): X = {'dataset_properties': {}} scalar = BaseTargetScaler(scaling_mode='min_max') @@ -245,6 +255,11 @@ def test_target_min_max_scalar(self): self.assertTrue(torch.equal(transformed_future_targets_full, transformed_future_targets_full)) self.assertTrue(torch.equal(scale, scale_full)) + _, _, _, scale = scalar( + torch.Tensor([[1e-10, 1e-10, 1e-10], [1e-15, 1e-15, 1e-15]]).reshape([2, 3, 1]) + ) + self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.]))) + def test_target_max_abs_scalar(self): X = {'dataset_properties': {}} scalar = BaseTargetScaler(scaling_mode='max_abs') @@ -309,3 +324,8 @@ def test_target_max_abs_scalar(self): self.assertTrue(torch.equal(transformed_future_targets_full, transformed_future_targets_full)) self.assertIsNone(loc_full) self.assertTrue(torch.equal(scale, scale_full)) + + _, _, _, scale = scalar( + torch.Tensor([[1e-10, 1e-10, 1e-10], [1e-15, 1e-15, 1e-15]]).reshape([2, 3, 1]) + ) + self.assertTrue(torch.equal(scale.flatten(), torch.Tensor([1e-10, 1.]))) diff --git a/test/test_utils/runhistory_no_test.json b/test/test_utils/runhistory_no_test.json new file mode 100644 index 000000000..35bf7311b --- /dev/null +++ b/test/test_utils/runhistory_no_test.json @@ -0,0 +1,1582 @@ +{ + "data": [ + [ + [ + 1, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.15204678362573099, + 3.154788017272949, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342638.6119366, + 1637342642.7887495, + { + "opt_loss": { + "accuracy": 0.15204678362573099, + "balanced_accuracy": 0.15263157894736845, + "roc_auc": 0.08981994459833786, + "average_precision": 0.1040861796433199, + "log_loss": 0.5765479137672738, + "precision": 0.17948717948717952, + "precision_macro": 0.15425971877584788, + "precision_micro": 0.15204678362573099, + "precision_weighted": 0.15145666758569976, + "recall": 0.1578947368421053, + "recall_macro": 0.15263157894736845, + "recall_micro": 0.15204678362573099, + "recall_weighted": 0.15204678362573099, + "f1": 0.16883116883116878, + "f1_macro": 0.15356452058579717, + "f1_micro": 0.15204678362573099, + "f1_weighted": 0.15186822633631147 + }, + "duration": 3.11077618598938, + "num_run": 8, + "train_loss": { + "accuracy": 0.09537572254335258, + "balanced_accuracy": 0.10239948774980623, + "roc_auc": 0.03963198867657458, + "average_precision": 0.044469547423341305, + "log_loss": 0.28008669264774966, + "precision": 0.03731343283582089, + "precision_macro": 0.08469445226696704, + "precision_micro": 0.09537572254335258, + "precision_weighted": 0.0890765118675354, + "recall": 0.17834394904458595, + "recall_macro": 0.10239948774980623, + "recall_micro": 0.09537572254335258, + "recall_weighted": 0.09537572254335258, + "f1": 0.11340206185567014, + "f1_macro": 0.09784816309741107, + "f1_micro": 0.09537572254335258, + "f1_weighted": 0.0964096522295953 + }, + "configuration_origin": "Default" + } + ] + ], + [ + [ + 2, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.4444444444444444, + 3.2763524055480957, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342642.963385, + 1637342647.2651122, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.25526315789473697, + "average_precision": 0.35005634879129066, + "log_loss": 1.0913122792494052, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 3.2138161659240723, + "num_run": 9, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.2745256630606949, + "average_precision": 0.4037230365622788, + "log_loss": 1.1229484684905306, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 3, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.5555555555555556, + 22.723600149154663, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342651.4707444, + 1637342675.2555833, + { + "opt_loss": { + "accuracy": 0.5555555555555556, + "balanced_accuracy": 0.5, + "roc_auc": 0.4924515235457063, + "average_precision": 0.5493808049535605, + "log_loss": 0.7291908971747459, + "precision": 0.5555555555555556, + "precision_macro": 0.7777777777777778, + "precision_micro": 0.5555555555555556, + "precision_weighted": 0.8024691358024691, + "recall": 0.0, + "recall_macro": 0.5, + "recall_micro": 0.5555555555555556, + "recall_weighted": 0.5555555555555556, + "f1": 0.3846153846153847, + "f1_macro": 0.6923076923076923, + "f1_micro": 0.5555555555555556, + "f1_weighted": 0.7264957264957266 + }, + "duration": 22.021637201309204, + "num_run": 10, + "train_loss": { + "accuracy": 0.546242774566474, + "balanced_accuracy": 0.5, + "roc_auc": 0.514423887035352, + "average_precision": 0.5521926852639938, + "log_loss": 0.7258427792546377, + "precision": 0.546242774566474, + "precision_macro": 0.773121387283237, + "precision_micro": 0.546242774566474, + "precision_weighted": 0.7941043803668683, + "recall": 0.0, + "recall_macro": 0.5, + "recall_micro": 0.546242774566474, + "recall_weighted": 0.546242774566474, + "f1": 0.3757455268389662, + "f1_macro": 0.6878727634194831, + "f1_micro": 0.546242774566474, + "f1_weighted": 0.7167400222939817 + }, + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 4, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.29824561403508776, + 4.990685224533081, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342675.317421, + 1637342681.334954, + { + "opt_loss": { + "accuracy": 0.29824561403508776, + "balanced_accuracy": 0.30263157894736836, + "roc_auc": 0.26869806094182835, + "average_precision": 0.3191125709864897, + "log_loss": 0.6374789248084465, + "precision": 0.33333333333333337, + "precision_macro": 0.30208333333333337, + "precision_micro": 0.29824561403508776, + "precision_weighted": 0.29861111111111116, + "recall": 0.3421052631578947, + "recall_macro": 0.30263157894736836, + "recall_micro": 0.29824561403508776, + "recall_weighted": 0.29824561403508776, + "f1": 0.3377483443708609, + "f1_macro": 0.30238202558857186, + "f1_micro": 0.29824561403508776, + "f1_weighted": 0.29845243461276183 + }, + "duration": 4.924501419067383, + "num_run": 11, + "train_loss": { + "accuracy": 0.3728323699421965, + "balanced_accuracy": 0.3800930138509757, + "roc_auc": 0.3314460957773059, + "average_precision": 0.3638537658311296, + "log_loss": 0.6533903728503023, + "precision": 0.4014084507042254, + "precision_macro": 0.3771748135874068, + "precision_micro": 0.3728323699421965, + "precision_weighted": 0.3749335523511692, + "recall": 0.4585987261146497, + "recall_macro": 0.3800930138509757, + "recall_micro": 0.3728323699421965, + "recall_weighted": 0.3728323699421965, + "f1": 0.43143812709030094, + "f1_macro": 0.3798412009497306, + "f1_micro": 0.3728323699421965, + "f1_weighted": 0.3750692309020478 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 5, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.4444444444444444, + 10.684926509857178, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342681.548915, + 1637342693.2717755, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.6092797783933518, + "average_precision": 0.6129755132627962, + "log_loss": 0.6905045174715811, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 10.401196956634521, + "num_run": 12, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.6309102551140767, + "average_precision": 0.6325768698403712, + "log_loss": 0.691941062839045, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 6, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.4444444444444444, + 9.947429180145264, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342693.356699, + 1637342704.341065, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.24930747922437668, + "average_precision": 0.31612650360994055, + "log_loss": 0.6525155201292875, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 9.76927137374878, + "num_run": 13, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.22427796313146642, + "average_precision": 0.2451792573360162, + "log_loss": 0.64721482587343, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 7, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 1.0, + 11.687273979187012, + { + "__enum__": "StatusType.CRASHED" + }, + 1637342713.4931705, + 1637342726.1866672, + { + "error": "Result queue is empty", + "exit_status": "", + "subprocess_stdout": "", + "subprocess_stderr": "Process pynisher function call:\nTraceback (most recent call last):\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/tae.py\", line 39, in fit_predict_try_except_decorator\n ta(queue=queue, **kwargs)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/train_evaluator.py\", line 485, in eval_function\n evaluator.fit_predict_and_loss()\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/train_evaluator.py\", line 163, in fit_predict_and_loss\n y_train_pred, y_opt_pred, y_valid_pred, y_test_pred = self._fit_and_predict(pipeline, split_id,\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/train_evaluator.py\", line 337, in _fit_and_predict\n fit_and_suppress_warnings(self.logger, pipeline, X, y)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/abstract_evaluator.py\", line 321, in fit_and_suppress_warnings\n pipeline.fit(X, y)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/base_pipeline.py\", line 153, in fit\n self.fit_estimator(X, y, **fit_params)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/base_pipeline.py\", line 172, in fit_estimator\n self._final_estimator.fit(X, y, **fit_params)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/components/training/trainer/__init__.py\", line 211, in fit\n self._fit(\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/components/training/trainer/__init__.py\", line 290, in _fit\n train_loss, train_metrics = self.choice.train_epoch(\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/components/training/trainer/base_trainer.py\", line 303, in train_epoch\n loss, outputs = self.train_step(data, targets)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/components/training/trainer/base_trainer.py\", line 357, in train_step\n outputs = self.model(data)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py\", line 117, in forward\n input = module(input)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py\", line 117, in forward\n input = module(input)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py\", line 117, in forward\n input = module(input)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/pipeline/components/setup/network_backbone/ResNetBackbone.py\", line 274, in forward\n x2 = self.shake_shake_layers(x)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py\", line 117, in forward\n input = module(input)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 727, in _call_impl\n result = self.forward(*input, **kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/modules/linear.py\", line 93, in forward\n return F.linear(input, self.weight, self.bias)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/torch/nn/functional.py\", line 1690, in linear\n ret = torch.addmm(bias, input, weight.t())\nRuntimeError: [enforce fail at CPUAllocator.cpp:65] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 713632 bytes. Error code 12 (Cannot allocate memory)\n\nDuring handling of the above exception, another exception occurred:\n\nTraceback (most recent call last):\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/multiprocessing/process.py\", line 315, in _bootstrap\n self.run()\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/multiprocessing/process.py\", line 108, in run\n self._target(*self._args, **self._kwargs)\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/site-packages/pynisher/limit_function_call.py\", line 138, in subprocess_func\n return_value = ((func(*args, **kwargs), 0))\n File \"/home/shuhei/research/Auto-PyTorch/autoPyTorch/evaluation/tae.py\", line 52, in fit_predict_try_except_decorator\n queue.put({'loss': cost_for_crash,\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/multiprocessing/queues.py\", line 88, in put\n self._start_thread()\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/multiprocessing/queues.py\", line 173, in _start_thread\n self._thread.start()\n File \"/home/shuhei/anaconda3/envs/auto-pytorch/lib/python3.8/threading.py\", line 852, in start\n _start_new_thread(self._bootstrap, ())\nRuntimeError: can't start new thread\n", + "exitcode": 1, + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 8, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.5555555555555556, + 8.478890419006348, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342733.815657, + 1637342743.3274522, + { + "opt_loss": { + "accuracy": 0.5555555555555556, + "balanced_accuracy": 0.5, + "roc_auc": 0.44605263157894737, + "average_precision": 0.526907034743722, + "log_loss": 0.722785997111895, + "precision": 0.5555555555555556, + "precision_macro": 0.7777777777777778, + "precision_micro": 0.5555555555555556, + "precision_weighted": 0.8024691358024691, + "recall": 0.0, + "recall_macro": 0.5, + "recall_micro": 0.5555555555555556, + "recall_weighted": 0.5555555555555556, + "f1": 0.3846153846153847, + "f1_macro": 0.6923076923076923, + "f1_micro": 0.5555555555555556, + "f1_weighted": 0.7264957264957266 + }, + "duration": 8.288825988769531, + "num_run": 15, + "train_loss": { + "accuracy": 0.546242774566474, + "balanced_accuracy": 0.5, + "roc_auc": 0.4537121288713646, + "average_precision": 0.5218043063878082, + "log_loss": 0.7198673617633092, + "precision": 0.546242774566474, + "precision_macro": 0.773121387283237, + "precision_micro": 0.546242774566474, + "precision_weighted": 0.7941043803668683, + "recall": 0.0, + "recall_macro": 0.5, + "recall_micro": 0.546242774566474, + "recall_weighted": 0.546242774566474, + "f1": 0.3757455268389662, + "f1_macro": 0.6878727634194831, + "f1_micro": 0.546242774566474, + "f1_weighted": 0.7167400222939817 + }, + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 9, + "{\"task_id\": \"Australian\"}", + 0, + 5.555555555555555 + ], + [ + 0.4444444444444444, + 5.485020637512207, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342743.4267018, + 1637342749.9442234, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.5, + "average_precision": 0.5555555555555556, + "log_loss": 15.350567287868923, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 5.376826286315918, + "num_run": 16, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.5, + "average_precision": 0.546242774566474, + "log_loss": 15.67221934809161, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 1, + "{\"task_id\": \"Australian\"}", + 0, + 16.666666666666664 + ], + [ + 0.15204678362573099, + 11.514830589294434, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342750.0053334, + 1637342762.5487585, + { + "opt_loss": { + "accuracy": 0.15204678362573099, + "balanced_accuracy": 0.15263157894736845, + "roc_auc": 0.08981994459833786, + "average_precision": 0.1040861796433199, + "log_loss": 0.5765479137672738, + "precision": 0.17948717948717952, + "precision_macro": 0.15425971877584788, + "precision_micro": 0.15204678362573099, + "precision_weighted": 0.15145666758569976, + "recall": 0.1578947368421053, + "recall_macro": 0.15263157894736845, + "recall_micro": 0.15204678362573099, + "recall_weighted": 0.15204678362573099, + "f1": 0.16883116883116878, + "f1_macro": 0.15356452058579717, + "f1_micro": 0.15204678362573099, + "f1_weighted": 0.15186822633631147 + }, + "duration": 11.44463586807251, + "num_run": 8, + "train_loss": { + "accuracy": 0.09537572254335258, + "balanced_accuracy": 0.10239948774980623, + "roc_auc": 0.03963198867657458, + "average_precision": 0.044469547423341305, + "log_loss": 0.28008669264774966, + "precision": 0.03731343283582089, + "precision_macro": 0.08469445226696704, + "precision_micro": 0.09537572254335258, + "precision_weighted": 0.0890765118675354, + "recall": 0.17834394904458595, + "recall_macro": 0.10239948774980623, + "recall_micro": 0.09537572254335258, + "recall_weighted": 0.09537572254335258, + "f1": 0.11340206185567014, + "f1_macro": 0.09784816309741107, + "f1_micro": 0.09537572254335258, + "f1_weighted": 0.0964096522295953 + }, + "configuration_origin": "Default" + } + ] + ], + [ + [ + 1, + "{\"task_id\": \"Australian\"}", + 0, + 50.0 + ], + [ + 0.15204678362573099, + 15.370736837387085, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342762.794756, + 1637342779.192385, + { + "opt_loss": { + "accuracy": 0.15204678362573099, + "balanced_accuracy": 0.15263157894736845, + "roc_auc": 0.08981994459833786, + "average_precision": 0.1040861796433199, + "log_loss": 0.5765479137672738, + "precision": 0.17948717948717952, + "precision_macro": 0.15425971877584788, + "precision_micro": 0.15204678362573099, + "precision_weighted": 0.15145666758569976, + "recall": 0.1578947368421053, + "recall_macro": 0.15263157894736845, + "recall_micro": 0.15204678362573099, + "recall_weighted": 0.15204678362573099, + "f1": 0.16883116883116878, + "f1_macro": 0.15356452058579717, + "f1_micro": 0.15204678362573099, + "f1_weighted": 0.15186822633631147 + }, + "duration": 15.300711154937744, + "num_run": 8, + "train_loss": { + "accuracy": 0.09537572254335258, + "balanced_accuracy": 0.10239948774980623, + "roc_auc": 0.03963198867657458, + "average_precision": 0.044469547423341305, + "log_loss": 0.28008669264774966, + "precision": 0.03731343283582089, + "precision_macro": 0.08469445226696704, + "precision_micro": 0.09537572254335258, + "precision_weighted": 0.0890765118675354, + "recall": 0.17834394904458595, + "recall_macro": 0.10239948774980623, + "recall_micro": 0.09537572254335258, + "recall_weighted": 0.09537572254335258, + "f1": 0.11340206185567014, + "f1_macro": 0.09784816309741107, + "f1_micro": 0.09537572254335258, + "f1_weighted": 0.0964096522295953 + }, + "configuration_origin": "Default" + } + ] + ], + [ + [ + 10, + "{\"task_id\": \"Australian\"}", + 0, + 16.666666666666664 + ], + [ + 0.4035087719298246, + 23.846530199050903, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342779.4572933, + 1637342804.3368232, + { + "opt_loss": { + "accuracy": 0.4035087719298246, + "balanced_accuracy": 0.39473684210526316, + "roc_auc": 0.3946675900277007, + "average_precision": 0.4846825737029168, + "log_loss": 6.419999084913276, + "precision": 0.4639175257731959, + "precision_macro": 0.39412092504876006, + "precision_micro": 0.4035087719298246, + "precision_weighted": 0.38636574719048944, + "recall": 0.3157894736842105, + "recall_macro": 0.39473684210526316, + "recall_micro": 0.4035087719298246, + "recall_weighted": 0.4035087719298246, + "f1": 0.3988439306358381, + "f1_macro": 0.4035639771522386, + "f1_micro": 0.4035087719298246, + "f1_weighted": 0.404088426765172 + }, + "duration": 23.588075160980225, + "num_run": 17, + "train_loss": { + "accuracy": 0.3988439306358381, + "balanced_accuracy": 0.3947359552455094, + "roc_auc": 0.4153776160145586, + "average_precision": 0.49525391358194226, + "log_loss": 7.800554750687936, + "precision": 0.4486486486486486, + "precision_macro": 0.39513177774047337, + "precision_micro": 0.3988439306358381, + "precision_weighted": 0.39018224054665374, + "recall": 0.35031847133757965, + "recall_macro": 0.3947359552455094, + "recall_micro": 0.3988439306358381, + "recall_weighted": 0.3988439306358381, + "f1": 0.4035087719298246, + "f1_macro": 0.39889724310776953, + "f1_micro": 0.3988439306358381, + "f1_weighted": 0.39847074333231924 + }, + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 11, + "{\"task_id\": \"Australian\"}", + 0, + 16.666666666666664 + ], + [ + 0.4444444444444444, + 6.757539510726929, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342813.02129, + 1637342820.8067145, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.6250692520775624, + "average_precision": 0.642743659212315, + "log_loss": 0.6874508627674036, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 6.695321321487427, + "num_run": 18, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.5560273649445624, + "average_precision": 0.5899410773859022, + "log_loss": 0.689653262926664, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 12, + "{\"task_id\": \"Australian\"}", + 0, + 16.666666666666664 + ], + [ + 0.4444444444444444, + 15.061991930007935, + { + "__enum__": "StatusType.SUCCESS" + }, + 1637342829.9214745, + 1637342846.0210106, + { + "opt_loss": { + "accuracy": 0.4444444444444444, + "balanced_accuracy": 0.5, + "roc_auc": 0.378393351800554, + "average_precision": 0.4680399341300143, + "log_loss": 0.6910723817278768, + "precision": 1.0, + "precision_macro": 0.7222222222222222, + "precision_micro": 0.4444444444444444, + "precision_weighted": 0.691358024691358, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.4444444444444444, + "recall_weighted": 0.4444444444444444, + "f1": 1.0, + "f1_macro": 0.6428571428571428, + "f1_micro": 0.4444444444444444, + "f1_weighted": 0.6031746031746031 + }, + "duration": 14.850486516952515, + "num_run": 19, + "train_loss": { + "accuracy": 0.45375722543352603, + "balanced_accuracy": 0.5, + "roc_auc": 0.44353452633707413, + "average_precision": 0.4796876232765652, + "log_loss": 0.6915661034556483, + "precision": 1.0, + "precision_macro": 0.726878612716763, + "precision_micro": 0.45375722543352603, + "precision_weighted": 0.7016188312339203, + "recall": 1.0, + "recall_macro": 0.5, + "recall_micro": 0.45375722543352603, + "recall_weighted": 0.45375722543352603, + "f1": 1.0, + "f1_macro": 0.6467289719626168, + "f1_micro": 0.45375722543352603, + "f1_weighted": 0.6140565069418183 + }, + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 10, + "{\"task_id\": \"Australian\"}", + 0, + 50.0 + ], + [ + 1.0, + 50.010520696640015, + { + "__enum__": "StatusType.TIMEOUT" + }, + 1637342846.0745292, + 1637342897.1205413, + { + "error": "Timeout", + "configuration_origin": "Random Search" + } + ] + ], + [ + [ + 13, + "{\"task_id\": \"Australian\"}", + 0, + 50.0 + ], + [ + 1.0, + 22.011935234069824, + { + "__enum__": "StatusType.TIMEOUT" + }, + 1637342905.7068844, + 1637342928.7456856, + { + "error": "Timeout", + "configuration_origin": "Random Search (sorted)" + } + ] + ], + [ + [ + 14, + "{\"task_id\": \"Australian\"}", + 0, + 50.0 + ], + [ + 1.0, + 0.0, + { + "__enum__": "StatusType.STOP" + }, + 1637342928.8133125, + 1637342928.8133128, + {} + ] + ] + ], + "config_origins": { + "1": "Default", + "2": "Random Search", + "3": "Random Search (sorted)", + "4": "Random Search", + "5": "Random Search", + "6": "Random Search", + "7": "Random Search (sorted)", + "8": "Random Search (sorted)", + "9": "Random Search", + "10": "Random Search", + "11": "Random Search (sorted)", + "12": "Random Search (sorted)", + "13": "Random Search (sorted)", + "14": "Random Search" + }, + "configs": { + "1": { + "data_loader:batch_size": 64, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "NoFeaturePreprocessor", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "mean", + "lr_scheduler:__choice__": "ReduceLROnPlateau", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "XavierInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "StandardScaler", + "trainer:__choice__": "StandardTrainer", + "lr_scheduler:ReduceLROnPlateau:factor": 0.1, + "lr_scheduler:ReduceLROnPlateau:mode": "min", + "lr_scheduler:ReduceLROnPlateau:patience": 10, + "network_backbone:ShapedMLPBackbone:activation": "relu", + "network_backbone:ShapedMLPBackbone:max_units": 200, + "network_backbone:ShapedMLPBackbone:mlp_shape": "funnel", + "network_backbone:ShapedMLPBackbone:num_groups": 5, + "network_backbone:ShapedMLPBackbone:output_dim": 200, + "network_backbone:ShapedMLPBackbone:use_dropout": false, + "network_head:fully_connected:num_layers": 2, + "network_init:XavierInit:bias_strategy": "Normal", + "optimizer:AdamOptimizer:beta1": 0.9, + "optimizer:AdamOptimizer:beta2": 0.9, + "optimizer:AdamOptimizer:lr": 0.01, + "optimizer:AdamOptimizer:weight_decay": 0.0, + "trainer:StandardTrainer:weighted_loss": true, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 128 + }, + "2": { + "data_loader:batch_size": 142, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "median", + "lr_scheduler:__choice__": "NoScheduler", + "network_backbone:__choice__": "ShapedResNetBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "KaimingInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "MixUpTrainer", + "feature_preprocessor:PowerTransformer:standardize": true, + "network_backbone:ShapedResNetBackbone:activation": "relu", + "network_backbone:ShapedResNetBackbone:blocks_per_group": 1, + "network_backbone:ShapedResNetBackbone:max_units": 175, + "network_backbone:ShapedResNetBackbone:num_groups": 3, + "network_backbone:ShapedResNetBackbone:output_dim": 550, + "network_backbone:ShapedResNetBackbone:resnet_shape": "funnel", + "network_backbone:ShapedResNetBackbone:use_dropout": false, + "network_backbone:ShapedResNetBackbone:use_shake_drop": true, + "network_backbone:ShapedResNetBackbone:use_shake_shake": true, + "network_head:fully_connected:num_layers": 2, + "network_init:KaimingInit:bias_strategy": "Normal", + "optimizer:AdamOptimizer:beta1": 0.8660298289969375, + "optimizer:AdamOptimizer:beta2": 0.9517157453274235, + "optimizer:AdamOptimizer:lr": 1.0377748473731365e-05, + "optimizer:AdamOptimizer:weight_decay": 0.07437634123996516, + "scaler:Normalizer:norm": "mean_abs", + "trainer:MixUpTrainer:alpha": 0.13179357367568267, + "trainer:MixUpTrainer:weighted_loss": true, + "network_backbone:ShapedResNetBackbone:max_shake_drop_probability": 0.7993610769045779, + "network_head:fully_connected:activation": "sigmoid", + "network_head:fully_connected:units_layer_1": 308 + }, + "3": { + "data_loader:batch_size": 246, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "most_frequent", + "lr_scheduler:__choice__": "CosineAnnealingWarmRestarts", + "network_backbone:__choice__": "ResNetBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "XavierInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:PowerTransformer:standardize": true, + "lr_scheduler:CosineAnnealingWarmRestarts:T_0": 17, + "lr_scheduler:CosineAnnealingWarmRestarts:T_mult": 1.0577034671447638, + "network_backbone:ResNetBackbone:activation": "sigmoid", + "network_backbone:ResNetBackbone:blocks_per_group_0": 1, + "network_backbone:ResNetBackbone:blocks_per_group_1": 4, + "network_backbone:ResNetBackbone:num_groups": 10, + "network_backbone:ResNetBackbone:num_units_0": 974, + "network_backbone:ResNetBackbone:num_units_1": 151, + "network_backbone:ResNetBackbone:use_dropout": true, + "network_backbone:ResNetBackbone:use_shake_drop": true, + "network_backbone:ResNetBackbone:use_shake_shake": true, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.6898659803969109, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.6193894885012183, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.27044405840757246, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.3353276257116905, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.25330009522745545, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.28087428370045076, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.985667346693578, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.532995443030165, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 4, + "network_head:fully_connected:num_layers": 4, + "network_init:XavierInit:bias_strategy": "Normal", + "optimizer:AdamOptimizer:beta1": 0.9319239463981089, + "optimizer:AdamOptimizer:beta2": 0.9290660642046109, + "optimizer:AdamOptimizer:lr": 4.934398361769327e-05, + "optimizer:AdamOptimizer:weight_decay": 0.0647885374302594, + "scaler:Normalizer:norm": "mean_abs", + "trainer:StandardTrainer:weighted_loss": true, + "network_backbone:ResNetBackbone:blocks_per_group_10": 2, + "network_backbone:ResNetBackbone:blocks_per_group_2": 1, + "network_backbone:ResNetBackbone:blocks_per_group_3": 1, + "network_backbone:ResNetBackbone:blocks_per_group_4": 3, + "network_backbone:ResNetBackbone:blocks_per_group_5": 1, + "network_backbone:ResNetBackbone:blocks_per_group_6": 3, + "network_backbone:ResNetBackbone:blocks_per_group_7": 4, + "network_backbone:ResNetBackbone:blocks_per_group_8": 1, + "network_backbone:ResNetBackbone:blocks_per_group_9": 3, + "network_backbone:ResNetBackbone:dropout_0": 0.1998483800982469, + "network_backbone:ResNetBackbone:dropout_1": 0.21671729531007777, + "network_backbone:ResNetBackbone:dropout_10": 0.2027668457966562, + "network_backbone:ResNetBackbone:dropout_2": 0.7140248388727144, + "network_backbone:ResNetBackbone:dropout_3": 0.1324478677866992, + "network_backbone:ResNetBackbone:dropout_4": 0.26711076053122573, + "network_backbone:ResNetBackbone:dropout_5": 0.2895993889716623, + "network_backbone:ResNetBackbone:dropout_6": 0.047419135928320616, + "network_backbone:ResNetBackbone:dropout_7": 0.593522761474697, + "network_backbone:ResNetBackbone:dropout_8": 0.11825542268484464, + "network_backbone:ResNetBackbone:dropout_9": 0.5802180655508312, + "network_backbone:ResNetBackbone:max_shake_drop_probability": 0.8422119101175598, + "network_backbone:ResNetBackbone:num_units_10": 1012, + "network_backbone:ResNetBackbone:num_units_2": 793, + "network_backbone:ResNetBackbone:num_units_3": 184, + "network_backbone:ResNetBackbone:num_units_4": 1022, + "network_backbone:ResNetBackbone:num_units_5": 88, + "network_backbone:ResNetBackbone:num_units_6": 666, + "network_backbone:ResNetBackbone:num_units_7": 927, + "network_backbone:ResNetBackbone:num_units_8": 614, + "network_backbone:ResNetBackbone:num_units_9": 552, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 92, + "network_head:fully_connected:units_layer_2": 202, + "network_head:fully_connected:units_layer_3": 171 + }, + "4": { + "data_loader:batch_size": 269, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "median", + "lr_scheduler:__choice__": "CosineAnnealingLR", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "KaimingInit", + "optimizer:__choice__": "RMSpropOptimizer", + "scaler:__choice__": "MinMaxScaler", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:PowerTransformer:standardize": true, + "lr_scheduler:CosineAnnealingLR:T_max": 57, + "network_backbone:ShapedMLPBackbone:activation": "relu", + "network_backbone:ShapedMLPBackbone:max_units": 199, + "network_backbone:ShapedMLPBackbone:mlp_shape": "stairs", + "network_backbone:ShapedMLPBackbone:num_groups": 12, + "network_backbone:ShapedMLPBackbone:output_dim": 641, + "network_backbone:ShapedMLPBackbone:use_dropout": true, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.8093046015402414, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.6952888698136637, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.7136167420874352, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.7071870846094686, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.8821351885181623, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.21840740866837938, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.7366390825638998, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.548715467945816, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 7, + "network_head:fully_connected:num_layers": 3, + "network_init:KaimingInit:bias_strategy": "Zero", + "optimizer:RMSpropOptimizer:alpha": 0.23716801972855298, + "optimizer:RMSpropOptimizer:lr": 0.0011708542709120838, + "optimizer:RMSpropOptimizer:momentum": 0.5620565618493047, + "optimizer:RMSpropOptimizer:weight_decay": 0.05858239202799009, + "trainer:StandardTrainer:weighted_loss": true, + "network_backbone:ShapedMLPBackbone:max_dropout": 0.20819857031346878, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 177, + "network_head:fully_connected:units_layer_2": 196 + }, + "5": { + "data_loader:batch_size": 191, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "RandomKitchenSinks", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "most_frequent", + "lr_scheduler:__choice__": "ExponentialLR", + "network_backbone:__choice__": "ShapedResNetBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "SparseInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "MinMaxScaler", + "trainer:__choice__": "MixUpTrainer", + "feature_preprocessor:RandomKitchenSinks:gamma": 0.00023806069646323692, + "feature_preprocessor:RandomKitchenSinks:n_components": 6, + "lr_scheduler:ExponentialLR:gamma": 0.7718494018636944, + "network_backbone:ShapedResNetBackbone:activation": "tanh", + "network_backbone:ShapedResNetBackbone:blocks_per_group": 3, + "network_backbone:ShapedResNetBackbone:max_units": 869, + "network_backbone:ShapedResNetBackbone:num_groups": 2, + "network_backbone:ShapedResNetBackbone:output_dim": 868, + "network_backbone:ShapedResNetBackbone:resnet_shape": "triangle", + "network_backbone:ShapedResNetBackbone:use_dropout": true, + "network_backbone:ShapedResNetBackbone:use_shake_drop": true, + "network_backbone:ShapedResNetBackbone:use_shake_shake": true, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.08846693746970624, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.6597252449477167, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.11290616066859738, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.4187266624427779, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.026810815995375492, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.02021466731982824, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.01616376260397212, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.6463510235731745, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 6, + "network_head:fully_connected:num_layers": 3, + "network_init:SparseInit:bias_strategy": "Normal", + "optimizer:AdamOptimizer:beta1": 0.9966044931531224, + "optimizer:AdamOptimizer:beta2": 0.9293356180290759, + "optimizer:AdamOptimizer:lr": 0.07180366191531826, + "optimizer:AdamOptimizer:weight_decay": 0.012304534471441598, + "trainer:MixUpTrainer:alpha": 0.8900376828213522, + "trainer:MixUpTrainer:weighted_loss": true, + "network_backbone:ShapedResNetBackbone:max_dropout": 0.6688622458251051, + "network_backbone:ShapedResNetBackbone:max_shake_drop_probability": 0.28903761225065516, + "network_head:fully_connected:activation": "sigmoid", + "network_head:fully_connected:units_layer_1": 198, + "network_head:fully_connected:units_layer_2": 283 + }, + "6": { + "data_loader:batch_size": 53, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "median", + "lr_scheduler:__choice__": "StepLR", + "network_backbone:__choice__": "ResNetBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "OrthogonalInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "MixUpTrainer", + "feature_preprocessor:PowerTransformer:standardize": false, + "lr_scheduler:StepLR:gamma": 0.27529217359012764, + "lr_scheduler:StepLR:step_size": 8, + "network_backbone:ResNetBackbone:activation": "sigmoid", + "network_backbone:ResNetBackbone:blocks_per_group_0": 3, + "network_backbone:ResNetBackbone:blocks_per_group_1": 1, + "network_backbone:ResNetBackbone:num_groups": 6, + "network_backbone:ResNetBackbone:num_units_0": 884, + "network_backbone:ResNetBackbone:num_units_1": 160, + "network_backbone:ResNetBackbone:use_dropout": false, + "network_backbone:ResNetBackbone:use_shake_drop": true, + "network_backbone:ResNetBackbone:use_shake_shake": false, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.5659324295712268, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.8744001957677244, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.3415903412295024, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.8599314829187148, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.9869678384973877, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.7490528427155283, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.9979477892240094, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.4171316119626819, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 5, + "network_head:fully_connected:num_layers": 1, + "network_init:OrthogonalInit:bias_strategy": "Zero", + "optimizer:AdamOptimizer:beta1": 0.9576776568384536, + "optimizer:AdamOptimizer:beta2": 0.9605074039230137, + "optimizer:AdamOptimizer:lr": 0.02098507521065345, + "optimizer:AdamOptimizer:weight_decay": 0.021686007599294888, + "scaler:Normalizer:norm": "mean_abs", + "trainer:MixUpTrainer:alpha": 0.8399712211486785, + "trainer:MixUpTrainer:weighted_loss": true, + "network_backbone:ResNetBackbone:blocks_per_group_2": 3, + "network_backbone:ResNetBackbone:blocks_per_group_3": 3, + "network_backbone:ResNetBackbone:blocks_per_group_4": 1, + "network_backbone:ResNetBackbone:blocks_per_group_5": 1, + "network_backbone:ResNetBackbone:blocks_per_group_6": 3, + "network_backbone:ResNetBackbone:max_shake_drop_probability": 0.09160627667494659, + "network_backbone:ResNetBackbone:num_units_2": 396, + "network_backbone:ResNetBackbone:num_units_3": 587, + "network_backbone:ResNetBackbone:num_units_4": 169, + "network_backbone:ResNetBackbone:num_units_5": 546, + "network_backbone:ResNetBackbone:num_units_6": 92 + }, + "7": { + "data_loader:batch_size": 232, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "RandomKitchenSinks", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "most_frequent", + "lr_scheduler:__choice__": "NoScheduler", + "network_backbone:__choice__": "ResNetBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "XavierInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:RandomKitchenSinks:gamma": 5.7968061542283495e-05, + "feature_preprocessor:RandomKitchenSinks:n_components": 5, + "network_backbone:ResNetBackbone:activation": "sigmoid", + "network_backbone:ResNetBackbone:blocks_per_group_0": 2, + "network_backbone:ResNetBackbone:blocks_per_group_1": 3, + "network_backbone:ResNetBackbone:num_groups": 14, + "network_backbone:ResNetBackbone:num_units_0": 63, + "network_backbone:ResNetBackbone:num_units_1": 229, + "network_backbone:ResNetBackbone:use_dropout": true, + "network_backbone:ResNetBackbone:use_shake_drop": false, + "network_backbone:ResNetBackbone:use_shake_shake": true, + "network_head:fully_connected:num_layers": 2, + "network_init:XavierInit:bias_strategy": "Normal", + "optimizer:AdamOptimizer:beta1": 0.9646917651093316, + "optimizer:AdamOptimizer:beta2": 0.9949552394978046, + "optimizer:AdamOptimizer:lr": 0.018422761006289576, + "optimizer:AdamOptimizer:weight_decay": 0.01700341747601285, + "scaler:Normalizer:norm": "mean_squared", + "trainer:StandardTrainer:weighted_loss": false, + "network_backbone:ResNetBackbone:blocks_per_group_10": 4, + "network_backbone:ResNetBackbone:blocks_per_group_11": 3, + "network_backbone:ResNetBackbone:blocks_per_group_12": 2, + "network_backbone:ResNetBackbone:blocks_per_group_13": 1, + "network_backbone:ResNetBackbone:blocks_per_group_14": 3, + "network_backbone:ResNetBackbone:blocks_per_group_2": 1, + "network_backbone:ResNetBackbone:blocks_per_group_3": 3, + "network_backbone:ResNetBackbone:blocks_per_group_4": 4, + "network_backbone:ResNetBackbone:blocks_per_group_5": 3, + "network_backbone:ResNetBackbone:blocks_per_group_6": 4, + "network_backbone:ResNetBackbone:blocks_per_group_7": 2, + "network_backbone:ResNetBackbone:blocks_per_group_8": 3, + "network_backbone:ResNetBackbone:blocks_per_group_9": 2, + "network_backbone:ResNetBackbone:dropout_0": 0.3872694110962167, + "network_backbone:ResNetBackbone:dropout_1": 0.7182095865182352, + "network_backbone:ResNetBackbone:dropout_10": 0.7518775284870586, + "network_backbone:ResNetBackbone:dropout_11": 0.3717581189860213, + "network_backbone:ResNetBackbone:dropout_12": 0.055178370982331075, + "network_backbone:ResNetBackbone:dropout_13": 0.5670307132839905, + "network_backbone:ResNetBackbone:dropout_14": 0.7859566818562779, + "network_backbone:ResNetBackbone:dropout_2": 0.5796670187291707, + "network_backbone:ResNetBackbone:dropout_3": 0.05370643307213783, + "network_backbone:ResNetBackbone:dropout_4": 0.37288408223729974, + "network_backbone:ResNetBackbone:dropout_5": 0.47179695650262793, + "network_backbone:ResNetBackbone:dropout_6": 0.20070003914010803, + "network_backbone:ResNetBackbone:dropout_7": 0.638048407623313, + "network_backbone:ResNetBackbone:dropout_8": 0.6190670404279601, + "network_backbone:ResNetBackbone:dropout_9": 0.33325853682297146, + "network_backbone:ResNetBackbone:num_units_10": 925, + "network_backbone:ResNetBackbone:num_units_11": 164, + "network_backbone:ResNetBackbone:num_units_12": 247, + "network_backbone:ResNetBackbone:num_units_13": 339, + "network_backbone:ResNetBackbone:num_units_14": 769, + "network_backbone:ResNetBackbone:num_units_2": 502, + "network_backbone:ResNetBackbone:num_units_3": 101, + "network_backbone:ResNetBackbone:num_units_4": 842, + "network_backbone:ResNetBackbone:num_units_5": 906, + "network_backbone:ResNetBackbone:num_units_6": 933, + "network_backbone:ResNetBackbone:num_units_7": 329, + "network_backbone:ResNetBackbone:num_units_8": 898, + "network_backbone:ResNetBackbone:num_units_9": 161, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 327 + }, + "8": { + "data_loader:batch_size": 164, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "NoFeaturePreprocessor", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "mean", + "lr_scheduler:__choice__": "StepLR", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "NoInit", + "optimizer:__choice__": "RMSpropOptimizer", + "scaler:__choice__": "MinMaxScaler", + "trainer:__choice__": "MixUpTrainer", + "lr_scheduler:StepLR:gamma": 0.2905213739360219, + "lr_scheduler:StepLR:step_size": 10, + "network_backbone:ShapedMLPBackbone:activation": "sigmoid", + "network_backbone:ShapedMLPBackbone:max_units": 903, + "network_backbone:ShapedMLPBackbone:mlp_shape": "brick", + "network_backbone:ShapedMLPBackbone:num_groups": 10, + "network_backbone:ShapedMLPBackbone:output_dim": 943, + "network_backbone:ShapedMLPBackbone:use_dropout": false, + "network_head:fully_connected:num_layers": 3, + "network_init:NoInit:bias_strategy": "Zero", + "optimizer:RMSpropOptimizer:alpha": 0.25445785033325663, + "optimizer:RMSpropOptimizer:lr": 0.00012058949092384073, + "optimizer:RMSpropOptimizer:momentum": 0.6601732030357997, + "optimizer:RMSpropOptimizer:weight_decay": 0.030275825765581223, + "trainer:MixUpTrainer:alpha": 0.2222082093355312, + "trainer:MixUpTrainer:weighted_loss": true, + "network_head:fully_connected:activation": "sigmoid", + "network_head:fully_connected:units_layer_1": 110, + "network_head:fully_connected:units_layer_2": 70 + }, + "9": { + "data_loader:batch_size": 94, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PolynomialFeatures", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "mean", + "lr_scheduler:__choice__": "CyclicLR", + "network_backbone:__choice__": "MLPBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "KaimingInit", + "optimizer:__choice__": "RMSpropOptimizer", + "scaler:__choice__": "StandardScaler", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:PolynomialFeatures:degree": 2, + "feature_preprocessor:PolynomialFeatures:include_bias": false, + "feature_preprocessor:PolynomialFeatures:interaction_only": false, + "lr_scheduler:CyclicLR:base_lr": 0.09510483864725039, + "lr_scheduler:CyclicLR:max_lr": 0.026215723559513626, + "lr_scheduler:CyclicLR:mode": "triangular", + "lr_scheduler:CyclicLR:step_size_up": 3829, + "network_backbone:MLPBackbone:activation": "sigmoid", + "network_backbone:MLPBackbone:num_groups": 7, + "network_backbone:MLPBackbone:num_units_1": 47, + "network_backbone:MLPBackbone:use_dropout": true, + "network_head:fully_connected:num_layers": 3, + "network_init:KaimingInit:bias_strategy": "Zero", + "optimizer:RMSpropOptimizer:alpha": 0.75085811094601, + "optimizer:RMSpropOptimizer:lr": 0.0002950013672615944, + "optimizer:RMSpropOptimizer:momentum": 0.515129966307681, + "optimizer:RMSpropOptimizer:weight_decay": 0.01979731884468683, + "trainer:StandardTrainer:weighted_loss": true, + "network_backbone:MLPBackbone:dropout_1": 0.5362963908147109, + "network_backbone:MLPBackbone:dropout_2": 0.09403575191589564, + "network_backbone:MLPBackbone:dropout_3": 0.5576340928985162, + "network_backbone:MLPBackbone:dropout_4": 0.3102921398336836, + "network_backbone:MLPBackbone:dropout_5": 0.36841155269138837, + "network_backbone:MLPBackbone:dropout_6": 0.459182557172949, + "network_backbone:MLPBackbone:dropout_7": 0.2741849570242409, + "network_backbone:MLPBackbone:num_units_2": 323, + "network_backbone:MLPBackbone:num_units_3": 424, + "network_backbone:MLPBackbone:num_units_4": 637, + "network_backbone:MLPBackbone:num_units_5": 668, + "network_backbone:MLPBackbone:num_units_6": 507, + "network_backbone:MLPBackbone:num_units_7": 972, + "network_head:fully_connected:activation": "sigmoid", + "network_head:fully_connected:units_layer_1": 482, + "network_head:fully_connected:units_layer_2": 425 + }, + "10": { + "data_loader:batch_size": 70, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "constant_zero", + "lr_scheduler:__choice__": "CyclicLR", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "SparseInit", + "optimizer:__choice__": "RMSpropOptimizer", + "scaler:__choice__": "NoScaler", + "trainer:__choice__": "MixUpTrainer", + "feature_preprocessor:PowerTransformer:standardize": false, + "lr_scheduler:CyclicLR:base_lr": 0.05611929109855669, + "lr_scheduler:CyclicLR:max_lr": 0.01831055731936772, + "lr_scheduler:CyclicLR:mode": "triangular2", + "lr_scheduler:CyclicLR:step_size_up": 3104, + "network_backbone:ShapedMLPBackbone:activation": "tanh", + "network_backbone:ShapedMLPBackbone:max_units": 759, + "network_backbone:ShapedMLPBackbone:mlp_shape": "brick", + "network_backbone:ShapedMLPBackbone:num_groups": 10, + "network_backbone:ShapedMLPBackbone:output_dim": 155, + "network_backbone:ShapedMLPBackbone:use_dropout": true, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.14333490052026576, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.7794060644969828, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.28020395999441905, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.2820327943739419, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.7390548552027222, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.025302711343403672, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.5677825375428477, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.7093786601691139, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 6, + "network_head:fully_connected:num_layers": 4, + "network_init:SparseInit:bias_strategy": "Zero", + "optimizer:RMSpropOptimizer:alpha": 0.15659144532965727, + "optimizer:RMSpropOptimizer:lr": 0.015691888676781927, + "optimizer:RMSpropOptimizer:momentum": 0.30317416976729206, + "optimizer:RMSpropOptimizer:weight_decay": 0.010642526008626797, + "trainer:MixUpTrainer:alpha": 0.3709089665342886, + "trainer:MixUpTrainer:weighted_loss": false, + "network_backbone:ShapedMLPBackbone:max_dropout": 0.3789762581174825, + "network_head:fully_connected:activation": "tanh", + "network_head:fully_connected:units_layer_1": 499, + "network_head:fully_connected:units_layer_2": 465, + "network_head:fully_connected:units_layer_3": 238 + }, + "11": { + "data_loader:batch_size": 274, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "RandomKitchenSinks", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "mean", + "lr_scheduler:__choice__": "CyclicLR", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "SparseInit", + "optimizer:__choice__": "AdamOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:RandomKitchenSinks:gamma": 0.00017836687895829377, + "feature_preprocessor:RandomKitchenSinks:n_components": 3, + "lr_scheduler:CyclicLR:base_lr": 0.061001847805883254, + "lr_scheduler:CyclicLR:max_lr": 0.037867703829357294, + "lr_scheduler:CyclicLR:mode": "triangular", + "lr_scheduler:CyclicLR:step_size_up": 3395, + "network_backbone:ShapedMLPBackbone:activation": "relu", + "network_backbone:ShapedMLPBackbone:max_units": 94, + "network_backbone:ShapedMLPBackbone:mlp_shape": "diamond", + "network_backbone:ShapedMLPBackbone:num_groups": 4, + "network_backbone:ShapedMLPBackbone:output_dim": 763, + "network_backbone:ShapedMLPBackbone:use_dropout": false, + "network_head:fully_connected:num_layers": 3, + "network_init:SparseInit:bias_strategy": "Zero", + "optimizer:AdamOptimizer:beta1": 0.9010241766841086, + "optimizer:AdamOptimizer:beta2": 0.9275862063741073, + "optimizer:AdamOptimizer:lr": 0.00048241454070108375, + "optimizer:AdamOptimizer:weight_decay": 0.058438892093437125, + "scaler:Normalizer:norm": "max", + "trainer:StandardTrainer:weighted_loss": true, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 293, + "network_head:fully_connected:units_layer_2": 177 + }, + "12": { + "data_loader:batch_size": 191, + "encoder:__choice__": "NoEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "NoFeaturePreprocessor", + "imputer:categorical_strategy": "constant_!missing!", + "imputer:numerical_strategy": "median", + "lr_scheduler:__choice__": "CosineAnnealingWarmRestarts", + "network_backbone:__choice__": "ResNetBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "XavierInit", + "optimizer:__choice__": "RMSpropOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "StandardTrainer", + "lr_scheduler:CosineAnnealingWarmRestarts:T_0": 18, + "lr_scheduler:CosineAnnealingWarmRestarts:T_mult": 1.7405132785152093, + "network_backbone:ResNetBackbone:activation": "relu", + "network_backbone:ResNetBackbone:blocks_per_group_0": 1, + "network_backbone:ResNetBackbone:blocks_per_group_1": 4, + "network_backbone:ResNetBackbone:num_groups": 6, + "network_backbone:ResNetBackbone:num_units_0": 894, + "network_backbone:ResNetBackbone:num_units_1": 395, + "network_backbone:ResNetBackbone:use_dropout": true, + "network_backbone:ResNetBackbone:use_shake_drop": false, + "network_backbone:ResNetBackbone:use_shake_shake": false, + "network_head:fully_connected:num_layers": 4, + "network_init:XavierInit:bias_strategy": "Zero", + "optimizer:RMSpropOptimizer:alpha": 0.6521242194975473, + "optimizer:RMSpropOptimizer:lr": 4.097035283946373e-05, + "optimizer:RMSpropOptimizer:momentum": 0.1792833337110808, + "optimizer:RMSpropOptimizer:weight_decay": 0.006909623450893943, + "scaler:Normalizer:norm": "mean_abs", + "trainer:StandardTrainer:weighted_loss": false, + "network_backbone:ResNetBackbone:blocks_per_group_2": 2, + "network_backbone:ResNetBackbone:blocks_per_group_3": 3, + "network_backbone:ResNetBackbone:blocks_per_group_4": 2, + "network_backbone:ResNetBackbone:blocks_per_group_5": 4, + "network_backbone:ResNetBackbone:blocks_per_group_6": 1, + "network_backbone:ResNetBackbone:dropout_0": 0.6575114752945207, + "network_backbone:ResNetBackbone:dropout_1": 0.28916184819601504, + "network_backbone:ResNetBackbone:dropout_2": 0.09888388652277876, + "network_backbone:ResNetBackbone:dropout_3": 0.791809735686961, + "network_backbone:ResNetBackbone:dropout_4": 0.06432675017963892, + "network_backbone:ResNetBackbone:dropout_5": 0.3015819044494064, + "network_backbone:ResNetBackbone:dropout_6": 0.792332044450592, + "network_backbone:ResNetBackbone:num_units_2": 173, + "network_backbone:ResNetBackbone:num_units_3": 290, + "network_backbone:ResNetBackbone:num_units_4": 633, + "network_backbone:ResNetBackbone:num_units_5": 16, + "network_backbone:ResNetBackbone:num_units_6": 542, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 429, + "network_head:fully_connected:units_layer_2": 342, + "network_head:fully_connected:units_layer_3": 322 + }, + "13": { + "data_loader:batch_size": 35, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "PowerTransformer", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "most_frequent", + "lr_scheduler:__choice__": "ExponentialLR", + "network_backbone:__choice__": "ShapedMLPBackbone", + "network_embedding:__choice__": "NoEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "XavierInit", + "optimizer:__choice__": "AdamWOptimizer", + "scaler:__choice__": "Normalizer", + "trainer:__choice__": "MixUpTrainer", + "feature_preprocessor:PowerTransformer:standardize": false, + "lr_scheduler:ExponentialLR:gamma": 0.863670772292724, + "network_backbone:ShapedMLPBackbone:activation": "sigmoid", + "network_backbone:ShapedMLPBackbone:max_units": 957, + "network_backbone:ShapedMLPBackbone:mlp_shape": "long_funnel", + "network_backbone:ShapedMLPBackbone:num_groups": 7, + "network_backbone:ShapedMLPBackbone:output_dim": 16, + "network_backbone:ShapedMLPBackbone:use_dropout": true, + "network_head:fully_connected:num_layers": 3, + "network_init:XavierInit:bias_strategy": "Normal", + "optimizer:AdamWOptimizer:beta1": 0.9298951109018316, + "optimizer:AdamWOptimizer:beta2": 0.9367719861032991, + "optimizer:AdamWOptimizer:lr": 2.3043911799203502e-05, + "optimizer:AdamWOptimizer:weight_decay": 0.08948752020001628, + "scaler:Normalizer:norm": "max", + "trainer:MixUpTrainer:alpha": 0.1848582510096881, + "trainer:MixUpTrainer:weighted_loss": true, + "network_backbone:ShapedMLPBackbone:max_dropout": 0.4933977554884884, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 105, + "network_head:fully_connected:units_layer_2": 185 + }, + "14": { + "data_loader:batch_size": 154, + "encoder:__choice__": "OneHotEncoder", + "coalescer:__choice__": "NoCoalescer", + "feature_preprocessor:__choice__": "KernelPCA", + "imputer:categorical_strategy": "most_frequent", + "imputer:numerical_strategy": "mean", + "lr_scheduler:__choice__": "StepLR", + "network_backbone:__choice__": "ResNetBackbone", + "network_embedding:__choice__": "LearnedEntityEmbedding", + "network_head:__choice__": "fully_connected", + "network_init:__choice__": "NoInit", + "optimizer:__choice__": "AdamWOptimizer", + "scaler:__choice__": "NoScaler", + "trainer:__choice__": "StandardTrainer", + "feature_preprocessor:KernelPCA:kernel": "sigmoid", + "feature_preprocessor:KernelPCA:n_components": 3, + "lr_scheduler:StepLR:gamma": 0.5658285105415104, + "lr_scheduler:StepLR:step_size": 10, + "network_backbone:ResNetBackbone:activation": "tanh", + "network_backbone:ResNetBackbone:blocks_per_group_0": 2, + "network_backbone:ResNetBackbone:blocks_per_group_1": 2, + "network_backbone:ResNetBackbone:num_groups": 1, + "network_backbone:ResNetBackbone:num_units_0": 623, + "network_backbone:ResNetBackbone:num_units_1": 42, + "network_backbone:ResNetBackbone:use_dropout": false, + "network_backbone:ResNetBackbone:use_shake_drop": true, + "network_backbone:ResNetBackbone:use_shake_shake": true, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_0": 0.7061800992159439, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_1": 0.40404533505032336, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_2": 0.14124612419045746, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_3": 0.24304972767199295, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_4": 0.8403938666630251, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_5": 0.11081539209354929, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_6": 0.5150164644256714, + "network_embedding:LearnedEntityEmbedding:dimension_reduction_7": 0.6185258490472787, + "network_embedding:LearnedEntityEmbedding:min_unique_values_for_embedding": 3, + "network_head:fully_connected:num_layers": 2, + "network_init:NoInit:bias_strategy": "Zero", + "optimizer:AdamWOptimizer:beta1": 0.9639206805787317, + "optimizer:AdamWOptimizer:beta2": 0.9439342949959634, + "optimizer:AdamWOptimizer:lr": 0.05110804312778185, + "optimizer:AdamWOptimizer:weight_decay": 0.026136253949706992, + "trainer:StandardTrainer:weighted_loss": true, + "feature_preprocessor:KernelPCA:coef0": 0.27733876378393374, + "network_backbone:ResNetBackbone:max_shake_drop_probability": 0.4280891218905112, + "network_head:fully_connected:activation": "relu", + "network_head:fully_connected:units_layer_1": 506 + } + } + } \ No newline at end of file diff --git a/test/test_utils/test_parallel_model_runner.py b/test/test_utils/test_parallel_model_runner.py new file mode 100644 index 000000000..a0a163f6e --- /dev/null +++ b/test/test_utils/test_parallel_model_runner.py @@ -0,0 +1,66 @@ +import unittest.mock +from test.test_api.utils import dummy_eval_train_function +from test.test_evaluation.evaluation_util import get_binary_classification_datamanager + +from ConfigSpace import Configuration + +from smac.tae import StatusType + +from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics +from autoPyTorch.utils.logging_ import PicklableClientLogger +from autoPyTorch.utils.parallel_model_runner import run_models_on_dataset +from autoPyTorch.utils.pipeline import get_configuration_space, get_dataset_requirements +from autoPyTorch.utils.single_thread_client import SingleThreadedClient + + +@unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', + new=dummy_eval_train_function) +def test_run_models_on_dataset(backend): + dataset = get_binary_classification_datamanager() + backend.save_datamanager(dataset) + # Search for a good configuration + dataset_requirements = get_dataset_requirements( + info=dataset.get_required_dataset_info() + ) + dataset_properties = dataset.get_dataset_properties(dataset_requirements) + search_space = get_configuration_space(info=dataset_properties) + num_random_configs = 5 + model_configurations = [(search_space.sample_configuration(), 1) for _ in range(num_random_configs)] + # Add a traditional model + model_configurations.append(('lgb', 1)) + + metric = get_metrics(dataset_properties=dataset_properties, + names=["accuracy"], + all_supported_metrics=False).pop() + logger = unittest.mock.Mock(spec=PicklableClientLogger) + + dask_client = SingleThreadedClient() + + runhistory = run_models_on_dataset( + time_left=15, + func_eval_time_limit_secs=5, + model_configs=model_configurations, + logger=logger, + metric=metric, + dask_client=dask_client, + backend=backend, + seed=1, + multiprocessing_context="fork", + current_search_space=search_space, + ) + + has_successful_model = False + has_matching_config = False + # assert atleast 1 successfully fitted model + for run_key, run_value in runhistory.data.items(): + if run_value.status == StatusType.SUCCESS: + has_successful_model = True + configuration = run_value.additional_info['configuration'] + for (config, _) in model_configurations: + if isinstance(config, Configuration): + config = config.get_dictionary() + if config == configuration: + has_matching_config = True + + assert has_successful_model, "Atleast 1 model should be successfully trained" + assert has_matching_config, "Configurations should match with the passed model configurations" diff --git a/test/test_utils/test_results_manager.py b/test/test_utils/test_results_manager.py index 496aec7fa..102526e1e 100644 --- a/test/test_utils/test_results_manager.py +++ b/test/test_utils/test_results_manager.py @@ -404,6 +404,59 @@ def test_search_results_sprint_statistics(): assert all([m1 == m2 for m1, m2 in zip(api.sprint_statistics().split("\n"), msg)]) +def test_search_results_sprint_statistics_no_test(): + BaseTask.__abstractmethods__ = set() + api = BaseTask() + for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']: + with pytest.raises(RuntimeError): + getattr(api, method)() + + run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), + 'runhistory_no_test.json'), + mode='r'))['data'] + api._results_manager.run_history = MagicMock() + api.run_history.empty = MagicMock(return_value=False) + + # The run_history has 16 runs + 1 run interruption ==> 16 runs + api.run_history.data = make_dict_run_history_data(run_history_data) + api._metric = accuracy + api.dataset_name = 'iris' + api._scoring_functions = [accuracy, balanced_accuracy] + api.search_space = MagicMock(spec=ConfigurationSpace) + worst_val = api._metric._worst_possible_result + search_results = api.get_search_results() + + _check_status(search_results.status_types) + _check_costs(search_results.opt_scores) + _check_end_times(search_results.end_times) + _check_fit_times(search_results.fit_times) + _check_budgets(search_results.budgets) + _check_metric_dict(search_results.opt_metric_dict, search_results.status_types, worst_val) + _check_additional_infos(status_types=search_results.status_types, + additional_infos=search_results.additional_infos) + + # config_ids can duplicate because of various budget size + config_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 10, 11, 12, 10, 13] + assert config_ids == search_results.config_ids + + # assert that contents of search_results are of expected types + assert isinstance(search_results.rank_opt_scores, np.ndarray) + assert search_results.rank_opt_scores.dtype is np.dtype(np.int) + assert isinstance(search_results.configs, list) + + n_success, n_timeout, n_memoryout, n_crashed = 13, 2, 0, 1 + msg = ["autoPyTorch results:", f"\tDataset name: {api.dataset_name}", + f"\tOptimisation Metric: {api._metric.name}", + f"\tBest validation score: {max(search_results.opt_scores)}", + "\tNumber of target algorithm runs: 16", f"\tNumber of successful target algorithm runs: {n_success}", + f"\tNumber of crashed target algorithm runs: {n_crashed}", + f"\tNumber of target algorithms that exceeded the time limit: {n_timeout}", + f"\tNumber of target algorithms that exceeded the memory limit: {n_memoryout}"] + + assert isinstance(api.sprint_statistics(), str) + assert all([m1 == m2 for m1, m2 in zip(api.sprint_statistics().split("\n"), msg)]) + + @pytest.mark.parametrize('run_history', (None, RunHistory())) def test_check_run_history(run_history): manager = ResultsManager() diff --git a/test/test_utils/test_results_visualizer.py b/test/test_utils/test_results_visualizer.py index e31571ef0..4e1bc5951 100644 --- a/test/test_utils/test_results_visualizer.py +++ b/test/test_utils/test_results_visualizer.py @@ -195,6 +195,42 @@ def test_plot_perf_over_time(metric_name): # TODO plt.close() +@pytest.mark.parametrize('metric_name', ('balanced_accuracy', 'accuracy')) +def test_plot_perf_over_time_no_test(metric_name): # TODO + dummy_history = [{'Timestamp': datetime(2022, 1, 1), 'train_accuracy': 1, 'test_accuracy': None}] + BaseTask.__abstractmethods__ = set() + api = BaseTask() + run_history_data = json.load(open(os.path.join(os.path.dirname(__file__), + 'runhistory_no_test.json'), + mode='r'))['data'] + api._results_manager.run_history = MagicMock() + api.run_history.empty = MagicMock(return_value=False) + + # The run_history has 16 runs + 1 run interruption ==> 16 runs + api.run_history.data = make_dict_run_history_data(run_history_data) + api._results_manager.ensemble_performance_history = dummy_history + api._metric = accuracy + api.dataset_name = 'iris' + api._scoring_functions = [accuracy, balanced_accuracy] + api.search_space = MagicMock(spec=ConfigurationSpace) + + api.plot_perf_over_time(metric_name=metric_name) + _, ax = plt.subplots(nrows=1, ncols=1) + api.plot_perf_over_time(metric_name=metric_name, ax=ax) + + # remove ensemble keys if metric name is not for the opt score + ans = set([ + name + for name in [f'single train {metric_name}', + f'single opt {metric_name}', + f'ensemble train {metric_name}'] + if metric_name == api._metric.name or not name.startswith('ensemble') + ]) + legend_set = set([txt._text for txt in ax.get_legend().texts]) + assert ans == legend_set + plt.close() + + @pytest.mark.parametrize('params', ( PlotSettingParams(xscale='none', yscale='none'), PlotSettingParams(xscale='none', yscale='log'),