Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX euclidean_distances float32 numerical instabilities #13554

Merged
merged 105 commits into from Apr 29, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
d38d8a0
vect vect euclidean
jeremiedbb Mar 4, 2019
01767d6
minmem
jeremiedbb Apr 1, 2019
83cada9
docstring
jeremiedbb Apr 1, 2019
ecb3d2c
lint
jeremiedbb Apr 1, 2019
6b140db
unrelated
jeremiedbb Apr 1, 2019
558b988
update tests + fix int
jeremiedbb Apr 1, 2019
de9d217
fix chunk size
jeremiedbb Apr 1, 2019
395be7a
fix sparse slice bounds
jeremiedbb Apr 2, 2019
9c79570
fix what's new
jeremiedbb Apr 2, 2019
4599379
tmp
jeremiedbb Apr 5, 2019
7425821
use gen_batches
jeremiedbb Apr 5, 2019
a2504af
nitpick
jeremiedbb Apr 5, 2019
fee1258
adress what's comments + clean naming
jeremiedbb Apr 12, 2019
fe74973
clearer comment
jeremiedbb Apr 15, 2019
7f5f257
Update doc/whats_new/v0.21.rst
glemaitre Apr 25, 2019
46fb590
Update doc/whats_new/v0.21.rst
glemaitre Apr 25, 2019
d8a2341
Update doc/whats_new/v0.21.rst
glemaitre Apr 25, 2019
e0a0dc5
DOC Fix missing space after backquotes (#13551)
Framartin Apr 1, 2019
4b6e707
FIX Explicitly ignore SparseEfficiencyWarning in DBSCAN (#13539)
peay Apr 1, 2019
dd634db
FIX _estimate_mi discrete_features str and value check (#13497)
hermidalc Apr 1, 2019
0e916a5
DOC Changed the docstring of class LinearSVR to reflect the default v…
mani2106 Apr 2, 2019
78cf2c6
DOC Added doc on how to override estimator tags (#13550)
NicolasHug Apr 2, 2019
1a96e1e
MNT Delete _scipy_sparse_lsqr_backport.py (#13569)
qinhanmin2014 Apr 3, 2019
46c4560
DOC Clarify eps parameter importance in DBSCAN (#13563)
kno10 Apr 3, 2019
dbbd855
DOC Minor typo in GridSearchCV (#13571)
NicolasHug Apr 3, 2019
77b456f
DOC add NicolasHug and thomasjpfan in authors list
TomDLT Apr 3, 2019
6dace0c
[MRG] Changed name n_components to n_connected_components in Agglomer…
Apr 4, 2019
37ef0c3
FIX non deterministic behavior in SGD (#13422)
ClemDoum Apr 4, 2019
1d03e13
ENH Allow nd array for CalibratedClassifierCV (#13485)
wdevazelhes Apr 4, 2019
1d5dca9
DOC reference tree structure example from user guide (#13561)
mani2106 Apr 4, 2019
bf1af7d
DOC correct reference to target in load_linnerud docstring (#13577)
mwestt Apr 4, 2019
b4fb670
FIX DummyEstimator when y is a 2d column vector (#13545)
adrinjalali Apr 5, 2019
2de3969
FIX Take sample weights into account in partial dependence computatio…
samronsin Apr 5, 2019
800ecae
DOC Add lucidfrontier to the emeritus core devs (#13579)
amueller Apr 6, 2019
5eff972
FIX MissingIndicator explicit zeros & output shape (#13562)
jeremiedbb Apr 6, 2019
5ebcef0
ENH cross_val_predict now handles multi-output predict_proba (#8773)
Apr 6, 2019
532f110
Use fixed random seed for generating X in test_mlp.test_gradient() (#…
aditya1702 Apr 7, 2019
37ad253
DOC fix typo in comments for svm/classes.py (#13589)
kfrncs Apr 7, 2019
984719c
DOC fix typo in contributing.rst (#13593)
mabdelaal86 Apr 7, 2019
d0b7441
Improve pipeline parameter error msg (#13536)
okz12 Apr 9, 2019
14a3f16
FEA Parameter for stacking missing indicator into imputer (#12583)
DanilBaibak Apr 9, 2019
1fd5b71
ENH Convert y in GradientBoosting to float64 instead of float32 (#13524)
adrinjalali Apr 9, 2019
7558d50
TST Fixes to make test_pprint.py more resilient to change (#13529)
Apr 9, 2019
b4d0527
FIX Fixed array equality check in pprint (#13584)
NicolasHug Apr 9, 2019
adceb7d
FEA VotingRegressor (#12513)
Apr 10, 2019
127cc41
CI update PyPy image to pypy-3-7.0.0 (#13600)
adrinjalali Apr 10, 2019
37b099b
Fix empty clusters not correctly relocated when using sample_weight(#…
jeremiedbb Apr 12, 2019
bdea46f
DOC Add project_urls to setup.py (#13623)
jarrodmillman Apr 12, 2019
d91682a
MNT Use a common language_level cython directive (#13630)
jeremiedbb Apr 13, 2019
d7815e3
[MRG] DOC Correct coef_ shape in RidgeClassifier (#13633)
qinhanmin2014 Apr 13, 2019
8a966f4
DOC Adds recent core devs to _contributors.rst (#13640)
thomasjpfan Apr 14, 2019
96ff3b0
DOC t-SNE perplexity docstring update (#13069)
kjacks21 Apr 14, 2019
c21be57
FIX feature importances in random forests sum up to 1 (#13636)
adrinjalali Apr 15, 2019
c3e72f2
DOC Removed a typo from the examples of normalized_mutual_info_score …
jfbeaumont Apr 15, 2019
78cb1b3
MAINT: n_jobs=-1 replaced with n_jobs=4 in tests (#13644)
oleksandr-pavlyk Apr 15, 2019
0e4f561
Add parameter for stacking missing indicator into iterative imputer (…
DanilBaibak Apr 15, 2019
383b132
FIX ignore single node trees in gbm's feature importances (#13620)
adrinjalali Apr 16, 2019
3801caf
DOC typo in sklearn.utils.extmath.weighted_mode (#13655)
Masstran Apr 16, 2019
5dc1c46
ENH Extension of v_measure_score metric to include beta parameter (#1…
Apr 16, 2019
b928396
BUG Fix missing 'const' in a few memoryview declaration in trees. (#1…
jeremiedbb Apr 16, 2019
b957edd
BLD: check OpenMP support and add a switch to build without it (#13543)
jeremiedbb Apr 16, 2019
c0fb225
FIX initialise Birch centroid_ in all cases (#13651)
jnothman Apr 16, 2019
95d39ca
DOC Improve performance of the plot_rbm_logistic_classification.py ex…
Framartin Apr 17, 2019
f11f647
MNT Import linear assignment from scipy (#13465)
praths007 Apr 17, 2019
306d202
DOC Fixes formatting issue in webkit (#13657)
thomasjpfan Apr 17, 2019
78c9cb5
Fixing parameter description (for assume_centered) (#13456)
falaktheoptimist Apr 17, 2019
66fa659
MAINT Unvendor joblib (#13531)
rth Apr 17, 2019
f229708
Improve comment in setup.py (#13661)
kfrncs Apr 17, 2019
6b29d38
MAINT Replace absolute imports with relative imports (#13653)
aditya1702 Apr 17, 2019
876908e
[MRG] Fix various solver issues in ridge_regression and Ridge classes…
btel Apr 18, 2019
04fcbce
[MRG + 1] Fix pprint ellipsis (#13436)
NicolasHug Apr 18, 2019
cc2c186
DOC Remove space in "cross-entropy" (#13671)
orestisfl Apr 18, 2019
02e864e
FIX broken references in documentation (#13664)
ogrisel Apr 18, 2019
740c0fb
DOC Remove synonyms in documentation of linear models (#13663)
tommyod Apr 18, 2019
247f0e7
MNT Correctly handle deprecated attribute warnings and docstrings (#1…
NicolasHug Apr 18, 2019
62b5a85
Deprecate "warn_on_dtype" from check_array (#13382)
praths007 Apr 19, 2019
3b54222
Fix MultiOutputClassifier checking for predict_proba method of base e…
rebekahkim Apr 19, 2019
63cd7d4
fix small latex issue (#13680)
NicolasHug Apr 19, 2019
e8be8aa
Fix sample_weight in label_ranking_average_precision_score (#13447)
dpwe Apr 20, 2019
78acc98
DOC Emeritus core devs final call (#13673)
amueller Apr 20, 2019
2a2caff
ENH Add verbose option to Pipeline, FeatureUnion, and ColumnTransform…
thomasjpfan Apr 21, 2019
5a1f05a
MNT Minor clean up in OneVsOneClassifier (#13677)
qinhanmin2014 Apr 22, 2019
e5a8e0b
DOC: Fixes for latest numpydoc (#13670)
larsoner Apr 22, 2019
0ca7c96
Typo (#13693)
jarrodmillman Apr 22, 2019
7005ab4
FIX make sure vectorizers read data from file before analyzing (#13641)
adrinjalali Apr 23, 2019
8e3cc5e
DOC: Add SLEP and Governance in Dev Docs (#13688)
bharatr21 Apr 23, 2019
18dc3a9
MAINT: minor fix to whats_new (#13695)
adrinjalali Apr 23, 2019
32851c1
DOC Describe what's new categories (#13697)
jnothman Apr 23, 2019
f818d86
MNT Minor clean up in OneVsRestClassifier (#13675)
qinhanmin2014 Apr 23, 2019
32ea647
MAINT: reduce example execution time of plot_image_denoising.py (#13683)
xinyuliu12 Apr 23, 2019
9cf6606
DOC better wording in changelog legend
jnothman Apr 24, 2019
8d4bffc
ENH Support Haversine distance in pairwise_distances (#12568)
eamanu Apr 24, 2019
5135252
FEA Partial dependence plots (#12599)
NicolasHug Apr 24, 2019
e15e391
MAINT Uses debian stretch for circleci doc building (#13642)
thomasjpfan Apr 24, 2019
9a5b3ac
FEA Add a stratify option to utils.resample (#13549)
NicolasHug Apr 24, 2019
77ac3df
additional tests for mean_shift algo (#13179)
rajdeepd Apr 25, 2019
184cd2c
DOC use :pr: rather than :issue: in what's new (#13701)
jnothman Apr 25, 2019
fbcfc52
DOC fix remaining :issue: (#13716)
glemaitre Apr 25, 2019
e5e2848
vect vect euclidean
jeremiedbb Mar 4, 2019
32c9f99
merge
jeremiedbb Apr 25, 2019
0f346f5
adress comments
jeremiedbb Apr 25, 2019
fce0713
Merge remote-tracking branch 'upstream/master' into euclidean_dist_up…
jeremiedbb Apr 25, 2019
117529e
astype=False per Roman's comment
jnothman Apr 29, 2019
9c205c1
Mib -> MiB
jeremiedbb Apr 29, 2019
985037a
back to copy=False
jeremiedbb Apr 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 8 additions & 3 deletions doc/whats_new/v0.21.rst
Expand Up @@ -347,9 +347,14 @@ Support for Python 3.4 and below has been officially dropped.
:issue:`12855` by :user:`Pawel Sendyk <psendyk>`.

- |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated
in version 0.21 and will be removed in version 0.23.
:issue:`10580` by :user:`Reshama Shaikh <reshamas>` and :user:`Sandra
Mitrovic <SandraMNE>`.
in version 0.21 and will be removed in version 0.23. :issue:`10580` by
:user:`Reshama Shaikh <reshamas>` and :user:`Sandra Mitrovic <SandraMNE>`.

- |Fix| The function :func:`euclidean_distances`, and therefore the functions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe more relevant: "and therefore several estimators with metric='euclidean'"

:func:`pairwise_distances` and :func:`pairwise_distances_chunked` with
``metric=euclidean``, suffered from numerical precision issues. Precision has
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"with float32 features"

been increased for float32. :issue:`13554` by :user:` <Celelibi>` and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think correct syntax is just :user:`Celelibi`

:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |API| :func:`metrics.jaccard_similarity_score` is deprecated in favour of
the more consistent :func:`metrics.jaccard_score`. The former behavior for
Expand Down
99 changes: 90 additions & 9 deletions sklearn/metrics/pairwise.py
Expand Up @@ -203,7 +203,7 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,

Returns
-------
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
distances : array, shape (n_samples_1, n_samples_2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put the comment here but I think that it should be written above.

I would either add a note (using sphinx syntax) or within the parameter description that X_norm_squared and Y_norm_squared will be ignored if they are passed as float32 since they will lead to inaccurate distances.


Examples
--------
Expand Down Expand Up @@ -231,34 +231,115 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
elif XX.shape != (X.shape[0], 1):
raise ValueError(
"Incompatible dimensions for X and X_norm_squared")
if XX.dtype == np.float32:
XX = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise a warning that these cannot be used??

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not but I'm afraid it could generate a lot of warnings. For instance when pairwise_distances is called in a loop.

One thing we should certainly do is adding an option to row_norms to use a float64 accumulator. That's on my todo list :)

elif X.dtype == np.float32:
XX = None
else:
XX = row_norms(X, squared=True)[:, np.newaxis]

if X is Y: # shortcut in the common case euclidean_distances(X, X)
if X is Y and XX is not None:
# shortcut in the common case euclidean_distances(X, X)
YY = XX.T
elif Y_norm_squared is not None:
YY = np.atleast_2d(Y_norm_squared)

if YY.shape != (1, Y.shape[0]):
raise ValueError(
"Incompatible dimensions for Y and Y_norm_squared")
if YY.dtype == np.float32:
YY = None
elif Y.dtype == np.float32:
YY = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance, it is not obvious that None is used by _euclidean_distances_upcast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean ? XX and YY are used by _euclidean_distances_upcast no matter their value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting XX or YY to None tells _euclidean_distances_upcast to call row_norms(*_chunk, squared=True)[:, np.newaxis].

Without reading _euclidean_distances_upcast, it is difficult to tell why XX or YY are set to None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we would need a small comment to clarify this part.

else:
YY = row_norms(Y, squared=True)[np.newaxis, :]

distances = safe_sparse_dot(X, Y.T, dense_output=True)
distances *= -2
distances += XX
distances += YY
if X.dtype == np.float32:
# To minimize precision issues with float32, we compute the distance
# matrix on chunks of X and Y upcast to float64
distances = _euclidean_distances_upcast_fast(X, XX, Y, YY)
else:
# if dtype is already float64, no need to chunk and upcast
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
distances += XX
distances += YY
np.maximum(distances, 0, out=distances)

# Ensure that distances between vectors and themselves are set to 0.0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this out of the if? It's only applicable if the if holds.

(Curiously we currently appear to duplicate this logic in pairwise_distances_chunked but not in pairwise_distances)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean zeroing the diagonal ? It applies using both methods.

It's necessary to also do it in pairwise_distances_chunked, because we lose the diagonal info when passing chunks to pairwise_distances.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I don't think we zero the diagonal in pairwise_distances at all for other metrics

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at euclidean_distances, not pairwise_distances... That's strange indeed. Even for metric=euclidean it's not enforced when n_jobs > 1. Let me put that also on my todo list :)

# This may not be the case due to floating point rounding errors.
if X is Y:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
distances.flat[::distances.shape[0] + 1] = 0.0
np.fill_diagonal(distances, 0)

return distances if squared else np.sqrt(distances, out=distances)


def _euclidean_distances_upcast_fast(X, XX=None, Y=None, YY=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"_fast" is a bit obscure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a residue of my previous PR implementing fast and exact method. I removed it.

"""Euclidean distances between X and Y

Assumes X and Y have float32 dtype.
Assumes XX and YY have float64 dtype or are None.

X and Y are upcast to float64 by chunks, which size is chosen to limit
memory increase by approximately 10%.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
memory increase by approximately 10%.
memory increase by approximately 10% (bounded to 10Mb).

"""
n_samples_X = X.shape[0]
n_samples_Y = Y.shape[0]
n_features = X.shape[1]

distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)

x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1

# Allow 10% more memory than X, Y and the distance matrix take (at least
# 10Mib)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use MiB (or MB) as all other cache sizes in scikit-learn? Mib unit can lead to confusion when comparing to the CPU L3 cache size or to working_memory.

maxmem = max(
((x_density * n_samples_X + y_density * n_samples_Y) * n_features
+ (x_density * n_samples_X * y_density * n_samples_Y)) / 10,
Copy link
Member

@rth rth Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the assumption that sizeof(dtype) == 8 without making it appear in the computation, I think?

Maybe we could add X.dtype.itemsize / 8 even if that is equal to 1, to make it easier to follow what is going on.

Also not that it matters too much, but

  1. for dense isn't that its 5% of (X, Y, distance) as those are float32. i.e. X.nbytes == x_density * n_samples_X * 4 / 8?
  2. for sparse it's a roughly 10% if it is CSR and we account for X.data, X.indptr in 32 bit, but again it's far from obvious for a casual reader.

Actually, maybe,

def get_array_nbytes(X):
   if issparse(X):
      if hasattr(X, 'indices'):
         # CSR or CSC
     	 return X.data.nbytes  + X.data.nbytes + X.data.indptr
      else:
         # some other sparse format, assume 8 bytes to index
         # one non zero element (e.g. COO)
         return X.data.nbytes + X.nnz*8
   else:
       return X.nbytes

maxmem = (get_array_size(X) + get_array_size(Y) + get_array_size(distances)) / 10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the assumption that sizeof(dtype) == 8 without making it appear in the computation, I think?

This function is only used for float32 X and Y. I've explained it in its docstring.

for sparse it's a roughly 10% if it is CSR and we account for X.data, X.indptr in 32 bit, but again it's far from obvious for a casual reader.

When you change the type of the csr matrix, only a copy of data is made. indices and indptr are still int, so I think this formula is correct

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you change the type of the csr matrix, only a copy of data is made.

Are you sure?

>>> import scipy.sparse
>>> import numpy as nop
>>> import numpy as np
>>> X = scipy.sparse.csr_matrix(np.ones((10, 10)))
>>> X.dtype
dtype('float64')
>>> Y = X.astype(np.float32)
>>> np.shares_memory(Y.data, Y.data)
True
>>> np.shares_memory(Y.data, X.data)
False
>>> np.shares_memory(Y.indices, X.indices)
False

That's not critical, my point here is that it might be beter (not necessarily in this PR) to use ndarray.nbytes than trying to re-compute that from scratch. If we get this calculation wrong we won't likely know, since no test will fail and we will only reason in wrong sizes of memory cache.

10 * 2**17)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this choice?
In my tests for #11271 I found that the computation time slowly increase with the memory size. It's the first few plots of this comment: #11271 (comment).
I'm mostly curious about this choice. I don't think this should prevent this PR from being merged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fixed 10Mib maxmem can lead to very small chunks when the number of features is very large (extreme case I agree). In that case the drop of performance can be quite severe.

Another possibility could be to have a fixed maxmem with a min chunk_size. I think it's kind of equivalent, but in this case you can have a memory increase bigger than expected whereas in the previous case the memory increase is bounded. That's why I went for the first one.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large chunks would also decrease cache efficiency. Although I'm not a big fan of having an unbounded amount of additional memory, 10% of the memory already used is probably fair.

When I look at my old benchmarks, I think that there might be something going on under the hood. Maybe the optimal block size is actually constant. Moreover, if the number of features is large, it can also be chunked. That might deserve to be explored as well.
I think we can add to the todo-list to benchmark the influence of the memory size and chunking to get a proper understanding of the performance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the input data is memmapped, calculating this as a function of n_features may not be fair.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my comment about memmapping was incorrect in any case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems to be more complicated than necessary, but okay.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine to merge with this heuristic.

In the future, I am wondering if we could expose a chunk_size parameter which by default do such heuristic and could be fine-tuned by the user. @jeremiedbb mentioned to me that it might not be that easy since it will impact pairwise_distance where we would need to give also this parameter, etc. Just a thought.


# The increase amount of memory is:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add "in 8-byte blocks"... in fact where do you account for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxmem is the max allowed number of 8-byte blocks for extra memory usage.
10Mib is 10 * 2**20 bits = 10 * 2**17 8-byte blocks.

# - x_density * batch_size * n_features (copy of chunk of X)
# - y_density * batch_size * n_features (copy of chunk of Y)
# - batch_size * batch_size (chunk of distance matrix)
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a grade school student ask, "Why are we learning the quadratic equation?", we have an answer.

# xd=x_density and yd=y_density
tmp = (x_density + y_density) * n_features
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = max(int(batch_size), 1)
rth marked this conversation as resolved.
Show resolved Hide resolved

x_batches = gen_batches(X.shape[0], batch_size)
y_batches = gen_batches(Y.shape[0], batch_size)

for i, x_slice in enumerate(x_batches):
X_chunk = X[x_slice].astype(np.float64)
if XX is None:
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
else:
XX_chunk = XX[x_slice]

for j, y_slice in enumerate(y_batches):
if X is Y and j < i:
# when X is Y the distance matrix is symmetric so we only need
# to compute half of it.
d = distances[y_slice, x_slice].T
Copy link
Member

@rth rth Apr 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be correct since the corresponding tests pass, but after thinking about it for 20s I'm still not confident that when we compute distances[x_slice, y_slice], distances[y_slice, x_slice] is already computed, is not null and can be copied from.

How about just adding a continue here, then add a second loop below,

if X is Y:
   # the result is symmetric, copy distances from the lower triangular matrix
	for i, x_slice in enumerate(x_batches):
    	  for j, y_slice in enumerate(y_batches):
              if j < i:
              	  distances[x_slice, y_slice] = distances[x_slice, y_slice].T

it shouldn't matter for run time too much time I think, but it might make things a bit more explicit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not confident that when we compute distances[x_slice, y_slice], distances[y_slice, x_slice] is already computed

It is since we compute the chunks with i (rows) as the outer loop. It implies that all chunks of the upper right triangular part of the distance matrix are computed before their symmetric chunk in the lower left triangular part.


else:
Y_chunk = Y[y_slice].astype(np.float64)
if YY is None:
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
else:
YY_chunk = YY[:, y_slice]

d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
d += XX_chunk
d += YY_chunk

distances[x_slice, y_slice] = d.astype(np.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add copy=False, in the case when d = distances[y_slice, x_slice].T it is already float32, and astype make a copy by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function only computes euclidean distances on float64 and downcast it to float32 at the end so actually there will always be a copy. Actually I find it more clear to not add the copy=False parameter to emphasize that


return distances


def _argmin_min_reduce(dist, start):
indices = dist.argmin(axis=1)
values = dist[np.arange(dist.shape[0]), indices]
Expand Down
113 changes: 93 additions & 20 deletions sklearn/metrics/tests/test_pairwise.py
Expand Up @@ -540,41 +540,114 @@ def test_pairwise_distances_chunked():
assert_raises(StopIteration, next, gen)


def test_euclidean_distances():
# Check the pairwise Euclidean distances computation
X = [[0]]
Y = [[1], [2]]
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_known_result(x_array_constr, y_array_constr):
# Check the pairwise Euclidean distances computation on known result
X = x_array_constr([[0]])
Y = y_array_constr([[1], [2]])
D = euclidean_distances(X, Y)
assert_array_almost_equal(D, [[1., 2.]])
assert_allclose(D, [[1., 2.]])

X = csr_matrix(X)
Y = csr_matrix(Y)
D = euclidean_distances(X, Y)
assert_array_almost_equal(D, [[1., 2.]])

@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_with_norms(dtype, y_array_constr):
# check that we still get the right answers with {X,Y}_norm_squared
# and that we get a wrong answer with wrong {X,Y}_norm_squared
rng = np.random.RandomState(0)
X = rng.random_sample((10, 4))
Y = rng.random_sample((20, 4))
X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1)
Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1)
X = rng.random_sample((10, 10)).astype(dtype, copy=False)
Y = rng.random_sample((20, 10)).astype(dtype, copy=False)

# norms will only be used if their dtype is float64
X_norm_sq = (X.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
Y_norm_sq = (Y.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)

Y = y_array_constr(Y)

# check that we still get the right answers with {X,Y}_norm_squared
D1 = euclidean_distances(X, Y)
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
Y_norm_squared=Y_norm_sq)
assert_array_almost_equal(D2, D1)
assert_array_almost_equal(D3, D1)
assert_array_almost_equal(D4, D1)
assert_allclose(D2, D1)
assert_allclose(D3, D1)
assert_allclose(D4, D1)

# check we get the wrong answer with wrong {X,Y}_norm_squared
X_norm_sq *= 0.5
Y_norm_sq *= 0.5
wrong_D = euclidean_distances(X, Y,
X_norm_squared=np.zeros_like(X_norm_sq),
Y_norm_squared=np.zeros_like(Y_norm_sq))
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
with pytest.raises(AssertionError):
assert_allclose(wrong_D, D1)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances(dtype, x_array_constr, y_array_constr):
# check that euclidean distances gives same result as scipy cdist
# when X and Y != X are provided
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
X[X < 0.8] = 0
Y = rng.random_sample((10, 10)).astype(dtype, copy=False)
Y[Y < 0.8] = 0

expected = cdist(X, Y)

X = x_array_constr(X)
Y = y_array_constr(Y)
distances = euclidean_distances(X, Y)

assert_allclose(distances, expected, rtol=1e-6)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small comments regarding rtol=1e-6

assert distances.dtype == dtype


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
ids=["dense", "sparse"])
def test_euclidean_distances_sym(dtype, x_array_constr):
# check that euclidean distances gives same result as scipy pdist
# when only X is provided
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
X[X < 0.8] = 0

expected = squareform(pdist(X))

X = x_array_constr(X)
distances = euclidean_distances(X)

assert_allclose(distances, expected, rtol=1e-6)
assert distances.dtype == dtype


@pytest.mark.parametrize("dtype, s",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could have a better name for s

[(np.float32, 1e-4),
(np.float64, 1e-8)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(np.float64, 1e-8)])
pytest.param(np.float64, 1e-8, marks=pytest.mark.xfail('failing due to lack of precision'))])

@pytest.mark.parametrize("dim", [1, 1000000])
def test_euclidean_distances_extreme_values(dtype, s, dim):
# check that euclidean distances is correct where 'fast' method wouldn't be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the comment clearer

# on float32 thanks to upcasting and is not on float64.
X = np.array([[1.] * dim], dtype=dtype)
Y = np.array([[1. + s] * dim], dtype=dtype)

distances = euclidean_distances(X, Y)
expected = cdist(X, Y)

if dtype == np.float64:
# This is expected to fail for float64 due to lack of precision of the
# fast method of euclidean distances.
with pytest.raises(AssertionError, match='Not equal to tolerance'):
assert_allclose(distances, expected, rtol=(1 - 0.001))
else:
assert_allclose(distances, expected, rtol=1e-5)


def test_cosine_distances():
Expand Down