Cardiac anatomy segmentation is crucial for assessing cardiac morphology and function, aiding diagnosis and intervention. Deep learning (DL) improves accuracy over traditional methods, and recent studies show that adding motion information can enhance segmentation further. However, current methods either increase input dimensionality, making them computationally expensive, or use suboptimal techniques like non-DL registration, non-attention networks, or single-headed attention.
We propose a novel, computation-efficient approach using a scalable Temporal Attention Module (TAM) for motion enhancement and improved performance. TAM features a multi-headed, KQV projection cross-attention architecture and can be easily integrated into existing CNN, Transformer, or Hybrid segmentation networks, offering flexibility for future implementations.
-
Novel Temporal Attention Mechanism for Segmentation:
- We present a new Temporal Attention Module (TAM), a multi-headed, temporal cross-time attention mechanism based on KQV projection, that enables the network to effectively capture dynamic changes across temporal frames for motion-enhanced cardiac anatomy segmentation.
-
Flexible Integration into a Range of Segmentation Networks:
- TAM can be plug-and-play integrated into a variety of established backbone segmentation architectures, including UNet, FCN8s, UNetR, SwinUNetR, IΒ²UNet, DT-VNet, and others, arming them with motion-awareness. This provides a simple and elegant approach to implementing motion awareness in future networks.
-
Consistent Performance Across Multiple Settings:
- Generalizable across different image types and qualities, from 2D to 3D cardiac datasets.
- Highly adaptable, improving segmentation performance across various backbone architectures.
- Computationally efficient, adding minimal overhead and outperforming methods that increase input dimensionality.
-
Extensive evaluation on diverse cardiac datasets:
Our results confirm that TAM enhances motion-aware segmentation (see in the following video) while maintaining computational efficiency, making it a promising addition to future deep learning-based cardiac segmentation methods (details will be in the paper).
This section provides an overview of how to load, preprocess, and structure cardiac imaging datasets (NIfTI format) for training and validating our motion-aware segmentation networks.
β Loss Function: Combination of DICE score and cross-entropy, in alignment with recent literature.
β Epoch Numbers: 300
β Batch Size: 8 for the CAMUS dataset and 4 for the MITEA and ACDC datasets.
β Image Size: For the CAMUS 2D dataset, the image size is 256Γ256, while for the MITEA 3D dataset, it is 128Γ128Γ128, and for the ACDC 3D dataset, it is 160Γ160Γ16, following the nnFormer settings.
β
Optimizer and LR: Adam optimizer with a learning rate of 1e-4
for the CAMUS and MITEA datasets. SGD with a polynomial learning rate schedule for the ACDC dataset, following the nnFormer settings.
Cardiac image sequences typically include multiple frames:
- End-Diastolic (ED) Frame
- End-Systolic (ES) Frame
- Mid-Systolic Frames (optional, intermediate frames between ED and ES)
For training our TAM network, at least two frames (ED and ES) are required. However, incorporating a mid-systolic frame enhances performance by helping the network bridge the large motion between ED and ES.
Our dataset preparation pipeline ensures:
β
Efficient loading of NIfTI images
β
Rescaling & Normalization to a consistent resolution
β
Preserving segmentation labels during resizing
β
Multi-frame integration for temporal attention
function load_nifti(filepath, target_shape, is_mask=False):
# Load NIfTI image or mask
# Resize: Cubic interpolation for images, Nearest-neighbor for masks
# Normalize image intensities (if not a mask)
# Convert to tensor (float for images, long for masks)
return tensor
function load_image_mask_pair(base_path, frame_type):
# Load image and mask for given frame type (ED, ES, Mid)
# Use load_nifti() for consistent processing
return {'image': image, 'mask': mask}
class TAM_Dataset(Dataset):
function __init__(self, image_paths, mask_paths, num_mid_frames=None, transform=None):
# Initialize dataset paths, frame count, and transformations
function __getitem__(self, idx):
# Extract base paths for images & masks
# Load ED & ES frames
# Load Mid frames if available
# Apply transformations (if any) or convert to tensor
return {'ED': ed_data, 'ES': es_data, 'Mid': mid_data (if available)}
This module performs multi-frame self-attention to enhance temporal feature learning. It integrates multi-head attention, gating, and convolutional refinement for motion-aware feature aggregation.
Class MultiHeadAttention:
Initialize(num_channels, embedding_dim, num_heads):
- Define Query, Key, and Value projection layers (Conv3D)
- Initialize Multi-Head Attention
- Define gating mechanism (Conv3D + Sigmoid)
- Define feature fusion layer (Conv3D + BatchNorm + ReLU)
- Define final classifier (Conv3D)
Forward(frame_sequence):
Initialize output_list
For each reference_frame in frame_sequence:
Initialize combined_output = 0
For each comparison_frame in frame_sequence:
If reference_frame == comparison_frame:
- Continue (skip self-attention)
# Project frame features into Query, Key, Value
query = ProjectQuery(comparison_frame)
key = ProjectKey(reference_frame)
value = ProjectValue(reference_frame)
# Compute attention-weighted features using scaled dot-product attention
attention_output = ComputeScaledDotProductAttention(query, key, value)
# Apply gating mechanism to the attention output
attention_mask = ApplyGatingMechanism(attention_output)
attention_output = attention_output * attention_mask
# Concatenate attended output with original frame
combined_features = Concatenate(attention_output, reference_frame)
combined_features = ApplyFeatureFusion(combined_features)
combined_features = ApplyBatchNorm(combined_features)
combined_features = ApplyReLU(combined_features)
# Accumulate attention results
combined_output += combined_features
# Average attention results across all frames and classify
avg_output = combined_output / (total_frames - 1)
final_output = ApplyClassifier(avg_output)
Add final_output to output_list
Return output_list
class ConvBlock(nn.Module):
function __init__(self, in_channels, out_channels):
# Initialize two 3D convolution layers followed by Batch Normalization and ReLU activation
# conv1: Conv3D + BatchNorm + ReLU
# conv2: Conv3D + BatchNorm + ReLU
function forward(self, x):
# Apply conv1, batch normalization, and ReLU activation
# Apply conv2, batch normalization, and ReLU activation
return processed_output
class EncoderBlock(nn.Module):
function __init__(self, in_channels, out_channels):
# Initialize ConvBlock followed by MaxPooling (2x2x2)
function forward(self, x):
# Pass input through ConvBlock
# Apply MaxPooling
return conv_output, pooled_output
class DecoderBlock(nn.Module):
function __init__(self, in_channels, out_channels):
# Initialize ConvTranspose3D for upsampling followed by ConvBlock
function forward(self, x, skip_connection):
# Upsample input using ConvTranspose3D
# Concatenate upsampled input with the skip connection
# Pass through ConvBlock
return decoded_output
class Encoder(nn.Module):
function __init__(self, input_channels, feature_depths):
# Initialize EncoderBlocks for multiple stages and Bottleneck layer
# Initialize attention mechanisms at bottleneck and last encoder stage
function forward(self, *frames):
# For each frame:
# Process through EncoderBlock stages
# Collect and store outputs at each stage (s1, s2, s3, s4) and pooled output (p4)
# Stack outputs from all frames for attention
# Apply attention (TAM) on pooled outputs and bottleneck outputs
return tuple of all outputs
class UNet(nn.Module):
function __init__(self, num_classes, feature_depths):
# Initialize Encoder
# Initialize DecoderBlocks for each stage of U-Net
# Initialize final Conv3D layer for classification and Softmax activation
function forward(self, *inputs):
# Get outputs from Encoder
# For each frame:
# Unpack encoder outputs for skip connections and bottleneck
# Pass through DecoderBlocks, using skip connections for each frame
# Apply final classification (Conv3D + Softmax) to get segmentation mask
return tuple of masks
Methods | Class-wise HD (mm) ( |
The average of the anatomical organs | ||||||
---|---|---|---|---|---|---|---|---|
LVMYO | LVENDO | LVEPI | LA | DSC( |
HD( |
MASD( |
PIA(%)( |
|
UNet | 5.65 | 4.21 | 5.67 | 4.91 | 0.913 | 5.11 | 1.13 | 2.05 |
TAM-UNet | 4.05 | 3.07 | 3.89 | 3.52 | 0.922 | 3.63 | 0.96 | 0.68 |
FCN8s | 6.80 | 5.44 | 6.01 | 7.26 | 0.899 | 6.38 | 1.33 | 0.58 |
TAM-FCN8s | 3.60 | 3.04 | 3.33 | 3.27 | 0.921 | 3.31 | 0.98 | 0.02 |
UNetR | 8.03 | 5.59 | 7.71 | 8.35 | 0.897 | 7.42 | 1.43 | 2.43 |
TAM-UNetR | 6.08 | 4.62 | 5.86 | 6.05 | 0.904 | 5.65 | 1.24 | 0.92 |
SwinUNetR | 8.33 | 5.60 | 8.24 | 6.41 | 0.888 | 7.15 | 1.52 | 2.67 |
TAM-SwinUNetR | 5.63 | 4.25 | 5.32 | 4.11 | 0.913 | 4.83 | 1.15 | 1.32 |
Methods | Class-wise HD (mm) ( |
The average of the anatomical organs | ||||||
---|---|---|---|---|---|---|---|---|
LVMYO | LVENDO | LVEPI | LA | DSC( |
HD( |
MASD( |
PIA(%) ( |
|
UNet | 14.58 | 9.89 | 12.31 | - | 0.830 | 12.26 | 2.03 | 0.30 |
TAM-UNet | 11.47 | 9.14 | 10.84 | - | 0.833 | 10.48 | 1.97 | 0.16 |
FCN8s | 12.07 | 11.95 | 10.95 | - | 0.828 | 11.66 | 2.06 | 1.07 |
TAM-FCN8s | 9.27 | 7.59 | 8.24 | - | 0.836 | 8.37 | 1.93 | 0.22 |
UNetR | 13.39 | 11.85 | 12.74 | - | 0.806 | 12.66 | 2.34 | 0.53 |
TAM-UNetR | 10.70 | 9.56 | 9.96 | - | 0.814 | 10.07 | 2.21 | 0.38 |
SwinUNetR | 10.95 | 10.10 | 10.25 | - | 0.818 | 10.43 | 2.27 | 0.36 |
TAM-SwinUNetR | 9.67 | 8.67 | 9.01 | - | 0.823 | 9.12 | 2.12 | 0.23 |
Methods (motion?) | LVMYO | LVENDO | LVEPI | LA | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
DSC (β) | HD (β) | MASD (β) | DSC (β) | HD (β) | MASD (β) | DSC (β) | HD (β) | MASD (β) | DSC (β) | HD (β) | MASD (β) | |
UNet(β) | 0.864 | 5.65 | 1.10 | 0.927 | 4.21 | 1.06 | 0.954 | 5.67 | 1.15 | 0.904 | 4.91 | 1.21 |
SwinUNetR(β) | 0.834 | 8.33 | 1.41 | 0.908 | 5.60 | 1.42 | 0.939 | 8.24 | 1.56 | 0.869 | 6.41 | 1.68 |
ACNN(β) | - | - | - | 0.918 | 5.90 | 1.80 | 0.946 | 6.35 | 1.95 | - | - | - |
BEASNet(β) | - | - | - | 0.915 | 6.0 | 1.95 | 0.943 | 6.35 | 2.15 | - | - | - |
UB2DNet(β) | - | - | - | 0.858 | - | - | - | - | - | - | - | - |
FFPN-R(β) | 0.850 | 3.65 | - | 0.924 | 3.05 | - | - | - | - | 0.888 | 3.80 | - |
PLANet(β) | - | - | - | 0.944 | 4.14 | 1.26 | 0.957 | 5.0 | 1.72 | - | - | - |
CoSTUNet(β) | - | - | - | 0.916 | 6.55 | - | 0.837 | 7.65 | - | 0.875 | 6.70 | - |
I2UNet(β) | 0.873 | 4.72 | 1.03 | 0.933 | 3.49 | 1.02 | 0.956 | 4.39 | 1.09 | 0.910 | 4.25 | 1.19 |
Our TAM-I2UNet(β) | 0.872 | 4.19 | 1.03 | 0.933 | 3.02 | 0.972 | 0.956 | 3.92 | 1.06 | 0.913 | 3.74 | 1.11 |
Our TAM-FCN8s(β) | 0.876 | 3.60 | 0.973 | 0.935 | 3.04 | 0.949 | 0.959 | 3.33 | 0.961 | 0.916 | 3.27 | 1.06 |
SOCOF(β) | - | - | - | 0.932 | 3.21 | 1.40 | 0.953 | 4.0 | 1.65 | - | - | - |
CLAS(β) | - | - | - | 0.935 | 4.60 | 1.40 | 0.958 | 4.85 | 1.55 | 0.915 | - | - |