Skip to content

Tensorflow implementation of U-Net model with TPU Estimator support.

License

Notifications You must be signed in to change notification settings

tlatkowski/u-net-tpu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[Under development]

TPU-optimized U-Net model implementation

Model description

This repository contains U-Net model implementation (originally proposed in U-Net: Convolutional Networks for Biomedical Image Segmentation) with usage of Tensoflow TPUEstimators API.

Model architecture

Training U-Net on Cloud TPU

Config your TPU flock:

Check your current Cloud TPU configuration (in Google Cloud Shell):

ctpu print-config

As a result you should get the similar outcome

ctpu configuration:
        name: your-user-name
        project: your-project-name
        zone: your-zone

To run CTPU flock, in Google Cloud Shell, execute the following command

ZONE=YOUR-ZONE
TPU_NAME=YOUR-TPU-NAME

ctpu up --zone $ZONE --name $TPU_NAME

Once the CTPU flock is ready, ssh to your instance

GCP_PROJECT=YOUR-GCP-PROJECT

gcloud compute --project $GCP_PROJECT ssh --zone $ZONE $TPU_NAME
Create bucket for storing training checkpoints:

Execute the following commands to create a bucket on your TPU VM instance:

GCP_PROJECT=YOUR-GCP-PROJECT
MODEL_DIR=gs://YOUR-GOOGLE-CLOUD-BUCKET

gsutil mb -p $GCP_PROJECT $MODEL_DIR
Training U-Net model as a classifier on Cloud TPU

In order to train your MNIST image classifier based on U-Net architecture and with the usage of Cloud TPU use u_net_clf_tpu.py script with the following parameters:

# TPU configuration
GCP_PROJECT=YOUR-GCP-PROJECT
TPU_ZONE=YOUR-TPU-ZONE
TPU_NAME=YOUR-TPU-NAME
MODEL_DIR=gs://YOUR-GOOGLE-CLOUD-BUCKET

# problem/model configuration
PROBLEM=MNIST
TRAIN_DIR=PATH-TO-TRAIN-DIR
EVAL_DIR=PATH-TO-EVALUATION-DIR

python u_net_clf_tpu.py \
--train_dir=$TRAIN_DIR \
--eval_dir=$EVAL_DIR \
--problem=$PROBLEM \
--model_dir=$MODEL_DIR \
--tpu_name=$TPU_NAME \
--tpu_zone=$TPU_ZONE \
--gcp_project=$GCP_PROJECT
Vizualize model training on Tensorboard
gcloud auth application-default login
tensorboard --logdir $MODEL_DIR