Skip to content

Commit

Permalink
Fix reference to InputValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Jul 16, 2022
1 parent 80fc8f2 commit 857a8c6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
11 changes: 5 additions & 6 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,15 +1741,15 @@ def score(self, X: SUPPORTED_FEAT_TYPES, y: SUPPORTED_TARGET_TYPES) -> float:
check_is_fitted(self)

prediction = self.predict(X)
y = self.InputValidator.target_validator.transform(y)
y = self.input_validator.target_validator.transform(y)

# Encode the prediction using the input validator
# We train autosklearn with a encoded version of y,
# which is decoded by predict().
# Above call to validate() encodes the y given for score()
# Below call encodes the prediction, so we compare in the
# same representation domain
prediction = self.InputValidator.target_validator.transform(prediction)
prediction = self.input_validator.target_validator.transform(prediction)

return compute_single_metric(
solution=y,
Expand Down Expand Up @@ -2267,16 +2267,15 @@ def predict(
n_jobs: int = 1,
) -> np.ndarray:
check_is_fitted(self)
assert self.InputValidator is not None

probabilities = self.predict_proba(X, batch_size=batch_size, n_jobs=n_jobs)
validator = self.input_validator

if self.InputValidator.target_validator.is_single_column_target():
if validator.target_validator.is_single_column_target():
predicted_indexes = np.argmax(probabilities, axis=1)
else:
predicted_indexes = (probabilities > 0.5).astype(int)

return self.InputValidator.target_validator.inverse_transform(predicted_indexes)
return validator.target_validator.inverse_transform(predicted_indexes)

def predict_proba(
self,
Expand Down
6 changes: 1 addition & 5 deletions autosklearn/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from scipy.sparse import spmatrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.ensemble import VotingClassifier, VotingRegressor
from sklearn.exceptions import NotFittedError
from sklearn.model_selection._split import (
BaseCrossValidator,
BaseShuffleSplit,
Expand Down Expand Up @@ -1522,10 +1521,7 @@ def classes_(self) -> np.ndarray:
np.ndarray
Class labels seen during fit
"""
if self.automl.InputValidator is None:
raise NotFittedError("Please call fit first")

return self.automl.InputValidator.target_validator.classes_
return self.automl.input_validator.classes_

def predict_proba(
self,
Expand Down

0 comments on commit 857a8c6

Please sign in to comment.