Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_gpt2.py *really* slow on 7900 XTX #4301

Open
anthonix opened this issue Apr 26, 2024 · 1 comment
Open

train_gpt2.py *really* slow on 7900 XTX #4301

anthonix opened this issue Apr 26, 2024 · 1 comment

Comments

@anthonix
Copy link

Is there something I'm missing with llm.c/train_gpt2.py running so slow? This is much slower than PyTorch even.

$ python3 train_gpt2.py --sequence_length 1024
ram used:  0.55 GB, lm_head.weight        [...]
loaded weights in 321.07 ms, 0.70 GB loaded at 2.19 GB/s
loading cached tokens in data/tiny_shakespeare_val.bin
iteration 0, loss: 4.5791144371032715, time: 11184.098ms
iteration 1, loss: 4.052018165588379, time: 3466.631ms
iteration 2, loss: 3.755559206008911, time: 1731.382ms
iteration 3, loss: 3.5776381492614746, time: 1742.382ms
iteration 4, loss: 3.3802051544189453, time: 1747.056ms
iteration 5, loss: 3.202749490737915, time: 1719.359ms
iteration 6, loss: 3.049722909927368, time: 1716.452ms
iteration 7, loss: 2.8996615409851074, time: 1738.209ms
iteration 8, loss: 2.7452614307403564, time: 1759.945ms
iteration 9, loss: 2.5893540382385254, time: 1734.943ms
@anthonix
Copy link
Author

I missed the "BEAM=2" for kernel search (thanks Chenyu).. this brings it down to ~440 ms:

$ BEAM=2 python3 train_gpt2.py --sequence_length 1024
ram used:  0.55 GB, lm_head.weight   [...]
loaded weights in 28938.98 ms, 0.70 GB loaded at 0.02 GB/s
loading cached tokens in data/tiny_shakespeare_val.bin
iteration 0, loss: 4.57911491394043, time: 560790.932ms
iteration 1, loss: 4.052017688751221, time: 110120.866ms
iteration 2, loss: 3.7555594444274902, time: 462.663ms
iteration 3, loss: 3.577639102935791, time: 438.331ms
iteration 4, loss: 3.3802037239074707, time: 443.539ms
iteration 5, loss: 3.2027504444122314, time: 441.166ms
iteration 6, loss: 3.0497217178344727, time: 438.978ms
iteration 7, loss: 2.899660587310791, time: 442.850ms
iteration 8, loss: 2.7452595233917236, time: 441.698ms
iteration 9, loss: 2.5893521308898926, time: 442.761ms

Still a bit of a gap between PyTorch nightly, which is at 254 ms on same device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant