Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming Training gives CheckpointMismatchError #1300

Open
arushi-08 opened this issue Jul 29, 2023 · 10 comments
Open

Resuming Training gives CheckpointMismatchError #1300

arushi-08 opened this issue Jul 29, 2023 · 10 comments
Labels
🛑 Checkpoints issues related to checkpoints and resuming training

Comments

@arushi-08
Copy link

arushi-08 commented Jul 29, 2023

I want to resume training my model from a checkpoint file (*.pt), but facing pykeen.training.training_loop.CheckpointMismatchError error.

Full stack trace:

INFO:pykeen.training.training_loop:=> loading checkpoint '/afs/ars539/.data/pykeen/checkpoints/complex-checkpoint.pt'
Traceback (most recent call last):
  File "/afs/ars539/biology_project/resume_training.py", line 32, in <module>
    result = pipeline(
  File "/afs/ars539/miniconda3/envs/biology-venv/lib/python3.10/site-packages/pykeen/pipeline/api.py", line 1546, in pipeline
    stopper_instance, configuration, losses, train_seconds = _handle_training(
  File "/afs/ars539/miniconda3/envs/biology-venv/lib/python3.10/site-packages/pykeen/pipeline/api.py", line 1190, in _handle_training
    losses = training_loop_instance.train(
  File "/afs/ars539/miniconda3/envs/biology-venv/lib/python3.10/site-packages/pykeen/training/training_loop.py", line 337, in train
    best_epoch_model_file_path, last_best_epoch = self._load_state(
  File "/afs/ars539/miniconda3/envs/biology-venv/lib/python3.10/site-packages/pykeen/training/training_loop.py", line 1185, in _load_state
    raise CheckpointMismatchError(
pykeen.training.training_loop.CheckpointMismatchError: The checkpoint file '/afs/ars539/.data/pykeen/checkpoints/complex-checkpoint.pt' that was provided already exists, but seems to be from a different training loop setup.

I have realised that the issue is with checksum mismatch i.e. the checkpoint file has a different configuration.

logger.info(f"=> loading checkpoint '{path}'")
checkpoint = torch.load(path)
if checkpoint["checksum"] != self.checksum:
raise CheckpointMismatchError(
f"The checkpoint file '{path}' that was provided already exists, but seems to be "
f"from a different training loop setup.",
)

However, I am not sure how to load the same configuration as given in the checkpoint file. I feel that this is highlighted in the "Word of Caution and Possible Errors" documentation section (https://pykeen.readthedocs.io/en/stable/tutorial/checkpoints.html#word-of-caution-and-possible-errors), but still unclear what are the next steps.

How do we resume training from the previous checkpoint?

@mberr
Copy link
Member

mberr commented Aug 8, 2023

Hi @arushi-08 ,

the checkpoint files are "just" normal torch archives, i.e., you can load them via torch.load as done in the code snippet you linked (ore more precisely, just one line above; I have updated your text above to include it).

The checksum was calculated from the string representations of the model and the optimizer, cf. here

@property
def checksum(self) -> str: # noqa: D401
"""The checksum of the model and optimizer the training loop was configured with."""
h = md5() # noqa: S303
h.update(str(self.model).encode("utf-8"))
h.update(str(self.optimizer).encode("utf-8"))
return h.hexdigest()

I would suggest that you load the checkpoint file via torch.load and carefully compare it with the configuration. If you still think that everything is sane, I would suggest to manually overide the checkpoint file's checksum and write it to a new checkpoint file.

d = torch.load(path)
d["checksum"] = checksum
torch.save(d, new_path)

@pablo-sanchez-sony
Copy link

Hi,

I was having the same error. I believe the problem comes when using the scheduler object from PyTorch. We can observe in the constructor whenever last_epoch=-1 the initial_lr of the optimizer is updated.

https://github.com/pytorch/pytorch/blob/a5d841ef01e615e2a654fb12cf0cd08697d12ccf/torch/optim/lr_scheduler.py#L38

Basically, this makes str(self.optimizer).encode("utf-8") to be different, given that we have not yet reloaded the optimizer nor the scheduler.

I believe the issue can be solved by moving the checksum comparison to the end of the method.

@mberr
Copy link
Member

mberr commented Aug 24, 2023

@pablo-sanchez-sony, would you mind opening a PR with the changes you suggest?

@pablo-sanchez-sony
Copy link

Sure!

@arushi-08
Copy link
Author

arushi-08 commented Sep 23, 2023

I am facing this checkpoint mismatch error in the same training loop for RotatE KGE model.
Following log messages shows that rotate-checkpoint.pt is created at some initial epoch and then after 30 epochs it tries to read from this checkpoint and gives this error:

INFO:pykeen.training.training_loop:=> no checkpoint found at '/afs/ars539/.data/pykeen/checkpoints/rotate-checkpoint.pt'. Creating a new file.
Training epochs on cuda:0:   2%|▏         | 9/500 [07:47<6:22:12, 46.71s/epoch, loss=0.123, prev_loss=0.123]INFO:pykeen.evaluation.evaluator:Starting batch_size search for evaluation now...                                                                              
Saved model weights to /afs/ars539/.data/pykeen/checkpoints/best-model-weights-0aa7f269-26c4-4e84-8d47-a55fc43911c6.pt
INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 10.
INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 10.
INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 20.
INFO:pykeen.stoppers.early_stopping:Stopping early at epoch 30. The best result 0.14622531740871292 occurred at epoch 10.
INFO:pykeen.stoppers.early_stopping:Re-loading weights from best epoch from /afs/ars539/.data/pykeen/checkpoints/best-model-weights-0aa7f269-26c4-4e84-8d47-a55fc43911c6.pt
INFO:pykeen.training.training_loop:=> Saved checkpoint after having finished epoch 30.
INFO:pykeen.evaluation.evaluator:Evaluation took 547.88s seconds
Best is trial 0 with value: 0.06680432707071304.
INFO:pykeen.pipeline.api:loaded random seed 42 from checkpoint.
INFO:pykeen.pipeline.api:Using device: None
INFO:pykeen.stoppers.early_stopping:Inferred checkpoint path for best model weights: /afs/ars539/.data/pykeen/checkpoints/best-model-weights-ea7a231a-d250-422a-a747-49f6b3a70e2f.pt
INFO:pykeen.training.training_loop:=> loading checkpoint '/afs/ars539/.data/pykeen/checkpoints/rotate-checkpoint.pt'
[W 2023-09-22 18:37:16,297] Trial 1 failed with parameters: {'model.embedding_dim': 200, 'loss.margin': 1.0271124464019343, 'optimizer.lr': 0.026733931043720773, 'negative_sampler.num_negs_per_pos': 3, 'training.batch_size': 64} because of the following error: CheckpointMismatchError("The checkpoint file '/afs/ars539/.data/pykeen/checkpoints/rotate-checkpoint.pt' that was provided already exists, but seems to be from a different training loop setup.").

My training script is:

result = hpo_pipeline(
    study_name='rotate_hpo',
    training=training,
    testing=testing,
    validation=validation,
    pruner="MedianPruner",
    sampler="tpe",
    model='RotatE',
    model_kwargs={
        "random_seed": 42,
    },
    model_kwargs_ranges=dict(
        embedding_dim=dict(type=int, low=100, high=300, q=100),
    ),
    negative_sampler_kwargs_ranges=dict(
        num_negs_per_pos=dict(type=int, low=1, high=100),
    ),
    stopper='early',
    n_trials=30,
    training_loop="sLCWA",
    training_kwargs=dict(
        num_epochs=500,
        checkpoint_name='rotate-checkpoint.pt',
        checkpoint_frequency=10,
     ),
    evaluator_kwargs={"filtered": True, "batch_size":128},
)

Kindly suggest how to resolve this, as I am not explicitly trying to resume training, rather the hpo_pipeline itself is reloading from the checkpoint.

@mberr
Copy link
Member

mberr commented Sep 23, 2023

When setting a checkpoint name

checkpoint_name='rotate-checkpoint.pt',

it seems to be used for all trials => the second run thinks it is a continuation of the first trial, but the model hyperparameters do not match.

@mberr
Copy link
Member

mberr commented Sep 23, 2023

Here is a smaller reproduction script to reproduce the error

from pykeen.hpo import hpo_pipeline

result = hpo_pipeline(
    study_name="rotate_hpo",
    dataset="nations",
    model="RotatE",
    model_kwargs_ranges=dict(
        embedding_dim=dict(type=int, low=8, high=24, q=8),
    ),
    stopper="early",
    n_trials=2,
    training_loop="sLCWA",
    training_kwargs=dict(
        num_epochs=2,
        checkpoint_name="rotate-checkpoint.pt",
        checkpoint_frequency=1,
    ),
)

@mberr
Copy link
Member

mberr commented Sep 23, 2023

@arushi-08 , what is your use case for providing a checkpoint name? Do you want to save each trial's model? If yes, we have an explicit save_model_directory for that, which will take care of creating one sub-directory per trial.

@mberr
Copy link
Member

mberr commented Sep 23, 2023

I have opened a small PR (#1324) to fail fast on the first trial with an error message about how to fix it 🙂

@mberr
Copy link
Member

mberr commented Sep 23, 2023

@pablo-sanchez-sony , would this resolve your issue, too?

mberr added a commit that referenced this issue Sep 23, 2023
…uration (#1324)

When providing a `checkpoint_name` only the second trial will fail,
since it finds an existing checkpoint and tries to continue training,
but the model configuration has likely changed.

This PR checks the configuration and directly raises an error with a
descriptive error message.

cf. #1300
@cthoyt cthoyt added the 🛑 Checkpoints issues related to checkpoints and resuming training label Sep 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🛑 Checkpoints issues related to checkpoints and resuming training
Projects
None yet
Development

No branches or pull requests

4 participants