/
FreivaldsAlgorithm.java
56 lines (42 loc) · 1.77 KB
/
FreivaldsAlgorithm.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/**
* Freivalds' algorithm is a probabilistic randomized algorithm used to verify matrix
* multiplication. Given three n x n matrices, Freivalds' algorithm determines in O(kn^2) whether
* the matrices are equal for a chosen k value with a probability of failure less than 2^-k.
*
* <p>Time Complexity: O(kn^2)
*
* @author William Fiset, william.alexandre.fiset@gmail.com
*/
package com.williamfiset.algorithms.linearalgebra;
public class FreivaldsAlgorithm {
// Randomly sets the values in the vector to either 0 or 1
private static void randomizeVector(int[] vector) {
for (int i = 0; i < vector.length; i++) {
vector[i] = (Math.random() < 0.5) ? 0 : 1;
}
}
// Compute the product of a vector with a matrix.
private static int[] product(int[] v, int[][] matrix) {
int N = matrix.length;
int[] vector = new int[N];
for (int i = 0; i < N; i++) for (int j = 0; j < N; j++) vector[i] += v[j] * matrix[i][j];
return vector;
}
// Freivalds' algorithm is a probabilistic randomized algorithm used to verify
// matrix multiplication. Given three n x n matrices, Freivalds' algorithm
// determines in O(kn^2) whether the matrices are equal for a chosen k value
// with a probability of failure less than 2^-k.
public static boolean freivalds(int[][] A, int[][] B, int[][] C, int k) {
final int n = A.length;
if (A[0].length != n || B.length != n || B[0].length != n || C.length != n || C[0].length != n)
throw new IllegalArgumentException("Input must be three nxn matrices");
int[] v = new int[n];
do {
randomizeVector(v);
int[] expected = product(v, C);
int[] result = product(product(v, B), A);
if (!java.util.Arrays.equals(expected, result)) return false;
} while (--k > 0);
return true;
}
}