Skip to content

Commit

Permalink
Last few mypy errors knocked out
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Jul 15, 2022
1 parent 3dd7753 commit a921117
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 407 deletions.
48 changes: 12 additions & 36 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,7 @@
from autosklearn.smbo import AutoMLSMBO
from autosklearn.util import RE_PATTERN, pipeline
from autosklearn.util.dask import Dask, LocalDask, UserDask
from autosklearn.util.data import (
DatasetCompressionSpec,
default_dataset_compression_arg,
reduce_dataset_size_if_too_large,
supported_precision_reductions,
validate_dataset_compression_arg,
)
from autosklearn.util.data import DatasetCompression
from autosklearn.util.logging_ import (
PicklableClientLogger,
get_named_client_logger,
Expand Down Expand Up @@ -252,15 +246,14 @@ def __init__(
)

# Validate dataset_compression and set its values
self._dataset_compression: DatasetCompressionSpec | None = None
if isinstance(dataset_compression, bool):
if dataset_compression is True:
self._dataset_compression = default_dataset_compression_arg
else:
self._dataset_compression = validate_dataset_compression_arg(
dataset_compression,
memory_limit=memory_limit,
)
self._dataset_compression: DatasetCompression | None = None
if dataset_compression is not False:

if memory_limit is None:
raise ValueError("Must provide a `memory_limit` for data compression")

spec = {} if dataset_compression is True else dataset_compression
self._dataset_compression = DatasetCompression(**spec, limit=memory_limit)

# If we got something callable for `get_trials_callback`, wrap it so SMAC
# will accept it.
Expand Down Expand Up @@ -667,30 +660,13 @@ def fit(
X_test, y_test = input_validator.transform(X_test, y_test)

# We don't support size reduction on pandas type object yet
if (
self._dataset_compression is not None
and not isinstance(X, pd.DataFrame)
and not (isinstance(y, pd.Series) or isinstance(y, pd.DataFrame))
):
methods = self._dataset_compression["methods"]
memory_allocation = self._dataset_compression["memory_allocation"]

# Remove precision reduction if we can't perform it
if (
"precision" in methods
and X.dtype not in supported_precision_reductions
):
methods = [method for method in methods if method != "precision"]

if self._dataset_compression and self._dataset_compression.supports(X, y):
with warnings_to(self.logger):
X, y = reduce_dataset_size_if_too_large(
X, y = self._dataset_compression.compress(
X=X,
y=y,
memory_limit=self._memory_limit,
is_classification=self.is_classification,
stratify=self.is_classification,
random_state=self._seed,
operations=methods,
memory_allocation=memory_allocation,
)

# Check the re-sampling strategy
Expand Down

0 comments on commit a921117

Please sign in to comment.