Skip to content

Commit

Permalink
MAINT Clean-up utils.__init__: move tests into corresponding test fil…
Browse files Browse the repository at this point in the history
…es (#28842)
  • Loading branch information
jeremiedbb committed Apr 30, 2024
1 parent 0bdc754 commit f61dd6c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 147 deletions.
19 changes: 19 additions & 0 deletions sklearn/utils/tests/test_mask.py
@@ -0,0 +1,19 @@
import pytest

from sklearn.utils._mask import safe_mask
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.validation import check_random_state


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_safe_mask(csr_container):
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = csr_container(X)
mask = [False, False, True, True, True]

mask = safe_mask(X, mask)
assert X[mask].shape[0] == 3

mask = safe_mask(X_csr, mask)
assert X_csr[mask].shape[0] == 3
27 changes: 27 additions & 0 deletions sklearn/utils/tests/test_missing.py
@@ -0,0 +1,27 @@
import numpy as np
import pytest

from sklearn.utils._missing import is_scalar_nan


@pytest.mark.parametrize(
"value, result",
[
(float("nan"), True),
(np.nan, True),
(float(np.nan), True),
(np.float32(np.nan), True),
(np.float64(np.nan), True),
(0, False),
(0.0, False),
(None, False),
("", False),
("nan", False),
([np.nan], False),
(9867966753463435747313673, False), # Python int that overflows with C type
],
)
def test_is_scalar_nan(value, result):
assert is_scalar_nan(value) is result
# make sure that we are returning a Python bool
assert isinstance(is_scalar_nan(value), bool)
148 changes: 1 addition & 147 deletions sklearn/utils/tests/test_utils.py
@@ -1,153 +1,7 @@
import warnings

import joblib
import numpy as np
import pytest

from sklearn.utils import (
check_random_state,
column_or_1d,
deprecated,
parallel_backend,
register_parallel_backend,
safe_mask,
tosequence,
)
from sklearn.utils._missing import is_scalar_nan
from sklearn.utils._testing import assert_array_equal
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.validation import _is_polars_df


def test_make_rng():
# Check the check_random_state utility function behavior
assert check_random_state(None) is np.random.mtrand._rand
assert check_random_state(np.random) is np.random.mtrand._rand

rng_42 = np.random.RandomState(42)
assert check_random_state(42).randint(100) == rng_42.randint(100)

rng_42 = np.random.RandomState(42)
assert check_random_state(rng_42) is rng_42

rng_42 = np.random.RandomState(42)
assert check_random_state(43).randint(100) != rng_42.randint(100)

with pytest.raises(ValueError):
check_random_state("some invalid seed")


def test_deprecated():
# Test whether the deprecated decorator issues appropriate warnings
# Copied almost verbatim from https://docs.python.org/library/warnings.html

# First a function...
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

@deprecated()
def ham():
return "spam"

spam = ham()

assert spam == "spam" # function must remain usable

assert len(w) == 1
assert issubclass(w[0].category, FutureWarning)
assert "deprecated" in str(w[0].message).lower()

# ... then a class.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")

@deprecated("don't use this")
class Ham:
SPAM = 1

ham = Ham()

assert hasattr(ham, "SPAM")

assert len(w) == 1
assert issubclass(w[0].category, FutureWarning)
assert "deprecated" in str(w[0].message).lower()


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_safe_mask(csr_container):
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = csr_container(X)
mask = [False, False, True, True, True]

mask = safe_mask(X, mask)
assert X[mask].shape[0] == 3

mask = safe_mask(X_csr, mask)
assert X_csr[mask].shape[0] == 3


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.0),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]

for y_type, y in EXAMPLES:
if y_type in ["binary", "multiclass", "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
with pytest.raises(ValueError):
column_or_1d(y)


@pytest.mark.parametrize(
"value, result",
[
(float("nan"), True),
(np.nan, True),
(float(np.nan), True),
(np.float32(np.nan), True),
(np.float64(np.nan), True),
(0, False),
(0.0, False),
(None, False),
("", False),
("nan", False),
([np.nan], False),
(9867966753463435747313673, False), # Python int that overflows with C type
],
)
def test_is_scalar_nan(value, result):
assert is_scalar_nan(value) is result
# make sure that we are returning a Python bool
assert isinstance(is_scalar_nan(value), bool)


def dummy_func():
pass


def test__is_polars_df():
"""Check that _is_polars_df return False for non-dataframe objects."""

class LooksLikePolars:
def __init__(self):
self.columns = ["a", "b"]
self.schema = ["a", "b"]

assert not _is_polars_df(LooksLikePolars())
from sklearn.utils import parallel_backend, register_parallel_backend, tosequence


# TODO(1.7): remove
Expand Down
55 changes: 55 additions & 0 deletions sklearn/utils/tests/test_validation.py
Expand Up @@ -80,11 +80,31 @@
check_is_fitted,
check_memory,
check_non_negative,
check_random_state,
check_scalar,
column_or_1d,
has_fit_parameter,
)


def test_make_rng():
# Check the check_random_state utility function behavior
assert check_random_state(None) is np.random.mtrand._rand
assert check_random_state(np.random) is np.random.mtrand._rand

rng_42 = np.random.RandomState(42)
assert check_random_state(42).randint(100) == rng_42.randint(100)

rng_42 = np.random.RandomState(42)
assert check_random_state(rng_42) is rng_42

rng_42 = np.random.RandomState(42)
assert check_random_state(43).randint(100) != rng_42.randint(100)

with pytest.raises(ValueError):
check_random_state("some invalid seed")


def test_as_float_array():
# Test function for as_float_array
X = np.ones((3, 10), dtype=np.int32)
Expand Down Expand Up @@ -2061,3 +2081,38 @@ def test_to_object_array(sequence):
assert isinstance(out, np.ndarray)
assert out.dtype.kind == "O"
assert out.ndim == 1


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.0),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]

for y_type, y in EXAMPLES:
if y_type in ["binary", "multiclass", "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
with pytest.raises(ValueError):
column_or_1d(y)


def test__is_polars_df():
"""Check that _is_polars_df return False for non-dataframe objects."""

class LooksLikePolars:
def __init__(self):
self.columns = ["a", "b"]
self.schema = ["a", "b"]

assert not _is_polars_df(LooksLikePolars())

0 comments on commit f61dd6c

Please sign in to comment.