Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix euclidean_distances numerical instabilities #13410

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
005a311
vect vect euclidean
jeremiedbb Mar 4, 2019
a7e8e5e
safe euclidean distances small n_features
jeremiedbb Mar 7, 2019
419d90e
fix arry order
jeremiedbb Mar 8, 2019
ecf8c3c
fix array order
jeremiedbb Mar 8, 2019
cd5eec2
tst debug windows
jeremiedbb Mar 8, 2019
81ef489
tst debug windows
jeremiedbb Mar 8, 2019
25d90f1
tst debug windows
jeremiedbb Mar 8, 2019
0d10f03
clean
jeremiedbb Mar 10, 2019
d6849f7
implement chunking upcasting
jeremiedbb Mar 11, 2019
f376632
lint
jeremiedbb Mar 11, 2019
8cf87df
comment on norms
jeremiedbb Mar 11, 2019
9adcc24
Merge branch 'master' into euclidean-dist
jeremiedbb Mar 11, 2019
bb750fc
remove unnecessary condition
jeremiedbb Mar 11, 2019
9f02593
consistent names
jeremiedbb Mar 12, 2019
9bd0f4d
move to Notes in docstrings and typos
jeremiedbb Mar 12, 2019
2c2098a
revert unrelated changes in docstring
jeremiedbb Mar 12, 2019
cecdb5f
document ignored norms
jeremiedbb Mar 12, 2019
7e4acdf
what's new
jeremiedbb Mar 12, 2019
dfc82b8
typo
jeremiedbb Mar 13, 2019
ab78675
what's new
jeremiedbb Mar 13, 2019
17fa839
clean and comments safe euclidean sparse dense
jeremiedbb Mar 13, 2019
f66dc97
same
jeremiedbb Mar 13, 2019
d49fe8b
sym -> symmetric
jeremiedbb Mar 13, 2019
aab5f48
add norms numpy style
jeremiedbb Mar 13, 2019
9e954b2
accept precomputed norms
jeremiedbb Mar 13, 2019
809591e
fix sym
jeremiedbb Mar 13, 2019
9b4a5a4
symmetric case in upcast euclidean
jeremiedbb Mar 13, 2019
79faace
lint
jeremiedbb Mar 13, 2019
bdbb5c3
switch 32 -> 16
jeremiedbb Mar 13, 2019
4e9e4c2
remove symmetric with syrk & don't force c order
jeremiedbb Mar 14, 2019
390cba4
fix upcast symmetric
jeremiedbb Mar 14, 2019
5d7721b
special case c contiguous
jeremiedbb Mar 14, 2019
f4fca49
clean
jeremiedbb Mar 14, 2019
5fe7644
fix docstring
jeremiedbb Mar 14, 2019
6694ee1
typo
jeremiedbb Mar 14, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 69 additions & 0 deletions sklearn/metrics/_safe_euclidean_sparse.pyx
@@ -0,0 +1,69 @@
#cython: language_level=3
#cython: boundscheck=False, cdivision=True, wraparound=False


import numpy as np
cimport numpy as np
from cython cimport floating
from libc.math cimport fmax


jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
np.import_array()


ctypedef fused INT:
np.int32_t
np.int64_t


def _euclidean_sparse_dense_exact(floating[::1] X_data,
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
INT[::1] X_indices,
INT[::1] X_indptr,
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
np.ndarray[floating, ndim=2, mode='c'] Y,
floating[::1] y_squared_norms):
cdef:
int n_samples_X = X_indptr.shape[0] - 1
int n_samples_Y = Y.shape[0]
int n_features = Y.shape[1]

int i, j

floating[:, ::1] D = np.empty((n_samples_X, n_samples_Y), Y.dtype)

for i in range(n_samples_X):
for j in range(n_samples_Y):
D[i, j] = _euclidean_sparse_dense_exact_1d(
&X_data[X_indptr[i]],
&X_indices[X_indptr[i]],
X_indptr[i + 1] - X_indptr[i],
&Y[j, 0],
y_squared_norms[j])

return np.asarray(D)


cdef floating _euclidean_sparse_dense_exact_1d(floating *y_data,
INT *y_indices,
int y_nnz,
floating *x,
floating x_squared_norm) nogil:
"""Euclidean distance between x dense and y sparse"""
cdef:
int i
floating xi
floating tmp = 0.0
floating result = 0.0
floating partial_x_squared_norm = 0.0

# Split the loop to avoid unsafe compiler auto optimizations
Copy link

@Celelibi Celelibi Mar 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might say something stupid, but I'm not sure what you're trying to avoid there.
If a compiler optimize agressively enough to try to rearrange your arithmetic operations in an unsafe way, then I don't think fusing two independant loops would be a problem for it.
Additionally, AFAIK, most (if not all) compilers nowadays know pretty well how floating point arithmetic work.
Moreover, I think see a few opportunity to improve the numeric accuracy by fusing the loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also wondering about this -- I imagine, there should be no unsafe optimizations unless you are using non-standard flags (e.g. -Ofast)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that too, but this is not what I observe. I don't really know what's going on. The loop should be:

for i in range(y_nnz):
    xi = x[y_indices[i]]
    tmp = y_data[i] - xi
    result += (tmp * tmp) - (xi * xi)

but even without the -Ofast flag, there's is some kind of optimization because the result is not correct.
It looks like gcc expands the expression

result += (tmp * tmp) - (xi * xi) = (y_data[i] - xi)² -xi²
                                  = y_data[i]² - 2*y_data[i]*xi

which might be wrong in floating point arithmetic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to add a unit test that detects bad numeric optimization!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you do that ?
I added tests for the euclidean distances which failed when I didn't split the loop or when I use the -Ofast flag. Now they pass. Is that what you mean ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parts (terms) of the last equation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar loss likely happens here: (tmp * tmp) - (xi * xi), when tmp and xi are of similar magnitude this equations does have catastrophic cancellation (here this happens when yi is small).

Copy link

@Celelibi Celelibi Mar 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not an expert of floating point arithmetic

Neither am I, so take everything I say here with a grain of salt.

First, I'd like to make a note on the vocabulary around floating point arithmetic.

  • precision usually refer to the number of bits used to make the calculation.
  • accuracy usually refer to how close to the real value a result is.
  • error usually refer to the difference between the real value and what's actually stored.

up to here, there shouldn't be any precision loss.

Well, none your transformations are exactly equivalent. Floating point arithmetic is commutative, but not associative. In general (a+b)+c != a+(b+c). You can't reorder a sum as you wish without changing the result. That might improve the accuracy, but that might also make it worse. You usually want to add values that are not too far from each other in order to avoid cancellation.
As a side note about summation, numpy implement a pairwise summation which has a better accuracy than the straightforward sum and is faster (but less accurate) than the Kahan summation.

Now, we can regroup the 2 loops into:
but this causes loss of precision

What makes you say so? How did you measure it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes you say so? How did you measure it?

well just that the tests fails when I regroup the 2 loops :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well just that the tests fails when I regroup the 2 loops :)

Which test? I can't reproduce.

for i in range(y_nnz):
xi = x[y_indices[i]]
partial_x_squared_norm += xi * xi

for i in range(y_nnz):
tmp = y_data[i] - x[y_indices[i]]
result += tmp * tmp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would using float64 accumulators here help precision and reduce the (low) risk of overflow? Related to #13010

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it force conversion of tmp in result += tmp*tmp in that case ?
also we need to convert the result to float32 when we want float32.

I should test how it impacts performances.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cost is quite high. In this example:

X = np.random.RandomState(0).random_sample((100000, 16)).astype(np.float32) 
Y = np.random.RandomState(1).random_sample((100, 16)).astype(np.float32) 
X[X<0.8] = 0 
X = csr_matrix(X) 

There's a 20% slowdown using a float64 accumulator.

I guess the precision is fine here since it's the exact calculation method.
I'm not sure about the risk of overflow. If the result is too big to be represented in float32, even if you can compute it's correct value on float64, when in the end you downcast it to float32 you'll get an inf anyway.


result += x_squared_norm - partial_x_squared_norm

return fmax(result, 0)
212 changes: 172 additions & 40 deletions sklearn/metrics/pairwise.py
Expand Up @@ -30,6 +30,10 @@
from ..utils._joblib import effective_n_jobs

from .pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
from .pairwise_fast import _euclidean_dense_dense_exact
from .pairwise_fast import _euclidean_dense_dense_fast_sym
from .pairwise_fast import _add_norms
from ._safe_euclidean_sparse import _euclidean_sparse_dense_exact
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a bit surprising that these implementations are in different modules. Why did you want them separated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want them to be separated. I had to separate them because I wanted to compile pairwise_fast with the -ffast-math flag. The thing is that with this flag, gcc performs unsafe optimizations for the _safe_euclidean_sparse functions.

The reason I wanted to use this flag for pairwise_fast is that it's much faster, and should be safe because there's no operation that gcc would optimize in an unsafe way. By the way, scipy uses this flag for its euclidean distance.



# Utility Functions
Expand Down Expand Up @@ -168,20 +172,6 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
Considering the rows of X (and Y=X) as vectors, compute the
distance matrix between each pair of vectors.

