Skip to content

Commit

Permalink
- arch : limit dilation and kernel size when image dimensions are too…
Browse files Browse the repository at this point in the history
… small + allow different kernel size in X and Y in names

- objectwise_computation_tf.py: use spatial softmax to locate center instead of weighted mean
  • Loading branch information
jeanollion committed Apr 26, 2024
1 parent 0547ac9 commit 2576393
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 42 deletions.
79 changes: 59 additions & 20 deletions distnet_2d/model/architectures.py
@@ -1,12 +1,17 @@
import math
import copy

from ..utils.helpers import ensure_multiplicity
def get_architecture(architecture_type:str, **kwargs):
kwargs = copy.deepcopy(kwargs)
if architecture_type.lower()=="blend":
arch = BlendD2 if kwargs.pop("n_downsampling", 2) == 2 else BlendD3
return arch(**kwargs)
else:
raise ValueError(f"Unknown architecture type: {architecture_type}")

class BlendD2():
def __init__(self, filters:int = 128, blending_filter_factor:float=0.5, batch_norm:bool = True, dropout:float=0.2, self_attention:int = 0, attention:int = 0, combine_kernel_size:int=1, pair_combine_kernel_size:int=5, skip_connections=[-1]):
def __init__(self, filters:int = 128, blending_filter_factor:float=0.5, batch_norm:bool = True, dropout:float=0.2, self_attention:int = 0, attention:int = 0, combine_kernel_size:int=1, pair_combine_kernel_size:int=5, skip_connections=[-1], spatial_dimensions=None):
prefix = f"{'a' if attention else ''}{'sa' if self_attention else ''}"
self.name = f"{prefix}blendD2-{filters}"
self.skip_connections=skip_connections
Expand All @@ -18,23 +23,28 @@ def __init__(self, filters:int = 128, blending_filter_factor:float=0.5, batch_no
self.blending_filter_factor=blending_filter_factor
self.downsampling_mode="maxpool_and_stride"
self.upsampling_mode ="tconv"
ker1, _ = get_kernels_and_dilation(3, 1, spatial_dimensions, 2 )
ker1_2, _ = get_kernels_and_dilation(5, 1, spatial_dimensions, 2)
ker2, dil2 = get_kernels_and_dilation(5, 2, spatial_dimensions, 2 * 2)
ker2_2, dil2_2 = get_kernels_and_dilation(5, 3, spatial_dimensions, 2 * 2)
ker2_3, dil2_3 = get_kernels_and_dilation(5, 4, spatial_dimensions, 2 * 2)
self.encoder_settings = [
[
{"filters":32, "downscale":2, "weight_scaled":False, "dropout_rate":0}
],
[
{"filters":32, "op":"conv", "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":False},
{"filters":32, "op":"conv", "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":False},
{"filters":filters, "downscale":2, "weight_scaled":False, "dropout_rate":0, "batch_norm":False}
{"filters":32, "op":"conv", "kernel_size":ker1, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":False},
{"filters":32, "op":"conv", "kernel_size":ker1_2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":False},
{"filters":filters, "kernel_size":ker1, "downscale":2, "weight_scaled":False, "dropout_rate":0, "batch_norm":False}
]
]
self.feature_settings = [
{"op":"res2d", "dilation":2, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":2 if self_attention>0 else 3, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":filters, "op":"selfattention" if self_attention>0 else "res2d", "kernel_size":5, "dilation":2 if self_attention>0 else 4, "dropout_rate":dropout, "num_attention_heads":self_attention },
{"op":"res2d", "dilation":2, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":2 if self_attention>0 else 3, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":1., "op":"conv", "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":batch_norm},
{"op":"res2d", "dilation":dil2, "kernel_size":ker2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":dil2 if self_attention>0 else dil2_2, "kernel_size":ker2 if self_attention>0 else ker2_2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":filters, "op":"selfattention" if self_attention>0 else "res2d", "kernel_size":ker2 if self_attention>0 else ker2_3, "dilation":dil2 if self_attention>0 else dil2_3, "dropout_rate":dropout, "num_attention_heads":self_attention },
{"op":"res2d", "dilation":dil2, "kernel_size":ker2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":dil2 if self_attention>0 else dil2_2, "kernel_size":ker2 if self_attention>0 else ker2_2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":1., "op":"conv", "kernel_size":ker2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":batch_norm},
]
self.feature_blending_settings = [
{"op":"res2d", "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False, "split_conv":False},
Expand All @@ -52,7 +62,7 @@ def __init__(self, filters:int = 128, blending_filter_factor:float=0.5, batch_no
]

class BlendD3():
def __init__(self, filters:int = 192, blending_filter_factor:float=0.5, batch_norm:bool = True, dropout:float=0.2, self_attention:int = 0, attention:int = 0, combine_kernel_size:int=1, pair_combine_kernel_size:int=5, skip_connections=[-1]):
def __init__(self, filters:int = 192, blending_filter_factor:float=0.5, batch_norm:bool = True, dropout:float=0.2, self_attention:int = 0, attention:int = 0, combine_kernel_size:int=1, pair_combine_kernel_size:int=5, skip_connections=[-1], spatial_dimensions=None):
prefix = f"{'a' if attention else ''}{'sa' if self_attention else ''}"
self.name = f"{prefix}blendD3-{filters}"
self.skip_connections=skip_connections
Expand All @@ -64,6 +74,10 @@ def __init__(self, filters:int = 192, blending_filter_factor:float=0.5, batch_no
self.blending_filter_factor = blending_filter_factor
self.downsampling_mode="maxpool_and_stride"
self.upsampling_mode ="tconv"
ker2, _ = get_kernels_and_dilation(3, 1, spatial_dimensions, 2 * 2)
ker3, dil3 = get_kernels_and_dilation(5, 2, spatial_dimensions, 2 * 2 * 2)
ker3_2, dil3_2 = get_kernels_and_dilation(5, 3, spatial_dimensions, 2 * 2 * 2)
ker3_3, dil3_3 = get_kernels_and_dilation(5, 4, spatial_dimensions, 2 * 2 * 2)
self.encoder_settings = [
[
{"filters":32, "downscale":2, "dropout_rate":0}
Expand All @@ -73,18 +87,18 @@ def __init__(self, filters:int = 192, blending_filter_factor:float=0.5, batch_no
{"filters":64, "downscale":2, "dropout_rate":0}
],
[
{"filters":64, "op":"res2d", "weighted_sum":False, "weight_scaled":False, "dropout_rate":0},
{"filters":64, "op":"res2d", "weighted_sum":False, "weight_scaled":False, "dropout_rate":0},
{"filters":filters, "downscale":2, "weight_scaled":False, "dropout_rate":0, "batch_norm":False}
{"filters":64, "op":"res2d", "kernel_size":ker2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0},
{"filters":64, "op":"res2d", "kernel_size":ker2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0},
{"filters":filters, "kernel_size":ker2, "downscale":2, "weight_scaled":False, "dropout_rate":0, "batch_norm":False}
]
]
self.feature_settings = [
{"op":"res2d", "dilation":2, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":2 if self_attention>0 else 3, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":filters, "op":"selfattention" if self_attention>0 else "res2d", "kernel_size":5, "dilation":2 if self_attention>0 else 4, "dropout_rate":dropout },
{"op":"res2d", "dilation":2, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":2 if self_attention>0 else 3, "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":1., "op":"conv", "kernel_size":5, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":batch_norm},
{"op":"res2d", "dilation":dil3, "kernel_size":ker3, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":dil3 if self_attention>0 else dil3_2, "kernel_size":ker3 if self_attention>0 else ker3_2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":filters, "op":"selfattention" if self_attention>0 else "res2d", "kernel_size":ker3 if self_attention>0 else ker3_3, "dilation":dil3 if self_attention>0 else dil3_3, "dropout_rate":dropout },
{"op":"res2d", "dilation":dil3, "kernel_size":ker3, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"op":"res2d", "dilation":dil3 if self_attention>0 else dil3_2, "kernel_size":ker3 if self_attention>0 else ker3_2, "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
{"filters":1., "op":"conv", "kernel_size":ker3, "weighted_sum":False, "weight_scaled":False, "dropout_rate":0, "batch_norm":batch_norm},
]
self.feature_blending_settings = [
{"op":"res2d", "weighted_sum":False, "weight_scaled":False, "dropout_rate":dropout, "batch_norm":False},
Expand All @@ -101,3 +115,28 @@ def __init__(self, filters:int = 192, blending_filter_factor:float=0.5, batch_no
{"filters":32, "op":"res2d", "weighted_sum":False, "n_conv":2, "up_kernel_size":4, "weight_scaled_up":False, "weight_scaled":False, "batch_norm":False, "dropout_rate":0},
{"filters":64, "op":"res2d", "weighted_sum":False, "n_conv":2, "up_kernel_size":4, "weight_scaled_up":False, "weight_scaled":False, "batch_norm":False, "dropout_rate":0}
]

def get_kernels_and_dilation(target_kernel, target_dilation, spa_dimensions, downsampling):
if spa_dimensions is None:
return target_kernel, target_dilation
spa_dimensions = ensure_multiplicity(2, spa_dimensions)
kernel = ensure_multiplicity(2, target_kernel)
dilation = ensure_multiplicity(2, target_dilation)
spa_dimensions = [d/downsampling if d is not None and d>0 else None for d in spa_dimensions]
for i in range(len(spa_dimensions)):
while not test_ker_dil(kernel[i], dilation[i], spa_dimensions[i]):
if dilation[i] > 1:
dilation[i] -=1
elif kernel[i] > 1:
kernel[i] = 1 + 2 * ((kernel[i] - 1) // 2 - 1)
else:
raise ValueError(f"Cannot find kernel size that suit dimension: {spa_dimensions[i]}")
kernel = kernel[0] if kernel[0] == kernel[1] else kernel
dilation = dilation[0] if dilation[0] == dilation[1] else dilation
return kernel, dilation

def test_ker_dil(ker, dil, dim):
if ker == 1 and dil == 1 or dim is None or dim <= 0:
return True
size = (ker-1)*dil
return dim >= size * 2

0 comments on commit 2576393

Please sign in to comment.