Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical precision of euclidean_distances with float32 #9354

Closed
mikeroberts3000 opened this issue Jul 13, 2017 · 102 comments · Fixed by #13554
Closed

Numerical precision of euclidean_distances with float32 #9354

mikeroberts3000 opened this issue Jul 13, 2017 · 102 comments · Fixed by #13554

Comments

@mikeroberts3000
Copy link

mikeroberts3000 commented Jul 13, 2017

Description

I noticed that sklearn.metrics.pairwise.pairwise_distances function agrees with np.linalg.norm when using np.float64 arrays, but disagrees when using np.float32 arrays. See the code snippet below.

Steps/Code to Reproduce

import numpy as np
import scipy
import sklearn.metrics.pairwise

# create 64-bit vectors a and b that are very similar to each other
a_64 = np.array([61.221637725830078125, 71.60662841796875,    -65.7512664794921875],  dtype=np.float64)
b_64 = np.array([61.221637725830078125, 71.60894012451171875, -65.72847747802734375], dtype=np.float64)

# create 32-bit versions of a and b
a_32 = a_64.astype(np.float32)
b_32 = b_64.astype(np.float32)

# compute the distance from a to b using numpy, for both 64-bit and 32-bit
dist_64_np = np.array([np.linalg.norm(a_64 - b_64)], dtype=np.float64)
dist_32_np = np.array([np.linalg.norm(a_32 - b_32)], dtype=np.float32)

# compute the distance from a to b using sklearn, for both 64-bit and 32-bit
dist_64_sklearn = sklearn.metrics.pairwise.pairwise_distances([a_64], [b_64])
dist_32_sklearn = sklearn.metrics.pairwise.pairwise_distances([a_32], [b_32])

# note that the 64-bit sklearn results agree exactly with numpy, but the 32-bit results disagree
np.set_printoptions(precision=200)

print(dist_64_np)
print(dist_32_np)
print(dist_64_sklearn)
print(dist_32_sklearn)

Expected Results

I expect that the results from sklearn.metrics.pairwise.pairwise_distances would agree with np.linalg.norm for both 64-bit and 32-bit. In other words, I expect the following output:

[ 0.0229059506440019884643266578905240749008953571319580078125]
[ 0.02290595136582851409912109375]
[[ 0.0229059506440019884643266578905240749008953571319580078125]]
[[ 0.02290595136582851409912109375]]

Actual Results

The code snippet above produces the following output for me:

[ 0.0229059506440019884643266578905240749008953571319580078125]
[ 0.02290595136582851409912109375]
[[ 0.0229059506440019884643266578905240749008953571319580078125]]
[[ 0.03125]]

Versions

Darwin-16.6.0-x86_64-i386-64bit
('Python', '2.7.11 | 64-bit | (default, Jun 11 2016, 03:41:56) \n[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)]')
('NumPy', '1.11.3')
('SciPy', '0.19.0')
('Scikit-Learn', '0.18.1')
@nvauquie
Copy link

Same results with python 3.5 :

Darwin-15.6.0-x86_64-i386-64bit
Python 3.5.1 (v3.5.1:37a07cee5969, Dec  5 2015, 21:12:44) 
[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)]
NumPy 1.11.0
SciPy 0.18.1
Scikit-Learn 0.17.1

It happens only with euclidean distance and can be reproduced using directly sklearn.metrics.pairwise.euclidean_distances :

import scipy
import sklearn.metrics.pairwise

# create 64-bit vectors a and b that are very similar to each other
a_64 = np.array([61.221637725830078125, 71.60662841796875,    -65.7512664794921875],  dtype=np.float64)
b_64 = np.array([61.221637725830078125, 71.60894012451171875, -65.72847747802734375], dtype=np.float64)

# create 32-bit versions of a and b
a_32 = a_64.astype(np.float32)
b_32 = b_64.astype(np.float32)

