New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broken Type Hints
in PyTorch 0.4.0, related to IDEs(eq. PyCharm)
#7318
Comments
Type Hints
in PyTorch 0.4.0 Type Hints
in PyTorch 0.4.0, related to IDEs(eq. PyCharm)
any related point? #4568 |
There's no auto-completion, quick definition & quick documentation for
|
If any of you have suggestions for how to fix this, please let us know! It sounds like PyCharm's searching for functions could be improved, but I'm not sure what can be done from the PyTorch side. Using |
@zou3519 for name in dir(_C._VariableFunctions):
globals()[name] = getattr(_C._VariableFunctions, name) from 30ec06c
I thought |
I'm trying to make PR for this issue. Would you review? @zou3519 |
@kimdwkimdw I don't know what's going on with the global() there or what would make it better, but yes, please submit a PR and I will look at it :) |
Just wanted to note that I have never run into this issue with any other package in PyCharm. That's not to say PyCharm shouldn't be doing things differently, but this issue does seem to be an uncommon case. |
I'm mainly working on these files.
Trying to figure out what's changed when generates C extensions. |
Any updates? |
I made a working example in my fork. If anyone want to use auto-complete first, try below. 1. git clone
2. Install PyTorch3. Clear cache in PyCharm
Before |
install failed following the description above. |
I'm having the same problem. I've always loved pytorch and pycharm. Such a shame they don't work along well : ( |
This is really strange when I used the same code in the python console with pycharm, it is normal, but in the editor. O__O "… |
Going to bump the priority on this because a lot of users have been requesting this... we'll try our best to look into it more. |
A lot of users are requesting.... I should make a pull request this weekend. |
@kimdwkimdw should I reinstall pytorch from source? Is there any easy way to renew in the windows platform? |
@541435721 |
It doesn't seem to change anything for Pycharm (2018.1, with rebuilt caches). |
You should check out @kimdwkimdw 's pull request, not pytorch master. The pull request has not been merged into pytorch master yet. This is the PR: #8845. |
Yes, this what I did:
|
@nlgranger you should delete python stubs manually. it cannot be deleted with 'invalidate cache'. When you setup project, Pycharm initialize its python stub files. After you delete python stub directory, restart Pycharm. PyCharm will re-generate its python stub. |
the same here. Tried PyCharm , VS code , Spyder .. torch.cat impossible to get , the same as the Longtensor. |
Throughout the comments in this issue, this has continually been billed as a problem related to PyCharm, and it's been mentioned that it's a problem with how PyCharm's internals and PyTorch's internals work together. This is a mischaracterization of the problem, and viewing it as such probably won't lead to the correct solution. PyCharm is taking a fairly standard approach to resolving packages, and, as has been mentioned by several others, this problem is not unique PyCharm. PyCharm just happens to be the most popular editor that runs into this issue with PyTorch. Any editor which does not dynamically resolve the package (for speed purposes) will have this problem. I think the fact that IPython/Jupyter/etc can resolve it correctly is the exception rather than the rule. I just wanted to emphasize this so that not too much focus is put on "making it work with PyCharm" rather than "making it work in general". That being said, the stubs being worked on in this PR are a good "standard" solution, and should probably be able to resolve the issue in general. For those of you looking for the quick temporary solution, copy that stub file into your dist-packages/site-packages as explained in the gist. Manually add any other missing parts, and report those missing parts to ensure they get handled. |
This is more a basis for discussion that a ready solution, as it does lots of funny things, and for many of them a better solution will be found. - Initial stab at creating a type torch/__init__.pyi . - We only do this for Python 3 because we want to use type hint introspection. - So far, we only aim at doing this for torch functions and torch.Tensor. - We need to import the newly build torch. Thus we do this at the end of the build. We use os.fork to only import the module in a child process. - We use an annotate decorator to be able to put type hints on the native python functions in a way that a) they're available in the usual place for Python 3 b) we stay Python 2 compatible - Some annotatons in torch/functional.py are provided as examples, but the remaining ones still need to be done. This could end up fixing pytorch#7318 This was backported from pytorch#12500 but with tests removed and the stub file checked in directly, so we don't have to also make sure all build shenanigans work. Master will be properly tested. The stub file was generated by running 'python setup.py create_pyi' and then copying the generated pyi file in the build directory (find -name *.pyi) to torch/
* create type hint stub files for module torch This is more a basis for discussion that a ready solution, as it does lots of funny things, and for many of them a better solution will be found. - Initial stab at creating a type torch/__init__.pyi . - We only do this for Python 3 because we want to use type hint introspection. - So far, we only aim at doing this for torch functions and torch.Tensor. - We need to import the newly build torch. Thus we do this at the end of the build. We use os.fork to only import the module in a child process. - We use an annotate decorator to be able to put type hints on the native python functions in a way that a) they're available in the usual place for Python 3 b) we stay Python 2 compatible - Some annotatons in torch/functional.py are provided as examples, but the remaining ones still need to be done. This could end up fixing #7318 This was backported from #12500 but with tests removed and the stub file checked in directly, so we don't have to also make sure all build shenanigans work. Master will be properly tested. The stub file was generated by running 'python setup.py create_pyi' and then copying the generated pyi file in the build directory (find -name *.pyi) to torch/ * Ignore pyi in flake8 Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Real fix for lint error Signed-off-by: Edward Z. Yang <ezyang@fb.com>
* create type hint stub files for module torch This is more a basis for discussion that a ready solution, as it does lots of funny things, and for many of them a better solution will be found. - Initial stab at creating a type torch/__init__.pyi . - We only do this for Python 3 because we want to use type hint introspection. - So far, we only aim at doing this for torch functions and torch.Tensor. - We need to import the newly build torch. Thus we do this at the end of the build. We use os.fork to only import the module in a child process. - We use an annotate decorator to be able to put type hints on the native python functions in a way that a) they're available in the usual place for Python 3 b) we stay Python 2 compatible - Some annotatons in torch/functional.py are provided as examples, but the remaining ones still need to be done. This could end up fixing #7318 This was backported from #12500 but with tests removed and the stub file checked in directly, so we don't have to also make sure all build shenanigans work. Master will be properly tested. The stub file was generated by running 'python setup.py create_pyi' and then copying the generated pyi file in the build directory (find -name *.pyi) to torch/ * Ignore pyi in flake8 Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Real fix for lint error Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Hi everyone! @t-vi's patch has merged to master, so if you update you should get working autocomplete for torch on master. Additionally, we committed a hand-generated type stub for 1.0.1, so when that release happens, autocomplete will also work for people on that release. The type stub is probably not good enough for actually typechecking your code with mypy. We're tracking this follow up work in #16574, please pipe up if it affects you there. Also, we only fixed autocomplete on torch; if you are encountering problems with autocomplete on other modules (or you think there are missing identifiers on torch), please let us know with a bug report. Thanks for your patience! |
just fyi, the new v1.0.1 release now has this solved, and it's shipped. |
@ezyang |
I am getting intellisense for VSCode to partially work on torch 1.0.1, but some functions are still missing, eg, |
I was also getting errors and no autocomplete for some functions in PyCharm (it was missing when I needed it most), and I don't think this is completely solved yet. I wrote my first PyTorch program last weekend so not sure if I made some mistake with the setup. |
@vpj it was fixed in PyTorch v1.0.1. Anything lower still has the issue. check |
Version is 1.0.1post2 Some definitions are missing in the init.pyi file. PyCharm also seems to be causing some problems because it fails to infer return types. For instance, it doesn't seem to infer that the return value of |
so, apparently a few namespaces are still missing, though we fixed the larger problem. I opened a new tracking task for this here: #16996 |
@yarcowang We've been steadily adding more and more namespaces. This should work quite well on 1.2.0. If there are still missing things please file bugs. |
@ezyang OK. I see. It seems a long-term bug. |
I am a newer for pytorch. Thanks a lot, it works well for me. |
Hello everyone, went through the entire thread. It's quite difficult to find whether the issue is resolved or not. I'm using torch1.4+cpu and torch.tensor still has warning issues. Can anyone help? Thanks |
Also still experiencing this issue with torch.Tensor |
Thanks for the comments. We've been fixing lots of small type hint bugs over the last few months; if you notice things that are not working on the latest release or master please open bugs for them and we'll look. Thanks! |
Broken things that I have noticed in version 1.4.0:
|
@tjysdsg Thanks for reporting. Could you open a new issue with those cases for more visibility? |
If you have a question or would like help and support, please ask at our
forums.
If you are submitting a feature request, please preface the title with [feature request].
If you are submitting a bug report, please fill in the following details.
Issue description
Recently, I figured out PyCharm cannot make auto-complete for
torch.zeros
.PyCharm says
I dig it for a while I found broken
Type Hints
.From these changes,
30ec06c
https://github.com/pytorch/pytorch/wiki/Breaking-Changes-from-Variable-and-Tensor-merge
Especially, 30ec06c#diff-14258fce7c17ccb97b488e64373b0803R308 @colesbury
This line cannot make
Type Hints
for lots of IDEs.Originally,
torch.zeros
was intorch/_C/__init__.py
But, it moved to
torch/_C/_VariableFunctions
Code example
https://gist.github.com/kimdwkimdw/50c18b5cf72c69c2d01bb4146c8a2b5c
This is Proof of Concept for this bug.
If you look at
main.py
System Info
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
PyTorch or Caffe2:
How you installed PyTorch (conda, pip, source):
Any case for conda, pip, source.
Build command you used (if compiling from source):
OS: Any
PyTorch version: 0.4.0
Python version: 3.6.5
CUDA/cuDNN version: .
GPU models and configuration: .
GCC version (if compiling from source): .
CMake version: .
Versions of any other relevant libraries:.
The text was updated successfully, but these errors were encountered: