Skip to content

Commit

Permalink
Update API to catch up with previous fiddling.
Browse files Browse the repository at this point in the history
Add weights and biases support
Update requirements to allow tf 2.5
  • Loading branch information
Mike Walmsley committed Jun 9, 2021
1 parent 5241439 commit dcb89d4
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 12 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Expand Up @@ -128,4 +128,6 @@ dmypy.json
# Pyre type checker
.pyre/

*.csv
*.csv

wandb_api.txt
2 changes: 0 additions & 2 deletions create_shards.py
Expand Up @@ -17,8 +17,6 @@
import pandas as pd
from tqdm import tqdm

# from shared_astro_utils import object_utils

from zoobot import label_metadata
from zoobot.data_utils import catalog_to_tfrecord, checks
from zoobot.estimators import preprocess
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Expand Up @@ -18,8 +18,8 @@ matplotlib
python-dateutil==2.8.1 # for boto3
boto3
statsmodels
requests>=2.4.2 # for panoptes-client
panoptes-client
# requests>=2.4.2 # for panoptes-client
# panoptes-client
# also requires personal package shared_astro_utils, pip install -e ../shared-astro-utilities
keras_applications # for efficientnet
tensorflow_probability >= 0.11 # 0.11 for tf 2.3, 0.14 for tf 2.4
Empty file added results/.gitkeep
Empty file.
28 changes: 25 additions & 3 deletions train_model.py
Expand Up @@ -5,6 +5,7 @@
import logging

import tensorflow as tf
import wandb

from zoobot.data_utils import tfrecord_datasets
from zoobot.training import training_config, losses
Expand Down Expand Up @@ -52,14 +53,18 @@

parser = argparse.ArgumentParser()
parser.add_argument('--experiment-dir', dest='save_dir', type=str)
parser.add_argument('--shard-img-size', dest='shard_img_size', type=int, default=256)
parser.add_argument('--resize-size', dest='resize_size', type=int, default=64)
parser.add_argument('--shard-img-size', dest='shard_img_size', type=int, default=300)
parser.add_argument('--resize-size', dest='resize_size', type=int, default=224)
parser.add_argument('--train-dir', dest='train_records_dir', type=str)
parser.add_argument('--eval-dir', dest='eval_records_dir', type=str)
parser.add_argument('--epochs', dest='epochs', type=int)
parser.add_argument('--batch-size', dest='batch_size', default=64, type=int)
parser.add_argument('--wandb', default=False, action='store_true')
args = parser.parse_args()

greyscale = True
# greyscale = False

initial_size = args.shard_img_size
resize_size = args.resize_size # step time prop. to resolution
batch_size = args.batch_size
Expand Down Expand Up @@ -87,7 +92,8 @@
preprocess_config = preprocess.PreprocessingConfig(
label_cols=schema.label_cols,
input_size=initial_size,
greyscale=True
make_greyscale=greyscale,
normalise_from_uint8=False
)
train_dataset = preprocess.preprocess_dataset(raw_train_dataset, preprocess_config)
test_dataset = preprocess.preprocess_dataset(raw_test_dataset, preprocess_config)
Expand All @@ -114,6 +120,22 @@
patience=10
)

if args.wandb:
this_script_dir = os.path.dirname(__file__)
# you need to make this file yourself, with your api key and nothing else
with open(os.path.join(this_script_dir, 'wandb_api.txt'), 'r') as f:
api_key = f.readline()
wandb.login(key=api_key)
wandb.init(sync_tensorboard=True)
config = wandb.config
config.label_cols=schema.label_cols,
config.initial_size=initial_size
config.greyscale = greyscale
config.resize_size = resize_size
config.batch_size = batch_size
config.train_records = train_records
config.epochs = epochs

# inplace on model
training_config.train_estimator(
model,
Expand Down
4 changes: 3 additions & 1 deletion zoobot/estimators/define_model.py
Expand Up @@ -62,7 +62,9 @@ def add_augmentation_layers(model, crop_size, resize_size, always_augment=False)
flip_layer = tf.keras.layers.experimental.preprocessing.RandomFlip
crop_layer = tf.keras.layers.experimental.preprocessing.RandomCrop

model.add(rotation_layer(np.pi, fill_mode='reflect'))

# np.pi fails with tf 2.5
model.add(rotation_layer(0.5, fill_mode='reflect')) # rotation range +/- 0.25 * 2pi i.e. +/- 90*.
model.add(flip_layer())
model.add(crop_layer(crop_size, crop_size))
if resize:
Expand Down
3 changes: 1 addition & 2 deletions zoobot/estimators/preprocess.py
Expand Up @@ -88,8 +88,7 @@ def preprocess_batch(batch, config):
batch_images = get_images_from_batch(
batch,
size=config.input_size,
channels=config.input_channels,
summary=True)
channels=config.input_channels)

if config.normalise_from_uint8:
batch_images = batch_images / 255.
Expand Down
2 changes: 1 addition & 1 deletion zoobot/predictions/predict_on_images.py
Expand Up @@ -45,7 +45,7 @@ def predict(image_ds: tf.data.Dataset, model: tf.keras.Model, n_samples: int, la

data = [prediction_to_row(predictions[n], image_paths[n], label_cols=label_cols) for n in range(len(predictions))]
predictions_df = pd.DataFrame(data)
logging.info(predictions_df)
# logging.info(predictions_df)

predictions_df.to_csv(save_loc, index=False)
logging.info(f'Predictions saved to {save_loc}')
Expand Down

0 comments on commit dcb89d4

Please sign in to comment.