For efficiency reasons, the euclidean distance between a pair of row
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
vector x and y is computed as::

dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))

This formulation has two advantages over other ways of computing distances.
First, it is computationally efficient when dealing with sparse data.
Second, if one argument varies but the other remains unchanged, then
`dot(x, x)` and/or `dot(y, y)` can be pre-computed.

However, this is not the most precise way of doing this computation, and
the distance matrix returned by this function may not be exactly
symmetric as required by, e.g., ``scipy.spatial.distance`` functions.

Read more in the :ref:`User Guide <metrics>`.

Parameters
Expand All @@ -193,17 +183,43 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
Y_norm_squared : array-like, shape (n_samples_2, ), optional
Pre-computed dot-products of vectors in Y (e.g.,
``(Y**2).sum(axis=1)``)
May be ignored in some cases, see the note below.

squared : boolean, optional
Return squared Euclidean distances.

X_norm_squared : array-like, shape = [n_samples_1], optional
Pre-computed dot-products of vectors in X (e.g.,
``(X**2).sum(axis=1)``)
May be ignored in some cases, see the note below.

Returns
-------
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
distances : array, shape (n_samples_1, n_samples_2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept that change because it was wrong. The output is never sparse

Note
----
When ``n_features > 32``, the euclidean distance between a pair of row
vector x and y is computed as::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector -> vectors


dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))

This formulation is computationaly more efficient than the usual one and
can benefit from pre-computed ``dot(x, x)`` and/or ``dot(y, y)``. When the
input is stored in float32, computations are done by first upcasting ``X``
and ``Y`` to float64 (by chunks to limit memory usage). In that case,
``X_norm_squared`` and ``Y_norm_squared`` are ignored and computed based on
upcast ``X`` and ``Y`` to keep good precision.

However, this is not the most precise way of doing this computation, and
the distance matrix returned by this function may not be exactly
symmetric as required by, e.g., ``scipy.spatial.distance`` functions.

When ``n_features < 32``, the previous method is not as efficient and is
more likely to suffer from numerical instabilities, so the euclidean
distance between a pair of row vector x and y is computed as::

dist(x, y) = sqrt(dot(x - y))

Examples
--------
Expand All @@ -224,41 +240,157 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
"""
X, Y = check_pairwise_arrays(X, Y)

if X_norm_squared is not None:
XX = check_array(X_norm_squared)
if XX.shape == (1, X.shape[0]):
XX = XX.T
elif XX.shape != (X.shape[0], 1):
raise ValueError(
"Incompatible dimensions for X and X_norm_squared")
else:
XX = row_norms(X, squared=True)[:, np.newaxis]
XX, YY = _check_norms(X, Y, X_norm_squared, Y_norm_squared)

n_features = X.shape[1]

# For n_features > 32 we use the 'fast 'method to compute the euclidean
# distance, i.e. d(x,y)² = ||x||² + ||y||² - 2 * x.y
# It's faster but less precise.
if n_features > 32:
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

if X is Y: # shortcut in the common case euclidean_distances(X, X)
YY = XX.T
elif Y_norm_squared is not None:
YY = np.atleast_2d(Y_norm_squared)
# To minimize precision issues with float32, we compute the distance
# matrix on chunks of X and Y upcast to float64
if X.dtype == np.float32:
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
distances = _euclidean_distances_upcast_fast(X, XX, Y, YY)

if YY.shape != (1, Y.shape[0]):
raise ValueError(
"Incompatible dimensions for Y and Y_norm_squared")
# if dtype is already float64, no need to chunk and upcast
else:
if X is Y and not issparse(X):
# In this case the distance matrix is symmetric, so we only
# need to compute half of it. When X is dense, we can benefit
# from the BLAS triangular matrix matrix multiplication `syrk`.
distances = _euclidean_dense_dense_fast_sym(X, XX)
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
else:
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
_add_norms(distances, XX, YY)
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

# For n_features <= 32, we use the 'exact' method, i.e. the usual method,
# d(x,y)² = ||x - y||².
else:
YY = row_norms(Y, squared=True)[np.newaxis, :]

distances = safe_sparse_dot(X, Y.T, dense_output=True)
distances *= -2
distances += XX
distances += YY
np.maximum(distances, 0, out=distances)
# distances being between rows of X and Y, it's more efficient to work
# on C-contiguous arrays
if not issparse(X):
X = np.asarray(X, order='C')
if not issparse(Y):
Y = np.asarray(Y, order='C')
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

# Euclidean distance between 2 sparse vectors is very slow. It's much
# faster to densify one. We densify the smaller one for lower memory
# usage.
if issparse(X) and issparse(Y):
if Y.shape[0] > X.shape[0]:
X = X.toarray()
else:
Y = Y.toarray()

if issparse(X):
distances = _euclidean_sparse_dense_exact(
X.data, X.indices, X.indptr, Y, YY)
elif issparse(Y):
distances = _euclidean_sparse_dense_exact(
Y.data, Y.indices, Y.indptr, X, XX).T
else:
distances = _euclidean_dense_dense_exact(X, Y)

# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
if X is Y:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
distances.flat[::distances.shape[0] + 1] = 0.0
np.fill_diagonal(distances, 0)

return distances if squared else np.sqrt(distances, out=distances)


def _check_norms(X, Y=None, X_norm_squared=None, Y_norm_squared=None):
n_features = X.shape[1]

if n_features > 32 and X.dtype == np.float32:
# In this case, we compute euclidean distances by upcasting to float64.
# It' necessary to compute the norms on upcast X and not to upcast
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
# the norms computed on X to keep good precision, so we don't use
# provided norms.
return None, None
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
else:
if X_norm_squared is not None:
XX = np.atleast_1d(X_norm_squared).reshape(-1)
if XX.shape != (X.shape[0],):
raise ValueError(
"Incompatible dimensions for X and X_norm_squared")
else:
XX = row_norms(X, squared=True)

if X is Y: # shortcut in the common case euclidean_distances(X, X)
YY = XX
elif Y_norm_squared is not None:
YY = np.atleast_1d(Y_norm_squared).reshape(-1)
if YY.shape != (Y.shape[0],):
raise ValueError(
"Incompatible dimensions for Y and Y_norm_squared")
else:
YY = row_norms(Y, squared=True)

XX = XX.astype(X.dtype, copy=False)
YY = YY.astype(Y.dtype, copy=False)

return XX, YY


def _euclidean_distances_upcast_fast(X, XX, Y, YY):
"""Euclidean distances between X and Y

