Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pycharm shows an incorrect dtype when assigning torch.randn to a variable #29

Open
jamesdsmith99 opened this issue Oct 3, 2021 · 5 comments

Comments

@jamesdsmith99
Copy link

Hi,

I have recently started using this library, so i might be using it incorrectly, but linting seems to fail in pycharm when assigning the result of torch.randn to TensorType with a float dtype.

Here is an example:

Matrix = TensorType['h', 'w', float]
x: Matrix = torch.randn(5, 3)

The second line gets underlined with the following error:

Expected type 'TensorType[Any, Any, float]', got 'Tensor' instead

If i modify the second line to:

x: Matrix = torch.randn(5, 3).float()

The error goes away, but I would rather not do that as one of the plus sides of this library is to remove extra typing related code from my main logic. Having to add an implicit .float defeats the purpose of this library IMO.

From reading the docs this should work, torch.randn returns a tensor the the deafult dtype, and TensorTypes that have float in them should be of the default type.

@patrick-kidger
Copy link
Owner

Without having things set up in PyCharm myself it'll be a fair bit of work to diagnose this.

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error? I'm just trying to gather some data on what raises an error and what doesn't. More broadly if you can track down what's causing the issue then I'd be happy to accept a PR.

@spietras
Copy link

spietras commented Apr 8, 2022

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error?

I checked it and yes, there is still an error. Doing torch.randn(5, 3).float() works only because float() is untyped so PyCharm can't assume anything about the return type and doesn't emit any warnings.

Seems that torchtyping doesn't work with PyCharm's type checker at all, because no matter what I do there is always a warning when assigning Tensor to anything with TensorType type hint.

And I guess it's not surprising because TensorType is a subclass of Tensor so it complains when we try to assign an instance of the parent class to something expecting a subclass.

@fzyzcjy
Copy link

fzyzcjy commented Dec 11, 2022

Hi, is there any updates?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 11, 2022

Yes! I'd recommend trying jaxtyping. Despite the name, it actually works equally well for PyTorch.

In particular, it's designed to play much better with static type checkers.

@fzyzcjy
Copy link

fzyzcjy commented Dec 11, 2022

@patrick-kidger Interesting, thanks for the quick reply! (Never thought "jax" typing would work for "pytorch" before ;) )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants