Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer load_state_dict() problem? #2830

Closed
JianyuZhan opened this issue Sep 22, 2017 · 26 comments · Fixed by #3658
Closed

optimizer load_state_dict() problem? #2830

JianyuZhan opened this issue Sep 22, 2017 · 26 comments · Fixed by #3658
Labels
awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user

Comments

@JianyuZhan
Copy link

JianyuZhan commented Sep 22, 2017

Hi, I encountered this bug:


    optimizer.step()
    exp_avg.mul_(beta1).add_(1 - beta1, grad)

TypeError: add_ received an invalid combination of arguments - got (float, torch.cuda.FloatTensor), but expected one of:
 * (float value)
 * (torch.FloatTensor other)
 * (torch.SparseFloatTensor other)
 * (float value, torch.FloatTensor other)
      didn't match because some of the arguments have invalid types: (float, torch.cuda.FloatTensor)
 * (float value, torch.SparseFloatTensor other)
      didn't match because some of the arguments have invalid types: (float, torch.cuda.FloatTensor)

The code skeleton is like:

model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()

optimizer = optim.Adam()
optimizer.load_state_dict(checkpoint['optimizer'])

...
#  In train loop
for epoch in range(...):
  ...
  optimizer.step()
     -> BUG <-

It seems the loaded param_groups are torch.cuda.FloatTensor, and I've tried a workaround to
move optmizer.param_groups to cpu, but it still has the same bug.

@chenzhekl
Copy link

Could you provide a full script to reproduce the problem?

@soumith soumith added awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Oct 5, 2017
@soumith soumith added this to Crashes / Segfaults / Errors in Issue Categories Oct 5, 2017
@hefeicyp
Copy link

maybe you can try like this,
optimizer.step()
exp_avg.mul_(beta1).add_(1 - beta1, grad.cpu())

@JianyuZhan
Copy link
Author

JianyuZhan commented Oct 11, 2017

Sorry, I missed the reply email.

I am afraid that I am unable to provide a reproducer now. It is a work I am doing for the OpenNMT-py project:https://github.com/OpenNMT/OpenNMT-py, trying to use lr_scheduler for doing lr update. And I encoutered this problem when testing the resume a suspended training case. So I factored out the code skeleton about this problem above.

I've tried several methods, including tricks like what @hefeicyp suggests, but it still happens.

Per my analysis, it is because the previous training was done on gpu, so when saving the optimizer.state_dict, the stored states(tensors) are of cuda version. During resuming, when we load the saved optimizer, load_state_dict() loads this cuda version to cpu(the model(nn.Module) can be moved to gpu easily, but torch.optimizer seems lacking this ability?) , so this problem emerges.

@dogancan
Copy link

Try moving optimizer state to the GPU memory manually after loading it from the checkpoint.

optimizer = optim.Adam()
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

I agree that having an optimizer.cuda() method for this operation would be nice.

@JianyuZhan
Copy link
Author

@dogancan, thanks. My work was suspended due to other problems, when resumed, I will try your method.

@apaszke
Copy link
Contributor

apaszke commented Oct 12, 2017

I'm afraid @dogancan's solution won't work. It will make the error go away, but your optimizer will no longer be training the model. You should recreate optimizers after casting modules to a different type or device, and you can use load_state_dict to restore the state from a previous copy. This currently doesn't work, but we should fix it (by copying from the data from the state dict, instead of using the tensors directly - this allows for cross-device or cross-type updates).

@JianyuZhan
Copy link
Author

@apaszke , yep, your method is what I currently use, it works. But I will wait for upstream to fix this problem though. Thanks for your great works!

@dogancan
Copy link

@apaszke Ah, my bad. I forgot to update the line where the optimizer is recreated. But otherwise, the following should do the job, right?

model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

@apaszke
Copy link
Contributor

apaszke commented Oct 12, 2017

ah, right. That should work 😊

@apaszke
Copy link
Contributor

apaszke commented Oct 12, 2017

Except that you should use torch.is_tensor(v) instead of isinstance(v, torch.Tensor)

@stormraiser
Copy link

I had a similar problem. When I save the optimizer state from a GPU other than GPU 0 and then load the state it still loads everything to GPU 0. Specifying map_location in torch.load() didn't work either. @dogancan 's solution solves this though.

@codars
Copy link

codars commented Nov 18, 2017

Hi guys, I have a very similar problem as the one in this thread, here's my code:

model = inceptionresnetv2(num_classes=config['tr_classes'])
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(checkpoint['md_state_dict'])
optimizer = torch.optim.Adam(model.parameters(), lr=config['tr_lr'], weight_decay=config['tr_weightdecay'])
optimizer.load_state_dict(checkpoint['md_optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if torch.is_tensor(v):
            state[k] = v.cuda()

And then once I resume, I got KeyErrors on my optimizer:

---> 40         optimizer.step()
     41 
     42         config['am_batch_time'].update(time.time() - end)
~/.conda/envs/env_pytorch/lib/python3.5/site-packages/torch/optim/adam.py in step(self, closure)
     44                     continue
     45                 grad = p.grad.data
---> 46                 state = self.state[p]
     47 
     48                 # State initialization
KeyError: Parameter containing:
(0 ,0 ,.,.) = 
 -1.6336e-01 -5.6482e-01 -4.2228e-02
...
[torch.cuda.FloatTensor of size 32x3x3x3 (GPU 0)]

Do you guys know how to fix this issue? BTW, I have 8 GPUs used, I'm guessing if this issue was because of that?

@rafaelvalle
Copy link

@CodArs-van were you able to solve your issue with multiple-GPUs?

@codars
Copy link

codars commented Feb 5, 2018

@rafaelvalle Thanks for asking. Yeah, I'm able to, turns out the issue is because I used an early version of PyTorch, after I updated the version, it works like a charm!

@lzcn
Copy link

lzcn commented Mar 18, 2018

Just a comment, this problem is caused by

    def load_state_dict(self, state_dict):
        ...
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
       ...  

deepcopy makes all state tensor are moved into GPU0 ,
so by moving the state of an optimizer to specific GPU will fix this problem.

gwenniger added a commit to gwenniger/OpenNMT-py that referenced this issue Mar 20, 2018
pytorch/pytorch#2830
1. Recreating the optimizer, using the model parameters
2. Loading the optimizer state saved from the checkpoint to the
optimizer.

	modified:   onmt/Optim.py
	modified:   train.py
gwenniger added a commit to gwenniger/OpenNMT-py that referenced this issue Mar 20, 2018
fix turned out not be correct. It is still nescessary to
(re-)create the optimizer at all times, using the state information.
But in case of loading an optimizer from a checkpoint, in a second
stage the saved optimizer state dictionary must be used with
the re-created optimizer to set the optimizer.state field.
In case of Adam for example, this is what restores the parameter
history from the previous epoch, which was previously lost
because the second step was not done.

As one additional last thing for this fix to work,
if the GPU is used, the relevant restored
optimizer state variables must be converted to their CUDA
counterpart. Note that this fix was inspired on a fix for
a similar problem, discussed at
pytorch/pytorch#2830

	modified:   train.py
@chrisliu54
Copy link

Hi @lzcn, how do you know the specific GPU location of different tensors in advance?

@sebastienwood
Copy link

Would a feature where all torch.save() calls always makes use of an automatically generated CPU version be feasible ?
And then at resume the torch.load() would make use of the "current" device being used (or any better strategy).
At the moment it seems we need a lot of boilerplater code to ensure saving and loading are consistents across devices for models/optimizer/scheduler/etc.

ishaanb92 added a commit to ishaanb92/Probabalistic-U-Net that referenced this issue Apr 12, 2019
@ran337287
Copy link

I have met similar problem, I recreated Adam optimizer without optimizer.cuda() after reloading model, model.cuda() and DataParallel(model) according to @dogancan's solution.

@jiangzhonglian
Copy link

thanks, it work!

@apaszke Ah, my bad. I forgot to update the line where the optimizer is recreated. But otherwise, the following should do the job, right?

model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

jiangzhonglian added a commit to jiangzhonglian/tutorials that referenced this issue Jul 25, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```


Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```


Fix it: pytorch/pytorch#2830

```
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)  # missing line from original code
        labels = labels.to(device)  # missing line from original code
        images = images.reshape(-1, 28 * 28)
        out = model(images)
        _, predicted = torch.max(out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
```
jiangzhonglian added a commit to jiangzhonglian/tutorials that referenced this issue Jul 25, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```

Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```

Fix it: pytorch/pytorch#2830

```
model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()
```
wanchaol pushed a commit to wanchaol/tutorials that referenced this issue Aug 7, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```


Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```


Fix it: pytorch/pytorch#2830

```
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)  # missing line from original code
        labels = labels.to(device)  # missing line from original code
        images = images.reshape(-1, 28 * 28)
        out = model(images)
        _, predicted = torch.max(out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
```
wanchaol pushed a commit to wanchaol/tutorials that referenced this issue Aug 7, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```

Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```

Fix it: pytorch/pytorch#2830

```
model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()
```
wanchaol pushed a commit to wanchaol/tutorials that referenced this issue Aug 7, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```


Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```


Fix it: pytorch/pytorch#2830

```
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)  # missing line from original code
        labels = labels.to(device)  # missing line from original code
        images = images.reshape(-1, 28 * 28)
        out = model(images)
        _, predicted = torch.max(out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
```
wanchaol pushed a commit to wanchaol/tutorials that referenced this issue Aug 7, 2019
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```

Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```

Fix it: pytorch/pytorch#2830

```
model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()
```
@menghuu
Copy link

menghuu commented Aug 25, 2019

@apaszke
Hi, as you say that: every time moving model to other device, we should build optimizer, but, if we move the model to other device and move back, should we build the optimizer again?
here is an example code:

model = Model()
model.cuda()
optimizer = optim.Adam(model.parameters())

for d, gt in trn_dataloader:
    # train
    ... 
    optimizer.step()
    model.cpu() # move to cpu
    # eval or do other things
    ...
    model.cuda()  # but finnally, move back

does optimizer run as expected?

also, if doing model.to(model.device), should we rebuild optimizer ?

avik-pal added a commit to fidler-lab/social-driving that referenced this issue Apr 9, 2020
After loading an optimizer originally saved on GPU, there seems to be
a device mismatch issue. Solution has been adapted from
[here](pytorch/pytorch#2830 (comment))
@mistermoutan
Copy link

@apaszke Ah, my bad. I forgot to update the line where the optimizer is recreated. But otherwise, the following should do the job, right?

model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

@apaszke Is there a problem if you switch the order to something like this?

model = Model()
model.to('cuda')
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()
model.load_state_dict(checkpoint['model'])

Meaning moving the model to 'cuda' but only loading it's state dict from checkpoint after loading the optimizer's state dict first?

@pingguokiller
Copy link

The problem can be concluded that the optimizer's state will be loaded to the device as same as the model. You must load the model to GPU at first, and then load the optimizer's state. So that both the model and the optimizer's state are loaded in GPU.

@kyteinsky
Copy link

Instead of moving optimizer to cuda after loading it in cpu, you could load the checkpoint directly in cuda:

model.to(device)

ckpt = torch.load(<model_path>, map_location=device)

model.load_state_dict(ckpt['state_dict'])
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])

del ckpt

@gzerveas
Copy link

gzerveas commented Mar 3, 2021

Instead of moving optimizer to cuda after loading it in cpu, you could load the checkpoint directly in cuda:

model.to(device)

ckpt = torch.load(<model_path>, map_location=device)

model.load_state_dict(ckpt['state_dict'])
optimizer.load_state_dict(ckpt['optimizer'])
scheduler.load_state_dict(ckpt['scheduler'])

del ckpt

I've independently rediscovered that this works :) Should read until the end of the thread next time 😅

@chg0901
Copy link

chg0901 commented Jun 10, 2021

I find my codes still have the problem. I tried my best to range the modules as the examples shown in the above.

can anyone give me some hints?

  print('loading checkpoint {}'.format(cfg.TRAIN.RESUME_PATH))
  checkpoint = torch.load(cfg.TRAIN.RESUME_PATH, map_location={'cuda:0': 'cuda:1'})
  cfg.TRAIN.BEGIN_EPOCH = checkpoint['epoch'] + 1
  model = YOWO(cfg)
  model = nn.DataParallel(model, device_ids=None)  # in multi-gpu case

  model.load_state_dict(checkpoint['state_dict'])
  model = model.cuda()

  pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  logging('Total number of trainable parameters: {}'.format(pytorch_total_params))

  parameters = get_fine_tuning_parameters(model, cfg)

  optimizer = torch.optim.Adam(parameters, lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
  best_score = 0  # initialize best score

  optimizer.load_state_dict(checkpoint['optimizer'])
  for state in optimizer.state.values():
      for k, v in state.items():
          if torch.is_tensor(v):  # isinstance(v, torch.Tensor):
              state[k] = v.cuda()

I also tried this

    model = YOWO(cfg)
    model = model.cuda()
    model = nn.DataParallel(model, device_ids=None)  # in multi-gpu case
    checkpoint = torch.load(cfg.TRAIN.RESUME_PATH, map_location={'cuda:0': 'cuda:1'})
    cfg.TRAIN.BEGIN_EPOCH = checkpoint['epoch'] + 1
    print(checkpoint.keys())
    best_score = checkpoint['fscore']

    model.load_state_dict(checkpoint['state_dict'])

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logging('Total number of trainable parameters: {}'.format(pytorch_total_params))

    parameters = get_fine_tuning_parameters(model, cfg)

    optimizer = torch.optim.Adam(parameters, lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    optimizer.load_state_dict(checkpoint['optimizer'])

rodrigo-techera pushed a commit to Experience-Monks/tutorials that referenced this issue Nov 29, 2021
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```


Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```


Fix it: pytorch/pytorch#2830

```
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)  # missing line from original code
        labels = labels.to(device)  # missing line from original code
        images = images.reshape(-1, 28 * 28)
        out = model(images)
        _, predicted = torch.max(out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
```
rodrigo-techera pushed a commit to Experience-Monks/tutorials that referenced this issue Nov 29, 2021
If you don't configure this string of code, you will get an error when you iterate over the update from 4000_checkpoint.tar:

```
encoder_optimizer.step()  
```

Error message:

```
exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float
```

Fix it: pytorch/pytorch#2830

```
model = Model()
model.load_state_dict(checkpoint['model'])
model.cuda()
optimizer = optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()
```
dsmoljan pushed a commit to dsmoljan/Lumen-Data-Science-2023 that referenced this issue Mar 28, 2023
@varungupta31
Copy link

For me, map_location=device doesn't solve the issue, but the one by @dogancan works. Any idea why?

I create the optimizer, load the state (with map location to CUDA), pass to the train loop, where the mode is pushed to device (though I assume it's not needed if the loaded model is pushed to cuda by map_location).

Specifically, I would get stuck at optmizer.step() call, throwing the device mismatch error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting response (this tag is deprecated) This tag is deprecated while we figure out what to do with it needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user
Projects
None yet
Development

Successfully merging a pull request may close this issue.