Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding factory script looping through each MGRS tile #125

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Jan 17, 2024

Jupyter notebook script to generate GeoParquet embedding files on a per MGRS tile basis.

Steps:

  1. The script first generates an mgrs_world.txt file with a list of MGRS code names like 12ABC. Need to run this command first:
    aws s3 ls s3://clay-tiles-02/02/ | tr -s ' ' |  cut -d ' ' -f 3 | cut -d '/' -f 1 > mgrs_world.txt
    
  2. A for-loop then goes through each MGRS tile, with the model running the prediction to generate GeoParquet files that are uploaded to s3.

Notes:

  • There were about 947019 rows of embeddings generated from the clay-small-70MT-1100T-10E.ckpt model checkpoint in Dec 2023.
  • Embeddings were generated using a g5.4xlarge EC2 instance with 1 NVIDIA A10G GPU that allows for bfloat16 dtype calculations.

Closes #120

Jupyter notebook script to generate GeoParquet embedding files on a per MGRS tile basis. The script first generates an mgrs_world.txt file with a list of MGRS code names like 12ABC. A for-loop then goes through each MGRS tile, with the model running the prediction to generate GeoParquet files that are uploaded to s3. There are about 947019 rows of embeddings generated from the clay-small-70MT-1100T-10E.ckpt model checkpoint in Dec 2023.
@weiji14 weiji14 added this to the v1 Release milestone Jan 17, 2024
@weiji14 weiji14 changed the title 🚧 Embedding factory script looping through each MGRS tile Embedding factory script looping through each MGRS tile Jan 17, 2024
Comment on lines +81 to +84
"# !aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/\n",
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n",
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n",
" checkpoint_path=\"checkpoints/clay-small-70MT-1100T-10E.ckpt\"\n",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be possible to load the checkpoint directly from HuggingFace now, instead of manually downloading from the s3 bucket, as mentioned at #116 (comment)

Suggested change
"# !aws s3 cp s3://clay-model-ckpt/v0/clay-small-70MT-1100T-10E.ckpt checkpoints/\n",
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n",
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n",
" checkpoint_path=\"checkpoints/clay-small-70MT-1100T-10E.ckpt\"\n",
"trainer = L.Trainer(precision=\"bf16-mixed\", logger=False)\n",
"model: L.LightningModule = CLAYModule.load_from_checkpoint(\n",
" checkpoint_path=\"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to move this into the scripts/ folder, and will need to update the paths below accordingly.

Comment on lines +86 to +87
"#!mamba install triton\n",
"# model.model.encoder = torch.compile(model=model.model.encoder)"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably remove this torch.compile line. Was trying to speed up the model by compiling it, but there were some layers that didn't work.

"import os\n",
"import warnings\n",
"\n",
"import duckdb\n",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DuckDB is optional and can be removed, but it's nice to get a quick count of all the rows across the GeoParquet files (see last cell).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

predict_step fails when input has more than one EPSG
2 participants