Improve behaviour of torch.linalg.lstsq
on CUDA GPU for rank defficient matrices
#117122
Labels
actionable
module: cuda
Related to torch.cuda, and CUDA support in general
module: linear algebra
Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 The feature, motivation and pitch
As per the documentation:
While documented, this behaviour is counter-intuitive for end-users especially if the function silently fails.
Interestingly, currently calling
torch.linalg.lstsq
on CUDA for rank defficient input silently fails in non-batched-mode but throws a_LinAlgError
in batched mode.It is also counter-intuitive that
torch.linalg.lstsq
on CUDA is not able to fallback to a more stable SVD driver despitetorch.linalg.svd
being supported on CUDA.It would be great if:
Alternatives
An alternative is to implement an SVD-based least-squares and use that instead of
torch.linalg.lstsq
. Here is a basic implementation (feel free to post refinements):Additional context
Simple script to replicate:
Related issues: #88101 #85021 #10454
cc @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @xwang233 @lezcano
The text was updated successfully, but these errors were encountered: