Skip to content

A single repo with all scripts and utils to train / fine-tune the Mamba model with or without FIM

License

Notifications You must be signed in to change notification settings

tanaymeh/mamba-train

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mamba-train

Open In Studio

A single repo with all scripts and utils to train / fine-tune the Mamba model with or without Fill-in-Middle objective (for code infilling).

Data

Currently, the train.py script only supports training from a Lance and a Huggingface dataset. If you are training using a Huggingface dataset, substitute MambaDataset with your Huggingface dataset in the train.py file.

In order for the training to run when using the aforementioned huggingface dataset, the data needs to be grouped in groups of 'context length'. That is, each sample in the dataset must have 'context length' number of tokens in it. For more information on how to achieve this, see the group_texts function.

Once the data is in the right format, call the apply_fim function in the training loop, passing in the samples and all the appropriate parameters with it. If you face any problems, please open an issue!

For the Lance dataset, I will be releasing the 5M samples subset of the Codeparrot dataset soon. For more information on how it was made using Lance, see my article.

A note about MambaSampler: I am training the model on the Lance dataset which is one large contiguous array of tokens. In this setting, it is very hard to distinguish between different samples (each with the size of context length) without altering the dataset creation process. We need to have non-overlapping samples so as to not overfit the model.

My workaround for this was making a new sampler that samples len(dataset) // context_len number of samples from the dataset, where each of those sample is atleast context_len indices apart from each other. This "emulates" them as individual samples with minimal processing overhead.

Fill-in-Middle

Both the Lance and HF datasets apply Fill-in-Middle transformation on each 'sample' during the training run. FIM training objectives allows the model to infill the code. FIM trained models are the ones used by code-completion tools like Github Copilot. In order to learn more about Fill-in-Middle training objective, see the OpenAI paper.

In order to adjust what percentage of training samples are transformed using FIM, you can adjust the fim_rate parameter in both datasets. By default it is set to 0.9, meaning 90% of all samples will be FIM transformed (this is because I am fine-tuning the model instead of pre-training it).

Training

Before starting the training run, you need to install all the dependencies from the requirements file

pip install -r requirements.txt

Once that is done, start the training run via:

python train.py

About

A single repo with all scripts and utils to train / fine-tune the Mamba model with or without FIM

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages