Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve behaviour of torch.linalg.lstsq on CUDA GPU for rank defficient matrices #117122

Open
tvercaut opened this issue Jan 10, 2024 · 3 comments · May be fixed by #125110
Open

Improve behaviour of torch.linalg.lstsq on CUDA GPU for rank defficient matrices #117122

tvercaut opened this issue Jan 10, 2024 · 3 comments · May be fixed by #125110
Labels
actionable module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tvercaut
Copy link

tvercaut commented Jan 10, 2024

🚀 The feature, motivation and pitch

As per the documentation:

For CUDA input [torch.linalg.lstsq] assumes that A is full-rank.

While documented, this behaviour is counter-intuitive for end-users especially if the function silently fails.

Interestingly, currently calling torch.linalg.lstsq on CUDA for rank defficient input silently fails in non-batched-mode but throws a _LinAlgError in batched mode.

It is also counter-intuitive that torch.linalg.lstsq on CUDA is not able to fallback to a more stable SVD driver despite torch.linalg.svd being supported on CUDA.

It would be great if:

  1. The QR based implementation always threw an error is the input is not full rank
  2. An SVD backend would be available on CUDA as well

Alternatives

An alternative is to implement an SVD-based least-squares and use that instead of torch.linalg.lstsq. Here is a basic implementation (feel free to post refinements):

def svd_lstsq(AA, BB, tol=1e-5):
    U, S, Vh = torch.linalg.svd(AA, full_matrices=False)
    Spinv = torch.zeros_like(S)
    Spinv[S>tol] = 1/S[S>tol]
    UhBB = U.adjoint() @ BB
    if Spinv.ndim!=UhBB.ndim:
      Spinv = Spinv.unsqueeze(-1)
    SpinvUhBB = Spinv * UhBB
    return Vh.adjoint() @ SpinvUhBB

Additional context

Simple script to replicate:

import torch
print(f'Running PyTorch version: {torch.__version__}')

torchdevice = torch.device('cpu')
if torch.cuda.is_available():
  torchdevice = torch.device('cuda')
  print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))
print('Running on ' + str(torchdevice))

b = 2
r = 5
c = 3
k = 1

if b==1:
  A = torch.randn(r, c, device=torchdevice)
  if k==1:
    B = torch.randn(r, device=torchdevice)
  else:
    B = torch.randn(r, k, device=torchdevice)
else:
  A = torch.randn(b, r, c, device=torchdevice)
  B = torch.randn(b, r, k, device=torchdevice)

# degrade rank
A[...,-1] = 0
print("A",A)

try:
  X_lstsq = torch.linalg.lstsq(A, B).solution
  print("X_lstsq",X_lstsq)
except Exception as error:
  print("An error occurred:", type(error).__name__, "–", error)

X_pinv=torch.linalg.pinv(A) @ B
print("X_pinv",X_pinv)

def svd_lstsq(AA, BB, tol=1e-5):
    U, S, Vh = torch.linalg.svd(AA, full_matrices=False)
    Spinv = torch.zeros_like(S)
    Spinv[S>tol] = 1/S[S>tol]
    UhBB = U.adjoint() @ BB
    if Spinv.ndim!=UhBB.ndim:
      Spinv = Spinv.unsqueeze(-1)
    SpinvUhBB = Spinv * UhBB
    return Vh.adjoint() @ SpinvUhBB

X_svd= svd_lstsq(A, B)
print("X_svd",X_svd)

Related issues: #88101 #85021 #10454

cc @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @xwang233 @lezcano

@bdhirsh bdhirsh added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Jan 11, 2024
@lezcano
Copy link
Collaborator

lezcano commented Jan 29, 2024

Yep, this makes sense. I had this in mind when we were implementing torch.linalg but never got around implementing it. Would you want to send a PR adding this behaviour?

@tvercaut
Copy link
Author

Sorry I don't think I will be able to work on a PR.

@ZelboK
Copy link
Contributor

ZelboK commented Apr 27, 2024

@lezcano I can work on a PR. Hopefully done by this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: cuda Related to torch.cuda, and CUDA support in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants