Skip to content

Commit

Permalink
Merge pull request #11 from 1-w/inverse_transformation
Browse files Browse the repository at this point in the history
implement inverse transformation into input space
  • Loading branch information
ravnoor committed Aug 12, 2022
2 parents 28fcb18 + 1400d0b commit d03d8bc
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data
__pycache__
12 changes: 8 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ USER user
ENV HOME=/home/user
RUN chmod 777 /home/user

RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py37_4.11.0-Linux-x86_64.sh \
&& /bin/bash Miniconda3-py37_4.11.0-Linux-x86_64.sh -b -p /home/user/conda \
&& rm -f Miniconda3-py37_4.11.0-Linux-x86_64.sh
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py37_4.12.0-Linux-x86_64.sh \
&& /bin/bash Miniconda3-py37_4.12.0-Linux-x86_64.sh -b -p /home/user/conda \
&& rm -f Miniconda3-py37_4.12.0-Linux-x86_64.sh

RUN conda update -n base -c defaults conda

RUN git clone --depth 1 https://github.com/NOEL-MNI/deepMask.git \
&& rm -rf deepMask/.git
Expand All @@ -48,12 +50,14 @@ RUN eval "$(conda shell.bash hook)" \
&& python -m pip install -r deepMask/app/requirements.txt \
&& conda deactivate

COPY app/ /app/
COPY app/requirements.txt /app/requirements.txt

RUN python -m pip install -r /app/requirements.txt \
&& conda install -c conda-forge pygpu==0.7.6 \
&& pip cache purge

COPY app/ /app/

RUN sudo chmod -R 777 /app && sudo chmod +x /app/inference.py

CMD ["python3"]
21 changes: 20 additions & 1 deletion app/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,32 @@
# --------------------------------------------------
# test_list = ['mcd_0468_1']
test_list = [args.id]
# t1_file = sys.argv[3]
# t2_file = sys.argv[4]
t1_file = args.t1
t2_file = args.t2

t1_transform = os.path.join(args.outdir, "transforms", args.id + "_t1-native-to-MNI152.mat")
t2_transform = os.path.join(args.outdir, "transforms", args.id + "_t2-native-to-MNI152.mat")

files = [args.t1, args.t2]

orig_files = {'T1':args.t1,'FLAIR':args.t2}

transform_files = [t1_transform, t2_transform]
# files = {}
# files['T1'], files['FLAIR'] = str(t1_file), t2_file
test_data = {}
# test_data = {f: {m: os.path.join(tfolder, f, m+'_stripped.nii.gz') for m in modalities} for f in test_list}
test_data = {f: {m: os.path.join(options['test_folder'], f, n) for m, n in zip(modalities, files)} for f in test_list}
test_tranforms = {f: {m: n for m, n in zip(modalities, transform_files)} for f in test_list}
# test_data = {f: {m: os.path.join(options['test_folder'], f, n) for m, n in zip(modalities, files)} for f in test_list}

for _, scan in enumerate(tqdm(test_list, desc='serving predictions using the trained model', colour='blue')):
t_data = {}
t_data[scan] = test_data[scan]
transforms = {}
transforms[scan] = test_tranforms[scan]

options['pred_folder'] = os.path.join(options['test_folder'], scan, options['experiment'])
if not os.path.exists(options['pred_folder']):
Expand All @@ -155,7 +174,7 @@
# test1: pred/stage2
# test2: morphological processing + contiguous clusters
# pred0, pred1, postproc, _, _ = test_model(model, t_data, options)
test_model(model, t_data, options)
test_model(model, t_data, options, transforms=transforms, orig_files=orig_files, invert_xfrm=True)

end = time.time()
diff = (end - start) // 60
Expand Down
3 changes: 3 additions & 0 deletions app/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

# set up parameters
args.outdir = os.path.join(args.dir, args.id)
args.tmpdir = os.path.join(args.outdir, "tmp")
if not os.path.exists(args.tmpdir):
os.makedirs(args.tmpdir)
args.t1 = os.path.join(args.outdir, args.t1_fname)
args.t2 = os.path.join(args.outdir, args.t2_fname)
args.seed = 666
Expand Down
1 change: 1 addition & 0 deletions app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
antspyx==0.3.2
Theano==1.0.4
keras==2.2.4
h5py==2.10.0
Expand Down
71 changes: 57 additions & 14 deletions app/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.model_selection import LeaveOneGroupOut
from utils.patch_dataloader import *
from utils.post_processor import *
import ants


def print_data_shape(X):
Expand Down Expand Up @@ -227,7 +228,7 @@ def train_model(model, train_x_data, train_y_data, options):
return model


def test_model(model, test_x_data, options, performance=False, uncertainty=True):
def test_model(model, test_x_data, options, performance=False, uncertainty=True, transforms=None, orig_files=None, invert_xfrm=True):
threshold = options['th_dnn_train_2']
scan = options['test_scan'] + '_'
# organize experiments
Expand All @@ -236,24 +237,66 @@ def test_model(model, test_x_data, options, performance=False, uncertainty=True)
options['test_mean_name'] = scan + options['experiment'] + '_prob_mean_0.nii.gz'
options['test_var_name'] = scan + options['experiment'] + '_prob_var_0.nii.gz'

t1, header = test_scan(model[0], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=20)
if uncertainty:
pred_mean_0, pred_var_0, header = test_scan(model[0], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=20)
pred_var_0_img = nifti2ants(pred_var_0, affine=header.get_qform(), header=header)
else:
pred_mean_0, header = test_scan(model[0], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=20)

pred_mean_0_img = nifti2ants(pred_mean_0, affine=header.get_qform(), header=header)

if isinstance(transforms, dict):
apply_transforms(pred_mean_0_img, pred_var_0_img, transforms, orig_files, invert_xfrm, options, uncertainty)

# second network
options['test_name'] = scan + options['experiment'] + '_prob_1.nii.gz'
options['test_mean_name'] = scan + options['experiment'] + '_prob_mean_1.nii.gz'
options['test_var_name'] = scan + options['experiment'] + '_prob_var_1.nii.gz'
t2, header = test_scan(model[1], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=50, candidate_mask=t1>threshold)

if uncertainty:
pred_mean_1, pred_var_1, header = test_scan(model[1], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=50, candidate_mask=pred_mean_0>threshold)
pred_var_1_img = nifti2ants(pred_var_1, affine=header.get_qform(), header=header)
else:
pred_mean_1, header = test_scan(model[1], test_x_data, options, save_nifti=True, uncertainty=uncertainty, T=50, candidate_mask=pred_mean_0>threshold)

pred_mean_1_img = nifti2ants(pred_mean_1, affine=header.get_qform(), header=header)

if isinstance(transforms, dict):
apply_transforms(pred_mean_1_img, pred_var_1_img, transforms, orig_files, invert_xfrm, options, uncertainty)

if performance:
# postprocess the output segmentation
options['test_name'] = options['experiment'] + '_out_CNN.nii.gz'
out_segmentation, lpred, count = post_processing(t2, options, header, save_nifti=True)
outputs = [t1, t2, out_segmentation, lpred, count]
out_segmentation, lpred, count = post_processing(pred_mean_1, options, header, save_nifti=True)
outputs = [pred_mean_0, pred_mean_1, out_segmentation, lpred, count]
else:
outputs = [t1, t2]
outputs = [pred_mean_0, pred_mean_1]
return outputs


def nifti2ants(input_np, affine, header):
nifti = nib.Nifti1Image(input_np, affine=affine, header=header)
output_ants = ants.convert_nibabel.from_nibabel(nifti)
return output_ants


def apply_transforms(pred_mean_img, pred_var_img, transforms, orig_files, invert_xfrm, options, uncertainty):
print("writing data transformed to the appropriate sterotaxic space")
for m, t in transforms[options["test_scan"]].items():
xfrm = ants.read_transform(t)
if invert_xfrm:
xfrm = xfrm.invert()
if uncertainty:
pred_var_xfmd = ants.apply_ants_transform_to_image(transform=xfrm, image=pred_var_img, reference=ants.image_read(orig_files[m]), interpolation="nearestneighbor")
pred_var_xfmd.to_filename(os.path.join(options['pred_folder'], options['test_var_name'].replace(".nii.gz", "_native-"+m+".nii.gz")))
# pred_var_xfmd = ants.resample_image_to_target(image=pred_var_xfmd, target=ants.image_read(orig_files[m]), verbose=True, interp_type="nearestNeighbor")
# pred_var_xfmd.to_filename(os.path.join(options['pred_folder'], options['test_var_name'].replace(".nii.gz", "_native_rsl-"+m+".nii.gz")))
pred_mean_xfmd = ants.apply_ants_transform_to_image(transform=xfrm, image=pred_mean_img, reference=ants.image_read(orig_files[m]), interpolation="nearestneighbor")
pred_mean_xfmd.to_filename(os.path.join(options['pred_folder'], options['test_mean_name'].replace(".nii.gz", "_native-"+m+".nii.gz")))
# pred_mean_xfmd = ants.resample_image_to_target(image=pred_mean_xfmd, target=ants.image_read(orig_files[m]), verbose=True, interp_type="nearestNeighbor")
# pred_mean_xfmd.to_filename(os.path.join(options['pred_folder'], options['test_mean_name'].replace(".nii.gz", "_native_rsl-"+m+".nii.gz")))


def test_scan(model, test_x_data, options, transit=None, save_nifti=False, uncertainty=False, candidate_mask=None, T=20):
"""
Test data based on one model
Expand All @@ -273,7 +316,7 @@ def test_scan(model, test_x_data, options, transit=None, save_nifti=False, uncer
flair_scans = [test_x_data[s]['FLAIR'] for s in scans]
flair_image = load_nii(flair_scans[0]).get_data()
header = load_nii(flair_scans[0]).header
# affine = header.get_qform()
affine = header.get_qform()
seg_image = np.zeros_like(flair_image)
var_image = np.zeros_like(flair_image)

Expand All @@ -297,37 +340,37 @@ def test_scan(model, test_x_data, options, transit=None, save_nifti=False, uncer

if save_nifti:
# out_scan = nib.Nifti1Image(seg_image, np.eye(4))
out_scan = nib.Nifti1Image(seg_image, affine=None, header=header)
out_scan = nib.Nifti1Image(seg_image, affine=affine, header=header)
out_scan.to_filename(os.path.join(options['pred_folder'], options['test_mean_name']))

if uncertainty:
out_scan = nib.Nifti1Image(var_image, affine=None, header=header)
out_scan = nib.Nifti1Image(var_image, affine=affine, header=header)
out_scan.to_filename(os.path.join(options['pred_folder'], options['test_var_name']))

if transit is not None:
if not os.path.exists(test_folder):
os.mkdir(test_folder)
out_scan = nib.Nifti1Image(seg_image, affine=None, header=header)
out_scan = nib.Nifti1Image(seg_image, affine=affine, header=header)
test_name = str.replace(scan, '_flair.nii.gz', '') + '_out_pred_mean_0.nii.gz'
out_scan.to_filename(os.path.join(test_folder, test_name))

if uncertainty:
out_scan = nib.Nifti1Image(var_image, affine=None, header=header)
out_scan = nib.Nifti1Image(var_image, affine=affine, header=header)
test_name = str.replace(scan, '_flair.nii.gz', '') + '_out_pred_var_0.nii.gz'
out_scan.to_filename(os.path.join(test_folder, test_name))

if not os.path.exists(os.path.join(test_folder, options['experiment'])):
os.mkdir(os.path.join(test_folder, options['experiment']))

out_scan = nib.Nifti1Image(seg_image, affine=None, header=header)
out_scan = nib.Nifti1Image(seg_image, affine=affine, header=header)
test_name = str.replace(scan, '_flair.nii.gz', '') + '_out_pred_0.nii.gz'
out_scan.to_filename(os.path.join(test_folder, test_name))

return seg_image, header
return (seg_image, var_image, header) if uncertainty else (seg_image, header)


def copy_most_recent_model(path, net_model):
files = os.listdir(path)
paths = [os.path.join(path, basename) for basename in files if basename.endswith('.h5')]
latest_model = max(paths, key=os.path.getctime)
copyfile(latest_model, os.path.join(path, net_model)+'.h5')
copyfile(latest_model, os.path.join(path, net_model)+'.h5')
4 changes: 2 additions & 2 deletions app/utils/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def post_processing(input_scan, options, header, save_nifti=True):

#save the output segmentation as nifti
if save_nifti:
nii_out = nib.Nifti1Image(output_scan, header=header)
nii_out = nib.Nifti1Image(output_scan, affine=header.get_qform(), header=header)
nii_out.to_filename(os.path.join(options['pred_folder'], options['test_name']))
labels_out = nib.Nifti1Image(labels_scan, header=header)
labels_out = nib.Nifti1Image(labels_scan, affine=header.get_qform(), header=header)
labels_out.to_filename(os.path.join(options['pred_folder'], options['test_morph_name']))
return output_scan, pred_labels, count

Expand Down

0 comments on commit d03d8bc

Please sign in to comment.