diff --git a/README.md b/README.md index fd9cb351..1c6220bc 100755 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ Zoobot is trained using millions of answers by Galaxy Zoo volunteers. This code - [Pretrained Weights](https://zoobot.readthedocs.io/en/latest/pretrained_models.html) - [Datasets](https://www.github.com/mwalmsley/galaxy-datasets) - [Documentation](https://zoobot.readthedocs.io/) (for understanding/reference) +- [Mailing List](https://groups.google.com/g/zoobot) (for updates) ## Installation @@ -59,44 +60,43 @@ The [Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7 Let's say you want to find ringed galaxies and you have a small labelled dataset of 500 ringed or not-ringed galaxies. You can retrain Zoobot to find rings like so: - ```python +```python +import pandas as pd +from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule +from zoobot.pytorch.training import finetune - import pandas as pd - from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule - from zoobot.pytorch.training import finetune +# csv with 'ring' column (0 or 1) and 'file_loc' column (path to image) +labelled_df = pd.read_csv('/your/path/some_labelled_galaxies.csv') - # csv with 'ring' column (0 or 1) and 'file_loc' column (path to image) - labelled_df = pd.read_csv('/your/path/some_labelled_galaxies.csv') +datamodule = GalaxyDataModule( + label_cols=['ring'], + catalog=labelled_df, + batch_size=32 +) - datamodule = GalaxyDataModule( - label_cols=['ring'], - catalog=labelled_df, - batch_size=32 - ) +# load trained Zoobot model +model = finetune.FinetuneableZoobotClassifier(checkpoint_loc, num_classes=2) - # load trained Zoobot model - model = finetune.FinetuneableZoobotClassifier(checkpoint_loc, num_classes=2) - - # retrain to find rings - trainer = finetune.get_trainer(save_dir) - trainer.fit(model, datamodule) - ``` +# retrain to find rings +trainer = finetune.get_trainer(save_dir) +trainer.fit(model, datamodule) +``` Then you can make predict if new galaxies have rings: - ```python - from zoobot.pytorch.predictions import predict_on_catalog +```python +from zoobot.pytorch.predictions import predict_on_catalog - # csv with 'file_loc' column (path to image). Zoobot will predict the labels. - unlabelled_df = pd.read_csv('/your/path/some_unlabelled_galaxies.csv') +# csv with 'file_loc' column (path to image). Zoobot will predict the labels. +unlabelled_df = pd.read_csv('/your/path/some_unlabelled_galaxies.csv') - predict_on_catalog.predict( - unlabelled_df, - model, - label_cols=['ring'], # only used for - save_loc='/your/path/finetuned_predictions.csv' - ) - ``` +predict_on_catalog.predict( + unlabelled_df, + model, + label_cols=['ring'], # only used for + save_loc='/your/path/finetuned_predictions.csv' +) +``` Zoobot includes many guides and working examples - see the [Getting Started](#getting-started) section below.