Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requirements for inference mode and path to weights.pt #4

Open
kkedich opened this issue Nov 5, 2020 · 1 comment
Open

Requirements for inference mode and path to weights.pt #4

kkedich opened this issue Nov 5, 2020 · 1 comment

Comments

@kkedich
Copy link

kkedich commented Nov 5, 2020

It seems that torchvision needs to be included in the requirements list. I tried to create a conda environment with just pytorch and python3.7, but that didn't work.

Conda environment created with:

    conda create --name dists python=3.7
    pip install -r requirements.txt

Error:

Traceback (most recent call last):
  File "DISTS_pt.py", line 8, in <module>
    from torchvision import models, transforms
ModuleNotFoundError: No module named 'torchvision'

Adding the torchvision to the requirements list fix this.
Additionally, when using the conda env, the path to the weights file DISTS_pytorch/weights.pt seems to point to an incorrect path. The sys.prefix points to the env directory (example: /home/karinabogdan/anaconda3/envs/dists/) and not the weights file in DISTS_pytorch/

python DISTS_pt.py --ref ../example_input/12_enh.jpg --dist ../example_input/12_raw.jpg
Traceback (most recent call last):
  File "DISTS_pt.py", line 134, in <module>
    model = DISTS().to(device)
  File "DISTS_pt.py", line 63, in __init__
    weights = torch.load(os.path.join(sys.prefix, 'weights.pt'))
  File "/home/karinabogdan/anaconda3/envs/dists/lib/python3.7/site-packages/torch/serialization.py", line 581, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/karinabogdan/anaconda3/envs/dists/lib/python3.7/site-packages/torch/serialization.py", line 230, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/karinabogdan/anaconda3/envs/dists/lib/python3.7/site-packages/torch/serialization.py", line 211, in __init__
    super(_open_file, self).__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: '/home/karinabogdan/anaconda3/envs/dists/weights.pt'

I fixed that by changing the line:

weights = torch.load(os.path.join(sys.prefix, 'weights.pt'))

to

from pathlib import Path, PurePosixPath
weights = torch.load(str(PurePosixPath(Path.cwd()).joinpath('weights.pt')))

but I am not sure if this is the best way to handle this.

@Ellyuca
Copy link

Ellyuca commented Sep 7, 2022

Thanks, this was really helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants