Skip to content

We propose a scalable Temporal Attention Module (TAM) to inject cardiac motion information into segmentation models. TAM captures dynamic changes across temporal frames using multi-headed, cross-temporal attention. It is adaptable across imaging modalities, scalable from 2D to 3D, and adds minimal computational overhead.

Notifications You must be signed in to change notification settings

kamruleee51/TAM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

22 Commits
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Motion-enhanced Cardiac Anatomy Segmentation via an Insertable Temporal Attention Module

⚠ Implementation details and updated code will be released soon.

Cardiac anatomy segmentation is crucial for assessing cardiac morphology and function, aiding diagnosis and intervention. Deep learning (DL) improves accuracy over traditional methods, and recent studies show that adding motion information can enhance segmentation further. However, current methods either increase input dimensionality, making them computationally expensive, or use suboptimal techniques like non-DL registration, non-attention networks, or single-headed attention.

We propose a novel, computation-efficient approach using a scalable Temporal Attention Module (TAM) for motion enhancement and improved performance. TAM features a multi-headed, KQV projection cross-attention architecture and can be easily integrated into existing CNN, Transformer, or Hybrid segmentation networks, offering flexibility for future implementations.

Key Contributions:

  • Novel Temporal Attention Mechanism for Segmentation:

    • We present a new Temporal Attention Module (TAM), a multi-headed, temporal cross-time attention mechanism based on KQV projection, that enables the network to effectively capture dynamic changes across temporal frames for motion-enhanced cardiac anatomy segmentation.
  • Flexible Integration into a Range of Segmentation Networks:

    • TAM can be plug-and-play integrated into a variety of established backbone segmentation architectures, including UNet, FCN8s, UNetR, SwinUNetR, IΒ²UNet, DT-VNet, and others, arming them with motion-awareness. This provides a simple and elegant approach to implementing motion awareness in future networks.
  • Consistent Performance Across Multiple Settings:

    • Generalizable across different image types and qualities, from 2D to 3D cardiac datasets.
    • Highly adaptable, improving segmentation performance across various backbone architectures.
    • Computationally efficient, adding minimal overhead and outperforming methods that increase input dimensionality.
  • Extensive evaluation on diverse cardiac datasets:

    • 2D echocardiography (CAMUS)
    • 3D echocardiography (MITEA)
    • 3D cardiac MRI (ACDC)

Our results confirm that TAM enhances motion-aware segmentation (see in the following video) while maintaining computational efficiency, making it a promising addition to future deep learning-based cardiac segmentation methods (details will be in the paper).

Image

πŸ“Œ Implementation

This section provides an overview of how to load, preprocess, and structure cardiac imaging datasets (NIfTI format) for training and validating our motion-aware segmentation networks.

πŸ”Ή Training Hyperparameters:

βœ… Loss Function: Combination of DICE score and cross-entropy, in alignment with recent literature.

βœ… Epoch Numbers: 300

βœ… Batch Size: 8 for the CAMUS dataset and 4 for the MITEA and ACDC datasets.

βœ… Image Size: For the CAMUS 2D dataset, the image size is 256Γ—256, while for the MITEA 3D dataset, it is 128Γ—128Γ—128, and for the ACDC 3D dataset, it is 160Γ—160Γ—16, following the nnFormer settings.

βœ… Optimizer and LR: Adam optimizer with a learning rate of 1e-4 for the CAMUS and MITEA datasets. SGD with a polynomial learning rate schedule for the ACDC dataset, following the nnFormer settings.

πŸ› οΈ Dataset Processing Pipeline

Cardiac image sequences typically include multiple frames:

  • End-Diastolic (ED) Frame
  • End-Systolic (ES) Frame
  • Mid-Systolic Frames (optional, intermediate frames between ED and ES)

For training our TAM network, at least two frames (ED and ES) are required. However, incorporating a mid-systolic frame enhances performance by helping the network bridge the large motion between ED and ES.

Our dataset preparation pipeline ensures:

βœ… Efficient loading of NIfTI images
βœ… Rescaling & Normalization to a consistent resolution
βœ… Preserving segmentation labels during resizing
βœ… Multi-frame integration for temporal attention

1️⃣ Load & Preprocess NIfTI Image/Mask

function load_nifti(filepath, target_shape, is_mask=False):
    # Load NIfTI image or mask
    # Resize: Cubic interpolation for images, Nearest-neighbor for masks
    # Normalize image intensities (if not a mask)
    # Convert to tensor (float for images, long for masks)
    return tensor

function load_image_mask_pair(base_path, frame_type):
    # Load image and mask for given frame type (ED, ES, Mid)
    # Use load_nifti() for consistent processing
    return {'image': image, 'mask': mask}

2️⃣ DataLoader Class

class TAM_Dataset(Dataset):
    function __init__(self, image_paths, mask_paths, num_mid_frames=None, transform=None):
        # Initialize dataset paths, frame count, and transformations

    function __getitem__(self, idx):
        # Extract base paths for images & masks
        # Load ED & ES frames
        # Load Mid frames if available
        # Apply transformations (if any) or convert to tensor
        return {'ED': ed_data, 'ES': es_data, 'Mid': mid_data (if available)}

πŸ“Œ Temporal Attention Module (TAM)

This module performs multi-frame self-attention to enhance temporal feature learning. It integrates multi-head attention, gating, and convolutional refinement for motion-aware feature aggregation.


πŸ› οΈ Pseudocode for TAM

Class MultiHeadAttention:
    Initialize(num_channels, embedding_dim, num_heads):
        - Define Query, Key, and Value projection layers (Conv3D)
        - Initialize Multi-Head Attention
        - Define gating mechanism (Conv3D + Sigmoid)
        - Define feature fusion layer (Conv3D + BatchNorm + ReLU)
        - Define final classifier (Conv3D)

    Forward(frame_sequence):
        Initialize output_list

        For each reference_frame in frame_sequence:
            Initialize combined_output = 0

            For each comparison_frame in frame_sequence:
                If reference_frame == comparison_frame:
                    - Continue (skip self-attention)

                # Project frame features into Query, Key, Value
                query = ProjectQuery(comparison_frame)
                key = ProjectKey(reference_frame)
                value = ProjectValue(reference_frame)

                # Compute attention-weighted features using scaled dot-product attention
                attention_output = ComputeScaledDotProductAttention(query, key, value)

                # Apply gating mechanism to the attention output
                attention_mask = ApplyGatingMechanism(attention_output)
                attention_output = attention_output * attention_mask

                # Concatenate attended output with original frame
                combined_features = Concatenate(attention_output, reference_frame)
                combined_features = ApplyFeatureFusion(combined_features)
                combined_features = ApplyBatchNorm(combined_features)
                combined_features = ApplyReLU(combined_features)

                # Accumulate attention results
                combined_output += combined_features

            # Average attention results across all frames and classify
            avg_output = combined_output / (total_frames - 1)
            final_output = ApplyClassifier(avg_output)
            Add final_output to output_list

        Return output_list

πŸ› οΈ Pseudocode for TAM-UNet

class ConvBlock(nn.Module):
    function __init__(self, in_channels, out_channels):
        # Initialize two 3D convolution layers followed by Batch Normalization and ReLU activation
        # conv1: Conv3D + BatchNorm + ReLU
        # conv2: Conv3D + BatchNorm + ReLU

    function forward(self, x):
        # Apply conv1, batch normalization, and ReLU activation
        # Apply conv2, batch normalization, and ReLU activation
        return processed_output


class EncoderBlock(nn.Module):
    function __init__(self, in_channels, out_channels):
        # Initialize ConvBlock followed by MaxPooling (2x2x2)

    function forward(self, x):
        # Pass input through ConvBlock
        # Apply MaxPooling
        return conv_output, pooled_output


class DecoderBlock(nn.Module):
    function __init__(self, in_channels, out_channels):
        # Initialize ConvTranspose3D for upsampling followed by ConvBlock

    function forward(self, x, skip_connection):
        # Upsample input using ConvTranspose3D
        # Concatenate upsampled input with the skip connection
        # Pass through ConvBlock
        return decoded_output


class Encoder(nn.Module):
    function __init__(self, input_channels, feature_depths):
        # Initialize EncoderBlocks for multiple stages and Bottleneck layer
        # Initialize attention mechanisms at bottleneck and last encoder stage

    function forward(self, *frames):
        # For each frame:
            # Process through EncoderBlock stages
            # Collect and store outputs at each stage (s1, s2, s3, s4) and pooled output (p4)
        # Stack outputs from all frames for attention
        # Apply attention (TAM) on pooled outputs and bottleneck outputs
        return tuple of all outputs


class UNet(nn.Module):
    function __init__(self, num_classes, feature_depths):
        # Initialize Encoder
        # Initialize DecoderBlocks for each stage of U-Net
        # Initialize final Conv3D layer for classification and Softmax activation

    function forward(self, *inputs):
        # Get outputs from Encoder
        # For each frame:
            # Unpack encoder outputs for skip connections and bottleneck
            # Pass through DecoderBlocks, using skip connections for each frame
        # Apply final classification (Conv3D + Softmax) to get segmentation mask
        return tuple of masks

πŸ“Œ Result Synopsis

Results of integrating our novel TAM with CNN- and Transformer-based segmentation models using the public CAMUS dataset. The improvements introduced by the TAM are highlighted in bold. The paper describes the PIA metric, which calculates the percentage of the total segmentation area accounted for by such β€œisland areas,” defined as any segmentation mass that is not the largest and that is disconnected from the largest mass. PIA measures anatomical plausibility.
Methods Class-wise HD (mm) ($\downarrow$) The average of the anatomical organs
LVMYO LVENDO LVEPI LA DSC($\uparrow$) HD($\downarrow$) MASD($\downarrow$) PIA(%)($\downarrow$)
UNet 5.65 4.21 5.67 4.91 0.913 5.11 1.13 2.05
TAM-UNet 4.05 3.07 3.89 3.52 0.922 3.63 0.96 0.68
FCN8s 6.80 5.44 6.01 7.26 0.899 6.38 1.33 0.58
TAM-FCN8s 3.60 3.04 3.33 3.27 0.921 3.31 0.98 0.02
UNetR 8.03 5.59 7.71 8.35 0.897 7.42 1.43 2.43
TAM-UNetR 6.08 4.62 5.86 6.05 0.904 5.65 1.24 0.92
SwinUNetR 8.33 5.60 8.24 6.41 0.888 7.15 1.52 2.67
TAM-SwinUNetR 5.63 4.25 5.32 4.11 0.913 4.83 1.15 1.32
Results of integrating our novel 3D-TAM with 3D-CNN- and Transformer-based segmentation models using public 3D echocardiography (MITEA). Improvements introduced by the TAM are highlighted in bold.
Methods Class-wise HD (mm) ($\downarrow$) The average of the anatomical organs
LVMYO LVENDO LVEPI LA DSC($\uparrow$) HD($\downarrow$) MASD($\downarrow$) PIA(%) ($\downarrow$)
UNet 14.58 9.89 12.31 - 0.830 12.26 2.03 0.30
TAM-UNet 11.47 9.14 10.84 - 0.833 10.48 1.97 0.16
FCN8s 12.07 11.95 10.95 - 0.828 11.66 2.06 1.07
TAM-FCN8s 9.27 7.59 8.24 - 0.836 8.37 1.93 0.22
UNetR 13.39 11.85 12.74 - 0.806 12.66 2.34 0.53
TAM-UNetR 10.70 9.56 9.96 - 0.814 10.07 2.21 0.38
SwinUNetR 10.95 10.10 10.25 - 0.818 10.43 2.27 0.36
TAM-SwinUNetR 9.67 8.67 9.01 - 0.823 9.12 2.12 0.23
Comparison of segmentation performance on the CAMUS dataset across state-of-the-art methods and our proposed motion-aware TAM-based segmentation models. Best-performing metrics are highlighted in bold.
Methods (motion?) LVMYO LVENDO LVEPI LA
DSC (↑) HD (↓) MASD (↓) DSC (↑) HD (↓) MASD (↓) DSC (↑) HD (↓) MASD (↓) DSC (↑) HD (↓) MASD (↓)
UNet(✘) 0.864 5.65 1.10 0.927 4.21 1.06 0.954 5.67 1.15 0.904 4.91 1.21
SwinUNetR(✘) 0.834 8.33 1.41 0.908 5.60 1.42 0.939 8.24 1.56 0.869 6.41 1.68
ACNN(✘) - - - 0.918 5.90 1.80 0.946 6.35 1.95 - - -
BEASNet(✘) - - - 0.915 6.0 1.95 0.943 6.35 2.15 - - -
UB2DNet(✘) - - - 0.858 - - - - - - - -
FFPN-R(✘) 0.850 3.65 - 0.924 3.05 - - - - 0.888 3.80 -
PLANet(✘) - - - 0.944 4.14 1.26 0.957 5.0 1.72 - - -
CoSTUNet(✘) - - - 0.916 6.55 - 0.837 7.65 - 0.875 6.70 -
I2UNet(✘) 0.873 4.72 1.03 0.933 3.49 1.02 0.956 4.39 1.09 0.910 4.25 1.19
Our TAM-I2UNet(βœ”) 0.872 4.19 1.03 0.933 3.02 0.972 0.956 3.92 1.06 0.913 3.74 1.11
Our TAM-FCN8s(βœ”) 0.876 3.60 0.973 0.935 3.04 0.949 0.959 3.33 0.961 0.916 3.27 1.06
SOCOF(βœ”) - - - 0.932 3.21 1.40 0.953 4.0 1.65 - - -
CLAS(βœ”) - - - 0.935 4.60 1.40 0.958 4.85 1.55 0.915 - -

About

We propose a scalable Temporal Attention Module (TAM) to inject cardiac motion information into segmentation models. TAM captures dynamic changes across temporal frames using multi-headed, cross-temporal attention. It is adaptable across imaging modalities, scalable from 2D to 3D, and adds minimal computational overhead.

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published