Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data streaming support through mosaic-streaming #1525

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

fmv1992
Copy link

@fmv1992 fmv1992 commented Apr 16, 2024

Description

This PR adds support for (non-volatile) memory efficient training through StreamingDataset.

Motivation and Context

Context: #585 .

How has this been tested?

I have tested this through docker on a VM.

I'm open to ideas as to how this should be added. Does the repo support an s3 bucket for instance?

@fmv1992 fmv1992 marked this pull request as draft April 16, 2024 12:42
requirements.txt Outdated
@@ -31,6 +31,7 @@ art
fschat==0.2.36
gradio==3.50.2
tensorboard
mosaicml-streaming
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: mosaicml-streaming should be an optional dependency. Is this the right way of adding it in this capacity?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with it being a required dependency. If it causes issues down the line we can make it optional then. Hoping to keep things simpler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this version be locked?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed by 976bc13.

setup.py Outdated Show resolved Hide resolved
@fmv1992 fmv1992 marked this pull request as ready for review April 16, 2024 12:43
#
# This is necessary because downstream functions use a different interface
# than `StreamingDataset` (e.g. the `features` attribute).
ds = Dataset.from_generator(
Copy link
Collaborator

@winglian winglian Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This becomes an IterableDataset, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian ,

Sorry for the delay here.

No, that was something that I wanted to verify but it looks like it goes to def process and everything is evaluated eagerly.

I started a draft like:

    def process(self, dataset):
        features = dataset.features.keys()
        map_kwargs = {}
        if self.prompt_tokenizer.supports_batched:
            map_kwargs["batched"] = True
            map_kwargs["batch_size"] = 100
        map_kwargs["desc"] = "Tokenizing Prompts"

        if isinstance(dataset, IterableDataset):
            dataset_wrapper = dataset.map(
                self.prompt_tokenizer.tokenize_prompt,
                remove_columns=features,
                keep_in_memory=self.keep_in_memory,
                **map_kwargs,
            )
        else:
            num_proc = min(
                64, self.process_count if self.process_count else os.cpu_count()
            )

            return dataset.map(
                self.prompt_tokenizer.tokenize_prompt,
                num_proc=num_proc,
                remove_columns=features,
                keep_in_memory=self.keep_in_memory,
                desc="Tokenizing Prompts",
                **map_kwargs,
            )

But I don't know whether that's a good idea. The .map API is different between Dataset (here) and IterableDataset (here).

Feel free to remove the "ready to merge" tag from this.

setup.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to go. thank you @fmv1992 !

@fmv1992
Copy link
Author

fmv1992 commented Apr 17, 2024

Thanks, much appreciated; I'm just checking a few more things before merging.

The experience of contributing to this repo has been very positive.

@NanoCode012
Copy link
Collaborator

Hey, thanks for the PR. I just wanted to clarify something I asked previously. This would require user's to preprocess their dataset to Mosaic's format first right? If so, I would prefer this to be documented somewhere near the cloud loading section. For ex, add stream: true to load a Mosaic streaming dataset.

You should also add this parameter to this https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/docs/config.qmd

https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start

@Kesta-bos
Copy link

I think it need additional 'StreamingDataset' support for pretraining dataset (completion) in addition to Finetuning dataset.

@ehartford
Copy link
Collaborator

We can pretrain with Axolotl streaming a data mix from s3?

@winglian
Copy link
Collaborator

Hey, thanks for the PR. I just wanted to clarify something I asked previously. This would require user's to preprocess their dataset to Mosaic's format first right? If so, I would prefer this to be documented somewhere near the cloud loading section. For ex, add stream: true to load a Mosaic streaming dataset.

You should also add this parameter to this https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/docs/config.qmd

https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start

JSONL should be fine for streaming. see https://github.com/mosaicml/streaming?tab=readme-ov-file#1-prepare-your-data

@fmv1992
Copy link
Author

fmv1992 commented Apr 22, 2024

We can pretrain with Axolotl streaming a data mix from s3?

We can, but I prefer if we include this in a second PR. Right now I would rather see this smaller change working and merged. Expanding on it should be easier later.

@fmv1992
Copy link
Author

fmv1992 commented Apr 22, 2024

We can pretrain with Axolotl streaming a data mix from s3?

We can, but I prefer if we include this in a second PR. Right now I would rather see this smaller change working and merged. Expanding on it should be easier later.

Hey, thanks for the PR. I just wanted to clarify something I asked previously. This would require user's to preprocess their dataset to Mosaic's format first right? If so, I would prefer this to be documented somewhere near the cloud loading section. For ex, add stream: true to load a Mosaic streaming dataset.

You should also add this parameter to this https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/docs/config.qmd

https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start

Addressed by ba86339 . Let me know if that addresses all your points.

@fmv1992
Copy link
Author

fmv1992 commented Apr 22, 2024

As per this comment this is not ready for merging. Maybe we want to remove that tag.

I posted a draft of the changes there, but the issue is that the tokenization should happen as we download the data, and right now I'm almost certain it does everything in a batch: it downloads everything, then tokenizes everything, then proceeds to do the fine tuning.

@NanoCode012
Copy link
Collaborator

NanoCode012 commented Apr 22, 2024

but the issue is that the tokenization should happen as we download the data, and right now I'm almost certain it does everything in a batch: it downloads everything, then tokenizes everything, then proceeds to do the fine tuning.

@fmv1992 , this is correct. I only got to review your code in detail earlier. The section I provided you was incorrect.

def load_tokenized_prepared_datasets(

This function runs the whole dataset, merges it, and perform tokenization at this point here.

LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)

The only part that "skips" tokenization before finetuning is the pretaining section that you attempted to modify before.

path = cfg.pretraining_dataset
split = "train"
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters

I have two ideas as of now:

  1. discuss a better way to handle data preprocessing between the current pretraining_dataset and dataset format as the code is currently messy before continuing further.
  2. Hack around and support streaming for pretraining datasets first and figure sft later. This is also because, your code expects the data in completion aka ({ "text": ..." }) format. This is not the case for SFT datasets.
    # Define dataset features according to the axolotl structure.
    features = Features({"text": Value("string")})

I would also appreciate @winglian 's comments on this.


Side note: what should this batch_size be set to? Is it hardcoded to 4 on purpose?

local=None, remote=config_dataset.path, shuffle=True, batch_size=4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants