New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #125110
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125110
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3006f30 with merge base e5e623a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
mind cleaning up all the spurious new lines and the PR in general? |
… when matrices are rank deficient.
06e42e8
to
7372645
Compare
My apologies! I've cleaned it up. I missed some new lines from when I was cleaning up my debugging/experimenting code so I could understand the codebase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks mostly good.
Needs tests in test_linalg.py
and updating the docs noting that this gelss
mode is also supported.
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
…tion linalg_svd for computation
So when it comes to the tests, what kind of test did you think would be appropriate, aside from checking that it no longer throws? I can add Edit: Workflow runs exposed two failing tests for CPU and complex lstsq computations. I didn't notice I didn't build with LAPACK, so these tests were skipped. Will look into it now. |
da93358
to
c71e504
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing, just add a path that tests this driver in the relevant tests that tests the other drivers. We may even already have a test that tests this driver for CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please try to keep the changes to the bare minimum
if (input.numel() == 0 || input.size(-2) == 0 || input.size(-1) == 0) { | ||
auto output_shape = input.sizes().vec(); | ||
output_shape.back() = other.size(-1); | ||
solution = at::zeros(output_shape, input.options()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I should've communicated these changes.
The tests will actually fail without this check because it'll generate tensors similar to this torch.empty((0, 1))
. The narrow
code will lead to
File "/home/ksm/pytorch/test/test_linalg.py", line 316, in test_linalg_lstsq
res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
RuntimeError: start (0) + length (1) exceeds dimension size (0).
Do we want to remove the edge case handling to simplify the logic, and communicate in the docs that this will occur?
If so, I'll also have to look at the tests again.
// LAPACK stores residuals data for postprocessing in rows n:(m-n) | ||
if (compute_residuals) { | ||
// LAPACK stores residuals data for postprocessing in rows n:(m-n) | ||
if (solution.size(-2) >= n + (m - n)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is all this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests were exposing some problems and they mainly originated from use of narrow
in specific tensor cases. So these new conditionals were added to handle those situations.
File "/home/ksm/pytorch/test/test_linalg.py", line 316, in test_linalg_lstsq
res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
RuntimeError: start (1) + length (1) exceeds dimension size (1)
Would you prefer to revert these conditionals, run the test workflows, see what fails and go from there?
Going forward I will keep code changes to a minimal and document/comment out why some changes are necessary. Didn't mean to make yor job harder, my bad 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this path was already working before. I don't understand why should we touch it at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed. When adding tests and trying to go through exceptions, I found narrow
was the main cause. So I added guards against all the narrow
s in the code.
I added this because if the second last dimension was less than n + m - n
(which I now realize is just m
) then this will throw.
I tried to produce a scenario where this would throw an exception, but it's always caught earlier on. This check is redundant and can be removed.
Edit: Also to clarify, the exception I pasted in the above comment was not from this line of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait. I actually just produced an exception from this. But I really don't understand how.
auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);
this actually raises an exception, if I leave the line
rank.fill_(0)
from earlier in the code. But when I remove it, the above line of code no longer raises an exception...? Why...? I understand that line is redundant, but I am still really really curious about this.
https://pastebin.com/BfysJQv4
^ backtrace(quite long) in case its useful
// LAPACK stores residuals data for postprocessing in rows n:(m-n) | ||
if (compute_residuals) { | ||
// LAPACK stores residuals data for postprocessing in rows n:(m-n) | ||
if (solution.size(-2) >= n + (m - n)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this path was already working before. I don't understand why should we touch it at all?
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
solution.set_(solution.storage(), solution_view.storage_offset(), solution_view.sizes(), solution_view.strides()); | ||
} | ||
if (solution.size(-2) >= n) { | ||
auto solution_view = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this, the tests for test_linalg_lstsq
was failing for :
python test_linalg.py -k test_linalg_lstsq_cuda_float32
with tensors like:
torch.Size([2, 1]) is a shape
a contents: tensor([[0.1540],
[0.9887]], device='cuda:0')
@@ -1080,7 +1080,7 @@ | |||
|
|||
Keyword args: | |||
driver (str, optional): name of the LAPACK/MAGMA method to be used. | |||
If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs. | |||
If `None`, `'gelsy'` is used for CPU inputs, `'gels'` and `'gelss'` for CUDA inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
// residuals are available only if m > n and drivers other than gelsy used | ||
if (m > n && driver != "gelsy") { | ||
// if the driver is gelss or gelsd then the residuals are available only if rank == n | ||
bool compute_residuals = true; | ||
if (driver == "gelss" || driver == "gelsd") { | ||
if (input.dim() == 2) { | ||
compute_residuals = (rank.item().toInt() == n); | ||
compute_residuals = (rank.item().toDouble() == n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this change necessary?
} else { | ||
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd"); | ||
rank = at::zeros({1}, at::kLong); | ||
rank[0] = (S > rcond).sum(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct. Compute the rank by looking at the zeros of S_pinv
.
Fixes #117122
This PR adds the logic so that in the case of rank deficient matrices, it can fallback to an SVD backend for batched mode. A big thank you to @tvercaut for the well written issue and suggestion on how to approach the problem.
Summary:
Please keep in mind this is my 2nd PR to pytorch, and I've never really used pytorch. I'm learning independently through digging deep in the internals so I may make some obvious mistakes. Please forgive!