fairseq-train
is apy
file; runwhich fairseq-train
to get location.- Once invoked, calls
def cli_main
in https://github.com/pytorch/fairseq/blob/master/fairseq_cli/train.py -
- This gets all the cli args, which are then passed to
distributed_utils.call_main
in https://github.com/pytorch/fairseq/blob/master/fairseq/distributed_utils.py
- This gets all the cli args, which are then passed to
-
- This checks for some args, and then calls
distributed_main
which in-turn calls thedef main
function infairseq_cli/train.py
-- that is the main training routine.
- This checks for some args, and then calls
- In
train.py
themain
function uses the args to create theTask
andModel
which are sent to create aTrainer
object using https://github.com/pytorch/fairseq/blob/master/fairseq/trainer.py#L39 -
Trainer
sets up the devices, params, etc, required for parallel training and returns aTrainer
object to themain
function intrain.py
train.py
will now setup some stuff to start the training -- loading from the last checkpoint, epochs, meters, etc.- The main training loop will be a
while
routine that checks themax_epochs
andlr
learning rate. Defaultmax_epochs
is infinity ... -
- Inside the loop, each step is 1 epoch. Every epoch (wrapped inside the
epoch_iter
) is sent todef train
in the same file.
- Inside the loop, each step is 1 epoch. Every epoch (wrapped inside the
-
-
- In
def train
you iterate over the samples in the epoch, and calltrainer.train_step(samples)
-- this method runs forward+backward+param_update.
- In
-
-
-
-
trainer.train_step
internally callsself.task.train_step(sample, model, optimizer, ...)
-
-
-
-
-
- Here, the
task
is a FairseqTask (or can be one of the specific Translation, Classification, LM tasks). Thetrain_step
function uses native PyTorch to run forward, backward passes.
- Here, the
-
-
def train_step(...):
"""Docstring ..."
model.train() # This changes the model from `eval` mode to `train` mode!!!
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
-
-
-
- Sample loss is returned at the end, to the
self.task.train_step
- Sample loss is returned at the end, to the
-
-
-
-
- The
train
function intrain.py
finishes all samples in the epoch. After completing the epoch, it logs some stats, resets some meters and returns the epoch losses withshould_stop
to thewhile
training loop indef main
in the same file.
- The
-
-
if should_stop
then thewhile
loop breaks, else it continues till the training completes.
def main
ends with logging adone training
message.
- The BART paper -- https://arxiv.org/pdf/1910.13461.pdf
- The main
while
loop involed byfairseq-train
-- https://github.com/pytorch/fairseq/blob/master/fairseq_cli/train.py#L117 - Class for
Trainer
that implements thetrain_step
for a list of samples in an epoch -- https://github.com/pytorch/fairseq/blob/master/fairseq/trainer.py#L39 - The
Task
class; summarization is implemented as aTranslation
task -- https://github.com/pytorch/fairseq/blob/master/fairseq/tasks/fairseq_task.py - The
bart.base
model reigistration, along with all model params and arguments -- https://github.com/pytorch/fairseq/blob/master/fairseq/models/bart/model.py#L297 - Issue for mismatch in BASE and LARGE vocab sizes; fix is to change the truncate the weights and save -- facebookresearch/fairseq#2242
- Issue for fine-tuning with limited resources; lots of useful tips -- facebookresearch/fairseq#1413
- Issue for fine-tuning with different vocab sizes -- facebookresearch/fairseq#2120
- Issue discussing the confusion on
MAX_TOKENS
in Bart for Summarization (the README is broken) -- facebookresearch/fairseq#1685 - Issue regarding training time, resources for BARTBase -- facebookresearch/fairseq#1651
- HgfT explain their fine-tuning procedure -- https://huggingface.co/transformers/training.html