Skip to content

Commit

Permalink
Make tests more resilient to signs when rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed May 13, 2024
1 parent 44313f6 commit d55faa1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions nx/test/nx/lin_alg_test.exs
Expand Up @@ -586,17 +586,17 @@ defmodule Nx.LinAlgTest do
Nx.tensor([
[
Complex.new(-0.408, 0.0),
Complex.new(0.0, 0.707),
Complex.new(-0.0, 0.707),
Complex.new(0.577, 0.0)
],
[
Complex.new(0.0, -0.816),
Complex.new(-0.0, -0.816),
Complex.new(0.0, 0.0),
Complex.new(0.0, -0.577)
],
[
Complex.new(0.408, 0.0),
Complex.new(0.0, 0.707),
Complex.new(-0.0, 0.707),
Complex.new(-0.577, 0.0)
]
])
Expand Down Expand Up @@ -731,7 +731,8 @@ defmodule Nx.LinAlgTest do

assert {u, s, vt} = Nx.LinAlg.svd(t)

assert round(Nx.as_type(t, :f32), 2) == u |> Nx.multiply(s) |> Nx.dot(vt) |> round(2)
assert round(Nx.as_type(t, :f32), 2) ==
u |> Nx.multiply(s) |> Nx.dot(vt) |> Nx.abs() |> round(2)
end

test "finds the singular values of wide matrices" do
Expand All @@ -755,7 +756,7 @@ defmodule Nx.LinAlgTest do
|> Nx.broadcast({3, 3})
|> Nx.put_diagonal(s)

assert round(t, 1) == u |> Nx.dot(s_matrix) |> Nx.dot(v) |> round(1)
assert round(t, 1) == u |> Nx.dot(s_matrix) |> Nx.dot(v) |> Nx.abs() |> round(1)

assert round(u, 3) ==
Nx.tensor([
Expand Down

0 comments on commit d55faa1

Please sign in to comment.