Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break at 1 epoch "Training epoch complete", can't pretraining beyond 1 epoch ? #554

Open
Xuekai-Zhu opened this issue Apr 23, 2024 · 3 comments
Labels
type/bug An issue about a bug

Comments

@Xuekai-Zhu
Copy link

🐛 Describe the bug

File :OLMo/olmo/train.py
In the following training loop, we will break our pre-training for only 1 epoch ?

@property
def max_epochs(self) -> int:
    if isinstance(self.cfg.max_duration, str) and self.cfg.max_duration.endswith("ep"):
        return int(self.cfg.max_duration[:-2].strip())
    else:
        return 1
with torch_profiler as p:
            for epoch in range(self.epoch or 0, self.max_epochs):
                for batch in self.train_loader:
                    # Bookkeeping.
                    # NOTE: To track the global batch size / number of tokens per batch we make the assumption that all
                    # batches see the same number of tokens, which should be the case for language model pre-training
                    # (at least when drop_last=True).
                    # Alternatively we'd have to use a distributed all reduce over seq_len here, but I don't want that
                    # overhead. So for now I'm putting these assertions here so if the assumption is violated it will
                    # fail loudly.
                    batch_size, seq_len = batch["input_ids"].shape
                    assert seq_len == self.cfg.model.max_sequence_length
                    assert batch_size == self.cfg.device_train_batch_size
                    global_batch_size = batch_size * get_world_size()  # assumes batch size equal across ranks
                    self.global_step += 1
                    self.global_train_examples_seen_this_epoch += global_batch_size
                    self.global_train_tokens_seen += global_batch_size * seq_len
                    speed_monitor.batch_start(
                        self.global_train_tokens_seen,
                        batch_size * seq_len,  # num tokens in batch for this device
                        # We start monitoring speed after the first batch since the first
                        # batch might be an outlier due to compiling and other initialization overhead.
                        record=not first_batch,
                    )

                    should_log_this_step = self.should_log_this_step()

                    # Run train step on batch.
                    metrics = self.train_step(batch, reduce_global_loss=should_log_this_step)

                    # Maybe collect other metrics.
                    if should_log_this_step:
                        # Speed metrics.
                        metrics.update(speed_monitor.check())
                        # System metrics.
                        metrics.update(self.system_metrics())
                        # Learning rate metrics.
                        metrics.update(lr_monitor.check())

                    # Log metrics to console.
                    if self.global_step % self.cfg.console_log_interval == 0:
                        self.log_metrics_to_console(f"[step={self.global_step}/{self.max_steps}]", metrics)

                    # Log metrics to W&B.
                    if (
                        wandb.run is not None
                        and self.cfg.wandb is not None
                        and self.global_step % self.cfg.wandb.log_interval == 0
                    ):
                        wandb.log(metrics, step=self.global_step)

                    # Check if/when run should be canceled.
                    if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
                        cancel_initiated, extra_steps = self.check_if_cancelled()
                        if cancel_initiated:
                            stop_at = (
                                self.global_step + extra_steps
                                if stop_at is None
                                else min(self.global_step + extra_steps, stop_at)
                            )

                    # Maybe save sharded checkpoint.
                    if save_checkpoints and (
                        cancel_initiated
                        or (
                            self.global_step % self.cfg.save_interval == 0
                            and self.cfg.save_num_checkpoints_to_keep != 0
                        )
                    ):
                        log.info("Saving checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Remove any ephemeral checkpoints.
                        while self.ephemeral_checkpoints:
                            self.remove_ephemeral_checkpoint()

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                        # If the run was just canceled this will be the final checkpoint.
                        if cancel_initiated:
                            save_checkpoints = False
                    elif (
                        self.cfg.save_interval_ephemeral is not None
                        and self.global_step % self.cfg.save_interval_ephemeral == 0
                    ):
                        log.info("Saving ephemeral checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded_ephemeral)
                        log.info(f"Checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe save unsharded checkpoint.
                    if (
                        save_checkpoints
                        and self.cfg.save_interval_unsharded is not None
                        and self.global_step % self.cfg.save_interval_unsharded == 0
                        and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
                    ):
                        log.info("Saving unsharded checkpoint...")
                        checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
                        log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

                        # Reset speed monitor so that we don't count the time taken to save checkpoints.
                        speed_monitor.reset()

                    # Maybe run evaluations.
                    if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
                        eval_metrics = self.eval()

                        # Log metrics to W&B.
                        if wandb.run is not None:
                            wandb.log(eval_metrics, step=self.global_step)

                        # Reset speed monitor so that we don't count the time taken to run evaluations.
                        speed_monitor.reset()

                        # Reset model to 'train' mode.
                        self.fsdp_model.train()

                    # End of batch.
                    first_batch = False
                    if p is not None:
                        p.step()

                    if stop_at is not None and self.global_step >= stop_at:
                        break

                    # Python Profiler stuff
                    # We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
                    if python_profiler is not None:
                        if self.global_step == 5:
                            python_profiler.enable()
                        elif self.global_step == 8:
                            python_profiler.disable()
                            python_profiler.print_stats(sort=SortKey.CUMULATIVE)
                            python_profiler = None
                else:
                    log.info("Training epoch complete")
                    self.epoch = epoch + 1
                    self.global_train_examples_seen_this_epoch = 0
                    if self.epoch < self.max_epochs:
                        self.dataset.reshuffle()
                    continue

                break

Versions

Python 3.10.13
WARNING: Could not find a Python project for directory /scratch2/nlp/zhuxuekai/scaling_law4AI_data/OLMo (tried all parent directories)
-e git+ssh://git@github.com/Xuekai-Zhu/scaling_law4AI_data.git@a15301e68a4dd616e3971c54370cb4a957e4d14c#egg=ai2_olmo
aiohttp==3.9.3
aiosignal==1.3.1
aniso8601==9.0.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anykeystore==0.2
appdirs==1.4.4
async-timeout==4.0.3
asyncio==3.4.3
attrs==23.2.0
backports.tarfile==1.1.0
beaker-gantry==0.22.2
beaker-py==1.26.4
black==23.12.1
blinker==1.7.0
boltons==24.0.0
boto3==1.34.86
botocore==1.34.86
build==1.2.1
cached_path==1.6.2
cachetools==5.3.3
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
click-help-colors==0.9.4
cmake==3.28.3
contourpy==1.2.0
cryptacular==1.6.2
cryptography==42.0.5
cycler==0.12.1
datasets==2.18.0
deepspeed==0.14.0
deepspeed-kernels==0.0.1.dev1698255861
deepspeed-mii==0.2.3
defusedxml==0.7.1
dill==0.3.8
docker==6.1.3
docker-pycreds==0.4.0
docutils==0.21.1
exceptiongroup==1.2.0
face==20.1.1
filelock==3.9.0
Flask==3.0.2
Flask-RESTful==0.3.10
fonttools==4.50.0
frozenlist==1.4.1
fsspec==2024.2.0
ftfy==6.2.0
gitdb==4.0.11
GitPython==3.1.42
glom==23.5.0
google-api-core==2.18.0
google-auth==2.29.0
google-cloud-core==2.4.1
google-cloud-storage==2.16.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
greenlet==3.0.3
grpcio==1.62.1
grpcio-tools==1.62.1
hjson==3.1.0
huggingface-hub==0.21.4
hupper==1.12.1
idna==3.6
importlib_metadata==7.1.0
iniconfig==2.0.0
isort==5.12.0
itsdangerous==2.1.2
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.0
jeepney==0.8.0
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.4.0
keyring==25.1.0
kiwisolver==1.4.5
lightning-utilities==0.11.2
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.3
mdurl==0.1.2
Megatron==0.5.1
megatron_core==0.5.0
more-itertools==10.2.0
mpmath==1.3.0
msgspec==0.18.6
multidict==6.0.5
multiprocess==0.70.16
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
networkx==3.2.1
nh3==0.2.17
ninja==1.11.1.1
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.19.3
nvidia-nvtx-cu11==11.8.86
oauthlib==3.2.2
omegaconf==2.3.0
packaging==24.0
pandas==2.2.1
PasteDeploy==3.1.0
pathspec==0.12.1
pbkdf2==1.3
petname==2.6
pillow==10.2.0
pkginfo==1.10.0
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==4.2.0
pluggy==1.4.0
proto-plus==1.23.0
protobuf==4.25.3
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==15.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.6.4
pydantic_core==2.16.3
Pygments==2.17.2
pynvml==11.5.0
pyparsing==3.1.2
pyproject_hooks==1.0.0
pyramid==2.0.2
pyramid-mailer==0.15.1
pytest==8.1.1
pytest-sphinx==0.6.3
python-dateutil==2.9.0.post0
python3-openid==3.2.0
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
readme_renderer==43.0
regex==2023.12.25
repoze.sendmail==4.4.1
requests==2.31.0
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.9.0
rfc3986==2.0.0
rich==13.7.1
rsa==4.9
ruff==0.3.7
s3transfer==0.10.1
safetensors==0.4.2
scikit-learn==1.4.2
scipy==1.13.0
seaborn==0.13.2
SecretStorage==3.3.3
sentry-sdk==1.43.0
setproctitle==1.3.3
six==1.16.0
smart-open==7.0.4
smashed==0.21.5
smmap==5.0.1
SQLAlchemy==2.0.29
sympy==1.12
threadpoolctl==3.4.0
tokenizers==0.15.2
tomli==2.0.1
torch==2.2.1+cu118
torchmetrics==1.3.2
tqdm==4.66.2
transaction==4.0
transformers==4.38.2
translationstring==1.4
triton==2.2.0
trouting==0.3.3
twine==5.0.0
types-setuptools==69.5.0.20240415
typing_extensions==4.8.0
tzdata==2024.1
ujson==5.9.0
urllib3==2.2.1
velruse==1.1.1
venusian==3.1.0
wandb==0.16.4
wcwidth==0.2.13
WebOb==1.8.7
websocket-client==1.7.0
Werkzeug==3.0.1
wrapt==1.16.0
WTForms==3.1.2
wtforms-recaptcha==0.3.2
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1
zmq==0.0.0
zope.deprecation==5.0
zope.interface==6.3
zope.sqlalchemy==3.1

@Xuekai-Zhu Xuekai-Zhu added the type/bug An issue about a bug label Apr 23, 2024
@dumitrac
Copy link
Contributor

@Xuekai-Zhu , what is the value of "max_duration" in the config that you're using?
If you want it to be more than 1 epoch, say 2 epochs, the config should have max_duration: 2ep.

@Xuekai-Zhu
Copy link
Author

Yes, i found if i want it to be more than 1 epoch, the config should have max_duration: 2ep.
But when i want use max tokens to control the the training process, i can't reach the max tokens casuing be limited by default 1 epochs.

source tokens 8B, max_duration:  30B, -> training complete at 8B tokens (1 epochs); 
❌ can't reach the max_duration set in config.

@dumitrac
Copy link
Contributor

@Xuekai-Zhu - agreed, this is a bug.
Thank you for reporting it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/bug An issue about a bug
Projects
None yet
Development

No branches or pull requests

2 participants