You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to finetune OLMO 7B on a single 4xA100 40GB node. I'm using the official config with no repo changes except that I've 1) loaded pretrained model parameters from HuggingFace and 2) enabled FSDP CPU offloading in scripts/train.py as follows:
fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
cpu_offload=CPUOffload(offload_params=True),
)
On my personal dataset and also Dolma v1.6_sample, training consistently fails---loss increases from the first iteration. Here are training loss curves on Dolma:
I ran it for longer on my personal dataset and loss eventually decreased, but it never fully recovers:
I've tried changing the learning rate, batch size, gradient clipping warming period, gradient clipping warmup factor, but no dice. I've also been able to reproduce this behavior by enabling FSDP w/ CPU offloading on OLMo-1B, which works normally otherwise.
Hey @gahdritz it's very possible you're running into an FSDP bug. It certainly wouldn't be the first major FSDP bug we've seen. If you can reproduce this at the 1B scale then that's enough evidence to report it to PyTorch.
Coming late to this discussion. Are you loading optimizer state from somewhere? If you are not, you should warm up your learning rate from 0 over a number of steps.
馃悰 Describe the bug
I'm trying to finetune OLMO 7B on a single 4xA100 40GB node. I'm using the official config with no repo changes except that I've 1) loaded pretrained model parameters from HuggingFace and 2) enabled FSDP CPU offloading in scripts/train.py as follows:
On my personal dataset and also Dolma v1.6_sample, training consistently fails---loss increases from the first iteration. Here are training loss curves on Dolma:
I ran it for longer on my personal dataset and loss eventually decreased, but it never fully recovers:
I've tried changing the learning rate, batch size, gradient clipping warming period, gradient clipping warmup factor, but no dice. I've also been able to reproduce this behavior by enabling FSDP w/ CPU offloading on OLMo-1B, which works normally otherwise.
Versions
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
beaker-gantry==0.21.0
beaker-py==1.24.0
black==23.12.1
boltons==23.1.1
boto3==1.34.39
botocore==1.34.39
build==1.0.3
cached-path==1.5.1
cachetools==5.3.2
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
click-help-colors==0.9.4
cryptography==42.0.2
datasets==2.17.0
dill==0.3.8
docker==6.1.3
docker-pycreds==0.4.0
docutils==0.20.1
exceptiongroup==1.2.0
face==20.1.1
filelock==3.12.4
frozenlist==1.4.1
fsspec==2023.10.0
ftfy==6.1.3
gitdb==4.0.11
GitPython==3.1.41
glom==23.5.0
google-api-core==2.17.0
google-auth==2.27.0
google-cloud-core==2.4.1
google-cloud-storage==2.14.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.62.0
huggingface-hub==0.19.4
idna==3.6
importlib-metadata==7.0.1
iniconfig==2.0.0
isort==5.12.0
jaraco.classes==3.3.1
jeepney==0.8.0
Jinja2==3.1.3
jmespath==1.0.1
joblib==1.3.2
keyring==24.3.0
lightning-utilities==0.10.1
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
more-itertools==10.2.0
mpmath==1.3.0
msgspec==0.18.6
multidict==6.0.5
multiprocess==0.70.16
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
networkx==3.2.1
nh3==0.2.15
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
packaging==23.2
pandas==2.2.0
pathspec==0.12.1
petname==2.6
pkginfo==1.9.6
platformdirs==4.2.0
pluggy==1.4.0
protobuf==4.25.2
psutil==5.9.8
pyarrow==15.0.0
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.17.2
pyproject_hooks==1.0.0
pytest==8.0.0
pytest-sphinx==0.6.0
python-dateutil==2.8.2
pytz==2024.1
PyYAML==6.0.1
readme-renderer==42.0
regex==2023.12.25
requests==2.31.0
requests-toolbelt==1.0.0
requirements-parser==0.5.0
rfc3986==2.0.0
rich==13.7.0
rsa==4.9
ruff==0.2.1
s3transfer==0.10.0
safetensors==0.4.2
scikit-learn==1.4.0
scipy==1.12.0
SecretStorage==3.3.3
sentry-sdk==1.40.3
setproctitle==1.3.3
six==1.16.0
smart-open==6.4.0
smashed==0.21.5
smmap==5.0.1
sympy==1.12
threadpoolctl==3.2.0
tokenizers==0.15.1
tomli==2.0.1
torch==2.1.2
torchmetrics==1.3.0.post0
tqdm==4.66.2
transformers==4.37.2
triton==2.1.0
trouting==0.3.3
twine==4.0.2
types-setuptools==69.0.0.20240125
typing_extensions==4.9.0
tzdata==2023.4
urllib3==1.26.18
wandb==0.16.3
wcwidth==0.2.13
websocket-client==1.7.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
The text was updated successfully, but these errors were encountered: