Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: keras.ops.linalg.lstsq #19678

Closed
jonbarron opened this issue May 6, 2024 · 4 comments
Closed

Feature request: keras.ops.linalg.lstsq #19678

jonbarron opened this issue May 6, 2024 · 4 comments
Assignees
Labels
type:feature The user is asking for a new feature.

Comments

@jonbarron
Copy link

Can this be added? It seems like it should be straightforward to implement, as it's just a thin wrapper around a SVD (in JAX at least https://github.com/google/jax/blob/main/jax/_src/numpy/linalg.py#L1368-L1406), and y'all already have SVD implemented.

@fchollet
Copy link
Member

fchollet commented May 7, 2024

Yes, that's in scope and it seems straightforward.

I'm curious, what are you building that requires all these niche linalg ops?

@SuryanarayanaY SuryanarayanaY added type:feature The user is asking for a new feature. stat:awaiting keras-eng Awaiting response from Keras engineer stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels May 7, 2024
@fchollet
Copy link
Member

fchollet commented May 9, 2024

I've investigated this, and it's actually impossible to achieve consistency across backends for the resid and s return values (the solutions x, the first returned value, is fine). Even jax.numpy isn't consistent with numpy. Torch does it differently from both as well (despite having the same API).

We could make the function only return x. Are the other values ever useful?

Looking at code in the wild I could only find samples that used x. Other values area apparently always discarded. It's a pretty weird API tbh.

@jonbarron
Copy link
Author

jonbarron commented May 9, 2024

yeah I think it makes sense to only return x, and I can imagine very few use-cases where the caller would really care about the other outputs. It might be helpful to return None for the other 3 outputs of the numpy interface just so it's a drop-in replacement for JAX. The other upside to having placeholder None outputs is that it avoids a potential footgun in case someone happens to be solving for an x with x.shape[0] == 4, as if they call lstsq with 4 output slots it will silently unstack the tensor along the first dimension which is definitely not what the caller would want.

@fchollet
Copy link
Member

For reference, TensorFlow's tf.linalg.lstsq only returns x https://www.tensorflow.org/api_docs/python/tf/linalg/lstsq (though it also has a different signature).

I added the API, only returning x. Returning None entries would be problematic since ops are only ever supposed to return tensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

3 participants