Skip to content

arashsm79/brain-opto-fmri-decoding-gnn

Repository files navigation

Brain Opto-fMRI Decoding Using Graph Neural Networks

Opto-fmri decoding of locus coeruleus firing patterns using Graph Neural Networks and the BrainGNN framework. The code is comprehensively commented and is a good example to understand how BrainGNN works and how to visualize the salient ROIs.

Contents

Introduction

The noradrenaline (NA) circuit holds significant importance in the brain, and its regulation is governed by the Locus Coeruleus (LC), a nucleus in the brainstem. Despite its significance, the impact of the activation pattern of the LC on circuit regulation remains underexplored. A deeper understanding of this aspect could pave the way for the development of techniques in the biomedical field. In this study, data from [1] was utilized to investigate how opto-fMRI experiments, involving tonic or burst-like stimulations (3Hz and 15Hz) of the LC in mice, lead to distinct modulations of brain activity. Graph Neural Networks using the BrainGNN framework were used to identify the brain regions that best discriminate between the two stimuli.

Methods

The dataset includes Functional Magnetic Resonance Imaging (fMRI) recordings of mouse brain activity. Each recording is composed of a sequence of 1440 fMRI images, where each image has dimensions of 90x50x25 voxels. Each image represents 1 second of recording time. In the experimental setup, 15 mice underwent tonic stimulation, while 18 mice received burst-like stimulation. The initial 8 minutes of each recording capture the mouse's resting state. Following this, a series of 8 stimulation protocols are applied, each consisting of 30 seconds of stimulation followed by 30 seconds of rest. The final 8 minutes of the experiment document another resting state for the mouse. The primary focus of this study is the analysis of the brain's response to stimulation during the middle 8 minutes.

To identify regions and networks that contribute to distinguishing between the two types of stimulations, graph neural networks will be employed in the analysis. I leverage the fMRI time series to model the brain as a graph, where nodes represent ROIs and edges denote functional connectivity between these ROIs. Traditional graph-based analyses for fMRI often involve feature engineering, the reliability of which significantly impacts further analysis. Recently, Graph Neural Networks (GNNs) have gained prominence in the field of neuroimaging for their effectiveness in handling graph-structured data, showcasing superior performance and interpretability [2].

Using the annotated brain atlas for this dataset, we first computed all the nodes that had more than 53 voxels inside of them, which gave us a total of 68 ROIs. We then extracted the fMRI time series for each 60s stimulation period. The mean time series for each node is derived from a random subset comprising one-third of the voxels within the ROI. This random sampling is performed 30 times, resulting in 30 graphs for each stimulation instance. Since we have 8 stimulations for each of our 33 fMRI images, we get a total of $8 \times 33 \times 30 = 7920$ graphs. The time series is used to calculate Pearson correlation, which is used as node features, and partial correlation that establishes the connectivity and the edges in the graph. We only keep the edges that are in the top $10%$ percent of the partial correlation matrix. Partial correlation gives us sparser connections to address the over-smoothing effect observed in general graph neural networks for densely connected graphs. Additionally, using both Pearson and partial correlations allows us to combine different measures of functional connectivity.

The model architecture, illustrated in Fig. \ref{fig:gnn_arch} is composed of 1) a graph attention convolutional layer that takes a graph (or batch of graphs) and transforms its node embeddings from $\mathbf{R}^{68}$ to $\mathbf{R}^{32}$, 2) a top-k pooling layer which reduces the dimensionality of the entire graph by removing nodes that are less important. Node pruning is done based on a projection score against a learnable vector, 3) a read out layer in $\mathbf{R}^{2 \times 32}$ that flattens the graph by concatenating global max and mean pool of the node embeddings, and 4) Finally, a three-layer Multilayer Perceptron (MLP) that transforms the read out layer from $\mathbf{R}^{64}$ to $\mathbf{R}^{2}$ after which a log softmax is applied for the final classification providing predictions for 3Hz and 15Hz cases. To reduce the effects of overfitting, we also added batch normalization to these last layers and employed dropout with probability of $0.7$.

Model Architecture

The loss function is the sum of a cross-entropy classification loss along with additional terms for regularization and interpretability, namely, the unit loss which addresses identifiability concerns by ensuring that the learnable vector in the pooling layer remains a unit vector, the group-level consistency (GLC) loss that enforces the selection of similar ROIs in the pooling layer for different inputs, and the TopK pooling (TPK) loss which promotes reasonable node selection by encouraging distinct scores for selected and unselected ROIs. The explicit formulation for these functions is delineated in the work referenced as [2], and their practical instantiation is expounded upon in the accompanying notebook. The simple formulas for these functions are detailed in [2], and their implementation is in the accompanying code. The complete loss function is, therefore, as below with $\lambda_{0, 1, 2}$ as the hyperparameters of the model.

$L = L_{\text{Cross-entropy}} + \lambda_0 L_{\text{Unit}} + \lambda_1 L_{\text{TopK}} + \lambda_2 L_{\text{GLC}} $

The dataset was randomly partitioned into five folds based on subjects, ensuring that graphs from a single subject appeared exclusively in either the training or test set. Four folds were allocated for training, and the remaining fold served as the test set. We used PyTorch and PyTorch Geometric to implement the model and perform training and testing. The Adam optimizer was employed with a learning rate of $0.001$ and weight decay of $0.2$, and the model underwent training for 50 epochs. The StepLR scheduler was also used to decay the learning rate of each parameter group by $0.5$ every $10$ epoch. The loss parameters were chosen as $\lambda_0 = 0.1, \lambda_1 = 0.1, \lambda_2 = 0.4$ to maximize accuracy and interpretability [2].

Results and Discussion

The cross-validation result showed an above-chance average performance of $61.6%$ on the left out test sets. The weights, scores of the topk pooling layer, and metrics of the top performing model in each fold were saved. To take advantage of the interpretability of our model, we combined the scores for ROIs across the folds and assigned a single average score to each ROI. The figure below shows these salient ROIs that seem to be informative for discriminating between the two classes. The higher the score, the better that region is for the decoding task. Refer to visualization.ipynb to see how to visualize top-k pooling scores using the atlas.

Salient ROI

The primary challenge for the GNN model was the limited size of the dataset, posing a significant constraint for training a generalizable deep learning model. In order to mitigate overfitting, we first minimized model complexity thereby reducing the overall number of parameters, we then systematically reduced our feature dimensionality by decreasing the nodes and edges of the graphs. We also introduced batch normalization and dropout techniques to enhance the model's generalization and, additionally, incorporated weight decay, to further regulate the model's parameters and discourage overfitting. Given the extremely marginal differences between the two classes, data augmentation could not help pass a certain point. Despite these challenges, our model achieved a remarkable above-chance accuracy and outlined the prominent ROIs.

References