Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

OLMo 7B finetuning w/ CPU offloading does not work #478

Open
gahdritz opened this issue Mar 2, 2024 · 4 comments
Open

OLMo 7B finetuning w/ CPU offloading does not work #478

gahdritz opened this issue Mar 2, 2024 · 4 comments
Assignees
Labels
type/bug An issue about a bug

Comments

@gahdritz
Copy link
Contributor

gahdritz commented Mar 2, 2024

馃悰 Describe the bug

I'm trying to finetune OLMO 7B on a single 4xA100 40GB node. I'm using the official config with no repo changes except that I've 1) loaded pretrained model parameters from HuggingFace and 2) enabled FSDP CPU offloading in scripts/train.py as follows:

    fsdp_model = FSDP(
        olmo_model,
        sharding_strategy=cfg.fsdp.sharding_strategy,
        mixed_precision=cfg.fsdp_precision,
        auto_wrap_policy=wrap_policy,
        use_orig_params=cfg.fsdp.use_orig_params,  # needed for compile and some of our optimizer/parameter metrics
        limit_all_gathers=True,
        device_id=get_local_rank(),
        param_init_fn=param_init_fn,
        cpu_offload=CPUOffload(offload_params=True),
    )

On my personal dataset and also Dolma v1.6_sample, training consistently fails---loss increases from the first iteration. Here are training loss curves on Dolma:

image

I ran it for longer on my personal dataset and loss eventually decreased, but it never fully recovers:

image

I've tried changing the learning rate, batch size, gradient clipping warming period, gradient clipping warmup factor, but no dice. I've also been able to reproduce this behavior by enabling FSDP w/ CPU offloading on OLMo-1B, which works normally otherwise.

Versions

aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
beaker-gantry==0.21.0
beaker-py==1.24.0
black==23.12.1
boltons==23.1.1
boto3==1.34.39
botocore==1.34.39
build==1.0.3
cached-path==1.5.1
cachetools==5.3.2
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
click-help-colors==0.9.4
cryptography==42.0.2
datasets==2.17.0
dill==0.3.8
docker==6.1.3
docker-pycreds==0.4.0
docutils==0.20.1
exceptiongroup==1.2.0
face==20.1.1
filelock==3.12.4
frozenlist==1.4.1
fsspec==2023.10.0
ftfy==6.1.3
gitdb==4.0.11
GitPython==3.1.41
glom==23.5.0
google-api-core==2.17.0
google-auth==2.27.0
google-cloud-core==2.4.1
google-cloud-storage==2.14.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.62.0
huggingface-hub==0.19.4
idna==3.6
importlib-metadata==7.0.1
iniconfig==2.0.0
isort==5.12.0
jaraco.classes==3.3.1
jeepney==0.8.0
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.3.2
keyring==24.3.0
lightning-utilities==0.10.1
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
more-itertools==10.2.0
mpmath==1.3.0
msgspec==0.18.6
multidict==6.0.5
multiprocess==0.70.16
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
networkx==3.2.1
nh3==0.2.15
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
packaging==23.2
pandas==2.2.0
pathspec==0.12.1
petname==2.6
pkginfo==1.9.6
platformdirs==4.2.0
pluggy==1.4.0
protobuf==4.25.2
psutil==5.9.8
pyarrow==15.0.0
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.17.2
pyproject_hooks==1.0.0
pytest==8.0.0
pytest-sphinx==0.6.0
python-dateutil==2.8.2
pytz==2024.1
PyYAML==6.0.1
readme-renderer==42.0
regex==2023.12.25
requests==2.31.0
requests-toolbelt==1.0.0
requirements-parser==0.5.0
rfc3986==2.0.0
rich==13.7.0
rsa==4.9
ruff==0.2.1
s3transfer==0.10.0
safetensors==0.4.2
scikit-learn==1.4.0
scipy==1.12.0
SecretStorage==3.3.3
sentry-sdk==1.40.3
setproctitle==1.3.3
six==1.16.0
smart-open==6.4.0
smashed==0.21.5
smmap==5.0.1
sympy==1.12
threadpoolctl==3.2.0
tokenizers==0.15.1
tomli==2.0.1
torch==2.1.2
torchmetrics==1.3.0.post0
tqdm==4.66.2
transformers==4.37.2
triton==2.1.0
trouting==0.3.3
twine==4.0.2
types-setuptools==69.0.0.20240125
typing_extensions==4.9.0
tzdata==2023.4
urllib3==1.26.18
wandb==0.16.3
wcwidth==0.2.13
websocket-client==1.7.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0

@gahdritz gahdritz added the type/bug An issue about a bug label Mar 2, 2024
@gahdritz gahdritz changed the title OLMO 7B finetuning does not work OLMO 7B finetuning w/ CPU offloading does not work Mar 2, 2024
@gahdritz gahdritz changed the title OLMO 7B finetuning w/ CPU offloading does not work OLMo 7B finetuning w/ CPU offloading does not work Mar 2, 2024
@AkshitaB
Copy link
Contributor

AkshitaB commented Mar 4, 2024

Hi @epwalsh do you have an intuition for why this might be the case?

@epwalsh
Copy link
Member

epwalsh commented Mar 5, 2024

Hey @gahdritz it's very possible you're running into an FSDP bug. It certainly wouldn't be the first major FSDP bug we've seen. If you can reproduce this at the 1B scale then that's enough evidence to report it to PyTorch.

@rajasekharmekala
Copy link

Hi @gahdritz,
How did you manage to download the checkpoint? Seems like the huggingface model cannot be directly loaded for fine tuning..

@dirkgr
Copy link
Member

dirkgr commented May 8, 2024

Coming late to this discussion. Are you loading optimizer state from somewhere? If you are not, you should warm up your learning rate from 0 over a number of steps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type/bug An issue about a bug
Projects
None yet
Development

No branches or pull requests

5 participants