Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

diffuseq-v2: TypeError: load_state_dict() takes 1 positional argument but 3 were given #61

Open
zzbuzzard opened this issue Oct 23, 2023 · 3 comments

Comments

@zzbuzzard
Copy link

Hi, thanks for releasing this code! I'm trying to run decoding code in the diffuseq-v2 branch, but get the above error. The exact command I am running is:
CUDA_VISIBLE_DEVICES=2 python -u run_decode_solver.py --model_dir models/qqp --seed 110 --bsz 100 --step 10 --split test

I noticed that commit cc4e9b4 changes line 65 of sample_seq2seq.py from
dist_util.load_state_dict(args.model_path, map_location="cpu")
to
dist_util.load_state_dict(args.model_path, False, "model", map_location="cpu")
(the same change is also present in sample_seq2seq_dpmSolver.py)

However, the definition in diffuseq/utils/dist_util.py, which is not changed, does indeed only take one positional argument:

def load_state_dict(path, **kwargs):
    """
    Load a PyTorch file.
    """
    # if int(os.environ['LOCAL_RANK']) == 0:
    with bf.BlobFile(path, "rb") as f:
        data = f.read()
    return th.load(io.BytesIO(data), **kwargs)

Not sure what's going on here... it seems the branch should have also updated this function? Any help appreciated!

P.S. I was able to run decoding successfully on the main branch, so it's just an issue with diffuseq-v2!

@zzbuzzard
Copy link
Author

The error:

Traceback (most recent call last):
  File "/mnt/c/Users/Z/Documents/GitHub/DiffuSeq/sample_seq2seq_dpmSolver.py", line 222, in <module>
    main()
  File "/mnt/c/Users/Z/Documents/GitHub/DiffuSeq/sample_seq2seq_dpmSolver.py", line 66, in main
    dist_util.load_state_dict(args.model_path, False, "amp", map_location="cpu")
TypeError: load_state_dict() takes 1 positional argument but 3 were given

@zzbuzzard zzbuzzard changed the title TypeError: load_state_dict() takes 1 positional argument but 3 were given diffuseq-v2: TypeError: load_state_dict() takes 1 positional argument but 3 were given Oct 23, 2023
@zzbuzzard
Copy link
Author

Workaround for now in case anyone else has this issue: I changed this line back to
dist_util.load_state_dict(args.model_path, map_location="cpu")
and had to remove CUDA_VISIBLE_DEVICES=2 from my command for torch to recognise my GPU (probably because I only have one GPU) and now it's working. In fact, it's working far faster than the main branch, very impressive :)

@summmeer
Copy link
Collaborator

Thank you for pointing out. I mixed them up during the version update. I will update this soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants