From dcb89d477f1f685d9945018b091f9ce0eee9c5b2 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Wed, 9 Jun 2021 21:32:34 +0100 Subject: [PATCH] Update API to catch up with previous fiddling. Add weights and biases support Update requirements to allow tf 2.5 --- .gitignore | 4 +++- create_shards.py | 2 -- requirements.txt | 4 ++-- results/.gitkeep | 0 train_model.py | 28 ++++++++++++++++++++++--- zoobot/estimators/define_model.py | 4 +++- zoobot/estimators/preprocess.py | 3 +-- zoobot/predictions/predict_on_images.py | 2 +- 8 files changed, 35 insertions(+), 12 deletions(-) create mode 100644 results/.gitkeep diff --git a/.gitignore b/.gitignore index 52360e2e..3f42968c 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,6 @@ dmypy.json # Pyre type checker .pyre/ -*.csv \ No newline at end of file +*.csv + +wandb_api.txt \ No newline at end of file diff --git a/create_shards.py b/create_shards.py index 37623074..291527b3 100644 --- a/create_shards.py +++ b/create_shards.py @@ -17,8 +17,6 @@ import pandas as pd from tqdm import tqdm -# from shared_astro_utils import object_utils - from zoobot import label_metadata from zoobot.data_utils import catalog_to_tfrecord, checks from zoobot.estimators import preprocess diff --git a/requirements.txt b/requirements.txt index ad3c530c..2f9610d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,8 +18,8 @@ matplotlib python-dateutil==2.8.1 # for boto3 boto3 statsmodels -requests>=2.4.2 # for panoptes-client -panoptes-client +# requests>=2.4.2 # for panoptes-client +# panoptes-client # also requires personal package shared_astro_utils, pip install -e ../shared-astro-utilities keras_applications # for efficientnet tensorflow_probability >= 0.11 # 0.11 for tf 2.3, 0.14 for tf 2.4 diff --git a/results/.gitkeep b/results/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/train_model.py b/train_model.py index 07748dcf..96f98b7c 100644 --- a/train_model.py +++ b/train_model.py @@ -5,6 +5,7 @@ import logging import tensorflow as tf +import wandb from zoobot.data_utils import tfrecord_datasets from zoobot.training import training_config, losses @@ -52,14 +53,18 @@ parser = argparse.ArgumentParser() parser.add_argument('--experiment-dir', dest='save_dir', type=str) - parser.add_argument('--shard-img-size', dest='shard_img_size', type=int, default=256) - parser.add_argument('--resize-size', dest='resize_size', type=int, default=64) + parser.add_argument('--shard-img-size', dest='shard_img_size', type=int, default=300) + parser.add_argument('--resize-size', dest='resize_size', type=int, default=224) parser.add_argument('--train-dir', dest='train_records_dir', type=str) parser.add_argument('--eval-dir', dest='eval_records_dir', type=str) parser.add_argument('--epochs', dest='epochs', type=int) parser.add_argument('--batch-size', dest='batch_size', default=64, type=int) + parser.add_argument('--wandb', default=False, action='store_true') args = parser.parse_args() + greyscale = True + # greyscale = False + initial_size = args.shard_img_size resize_size = args.resize_size # step time prop. to resolution batch_size = args.batch_size @@ -87,7 +92,8 @@ preprocess_config = preprocess.PreprocessingConfig( label_cols=schema.label_cols, input_size=initial_size, - greyscale=True + make_greyscale=greyscale, + normalise_from_uint8=False ) train_dataset = preprocess.preprocess_dataset(raw_train_dataset, preprocess_config) test_dataset = preprocess.preprocess_dataset(raw_test_dataset, preprocess_config) @@ -114,6 +120,22 @@ patience=10 ) + if args.wandb: + this_script_dir = os.path.dirname(__file__) + # you need to make this file yourself, with your api key and nothing else + with open(os.path.join(this_script_dir, 'wandb_api.txt'), 'r') as f: + api_key = f.readline() + wandb.login(key=api_key) + wandb.init(sync_tensorboard=True) + config = wandb.config + config.label_cols=schema.label_cols, + config.initial_size=initial_size + config.greyscale = greyscale + config.resize_size = resize_size + config.batch_size = batch_size + config.train_records = train_records + config.epochs = epochs + # inplace on model training_config.train_estimator( model, diff --git a/zoobot/estimators/define_model.py b/zoobot/estimators/define_model.py index 8271a511..3b0f2050 100644 --- a/zoobot/estimators/define_model.py +++ b/zoobot/estimators/define_model.py @@ -62,7 +62,9 @@ def add_augmentation_layers(model, crop_size, resize_size, always_augment=False) flip_layer = tf.keras.layers.experimental.preprocessing.RandomFlip crop_layer = tf.keras.layers.experimental.preprocessing.RandomCrop - model.add(rotation_layer(np.pi, fill_mode='reflect')) + + # np.pi fails with tf 2.5 + model.add(rotation_layer(0.5, fill_mode='reflect')) # rotation range +/- 0.25 * 2pi i.e. +/- 90*. model.add(flip_layer()) model.add(crop_layer(crop_size, crop_size)) if resize: diff --git a/zoobot/estimators/preprocess.py b/zoobot/estimators/preprocess.py index 1920d3fb..e6378923 100644 --- a/zoobot/estimators/preprocess.py +++ b/zoobot/estimators/preprocess.py @@ -88,8 +88,7 @@ def preprocess_batch(batch, config): batch_images = get_images_from_batch( batch, size=config.input_size, - channels=config.input_channels, - summary=True) + channels=config.input_channels) if config.normalise_from_uint8: batch_images = batch_images / 255. diff --git a/zoobot/predictions/predict_on_images.py b/zoobot/predictions/predict_on_images.py index 6852f16c..dd53690d 100644 --- a/zoobot/predictions/predict_on_images.py +++ b/zoobot/predictions/predict_on_images.py @@ -45,7 +45,7 @@ def predict(image_ds: tf.data.Dataset, model: tf.keras.Model, n_samples: int, la data = [prediction_to_row(predictions[n], image_paths[n], label_cols=label_cols) for n in range(len(predictions))] predictions_df = pd.DataFrame(data) - logging.info(predictions_df) + # logging.info(predictions_df) predictions_df.to_csv(save_loc, index=False) logging.info(f'Predictions saved to {save_loc}')