Assumes X and Y have float32 dtype.
X and Y are upcast to float64 by chunks, which size is chosen to limit
memory increase by approximately 10MiB.
"""
n_samples_X = X.shape[0]
n_samples_Y = Y.shape[0]
n_features = X.shape[1]

distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)

maxmem = 10 * 2**17 # this number of float64 take 10MiB memory.

x_density = X.getnnz() / np.prod(X.shape) if issparse(X) else 1
y_density = Y.getnnz() / np.prod(Y.shape) if issparse(Y) else 1

# The increase amount of memory is:
# - x_density * chunk_size * n_features (copy of chunk of X)
# - y_density * chunk_size * n_features (copy of chunk of Y)
# - chunk_size * chunk_size (chunk of distance matrix)
# Hence x² + (xd+yd)kx = M, where x=chunk_size, k=n_features, M=maxmem
# xd=x_density and yd=y_density
tmp = (x_density + y_density) * n_features
chunk_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
chunk_size = max(int(chunk_size), 1)

n_samples_X_rem = n_samples_X % chunk_size
n_chunks_X = n_samples_X // chunk_size + (n_samples_X_rem != 0)
n_samples_Y_rem = n_samples_Y % chunk_size
n_chunks_Y = n_samples_Y // chunk_size + (n_samples_Y_rem != 0)

for i in range(n_chunks_X):
xs = i * chunk_size
xe = xs + (chunk_size if i < n_chunks_X - 1 else n_samples_X_rem)

X_chunk = X[xs:xe].astype(np.float64)
XX_chunk = row_norms(X_chunk, squared=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you won't use the XX and YY that were passed in?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that special case, I won't indeed. The reason is that I only take that pass when n_features > 32 and dtype=float32 (for float64 no need to upcast so no need to chunk). In that case I can't use norms computed on float32 data. This is the main reason of the loss of precision. So I need to first upcast X and then compute the norms.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct to think that you ask for XY and YY to be passed on, but don't use them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I fixed that. I do use them now if they are in float64.


for j in range(n_chunks_Y):
ys = j * chunk_size
ye = ys + (chunk_size if j < n_chunks_Y - 1 else n_samples_Y_rem)

Y_chunk = Y[ys:ye].astype(np.float64)
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved
YY_chunk = row_norms(Y_chunk, squared=True)

d = - 2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
_add_norms(d, XX_chunk, YY_chunk)

distances[xs:xe, ys:ye] = d.astype(np.float32)
jeremiedbb marked this conversation as resolved.
Show resolved Hide resolved

return distances


def _argmin_min_reduce(dist, start):
indices = dist.argmin(axis=1)
values = dist[np.arange(dist.shape[0]), indices]
Expand Down