This repository provides functionality for Stochastic Weight Averaging-Gaussian training for Transformer models. The implementation is tied into two libraries:
- transformers (maintained by Hugging Face)
- swa_gaussian (maintained by the Language Technology Research Group at the University of Helsinki)
The goal is to make an implementation that works directly with the
convenience tools in the transformers
library (e.g. Pipeline
and
Trainer
) as well as evaluator
from the related evaluate
library.
See also examples.
BERT model, sequence classification task:
- Load pretrained Bert model by
base_model = AutoModelForSequenceClassification.from_pretrained(name_or_path)
- Initialize SWAG model by
swag_model = SwagBertForSequenceClassification.from_base(base_model)
- Initialize SWAG callback object
swag_callback = SwagUpdateCallback(swag_model)
- Initialize
transformers.Trainer
with thebase_model
as model andswag_callback
in callbacks. - Train the model (
trainer.train()
) - Store the complete model using
swag_model.save_pretrained(path)
- BERT
BertPreTrainedModel
->SwagBertPreTrainedModel
BertModel
->SwagBertModel
BertForSequenceClassification
->SwagBertForSequenceClassification