Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT Clean-up utils.__init__: move tests into corresponding test files #28842

Merged
merged 2 commits into from Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions sklearn/utils/tests/test_mask.py
@@ -0,0 +1,19 @@
import pytest

from sklearn.utils._mask import safe_mask
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.validation import check_random_state


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_safe_mask(csr_container):
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = csr_container(X)
mask = [False, False, True, True, True]

mask = safe_mask(X, mask)
assert X[mask].shape[0] == 3

mask = safe_mask(X_csr, mask)
assert X_csr[mask].shape[0] == 3
27 changes: 27 additions & 0 deletions sklearn/utils/tests/test_missing.py
@@ -0,0 +1,27 @@
import numpy as np
import pytest

from sklearn.utils._missing import is_scalar_nan


@pytest.mark.parametrize(
"value, result",
[
(float("nan"), True),
(np.nan, True),
(float(np.nan), True),
(np.float32(np.nan), True),
(np.float64(np.nan), True),
(0, False),
(0.0, False),
(None, False),
("", False),
("nan", False),
([np.nan], False),
(9867966753463435747313673, False), # Python int that overflows with C type
],
)
def test_is_scalar_nan(value, result):
assert is_scalar_nan(value) is result
# make sure that we are returning a Python bool
assert isinstance(is_scalar_nan(value), bool)
148 changes: 1 addition & 147 deletions sklearn/utils/tests/test_utils.py
@@ -1,153 +1,7 @@
import warnings

import joblib
import numpy as np
import pytest

from sklearn.utils import (
check_random_state,
column_or_1d,
deprecated,
parallel_backend,
register_parallel_backend,
safe_mask,
tosequence,
)
from sklearn.utils._missing import is_scalar_nan
from sklearn.utils._testing import assert_array_equal
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.validation import _is_polars_df


def test_make_rng():
# Check the check_random_state utility function behavior
assert check_random_state(None) is np.random.mtrand._rand
assert check_random_state(np.random) is np.random.mtrand._rand

rng_42 = np.random.RandomState(42)
assert check_random_state(42).randint(100) == rng_42.randint(100)

rng_42 = np.random.RandomState(42)
assert check_random_state(rng_42) is rng_42

rng_42 = np.random.RandomState(42)
assert check_random_state(43).randint(100) != rng_42.randint(100)

with pytest.raises(ValueError):
check_random_state("some invalid seed")


def test_deprecated():
# Test whether the deprecated decorator issues appropriate warnings
# Copied almost verbatim from https://docs.python.org/library/warnings.html

# First a function...
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

@deprecated()
def ham():
return "spam"

spam = ham()

assert spam == "spam" # function must remain usable

assert len(w) == 1
assert issubclass(w[0].category, FutureWarning)
assert "deprecated" in str(w[0].message).lower()

# ... then a class.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

@deprecated("don't use this")
class Ham:
SPAM = 1

ham = Ham()

assert hasattr(ham, "SPAM")

assert len(w) == 1
assert issubclass(w[0].category, FutureWarning)
assert "deprecated" in str(w[0].message).lower()

Comment on lines -40 to -74
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I just removed it because it's redundant and strickly less exhaustive than the test_deprecated function already in test_deprecation.py.


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_safe_mask(csr_container):
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = csr_container(X)
mask = [False, False, True, True, True]

mask = safe_mask(X, mask)
assert X[mask].shape[0] == 3

mask = safe_mask(X_csr, mask)
assert X_csr[mask].shape[0] == 3


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.0),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]

for y_type, y in EXAMPLES:
if y_type in ["binary", "multiclass", "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
with pytest.raises(ValueError):
column_or_1d(y)


@pytest.mark.parametrize(
"value, result",
[
(float("nan"), True),
(np.nan, True),
(float(np.nan), True),
(np.float32(np.nan), True),
(np.float64(np.nan), True),
(0, False),
(0.0, False),
(None, False),
("", False),
("nan", False),
([np.nan], False),
(9867966753463435747313673, False), # Python int that overflows with C type
],
)
def test_is_scalar_nan(value, result):
assert is_scalar_nan(value) is result
# make sure that we are returning a Python bool
assert isinstance(is_scalar_nan(value), bool)


def dummy_func():
pass


def test__is_polars_df():
"""Check that _is_polars_df return False for non-dataframe objects."""

class LooksLikePolars:
def __init__(self):
self.columns = ["a", "b"]
self.schema = ["a", "b"]

assert not _is_polars_df(LooksLikePolars())
from sklearn.utils import parallel_backend, register_parallel_backend, tosequence


# TODO(1.7): remove
Expand Down
55 changes: 55 additions & 0 deletions sklearn/utils/tests/test_validation.py
Expand Up @@ -80,11 +80,31 @@
check_is_fitted,
check_memory,
check_non_negative,
check_random_state,
check_scalar,
column_or_1d,
has_fit_parameter,
)


def test_make_rng():
# Check the check_random_state utility function behavior
assert check_random_state(None) is np.random.mtrand._rand
assert check_random_state(np.random) is np.random.mtrand._rand

rng_42 = np.random.RandomState(42)
assert check_random_state(42).randint(100) == rng_42.randint(100)

rng_42 = np.random.RandomState(42)
assert check_random_state(rng_42) is rng_42

rng_42 = np.random.RandomState(42)
assert check_random_state(43).randint(100) != rng_42.randint(100)

with pytest.raises(ValueError):
check_random_state("some invalid seed")


def test_as_float_array():
# Test function for as_float_array
X = np.ones((3, 10), dtype=np.int32)
Expand Down Expand Up @@ -2061,3 +2081,38 @@ def test_to_object_array(sequence):
assert isinstance(out, np.ndarray)
assert out.dtype.kind == "O"
assert out.ndim == 1


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.0),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]

for y_type, y in EXAMPLES:
if y_type in ["binary", "multiclass", "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
with pytest.raises(ValueError):
column_or_1d(y)


def test__is_polars_df():
"""Check that _is_polars_df return False for non-dataframe objects."""

class LooksLikePolars:
def __init__(self):
self.columns = ["a", "b"]
self.schema = ["a", "b"]

assert not _is_polars_df(LooksLikePolars())