# compute the distance from a to b using sklearn, for both 64-bit and 32-bit
dist_64_sklearn = sklearn.metrics.pairwise.euclidean_distances([a_64], [b_64])
dist_32_sklearn = sklearn.metrics.pairwise.euclidean_distances([a_32], [b_32])

np.set_printoptions(precision=200)

print(dist_64_sklearn)
print(dist_32_sklearn)

I couldn't track down further the error.
I hope this can help.

@amueller amueller added this to the 0.20 milestone Jul 18, 2017
@jnothman
Copy link
Member

jnothman commented Jul 18, 2017 via email

@osaid-r
Copy link
Contributor

osaid-r commented Sep 21, 2017

I'd like to work on this if possible

@lesteve
Copy link
Member

lesteve commented Sep 21, 2017

Go for it!

@osaid-r
Copy link
Contributor

osaid-r commented Sep 24, 2017

So I think the problem lies around the fact that we are using sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) for computing euclidean distance
Because if I try - (-2 * np.dot(X, Y.T) + (X * X).sum(axis=1) + (Y * Y).sum(axis=1) I get the answer 0 for np.float32, while I get the correct ans for np.float 64.

@osaid-r
Copy link
Contributor

osaid-r commented Sep 28, 2017

@jnothman What do you think I should do then ? As mentioned in my comment above the problem is probably computing euclidean distance using sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))

@jnothman
Copy link
Member

jnothman commented Oct 3, 2017

So you're saying that dot is returning a less precise result than product-then-sum?

@osaid-r
Copy link
Contributor

osaid-r commented Oct 3, 2017

No, what I'm trying to say is dot is returning more precise result than product-then-sum
-2 * np.dot(X, Y.T) + (X * X).sum(axis=1) + (Y * Y).sum(axis=1) gives output [[0.]]
while np.sqrt(((X-Y) * (X-Y)).sum(axis=1)) gives output [ 0.02290595]

@lesteve
Copy link
Member

lesteve commented Oct 3, 2017

It is not clear what you are doing, partly because you are not posting a fully stand-alone snippet.

Quickly looking at your last post the two things you are trying to compare [[0.]] and [0.022...] do not have the same dimensions (maybe a copy and paste problem but again hard to know because we don't have a full snippet).

@osaid-r
Copy link
Contributor

osaid-r commented Oct 3, 2017

Ok sorry my bad

import numpy as np
import scipy
from sklearn.metrics.pairwise import check_pairwise_arrays, row_norms
from sklearn.utils.extmath import safe_sparse_dot

# create 64-bit vectors a and b that are very similar to each other
a_64 = np.array([61.221637725830078125, 71.60662841796875,    -65.7512664794921875],  dtype=np.float64)
b_64 = np.array([61.221637725830078125, 71.60894012451171875, -65.72847747802734375], dtype=np.float64)

# create 32-bit versions of a and b
X = a_64.astype(np.float32)
Y = b_64.astype(np.float32)

X, Y = check_pairwise_arrays(X, Y)
XX = row_norms(X, squared=True)[:, np.newaxis]
YY = row_norms(Y, squared=True)[np.newaxis, :]

#Euclidean distance computed using product-then-sum
distances = safe_sparse_dot(X, Y.T, dense_output=True)
distances *= -2
distances += XX
distances += YY
print(np.sqrt(distances))

#Euclidean distance computed using (X-Y)^2
print(np.sqrt(row_norms(X-Y, squared=True)[:, np.newaxis]))

OUTPUT

[[ 0.03125]]
[[ 0.02290595136582851409912109375]]

The first method is how it is computed by the euclidean distance function.
Also to clarify what I meant above was the fact that sum-then-product has lower precision even when we use numpy functions to do it

@jnothman
Copy link
Member

jnothman commented Oct 3, 2017 via email

@jnothman
Copy link
Member

jnothman commented Oct 3, 2017 via email

@osaid-r
Copy link
Contributor

osaid-r commented Oct 5, 2017

So for this example product-then-sum works perfectly fine for np.float64, so a possible solution could be to convert the input to float64 then compute the result and return the result converted back to float32. I guess this would be more efficient, but not sure if this would work fine for some other example.

@jnothman
Copy link
Member

jnothman commented Oct 7, 2017 via email

@osaid-r
Copy link
Contributor

osaid-r commented Oct 9, 2017

Oh yeah you are right sorry about that, but I think using float64 and then doing product-then-sum would be more efficient computationally if not memory wise.

@osaid-r
Copy link
Contributor

osaid-r commented Oct 9, 2017

And the reason for using product-then-sum was to have more computational efficiency and not memory efficiency.

@jnothman
Copy link
Member

jnothman commented Oct 9, 2017 via email

@osaid-r
Copy link
Contributor

osaid-r commented Oct 19, 2017

Ok so I created a python script to compare the time taken by subtraction-then-squaring and conversion to float64 then product-then-sum and it turns out if we choose an X and Y as very big vectors then the 2 results are very different. Also @jnothman you were right subtraction-then-squaring is faster.
Here's the script that I wrote, if there's any problem please let me know

import numpy as np
import scipy
from sklearn.metrics.pairwise import check_pairwise_arrays, row_norms
from sklearn.utils.extmath import safe_sparse_dot
from timeit import default_timer as timer

for i in range(9):
	X = np.random.rand(1,3 * (10**i)).astype(np.float32)
	Y = np.random.rand(1,3 * (10**i)).astype(np.float32)

	X, Y = check_pairwise_arrays(X, Y)
	XX = row_norms(X, squared=True)[:, np.newaxis]
	YY = row_norms(Y, squared=True)[np.newaxis, :]

	#Euclidean distance computed using product-then-sum
	distances = safe_sparse_dot(X, Y.T, dense_output=True)
	distances *= -2
	distances += XX
	distances += YY

	ans1 = np.sqrt(distances)

	start = timer()
	ans2 = np.sqrt(row_norms(X-Y, squared=True)[:, np.newaxis])
	end = timer()
	if ans1 != ans2:
		print(end-start)

		start = timer()
		X = X.astype(np.float64)
		Y = Y.astype(np.float64)
		X, Y = check_pairwise_arrays(X, Y)
		XX = row_norms(X, squared=True)[:, np.newaxis]
		YY = row_norms(Y, squared=True)[np.newaxis, :]
		distances = safe_sparse_dot(X, Y.T, dense_output=True)
		distances *= -2
		distances += XX
		distances += YY
		distances = np.sqrt(distances)
		end = timer()
		print(end-start)
		print('')
		if abs(ans2 - distances) > 1e-3:
			# np.set_printoptions(precision=200)
			print(ans2)
			print(np.sqrt(distances))

			print(X, Y)
			break

@jnothman
Copy link
Member

jnothman commented Oct 19, 2017 via email

@jnothman
Copy link
Member

anyway, would you like to submit a PR, @ragnerok?

@osaid-r
Copy link
Contributor

osaid-r commented Oct 21, 2017

yeah sure, what do you want me to do ?

@jnothman
Copy link
Member

jnothman commented Oct 22, 2017 via email

@osaid-r
Copy link
Contributor

osaid-r commented Nov 3, 2017

I wanted to ask if it is possible to find distance between each pair of rows with vectorisation. I cannot think about how to do it vectorised.

@jnothman
Copy link
Member

jnothman commented Nov 4, 2017

You mean difference (not distance) between pairs of rows? Sure you can do that if you're working with numpy arrays. If you have arrays with shapes (n_samples1, n_features) and (n_samples2, n_features), you just need to reshape it to (n_samples1, 1, n_features) and (1, n_samples2, n_features) and do the subtraction:

>>> X = np.random.randint(10, size=(10, 5))
>>> Y = np.random.randint(10, size=(11, 5))
X.reshape(-1, 1, X.shape[1]) - Y.reshape(1, -1, X.shape[1])

@osaid-r
Copy link
Contributor

osaid-r commented Nov 4, 2017

Yeah thanks that really helped 😄

@jeremiedbb
Copy link
Member

After a quick discussion at the sprint we ended up on the following way:

  • in high dimensional case (> 32 or > 64 choose the best): upcast by chunks to float64 when it's float32 and keep the 'fast' method. For this kind of data, numerical issues, on float64, are almost negligible (I'll provide benchmarks for that)

  • in low dimensional case: implement the safe computation (instead of using scipy cdist because of the upcast) in sklearn.

@jnothman
Copy link
Member

(It's tempting to throw upcasting float32 into 0.20.3 also)

@Celelibi
Copy link

Celelibi commented Mar 12, 2019

Here are some benchmarks for speed comparison between scipy and sklearn.
[... snip ...]

This is very interesting. I wasn't actually expecting this result. I re-did your benchmark and found a very similar result. Except I would advocate for a lower decision boundary. My benchmark would suggest 8 features.

heatmap

The cost of being wrong is not symmetric. cdist is better only for computations lasting less than a few seconds and it gets slow really fast when the number of feature increase. So, better use the BLAS implementation when in doubt.

Edit: This benchmark was for float64, but I also find that upcasting float32 matrices to float64 only barely add a few percent to the total time and doesn't change the conclusion.

@jeremiedbb
Copy link
Member

I noticed that the threshold depends on the machine you're running the benchmarks on. I suspect it may have to do with the AVX instructions. I realized the benchmarks I published were run on a machine which didn't have AVX2 instructions, only AVX. And on a machine which have AVX2, I got similar results to yours.

But the question is not only about performance but also about precision and it's more likely to have precision issues when the dimension is small. Maybe 16 is a good compromise. What do you think ?

@Celelibi
Copy link

In regards of this discussion, I'd say we need to benchmark the accuracy to take an informed decision.

However, in regards of your PR, the accuracy shouldn't be an issue anymore. But at the cost of a slightly more expensive computation. Therefore the threshold should probably be decided by benchmarking your PR.

@kno10
Copy link
Contributor

kno10 commented Mar 13, 2019

Benchmarking accuracy is not that easy. Because the difficult cases will not be uniformly distributed.
And it may be problematic if it happens undetected in a corner case. Usually, you will want to have guaranteed numerical accuracy within affordable CPU limits.
But as mentioned elsewhere a single feature with 10000000.01 and 10000000.00 should be enough to trigger numeric instability with fp64 when using the known-problematic equation, 10000 and 10001 with fp32. With 1024 features, try

>>> import sklearn.metrics.pairwise as sk, scipy.spatial.distance as sp
>>> X = [[10000.01] * 1024, [10000.00] * 1024]
>>> print(sk.euclidean_distances(X,X), "\n", sp.cdist(X,X))
[[ 0.          0.31895195]
 [ 0.31895195  0.        ]] 
 [[ 0.    0.32]
 [ 0.32  0.  ]]

(this was using 0.19.1) The correct distance is 0.32.

As you can see, the numeric instabilities tend to get worse with the number of features (unless your data is sparse). Here, the result has less than two digits of precision with FP64.

@jeremiedbb
Copy link
Member

#13410 does not fix this specific case. i.e float64 + high dimension.
It fixes it for float32 however.

But we decided that for float64 + high dim, we keep it as it was, because the accuracy issues are very unlikely to happen and don't really apply to machine learning use cases.

In your example, X[0] and X[1] have norms equal to 320000.32 and 320000 and their distance is 0.32, i.e. 1e-6 times their norm. In machine learning, the 16 significant digits (in float64) are not all relevant.

@Celelibi
Copy link

But we decided that for float64 + high dim, we keep it as it was, because the accuracy issues are very unlikely to happen and don't really apply to machine learning use cases.

I would be more moderate on this one. Reducing the dimensionality is a usual first step in ML. MDS can be used for that, and it makes a heavy use of the euclidean distance matrix.

If someone want to have a look at improving the accuracy of the float64 case, there's a way to use two floats to represent the intermediate results. Although I think it starts to fall beyond the scope of scikit-learn.
ftp://ftp.math.ethz.ch/users/wpp/CELL/qd.pdf

@jeremiedbb
Copy link
Member

I was not clear. I'm not saying high dimensional data does not apply to machine learning. I'm saying that the kind of precision issues which happen in float64 involves points which distance is 6 orders of magnitude smaller than their norms. Having such a precision has no meaning in a realistic machine learning model

@kno10
Copy link
Contributor

kno10 commented Mar 13, 2019

In machine learning, the 16 significant digits (in float64) are not all relevant.

I am not at all convinced that this is that generally true.

In this example, we have lost 15 of 16 digits in precision. I'd agree if we would use half of the precision, but we don't have such a relationship. The loss from downcasting FP64 to FP32 may often be tolerable because of measurement precision. And consumer-grade GPUs are much faster with FP32 than with FP64, for example (in some cases, they allow FP32 data and FP64 accumulators now, though), and for neural networks inference, you may even see int8 now. But that doesn't hold everywhere.

For example in k-means clustering, there is the assumption that clusters differ substantially in their means (and that we don't know the means beforehand), and hence we have a loss in precision here. If you have many clusters, some of their norms can be large compared to their separation.
Furthermore, after the first initial iterations, its often small differences in distance that make one point switch to another cluster. Loss of precision here can affect results, and could cause instability.
Now consider k-means on time series fragments with many variables.

With increasing data sizes, we must assume that the distances to the nearest neighbor get smaller, and unless your norms are 0, they will eventually be smaller than the vector norms and cause problems. So this will likely become more severe with increasing data set sizes. The curse of dimensionality says that the largest and the smallest distances get more and more similar; so in order to compute the correct nearest neighbor ranking, we may need good precision in high-dimensional data. On the 20news data set, the smallest non-zero distance is around 0.02 (the norms are all 1). But that is just 10k instances, and fairly diverse contents. Now assume the data set was about near-duplicate detection instead...

I would not be sure this "unlikely" happens in ML... of course it won't affect everybody though.

@jeremiedbb
Copy link
Member

When I say "In machine learning, the 16 significant digits (in float64) are not all relevant.", I'm not speaking of the computed distance, I'm speaking of the data X.
In machine learning, your data comes from a measure, and there's no measure precise to the 9th digit (besides very few ones in particle physics).
So in your example of 10000000.01 and 10000000.00, how would you give some importance to a distance of 0.01 when your uncertainty on the values of X are way bigger ?

For KMeans, first there's a way to overcome a large part of losses of precision. When you're looking for the closest center of an observation x, you don't need to add the norm of x to the distance calculation which avoids the catastrophic cancellation in most cases.
Then, kmeans clusters based on euclidean distances. But you don't know if this is the exact way your data are gathered. In fact there's a 0 probability that your data is clustered that way. Kmeans gives an estimation of how your data could be clustered and points which are at the frontier of 2 clusters can definitely not be considered belonging with certainty to one or the other. What's your interpretation of a point at the same distance of 2 clusters ? Mine is either the 2 clusters should be only one cluster or KMeans is not the best algo to cluster my data (or even kmeans gives me a somewhat good idea of how my data is clustered but I know that frontiers of clusters are not relevant).

@kno10
Copy link
Contributor

kno10 commented Mar 14, 2019

The use of only "|b|^2-2ab" does not have catastrophic cancellation - but the same loss in precision in the digits that make the difference. The results are the same as if you added the norm of a to each distance afterwards; if the distances are much smaller than the norm of a, then you get a loss in precision that is avoidable by doing the computations the traditional way without BLAS hacks.
So you actually can NOT overcome the numerical problem this way!

K-means is an optimization problem. So such hacks may mean that sklearn finds only worse solutions than other tools. And as indicated before, this can also cause instabilities. In the worst case, this could cause sklearn kmeans to iterate through the same states until max_iter with no improvement (assuming tol=0, if you want to find a local optimum), which theory would say is impossible.
Until k-means has converged, you can't say much about points with the "same" distance to two clusters. The next iteration, the means may have moved and the difference could become much larger and matter!
I am not a big fan of k-means because it doesn't work too well on noisy data. But there are variations that handle such cases better. But nevertheless, if you use it, you should probably try to get the full quality (which is why I also always use tol=0) and not make it worse than necessary. It's cheap enough to do the proper calculations (and, as mentioned, the problems get worse with data size - so for small data, the slower runtime does usually not matter, for larger data sets the precision becomes more likely important).

Depending on the application, the difference between 10000000.01 and 10000000.00 can matter. And as I showed before, if you use multiple features the problems arise earlier. With fp32 as little as 10000 and 10001 with a single feature and 100 vs. 101 with 100 features I guess:

As mentioned, the mean may have a physical meaning that you don't want to lose. If you have data with temperatures in Kelvin, you don't want to 0:1 scale them or center them; that would ruin your ratio scale. Now if you want to compare, for example, time series of the temperature of some steel product as it cools down, and figure out if the cool down process affects the reliability of your steel product. You may be having temperatures of over 700 K, and the time series may have hundreds of data points if you want to analyze the cooldown process. Even with just 5 digits of input precision (0.01K) with the length of the time series the numeric problem can occur. You may again end up with only 1-2 digits in the result. I don't think you can just rule out that precision ever matters in ML if you have this catastrophic kind of effect. Its a different if you could guarantee to always get, say 10 of 16 digits in precision. Here you can't do that, you may have 0 digits right in the worst case (that is why it's catastrophic).

@Celelibi
Copy link

Celelibi commented Mar 15, 2019

In machine learning, your data comes from a measure, and there's no measure precise to the 9th digit (besides very few ones in particle physics).

The raw values from the real world rarely have that kind of accuracy, that's right. But ML isn't limited to that kind of input. One might want to apply ML to mathematical problems, like applying MDS on the graph of a rubik's cube-like puzzle or clustering the successful strategies found by your swarm of RL agents playing pacman.
Even if the initial source of the information is the real world, there might be some mid-way processing that makes most digits relevant to the clustering algorithm. Like the result of a gradient descent on a function whose parameters are statistically sampled in the real world.

I'm actually wondering why we're still discussing this. I guess we all agree that scikit-learn should try its best in the trade-off accuracy vs. computation time. And whoever isn't happy with the current state should submit a pull request.

@jeremiedbb
Copy link
Member

The use of only "|b|^2-2ab" does not have catastrophic cancellation - but the same loss in precision in the digits that make the difference. The results are the same as if you added the norm of a to each distance afterwards; if the distances are much smaller than the norm of a, then you get a loss in precision that is avoidable by doing the computations the traditional way without BLAS hacks.
So you actually can NOT overcome the numerical problem this way!

There is a loss of precision, but it can't cause a catastrophic cancellation (at least when a and b are close), and you can show that the relative error on the distance (which is not a distance) stays small.
In the case of KMeans where you're only interested in finding the closest center, you have enough precision to keep the ordering right. If at the end you want the inertia, then you can just calculate the distances of each point to its cluster center with the exact formula.

Besides, KMeans is not a convex optimization problem, so even if you let it run with tol=0 until convergence, you end up in a local minima which can be far off the global minima (even with kmeans++ initialization). So I'd rather run kmeans many times with different init and a reasonably small number of iterations. You'll have better chance to end up in a better local minima. Then you can rerun the best one until convergence.

@kno10
Copy link
Contributor

kno10 commented Mar 15, 2019

The relative error compared to the real distance can be arbitrary large, and hence cause wrong nearest neighbors. Consider the case where |a|²=|b|²=1, for example on tf-idf. Assume that the vectors are very close. Then ab is also close to 1, and at this point you already lost much of your precision.
As I wrote above, the error is there, even if you don't have catastrophic cancellation. Consider 8 digits of precision. The real distance may be 0.000012345678 and can be computed with eight digits using FP32 and regular Euclidean distance. But with this equation, you compute the value ab=0.99998765432 instead, which with FP32 will be truncated to approximately 0.99998765 at best, so you lost three digits of precision unnecessarily in this example. The loss is as big as in the catastrophic case. If the distances are much smaller than the norms, your precision can become arbitrarily bad with this approach.

Yes, kmeans is not convex. But then you will want to at least find a local optimum, and not get stuck (or even oscillate because the resulting errors behave erratically) because your precision is too low. So you at least get a chance to find the global one in well-behaved cases and with multiple attempts.

@jnothman
Copy link
Member

jnothman commented Mar 16, 2019

I appreciate this discussion, but what we really need is a solution that is no worse than what we were doing before we stopped upcasting things to float64. In that sense, @Celelibi's upcasting solution was sufficient. Using the exact solution in low dimensions is an added improvement on what we used to do.

Regarding a future version, do you feel any more confidence to efficiently detect when we might consider the exact computation in high dimensions?

@Celelibi
Copy link

Celelibi commented Apr 2, 2019

I've run a benchmark to evaluate the average accuracy of the float64 case with random numbers. I compare 3 algorithms: neumaier_sum((x-y)**2), numpy.sum((x-y)**2) and X2 - 2*X.dot(Y.T) + Y2.T. The exact result to compare to has been obtained using mpmath with a precision of 256 bits.
X and Y have 100 samples and a variable number of features and are filled with random numbers between -2 and 2.

On the following gif, there's one image per number of feature (between 1 and 200). On each image, each dot represent the relative error of the squared euclidean distance between one of the 10000 pair of vectors of X and Y. The relative error is multiplied by 2^53 for readability, which corresponds roughly to the ULP unit.
The curves above are the approximate distribution (using a kernel density estimate).

float64_relerr

Note that the graphs were cut at 6 ULP for readability. It shows the average case, not the worse case. The error of the expanded formula can grow pretty large.

My analysis of this result is that on average, the relative error of the expanded formula can be very large with few features, but quickly become similar to that of the difference and numpy sum. The threshold being between 5 and 10 features.

I'm also currently trying to find an upper bound for the error of the expanded formula as well as pathological examples.

@jnothman
Copy link
Member

jnothman commented Apr 2, 2019 via email

@Celelibi
Copy link

Celelibi commented Apr 3, 2019

Indeed, but I needed to be convinced that in practice, it's not complete BS. ^^

To complete the comment above: the relative error of the formula x²+y²-2ab seems to be unbounded. Unless my analysis is wrong, when x and y are close to each other, the relative error can be up to 2^(52*2). At least theoretically. In practice, the worst case I found is a relative error of 2^52+1.

>>> a, b = (0xfffffec4d6282+1) * 2.0**(511-52), 0xfffffec4d6282 * 2.0**(511-52)
>>> a, b
(6.703903473040778e+153, 6.7039034730407766e+153)
>>> exactdiff = (a-b)**2
>>> exactdiff
2.2158278651204453e+276
>>> computeddiff = a**2 + b**2 - 2*a*b
>>> computeddiff
-9.9792015476736e+291
>>> abs((computeddiff - exactdiff) / exactdiff)
4503599627370497.0
>>> bin(int(abs((computeddiff - exactdiff) / exactdiff)))
'0b10000000000000000000000000000000000000000000000000001'

Flipping the sign of the result would actually make it closer to the truth. This is the most dramatic example I could find, but actually changing the exponent in the values of a and b doesn't change the relative error.

>>> a, b = (0xfffffec4d6282+1) * 2.0**(-52), 0xfffffec4d6282 * 2.0**(-52)
>>> a, b
(0.9999999266202912, 0.999999926620291)
>>> exactdiff = (a-b)**2
>>> computeddiff = a**2 + b**2 - 2*a*b
>>> abs((computeddiff - exactdiff) / exactdiff)
4503599627370497.0

@kno10
Copy link
Contributor

kno10 commented Apr 3, 2019

I think a histogram plot in ULPs would make more sense than above animation with the within-ULP error distribution. So 0 ULP error and 1 ULP error are "as good as it gets". 2 ULP is likely unavoidable because of the sqrt. Any larger errors are worth investigating I assume.

Using (computed - exact) / exact is reasonable as long as exact is large. But once we are getting numerical challenges for the exact value, this becomes quite unstable. In such cases, (computed-exact)/norm may be worth using instead, i.e. looking at the precision of our distance computations compared to the input data, not compared to the derived distances.
If we have two one-dimensional values that only differ by 1 ULP, and error of 2 ULP may seem huge; but we are at input data resolution already, so the results are quite unstable.
Note that with multiple dimensions, we may get a higher resolution in the input data.

Consider input data of the type (1, 1e-16) vs. (1, 2e-16). For example if we have a constant attribute in the input data, say, a white pixel in MNIST.
With the difference-based equation this will be fine, but the dot-version gets into trouble, doesn't it? That is why one-dimensional experiments may not be enough to study this.

@Celelibi
Copy link

Celelibi commented Apr 4, 2019

I think a histogram plot in ULPs would make more sense than above animation with the within-ULP error distribution.

I'm not sure I see how you would have represented it. There would be one histogram per number of feature and per algorithm. There's not much I can do beside a 3D plot or an animation.

Using (computed - exact) / exact is reasonable as long as exact is large. But once we are getting numerical challenges for the exact value, this becomes quite unstable.

I'm not sure what you mean by unstable in this context. The exact value should be computed with whatever it takes to make it exact.
(Speaking of which, I should have computed the relative error with arbitrary precision too in my plot, instead of comparing to the exactly rounded result. I updated my plot, the weird waves disappeared.)

In such cases, (computed-exact)/norm may be worth using instead, i.e. looking at the precision of our distance computations compared to the input data, not compared to the derived distances.

If I understand your idea correctly, you would rather compare the absolute error to the magnitude of the input data. Using the vector norms as an aggregated measure of the magnitude of the inputs. Whereas the standard relative error compare it to the magnitude of the exact result.

I think with this metric you try to capture how much faulty is an algorithm. But it actually doesn't seem particularly useful for a few reasons.

  • It doesn't really say how many digits of the result are inexact.
  • Actually, most algorithms would have a score less than 1e-15. Even the expanded formula (dot-based algorithm) would have a score bounded by something like 5 ULP(input) (rough estimation, I didn't write the full proof).
  • And since both metrics are just a rescaled version of the absolute error computed - exact, they would rank the algorithms in the same order when evaluated on the same inputs.
    So it's the same as the usual relative error, just with a value interpretation less useful (IMO).

Consider input data of the type (1, 1e-16) vs. (1, 2e-16). For example if we have a constant attribute in the input data, say, a white pixel in MNIST.
With the difference-based equation this will be fine, but the dot-version gets into trouble, doesn't it? That is why one-dimensional experiments may not be enough to study this.

The dot-based algorithm would have a relative error of 1, meaning that the error is as large as the exact result, and thus, no digit of the result is correct. And your metric would have a value of 1e-16 meaning that relative to the scale of the vector norm, only the 16th digit is off.
I'm unsure what you're trying to show with this example.

@rth
Copy link
Member

rth commented Apr 29, 2019

If we are still concerned about the precision of euclidean_distances with float64, probably better to summarize this discussion in a new issue as there are 100 comments here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment