New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Providing stable implementation for euclidean_distances #10069
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -222,31 +222,9 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, | |
""" | ||
X, Y = check_pairwise_arrays(X, Y) | ||
|
||
if X_norm_squared is not None: | ||
XX = check_array(X_norm_squared) | ||
if XX.shape == (1, X.shape[0]): | ||
XX = XX.T | ||
elif XX.shape != (X.shape[0], 1): | ||
raise ValueError( | ||
"Incompatible dimensions for X and X_norm_squared") | ||
else: | ||
XX = row_norms(X, squared=True)[:, np.newaxis] | ||
|
||
if X is Y: # shortcut in the common case euclidean_distances(X, X) | ||
YY = XX.T | ||
elif Y_norm_squared is not None: | ||
YY = np.atleast_2d(Y_norm_squared) | ||
|
||
if YY.shape != (1, Y.shape[0]): | ||
raise ValueError( | ||
"Incompatible dimensions for Y and Y_norm_squared") | ||
else: | ||
YY = row_norms(Y, squared=True)[np.newaxis, :] | ||
diff = X.reshape(-1, 1, X.shape[1]) - Y.reshape(1, -1, X.shape[1]) | ||
distances = np.sum(np.power(diff, 2), axis=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please try using np.dot to calculate the sum of squares? |
||
|
||
distances = safe_sparse_dot(X, Y.T, dense_output=True) | ||
distances *= -2 | ||
distances += XX | ||
distances += YY | ||
np.maximum(distances, 0, out=distances) | ||
|
||
if X is Y: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -384,6 +384,15 @@ def test_euclidean_distances(): | |
assert_array_almost_equal(D, [[1., 2.]]) | ||
|
||
rng = np.random.RandomState(0) | ||
#check if it works for float32 | ||
X = rng.rand(1,3000000).astype(np.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a specific reason for such a complex test? It seems that the current problem can be reproduced with very simple array, e,g arr1_32 = np.array([555.5, 666.6, 777.7], dtype=np.float32)
arr2_32 = np.array([555.6, 666.7, 777.8], dtype=np.float32)
arr1_64 = np.array([555.5, 666.6, 777.7], dtype=np.float64)
arr2_64 = np.array([555.6, 666.7, 777.8], dtype=np.float64)
euclidean_distances(arr1_32.reshape(1, -1), arr2_32.reshape(1, -1))
# array([[ 0.17319804]], dtype=float32)
euclidean_distances(arr1_64.reshape(1, -1), arr2_64.reshape(1, -1))
# array([[ 0.17320508]]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well I thought it would be better to show that it fails with a random array. No particular reason though. I'll change it to a simpler test if it's a problem. |
||
Y = rng.rand(1,3000000).astype(np.float32) | ||
#answer computed by np.linalg.norm | ||
dist_np = np.linalg.norm(X-Y)[np.newaxis, np.newaxis] | ||
#answer computed by scikit-learn | ||
dist_sk = euclidean_distances(X, Y) | ||
assert_almost_equal(dist_sk, dist_np) | ||
|
||
X = rng.random_sample((10, 4)) | ||
Y = rng.random_sample((20, 4)) | ||
X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to support sparse matrices, which this does not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok I'll fix this