Skip to content

Commit

Permalink
FIX euclidean_distances float32 numerical instabilities (scikit-learn…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb authored and koenvandevelde committed Jul 12, 2019
1 parent 4853345 commit a28577d
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 33 deletions.
11 changes: 8 additions & 3 deletions doc/whats_new/v0.21.rst
Expand Up @@ -543,9 +543,14 @@ Support for Python 3.4 and below has been officially dropped.
:pr:`13447` by :user:`Dan Ellis <dpwe>`.

- |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated
in version 0.21 and will be removed in version 0.23.
:pr:`10580` by :user:`Reshama Shaikh <reshamas>` and :user:`Sandra
Mitrovic <SandraMNE>`.
in version 0.21 and will be removed in version 0.23. :pr:`10580` by
:user:`Reshama Shaikh <reshamas>` and :user:`Sandra Mitrovic <SandraMNE>`.

- |Fix| The function :func:`euclidean_distances`, and therefore
several estimators with ``metric='euclidean'``, suffered from numerical
precision issues with ``float32`` features. Precision has been increased at the
cost of a small drop of performance. :pr:`13554` by :user:`Celelibi` and
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |API| :func:`metrics.jaccard_similarity_score` is deprecated in favour of
the more consistent :func:`metrics.jaccard_score`. The former behavior for
Expand Down
109 changes: 100 additions & 9 deletions sklearn/metrics/pairwise.py
Expand Up @@ -193,17 +193,24 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
Y_norm_squared : array-like, shape (n_samples_2, ), optional
Pre-computed dot-products of vectors in Y (e.g.,
``(Y**2).sum(axis=1)``)
May be ignored in some cases, see the note below.
squared : boolean, optional
Return squared Euclidean distances.
X_norm_squared : array-like, shape = [n_samples_1], optional
Pre-computed dot-products of vectors in X (e.g.,
``(X**2).sum(axis=1)``)
May be ignored in some cases, see the note below.
Notes
-----
To achieve better accuracy, `X_norm_squared` and `Y_norm_squared` may be
unused if they are passed as ``float32``.
Returns
-------
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
distances : array, shape (n_samples_1, n_samples_2)
Examples
--------
Expand All @@ -224,41 +231,125 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
"""
X, Y = check_pairwise_arrays(X, Y)

# If norms are passed as float32, they are unused. If arrays are passed as
# float32, norms needs to be recomputed on upcast chunks.
# TODO: use a float64 accumulator in row_norms to avoid the latter.
if X_norm_squared is not None:
XX = check_array(X_norm_squared)
if XX.shape == (1, X.shape[0]):
XX = XX.T
elif XX.shape != (X.shape[0], 1):
raise ValueError(
"Incompatible dimensions for X and X_norm_squared")
if XX.dtype == np.float32:
XX = None
elif X.dtype == np.float32:
XX = None
else:
XX = row_norms(X, squared=True)[:, np.newaxis]

if X is Y: # shortcut in the common case euclidean_distances(X, X)
if X is Y and XX is not None:
# shortcut in the common case euclidean_distances(X, X)
YY = XX.T
elif Y_norm_squared is not None:
YY = np.atleast_2d(Y_norm_squared)

if YY.shape != (1, Y.shape[0]):
raise ValueError(
"Incompatible dimensions for Y and Y_norm_squared")
if YY.dtype == np.float32:
YY = None
elif Y.dtype == np.float32:
YY = None
else:
YY = row_norms(Y, squared=True)[np.newaxis, :]

distances = safe_sparse_dot(X, Y.T, dense_output=True)
distances *= -2
distances += XX
distances += YY
if X.dtype == np.float32:
# To minimize precision issues with float32, we compute the distance
# matrix on chunks of X and Y upcast to float64
distances = _euclidean_distances_upcast(X, XX, Y, YY)
else:
# if dtype is already float64, no need to chunk and upcast
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
distances += XX
distances += YY
np.maximum(distances, 0, out=distances)

# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
if X is Y:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
distances.flat[::distances.shape[0] + 1] = 0.0
np.fill_diagonal(distances, 0)

return distances if squared else np.sqrt(distances, out=distances)


def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None):
"""Euclidean distances between X and Y
Assumes X and Y have float32 dtype.
Assumes XX and YY have float64 dtype or are None.
X and Y are upcast to float64 by chunks, which size is chosen to limit
memory increase by approximately 10% (at least 10MiB).
"""
n_samples_X = X.shape[0]
n_samples_Y = Y.shape[0]
n_features = X.shape[1]

distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)

x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1

# Allow 10% more memory than X, Y and the distance matrix take (at least
# 10MiB)
maxmem = max(
((x_density * n_samples_X + y_density * n_samples_Y) * n_features
+ (x_density * n_samples_X * y_density * n_samples_Y)) / 10,
10 * 2**17)

# The increase amount of memory in 8-byte blocks is:
# - x_density * batch_size * n_features (copy of chunk of X)
# - y_density * batch_size * n_features (copy of chunk of Y)
# - batch_size * batch_size (chunk of distance matrix)
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
# xd=x_density and yd=y_density
tmp = (x_density + y_density) * n_features
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = max(int(batch_size), 1)

x_batches = gen_batches(X.shape[0], batch_size)
y_batches = gen_batches(Y.shape[0], batch_size)

for i, x_slice in enumerate(x_batches):
X_chunk = X[x_slice].astype(np.float64)
if XX is None:
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
else:
XX_chunk = XX[x_slice]

for j, y_slice in enumerate(y_batches):
if X is Y and j < i:
# when X is Y the distance matrix is symmetric so we only need
# to compute half of it.
d = distances[y_slice, x_slice].T

else:
Y_chunk = Y[y_slice].astype(np.float64)
if YY is None:
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
else:
YY_chunk = YY[:, y_slice]

d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
d += XX_chunk
d += YY_chunk

distances[x_slice, y_slice] = d.astype(np.float32, copy=False)

return distances


def _argmin_min_reduce(dist, start):
indices = dist.argmin(axis=1)
values = dist[np.arange(dist.shape[0]), indices]
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/pairwise_fast.pyx
Expand Up @@ -7,10 +7,10 @@
#
# License: BSD 3 clause

from libc.string cimport memset
import numpy as np
cimport numpy as np
from cython cimport floating
from libc.string cimport memset

from ..utils._cython_blas cimport _asum

Expand Down
114 changes: 94 additions & 20 deletions sklearn/metrics/tests/test_pairwise.py
Expand Up @@ -584,41 +584,115 @@ def test_pairwise_distances_chunked():
assert_raises(StopIteration, next, gen)


def test_euclidean_distances():
# Check the pairwise Euclidean distances computation
X = [[0]]
Y = [[1], [2]]
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_known_result(x_array_constr, y_array_constr):
# Check the pairwise Euclidean distances computation on known result
X = x_array_constr([[0]])
Y = y_array_constr([[1], [2]])
D = euclidean_distances(X, Y)
assert_array_almost_equal(D, [[1., 2.]])
assert_allclose(D, [[1., 2.]])

X = csr_matrix(X)
Y = csr_matrix(Y)
D = euclidean_distances(X, Y)
assert_array_almost_equal(D, [[1., 2.]])

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_with_norms(dtype, y_array_constr):
# check that we still get the right answers with {X,Y}_norm_squared
# and that we get a wrong answer with wrong {X,Y}_norm_squared
rng = np.random.RandomState(0)
X = rng.random_sample((10, 4))
Y = rng.random_sample((20, 4))
X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1)
Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1)
X = rng.random_sample((10, 10)).astype(dtype, copy=False)
Y = rng.random_sample((20, 10)).astype(dtype, copy=False)

# norms will only be used if their dtype is float64
X_norm_sq = (X.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
Y_norm_sq = (Y.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)

Y = y_array_constr(Y)

# check that we still get the right answers with {X,Y}_norm_squared
D1 = euclidean_distances(X, Y)
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
Y_norm_squared=Y_norm_sq)
assert_array_almost_equal(D2, D1)
assert_array_almost_equal(D3, D1)
assert_array_almost_equal(D4, D1)
assert_allclose(D2, D1)
assert_allclose(D3, D1)
assert_allclose(D4, D1)

# check we get the wrong answer with wrong {X,Y}_norm_squared
X_norm_sq *= 0.5
Y_norm_sq *= 0.5
wrong_D = euclidean_distances(X, Y,
X_norm_squared=np.zeros_like(X_norm_sq),
Y_norm_squared=np.zeros_like(Y_norm_sq))
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
with pytest.raises(AssertionError):
assert_allclose(wrong_D, D1)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances(dtype, x_array_constr, y_array_constr):
# check that euclidean distances gives same result as scipy cdist
# when X and Y != X are provided
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
X[X < 0.8] = 0
Y = rng.random_sample((10, 10)).astype(dtype, copy=False)
Y[Y < 0.8] = 0

expected = cdist(X, Y)

X = x_array_constr(X)
Y = y_array_constr(Y)
distances = euclidean_distances(X, Y)

# the default rtol=1e-7 is too close to the float32 precision
# and fails due too rounding errors.
assert_allclose(distances, expected, rtol=1e-6)
assert distances.dtype == dtype


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_sym(dtype, x_array_constr):
# check that euclidean distances gives same result as scipy pdist
# when only X is provided
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
X[X < 0.8] = 0

expected = squareform(pdist(X))

X = x_array_constr(X)
distances = euclidean_distances(X)

# the default rtol=1e-7 is too close to the float32 precision
# and fails due too rounding errors.
assert_allclose(distances, expected, rtol=1e-6)
assert distances.dtype == dtype


@pytest.mark.parametrize(
"dtype, eps, rtol",
[(np.float32, 1e-4, 1e-5),
pytest.param(
np.float64, 1e-8, 0.99,
marks=pytest.mark.xfail(reason='failing due to lack of precision'))])
@pytest.mark.parametrize("dim", [1, 1000000])
def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim):
# check that euclidean distances is correct with float32 input thanks to
# upcasting. On float64 there are still precision issues.
X = np.array([[1.] * dim], dtype=dtype)
Y = np.array([[1. + eps] * dim], dtype=dtype)

distances = euclidean_distances(X, Y)
expected = cdist(X, Y)

assert_allclose(distances, expected, rtol=1e-5)


def test_cosine_distances():
Expand Down

0 comments on commit a28577d

Please sign in to comment.