New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Embedding factory script looping through each MGRS tile #125
base: main
Are you sure you want to change the base?
Conversation
Jupyter notebook script to generate GeoParquet embedding files on a per MGRS tile basis. The script first generates an mgrs_world.txt file with a list of MGRS code names like 12ABC. A for-loop then goes through each MGRS tile, with the model running the prediction to generate GeoParquet files that are uploaded to s3. There are about 947019 rows of embeddings generated from the clay-small-70MT-1100T-10E.ckpt model checkpoint in Dec 2023.
"# !aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/\n", | ||
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n", | ||
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n", | ||
" checkpoint_path=\"checkpoints/clay-small-70MT-1100T-10E.ckpt\"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be possible to load the checkpoint directly from HuggingFace now, instead of manually downloading from the s3 bucket, as mentioned at #116 (comment)
"# !aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/\n", | |
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n", | |
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n", | |
" checkpoint_path=\"checkpoints/clay-small-70MT-1100T-10E.ckpt\"\n", | |
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n", | |
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n", | |
" checkpoint_path=\"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to move this into the scripts/
folder, and will need to update the paths below accordingly.
"#!mamba install triton\n", | ||
"# model.model.encoder = torch.compile(model=model.model.encoder)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can probably remove this torch.compile
line. Was trying to speed up the model by compiling it, but there were some layers that didn't work.
"import os\n", | ||
"import warnings\n", | ||
"\n", | ||
"import duckdb\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DuckDB is optional and can be removed, but it's nice to get a quick count of all the rows across the GeoParquet files (see last cell).
Jupyter notebook script to generate GeoParquet embedding files on a per MGRS tile basis.
Steps:
Notes:
g5.4xlarge
EC2 instance with 1 NVIDIA A10G GPU that allows for bfloat16 dtype calculations.Closes #120