Skip to content

Commit

Permalink
Add lstsq.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 10, 2024
1 parent da83683 commit 4b18bd4
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 0 deletions.
6 changes: 6 additions & 0 deletions keras/src/backend/jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,9 @@ def solve_triangular(a, b, lower=False):

def svd(x, full_matrices=True, compute_uv=True):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)


def lstsq(a, b, rcond=None):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return jnp.linalg.lstsq(a, b, rcond=rcond)[0]
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,9 @@ def solve_triangular(a, b, lower=False):

def svd(x, full_matrices=True, compute_uv=True):
return np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)


def lstsq(a, b, rcond=None):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return np.linalg.lstsq(a, b, rcond=rcond)[0]
41 changes: 41 additions & 0 deletions keras/src/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,44 @@ def svd(x, full_matrices=True, compute_uv=True):
x, full_matrices=full_matrices, compute_uv=compute_uv
)
return u, s, tf.linalg.adjoint(v)


def lstsq(a, b, rcond=None):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
if a.shape[0] != b.shape[0]:
raise ValueError("Leading dimensions of input arrays must match")
b_orig_ndim = b.ndim
if b_orig_ndim == 1:
b = b[:, None]
if a.ndim != 2:
raise TypeError(
f"{a.ndim}-dimensional array given. "
"Array must be two-dimensional"
)
if b.ndim != 2:
raise TypeError(
f"{b.ndim}-dimensional array given. "
"Array must be one or two-dimensional"
)
m, n = a.shape
dtype = a.dtype
eps = tf.experimental.numpy.finfo(dtype).eps
if a.shape == ():
s = tf.zeros(0, dtype=a.dtype)
x = tf.zeros((n, *b.shape[1:]), dtype=a.dtype)
else:
if rcond is None:
rcond = eps * max(n, m)
else:
rcond = tf.where(rcond < 0, eps, rcond)
u, s, vt = svd(a, full_matrices=False)
mask = s >= tf.convert_to_tensor(rcond, dtype=s.dtype) * s[0]
safe_s = tf.cast(tf.where(mask, s, 1), dtype=a.dtype)
s_inv = tf.where(mask, 1 / safe_s, 0)[:, tf.newaxis]
u_t_b = tf.matmul(tf.transpose(tf.math.conj(u)), b)
x = tf.matmul(tf.transpose(tf.math.conj(vt)), s_inv * u_t_b)

if b_orig_ndim == 1:
x = tf.reshape(x, [-1])
return x
6 changes: 6 additions & 0 deletions keras/src/backend/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@ def svd(x, full_matrices=True, compute_uv=True):
"`compute_uv=False` is not supported for torch backend."
)
return torch.linalg.svd(x, full_matrices=full_matrices)


def lstsq(a, b, rcond=None):
a = convert_to_tensor(a)
b = convert_to_tensor(b)
return torch.linalg.lstsq(a, b, rcond=rcond)[0]
77 changes: 77 additions & 0 deletions keras/src/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,83 @@ def _svd(x, full_matrices=True, compute_uv=True):
return backend.linalg.svd(x, full_matrices, compute_uv)


class Lstsq(Operation):
def __init__(self, rcond=None):
super().__init__()
self.rcond = rcond

def call(self, a, b):
return backend.linalg.lstsq(a, b, rcond=self.rcond)

def compute_output_spec(self, a, b):
if len(a.shape) != 2:
raise ValueError(
"Expected a to have rank 2. " f"Received: a.shape={a.shape}"
)
if len(b.shape) not in (1, 2):
raise ValueError(
"Expected b to have rank 1 or 2. "
f"Received: b.shape={b.shape}"
)
m, n = a.shape
if b.shape[0] != m:
raise ValueError(
"Expected b.shape[0] to be equal to "
"a.shape[0]. Received: "
f"a.shape={a.shape}, b.shape={b.shape}"
)
if len(b.shape) == 2:
k = b.shape[1]
x = KerasTensor((n, k), dtype=a.dtype)
else:
x = KerasTensor((n,), dtype=a.dtype)
return x


@keras_export(["keras.ops.lstsq", "keras.ops.linalg.lstsq"])
def lstsq(a, b, rcond=None):
"""Return the least-squares solution to a linear matrix equation.
Computes the vector x that approximately solves the equation
`a @ x = b`. The equation may be under-, well-, or over-determined
(i.e., the number of linearly independent rows of a can be less than,
equal to, or greater than its number of linearly independent columns).
If a is square and of full rank, then `x` (but for round-off error)
is the exact solution of the equation. Else, `x` minimizes the
L2 norm of `b - a * x`.
If there are multiple minimizing solutions,
the one with the smallest L2 norm is returned.
Args:
a: "Coefficient" matrix of shape `(M, N)`.
b: Ordinate or "dependent variable" values,
of shape `(M,)` or `(M, K)`.
If `b` is two-dimensional, the least-squares solution
is calculated for each of the K columns of `b`.
rcond: Cut-off ratio for small singular values of `a`.
For the purposes of rank determination,
singular values are treated as zero if they are
smaller than rcond times the largest
singular value of `a`.
Returns:
Tensor with shape `(N,)` or `(N, K)` containing
the least-squares solutions.
**NOTE:** The output differs from `numpy.linalg.lstsq`.
NumPy returns a tuple with four elements, the first of which
being the least-squares solutions and the others
being essentially never used.
Keras only returns the first value. This is done both
to ensure consistency across backends (which cannot be achieved
for the other values) and to simplify the API.
"""
if any_symbolic_tensors((a, b)):
return Lstsq(rcond=rcond).symbolic_call(a, b)
return backend.linalg.lstsq(a, b, rcond=rcond)


def _assert_1d(*arrays):
for a in arrays:
if a.ndim < 1:
Expand Down
21 changes: 21 additions & 0 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,27 @@ def test_svd(self):
]
self.assertAllClose(x_reconstructed, x, atol=1e-4)

@parameterized.named_parameters(
("b_rank_1", 1, None),
("b_rank_2", 2, None),
("rcond", 1, 1e-3),
)
def test_lstsq(self, b_rank, rcond):
a = np.random.random((5, 7)).astype("float32")
a_symb = backend.KerasTensor((5, 7))
if b_rank == 1:
b = np.random.random((5,)).astype("float32")
b_symb = backend.KerasTensor((5,))
else:
b = np.random.random((5, 4)).astype("float32")
b_symb = backend.KerasTensor((5, 4))
out = linalg.lstsq(a, b, rcond=rcond)
ref_out = np.linalg.lstsq(a, b, rcond=rcond)[0]
self.assertAllClose(out, ref_out, atol=1e-5)

out_symb = linalg.lstsq(a_symb, b_symb)
self.assertEqual(out_symb.shape, out.shape)


class QrOpTest(testing.TestCase):
def test_qr_init_mode_reduced(self):
Expand Down

0 comments on commit 4b18bd4

Please sign in to comment.