Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Inaccurate Attribute Listing with dir(obj) #28749

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions sklearn/base.py
Expand Up @@ -193,6 +193,13 @@ class BaseEstimator(_HTMLDocumentationLinkMixin, _MetadataRequester):
array([3, 3, 3])
"""

def __dir__(self):
"""Filters conditional methods that should be hidden based
on the `available_if` decorator from SciKit Learn."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
return [attr for attr in super().__dir__() if hasattr(self, attr)]

@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
Expand Down Expand Up @@ -1353,8 +1360,9 @@ class _UnstableArchMixin:

def _more_tags(self):
return {
"non_deterministic": _IS_32BIT
or platform.machine().startswith(("ppc", "powerpc"))
"non_deterministic": _IS_32BIT or platform.machine().startswith(
("ppc", "powerpc")
)
}


Expand Down
1 change: 1 addition & 0 deletions sklearn/ensemble/tests/test_bagging.py
Expand Up @@ -448,6 +448,7 @@ def test_error():
X, y = iris.data, iris.target
base = DecisionTreeClassifier()
assert not hasattr(BaggingClassifier(base).fit(X, y), "decision_function")
assert "decision_function" not in dir(BaggingClassifier(base).fit(X, y))


def test_parallel_classification():
Expand Down
3 changes: 3 additions & 0 deletions sklearn/ensemble/tests/test_stacking.py
Expand Up @@ -882,8 +882,11 @@ def test_stacking_final_estimator_attribute_error():
estimators=estimators, final_estimator=final_estimator, cv=3
)

assert "decision_function" not in dir(clf.fit(X, y))

outer_msg = "This 'StackingClassifier' has no attribute 'decision_function'"
inner_msg = "'RandomForestClassifier' object has no attribute 'decision_function'"

with pytest.raises(AttributeError, match=outer_msg) as exec_info:
clf.fit(X, y).decision_function(X)
assert isinstance(exec_info.value.__cause__, AttributeError)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/ensemble/tests/test_voting.py
Expand Up @@ -79,8 +79,10 @@ def test_predictproba_hardvoting():
assert inner_msg in str(exec_info.value.__cause__)

assert not hasattr(eclf, "predict_proba")
assert "predict_proba" not in dir(eclf)
eclf.fit(X_scaled, y)
assert not hasattr(eclf, "predict_proba")
assert "predict_proba" not in dir(eclf)


def test_notfitted():
Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_selection/tests/test_rfe.py
Expand Up @@ -645,6 +645,7 @@ def test_rfe_estimator_attribute_error():

outer_msg = "This 'RFE' has no attribute 'decision_function'"
inner_msg = "'LinearRegression' object has no attribute 'decision_function'"
assert "decision_function" not in dir(rfe.fit(iris.data, iris.target))
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
rfe.fit(iris.data, iris.target).decision_function(iris.data)
assert isinstance(exec_info.value.__cause__, AttributeError)
Expand Down
3 changes: 2 additions & 1 deletion sklearn/linear_model/_stochastic_gradient.py
Expand Up @@ -1358,7 +1358,8 @@ def predict_proba(self, X):
raise NotImplementedError(
"predict_(log_)proba only supported when"
" loss='log_loss' or loss='modified_huber' "
"(%r given)" % self.loss
"(%r given)"
% self.loss
)

@available_if(_check_proba)
Expand Down
4 changes: 4 additions & 0 deletions sklearn/linear_model/tests/test_sgd.py
Expand Up @@ -728,7 +728,9 @@ def test_sgd_predict_proba_method_access(klass):
loss
)
assert not hasattr(clf, "predict_proba")
assert "predict_proba" not in dir(clf)
assert not hasattr(clf, "predict_log_proba")
assert "predict_log_proba" not in dir(clf)
with pytest.raises(
AttributeError, match="has no attribute 'predict_proba'"
) as exec_info:
Expand All @@ -754,7 +756,9 @@ def test_sgd_proba(klass):
# anyway.
clf = SGDClassifier(loss="hinge", alpha=0.01, max_iter=10, tol=None).fit(X, Y)
assert not hasattr(clf, "predict_proba")
assert "predict_proba" not in dir(clf)
assert not hasattr(clf, "predict_log_proba")
assert "predict_log_proba" not in dir(clf)

# log and modified_huber losses can output probability estimates
# binary case
Expand Down
3 changes: 2 additions & 1 deletion sklearn/model_selection/_search.py
Expand Up @@ -486,7 +486,8 @@ def score(self, X, y=None, **params):
if self.scorer_ is None:
raise ValueError(
"No score function explicitly defined, "
"and the estimator doesn't provide one %s" % self.best_estimator_
"and the estimator doesn't provide one %s"
% self.best_estimator_
)
if isinstance(self.scorer_, dict):
if self.multimetric_:
Expand Down
4 changes: 4 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Expand Up @@ -1557,6 +1557,7 @@ def test_predict_proba_disabled():
clf = SVC(probability=False)
gs = GridSearchCV(clf, {}, cv=2).fit(X, y)
assert not hasattr(gs, "predict_proba")
assert "predict_proba" not in dir(gs)


def test_grid_search_allows_nans():
Expand Down Expand Up @@ -1771,6 +1772,7 @@ def test_stochastic_gradient_loss_param():
# When the estimator is not fitted, `predict_proba` is not available as the
# loss is 'hinge'.
assert not hasattr(clf, "predict_proba")
assert "predict_proba" not in dir(clf)
clf.fit(X, y)
clf.predict_proba(X)
clf.predict_log_proba(X)
Expand All @@ -1784,8 +1786,10 @@ def test_stochastic_gradient_loss_param():
estimator=SGDClassifier(loss="hinge"), param_grid=param_grid, cv=3
)
assert not hasattr(clf, "predict_proba")
assert "predict_proba" not in dir(clf)
clf.fit(X, y)
assert not hasattr(clf, "predict_proba")
assert "predict_proba" not in dir(clf)


def test_search_train_scores_set_to_false():
Expand Down
4 changes: 4 additions & 0 deletions sklearn/neighbors/tests/test_lof.py
Expand Up @@ -213,14 +213,18 @@ def test_hasattr_prediction():
assert hasattr(clf, "decision_function")
assert hasattr(clf, "score_samples")
assert not hasattr(clf, "fit_predict")
assert "fit_predict" not in dir(clf)

# when novelty=False
clf = neighbors.LocalOutlierFactor(novelty=False)
clf.fit(X)
assert hasattr(clf, "fit_predict")
assert not hasattr(clf, "predict")
assert "predict" not in dir(clf)
assert not hasattr(clf, "decision_function")
assert "decision_function" not in dir(clf)
assert not hasattr(clf, "score_samples")
assert "score_samples" not in dir(clf)


@parametrize_with_checks([neighbors.LocalOutlierFactor(novelty=True)])
Expand Down
3 changes: 2 additions & 1 deletion sklearn/neural_network/_multilayer_perceptron.py
Expand Up @@ -754,7 +754,8 @@ def _check_solver(self):
if self.solver not in _STOCHASTIC_SOLVERS:
raise AttributeError(
"partial_fit is only available for stochastic"
" optimizers. %s is not stochastic." % self.solver
" optimizers. %s is not stochastic."
% self.solver
)
return True

Expand Down
4 changes: 3 additions & 1 deletion sklearn/neural_network/tests/test_mlp.py
Expand Up @@ -500,6 +500,7 @@ def test_partial_fit_errors():

# lbfgs doesn't support partial_fit
assert not hasattr(MLPClassifier(solver="lbfgs"), "partial_fit")
assert "parital_fit" not in dir(MLPClassifier(solver="lbfgs"))


def test_nonfinite_params():
Expand Down Expand Up @@ -732,7 +733,8 @@ def test_warm_start():
message = (
"warm_start can only be used where `y` has the same "
"classes as in the previous call to fit."
" Previously got [0 1 2], `y` has %s" % np.unique(y_i)
" Previously got [0 1 2], `y` has %s"
% np.unique(y_i)
)
with pytest.raises(ValueError, match=re.escape(message)):
clf.fit(X, y_i)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/preprocessing/tests/test_function_transformer.py
Expand Up @@ -352,7 +352,7 @@ def test_function_transformer_feature_names_out_is_None():
transformer = FunctionTransformer()
X = np.random.rand(100, 2)
transformer.fit_transform(X)

assert "get_feature_names_out" not in dir(transformer)
msg = "This 'FunctionTransformer' has no attribute 'get_feature_names_out'"
with pytest.raises(AttributeError, match=msg):
transformer.get_feature_names_out()
Expand Down
5 changes: 5 additions & 0 deletions sklearn/semi_supervised/tests/test_self_training.py
Expand Up @@ -311,6 +311,7 @@ def test_base_estimator_meta_estimator():
)

assert not hasattr(base_estimator, "predict_proba")
assert "predict_proba" not in dir(base_estimator)
clf = SelfTrainingClassifier(base_estimator=base_estimator)
with pytest.raises(AttributeError):
clf.fit(X_train, y_train_missing_labels)
Expand All @@ -337,6 +338,10 @@ def test_self_training_estimator_attribute_error():
# should raise an AttributeError
self_training = SelfTrainingClassifier(base_estimator=DecisionTreeClassifier())

assert "decision_function" not in dir(
self_training.fit(X_train, y_train_missing_labels)
)

outer_msg = "This 'SelfTrainingClassifier' has no attribute 'decision_function'"
inner_msg = "'DecisionTreeClassifier' object has no attribute 'decision_function'"
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
Expand Down
6 changes: 4 additions & 2 deletions sklearn/svm/_base.py
Expand Up @@ -297,7 +297,8 @@ def _warn_from_fit_status(self):
warnings.warn(
"Solver terminated early (max_iter=%i)."
" Consider pre-processing your data with"
" StandardScaler or MinMaxScaler." % self.max_iter,
" StandardScaler or MinMaxScaler."
% self.max_iter,
ConvergenceWarning,
)

Expand Down Expand Up @@ -1173,7 +1174,8 @@ def _fit_liblinear(
raise ValueError(
"This solver needs samples of at least 2 classes"
" in the data, but the data contains only one"
" class: %r" % classes_[0]
" class: %r"
% classes_[0]
)

class_weight_ = compute_class_weight(class_weight, classes=classes_, y=y)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/svm/tests/test_svm.py
Expand Up @@ -1119,8 +1119,10 @@ def test_hasattr_predict_proba():

G = svm.SVC(probability=False)
assert not hasattr(G, "predict_proba")
assert "predict_proba" not in dir(G)
G.fit(iris.data, iris.target)
assert not hasattr(G, "predict_proba")
assert "predict_proba" not in dir(G)

# Switching to `probability=True` after fitting should make
# predict_proba available, but calling it must not work:
Expand Down
9 changes: 9 additions & 0 deletions sklearn/tests/test_multiclass.py
Expand Up @@ -133,7 +133,10 @@ def test_ovr_partial_fit():

# test partial_fit only exists if estimator has it:
ovr = OneVsRestClassifier(SVC())
# check __dir__ method does not return partial_fit

assert not hasattr(ovr, "partial_fit")
assert "partial_fit" not in dir(ovr)


def test_ovr_partial_fit_exceptions():
Expand Down Expand Up @@ -385,12 +388,15 @@ def test_ovr_multilabel_predict_proba():
# Decision function only estimator.
decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
assert not hasattr(decision_only, "predict_proba")
assert "predict_proba" not in dir(decision_only)

# Estimator with predict_proba disabled, depending on parameters.
decision_only = OneVsRestClassifier(svm.SVC(probability=False))
assert not hasattr(decision_only, "predict_proba")
assert "predict_proba" not in dir(decision_only)
decision_only.fit(X_train, Y_train)
assert not hasattr(decision_only, "predict_proba")
assert "predict_proba" not in dir(decision_only)
assert hasattr(decision_only, "decision_function")

# Estimator which can get predict_proba enabled after fitting
Expand All @@ -399,6 +405,7 @@ def test_ovr_multilabel_predict_proba():
)
proba_after_fit = OneVsRestClassifier(gs)
assert not hasattr(proba_after_fit, "predict_proba")
assert "predict_proba" not in dir(proba_after_fit)
proba_after_fit.fit(X_train, Y_train)
assert hasattr(proba_after_fit, "predict_proba")

Expand All @@ -421,6 +428,7 @@ def test_ovr_single_label_predict_proba():
# Decision function only estimator.
decision_only = OneVsRestClassifier(svm.SVR()).fit(X_train, Y_train)
assert not hasattr(decision_only, "predict_proba")
assert "predict_proba" not in dir(decision_only)

Y_pred = clf.predict(X_test)
Y_proba = clf.predict_proba(X_test)
Expand Down Expand Up @@ -560,6 +568,7 @@ def test_ovo_partial_fit_predict():
# test partial_fit only exists if estimator has it:
ovr = OneVsOneClassifier(SVC())
assert not hasattr(ovr, "partial_fit")
assert "partial_fit" not in dir(ovr)


def test_ovo_decision_function():
Expand Down
9 changes: 8 additions & 1 deletion sklearn/tests/test_multioutput.py
Expand Up @@ -97,6 +97,7 @@ def test_multi_target_regression_partial_fit():
y_pred = sgr.predict(X_test)
assert_almost_equal(references, y_pred)
assert not hasattr(MultiOutputRegressor(Lasso), "partial_fit")
assert "partial_fit" not in dir(MultiOutputRegressor(Lasso))


def test_multi_target_regression_one_target():
Expand Down Expand Up @@ -215,7 +216,9 @@ def test_hasattr_multi_output_predict_proba():
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
multi_target_linear.fit(X, y)

assert not hasattr(multi_target_linear, "predict_proba")
assert "predict_proba" not in dir(multi_target_linear)

# case where predict_proba attribute exists
sgd_linear_clf = SGDClassifier(loss="log_loss", random_state=1, max_iter=5)
Expand Down Expand Up @@ -478,8 +481,10 @@ def test_multi_output_delegate_predict_proba():
assert hasattr(moc, "predict_proba")

# A base estimator without `predict_proba` should raise an AttributeError
moc = MultiOutputClassifier(LinearSVC())
moc = MultiOutputClassifier(LinearSVC(dual="auto"))

assert not hasattr(moc, "predict_proba")
assert "predict_proba" not in dir(moc)

outer_msg = "'MultiOutputClassifier' has no attribute 'predict_proba'"
inner_msg = "'LinearSVC' object has no attribute 'predict_proba'"
Expand All @@ -490,6 +495,7 @@ def test_multi_output_delegate_predict_proba():

moc.fit(X, y)
assert not hasattr(moc, "predict_proba")
assert "predict_proba" not in dir(moc)
with pytest.raises(AttributeError, match=outer_msg) as exec_info:
moc.predict_proba(X)
assert isinstance(exec_info.value.__cause__, AttributeError)
Expand Down Expand Up @@ -525,6 +531,7 @@ def test_classifier_chain_fit_and_predict_with_linear_svc(chain_method):
Y_binary = Y_decision >= 0
assert_array_equal(Y_binary, Y_pred)
assert not hasattr(classifier_chain, "predict_proba")
assert "predict_proba" not in dir(classifier_chain)


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
Expand Down
6 changes: 6 additions & 0 deletions sklearn/tests/test_pipeline.py
Expand Up @@ -817,24 +817,30 @@ def test_pipeline_ducktyping():

pipeline = make_pipeline(Transf())
assert not hasattr(pipeline, "predict")
assert "predict" not in dir(pipeline)
pipeline.transform
pipeline.inverse_transform

pipeline = make_pipeline("passthrough")
assert pipeline.steps[0] == ("passthrough", "passthrough")
assert not hasattr(pipeline, "predict")
assert "predict" not in dir(pipeline)
pipeline.transform
pipeline.inverse_transform

pipeline = make_pipeline(Transf(), NoInvTransf())
assert not hasattr(pipeline, "predict")
assert "predict" not in dir(pipeline)
pipeline.transform
assert not hasattr(pipeline, "inverse_transform")
assert "inverse_transform" not in dir(pipeline)

pipeline = make_pipeline(NoInvTransf(), Transf())
assert not hasattr(pipeline, "predict")
assert "predict" not in dir(pipeline)
pipeline.transform
assert not hasattr(pipeline, "inverse_transform")
assert "inverse_transform" not in dir(pipeline)


def test_make_pipeline():
Expand Down
1 change: 1 addition & 0 deletions sklearn/utils/tests/test_mocking.py
Expand Up @@ -203,3 +203,4 @@ def test_mock_estimator_on_off_prediction(iris, response_methods):
assert getattr(estimator, response)(X) == response
else:
assert not hasattr(estimator, response)
assert response not in dir(estimator)