Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do you load a custom checkpoint? #53

Open
bottiger1 opened this issue May 7, 2022 · 2 comments
Open

How do you load a custom checkpoint? #53

bottiger1 opened this issue May 7, 2022 · 2 comments

Comments

@bottiger1
Copy link

bottiger1 commented May 7, 2022

Hello I want to train the network on my own samples but I'm finding it quite difficult.

Right now I edited Toxic_comment_classification_BERT.json to point to my own training and test csv. Then I have to edit train.py to manually save the model object inside ToxicClassifier at the end of the training.

torch.save(model.model, 'custom.pt')

Then I have load the file manually, instantiate the normal instance of detoxify, and then replace the internal model object with the saved version to get it to work.

saved = torch.load('custom.pt')
d = detoxify.Detoxify('original')
d.model = saved

If I try to load a checkpoint generated at "saved\Jigsaw_BERT\lightning_logs\version_x\checkpoints\epoch=3-step=76.ckpt" with detoxify or try to instantiate detoxify with the "checkpoint parameter" or with a file generated by torch.save(model), it always says

Checkpoint needs to contain the config it was trained with as well as the state dict

What's the proper way of saving the checkpoint so it has the config and state dict with it? Or is my workaround the best way to use custom training data?

@laurahanu
Copy link
Collaborator

Hello!
To load the custom pretrained model you would need to save both the config and the state dict in the checkpoint e.g.
torch.save({"config": custom_config, "state_dict": custom_model_state_dict}). Make sure you only save the state_dict and not the whole PL checkpoint.

@auadams
Copy link

auadams commented Feb 16, 2023

I just went down this path if pain and suffering, I want to post this for anyone else who wants to use this AI to train data for their own purposes. It was a really painful experience that let me learn how detoxify works from a tensor level lol.

    trainer.fit(model, data_loader, valid_data_loader)
    torch.save({"config": model.config, "state_dict": model.state_dict()},"model.pt")

my original idea after seeing this post was just to put this save line after the trainfer.fit. The issue i have been running to is that you can't use model.state_dict() because everything in the state dictonary is prefixed with model. i.g model.bert.encoder.layer.8.output.LayerNorm.weight needs to be converted to bert.encoder.layer.8.output.LayerNorm.weight. after doing all of the translations of every element in the state dictonary i could sucessfully run the checkpoint method in detoxify.

you need to add this to the bottom of train.py

    trainer.fit(model, data_loader, valid_data_loader)
    statedict = {}
    for param_tensor in model.state_dict():
        if "model.bert." in param_tensor:
            newname = param_tensor.replace("model.","")
            statedict[newname] = model.state_dict()[param_tensor]
    statedict["classifier.weight"] = model.state_dict()["model.classifier.weight"]
    statedict["classifier.bias"] = model.state_dict()["model.classifier.bias"]
    torch.save({"config": model.config, "state_dict": statedict},"model.pt")

then you can just import your model using detoxify like below

ai = Detoxify(checkpoint="model.pt")

also for other models like Robert or albert you just need to replace the bert in the if statement above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants