Skip to content

Commit

Permalink
modified example leadfield matrix notebook and leadfield function scr…
Browse files Browse the repository at this point in the history
…ipt to fix some Python version incompatibility, also adding some new features.
  • Loading branch information
Wirkungstreffer committed Apr 3, 2024
1 parent efc51a4 commit ffdac07
Show file tree
Hide file tree
Showing 2 changed files with 426 additions and 30 deletions.
248 changes: 223 additions & 25 deletions examples/example-0.8-leadfield-matrix.ipynb

Large diffs are not rendered by default.

208 changes: 203 additions & 5 deletions neurolib/utils/leadfield.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import numpy as np
import matplotlib.pyplot as plt
import typing
import pandas as pd

import nibabel as nib
import mne
Expand All @@ -15,8 +17,10 @@
class LeadfieldGenerator:

"""
Authors: Mohammad Orabe <orabe.mhd@gmail.com>
Zixuan liu <zixuan.liu@campus.tu-berlin.de>
Authors:
Zixuan liu <zixuan.liu@campus.tu-berlin.de>
Mohammad Orabe <orabe.mhd@gmail.com>
A class to compute the lead-field matrix and perform related operations.
The default loaded data is the template data 'fsaverage'.
Expand Down Expand Up @@ -318,7 +322,7 @@ def __get_backprojection(

return back_proj_rounded

def __filter_for_regions(self, label_strings: list[str], regions: list[str]) -> list[bool]:
def __filter_for_regions(self, label_strings: typing.List[str], regions: typing.List[str]) -> typing.List[bool]:
"""
Create a list of bools indicating if the label_strings are in the regions list.
This function can be used if one is only interested in a subset of regions defined by an atlas.
Expand Down Expand Up @@ -349,7 +353,7 @@ def __get_labels_of_points(
xml_file: dict,
atlas="aal2_cortical",
cortex_parts="only_cortical_parts",
) -> tuple[list[bool], np.ndarray, list[str]]:
) -> typing.Tuple[typing.List[bool], np.ndarray, typing.List[str]]:
"""
Gives labels of regions the points fall into.
Expand Down Expand Up @@ -443,7 +447,7 @@ def __get_labels_of_points(

def __downsample_leadfield_matrix(
self, leadfield: np.ndarray, label_codes: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Downsample the leadfield matrix by computing the average across all dipoles falling within specific regions. This process assumes a one-to-one correspondence between source positions and dipoles, as commonly found in a surface source space where the dipoles' orientations are aligned with the surface normals.
Expand Down Expand Up @@ -611,3 +615,197 @@ def check_atlas_missing_regions(self, atlas_xml_path, unique_labels):
missed_region_indices = np.array([i + 1 for i, e in enumerate(label_numbers) if e in subset])
print("missed region indices:", missed_region_indices)
print("=====================================================")

def view_all_region_names(self,atlas_xml_path):
"""
take a view of all the region names in the atlas.
Parameters:
==========
atlas_xml_path (str): Path to the XML file containing label information.
Returns:
=======
None
"""
xml_file = self.__create_label_lut(atlas_xml_path)
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers
empty_set = []
all_region_labels = np.setdiff1d(label_numbers, empty_set)
#print("all region labels:", all_region_labels)

all_region_labels_str = all_region_labels.astype(str)
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file)
print("all region names:", all_region_values)
print("=====================================================")

def find_region_corresponding_index(self,atlas_xml_path, region_name):
"""
find the index of given region name of the atlas.
Parameters:
==========
atlas_xml_path (str): Path to the XML file containing label information.
region_name (str): The name of the region of the atlas.
Returns:
=======
region_index (np.ndarray): The index of given region name in the atlas.
"""
xml_file = self.__create_label_lut(atlas_xml_path)
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers
empty_set = []
all_region_labels = np.setdiff1d(label_numbers, empty_set)
#print("all region labels:", all_region_labels)

all_region_labels_str = all_region_labels.astype(str)
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file)
#print("all region names:", all_region_values)

all_subset = set(all_region_labels)
all_region_index = np.array([i+1 for i, e in enumerate(label_numbers) if e in all_subset])
#print("all region index:", all_region_index)

region_index = all_region_values.index(region_name)

print("Index for %s:" %region_name, region_index)
print("=====================================================")

return region_index

def simulated_source_data(
self,
leadfield_downsampled,
timepoints_number=1000,
frequency_parameter=(5,10),
time_parameter=(0, 1)
):
"""
Generate simulated source data.
Parameters:
==========
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix.
timepoints (np.ndarray): Number of timepoints of generated data.
frequency_parameter (np.ndarray): The parameter of random frequencies for each dipole.
time_paremter (np.ndarray): The total time of generated data.
Returns:
=======
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints number
time (np.ndarray): The total timepoints of generated data.
"""

n_dipoles = leadfield_downsampled.shape[1] # Number of dipoles
n_timepoints = timepoints_number # Number of time points in the simulated data
frequencies = np.random.uniform(frequency_parameter[0], frequency_parameter[1], n_dipoles) # Random frequencies for each dipole
time = np.linspace(time_parameter[0], time_parameter[1], n_timepoints) # 1 second of data

# Create source time-series: [n_dipoles x n_timepoints]
simulated_source_data = np.array([np.sin(2 * np.pi * f * time) for f in frequencies])

return simulated_source_data, time

def plot_eeg_data(self, eeg_data, time, title, offset_per_channel):
"""
Plot the calculated EEG data.
Parameters:
==========
eeg_data (np.ndarray): The calculated EEG data based on source data and the lead field matrix.
time (np.ndarray): The total timepoints of generated data
title (str): The title of the simulated EEG plot.
offset_per_channel (np.ndarray): The offset of the simulated EEG plot.
csv_file_title (str): Title name of the csv file
path_to_save_csv (str): The path to save the csv file.
Returns:
=======
None
"""
channel_offsets = np.arange(eeg_data.shape[0]) * offset_per_channel
for i, channel_data in enumerate(eeg_data):
plt.plot(time, channel_data + channel_offsets[i], label=f'Channel {i}')
plt.yticks(channel_offsets, [f'Channel {i}' for i in range(eeg_data.shape[0])])
plt.title(title)
plt.xlabel("Time (s)")
plt.ylabel("Channels")
plt.tight_layout()
plt.show() # Explicitly display the plot

def simulated_eeg_data(
self,
simulated_source_data,
leadfield_downsampled,
time,
visualization=True,
plot_title="Simulated EEG Data",
plot_offset=None,
plot_size=(8, 16),
csv_file_name="simulated_eeg_data.csv",
folder_to_save_csv="examples/data/AAL2_atlas_data"
):
"""
Calculate simulated EEG data based on generated simulated source data, generate the plot and csv file.
Parameters:
==========
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix.
time (np.ndarray): The total timepoints of generated data
plot_title (str): The title of the simulated EEG plot.
plot_offset (np.ndarray): The offset of the simulated EEG plot.
plot_size (np.ndarray): The size of the simulated EEG plot.
csv_file_name (str): Saved name of the csv file.
folder_to_save_csv (str): The folder to save the csv file.
Returns:
=======
simulated_eeg_data (np.ndarray): The calculated EEG data with the dimension channels x timepoints
"""
# Simulate EEG data: [n_sensors x n_timepoints]
simulated_eeg_data = np.dot(leadfield_downsampled, simulated_source_data)

# Plot EEG data
if visualization == True:
# Define an offset between each channel's plot
if plot_offset is None:
offset_per_channel = np.max(np.abs(simulated_eeg_data)) * 1.5
else:
offset_per_channel = plot_offset

plt.figure(figsize=plot_size)
self.plot_eeg_data(simulated_eeg_data, time, plot_title, offset_per_channel)

# List to store individual dataframes for each channel
dfs = []
# Loop through each channel and create individual dataframes
for i in range(simulated_eeg_data.shape[0]):
df_channel = pd.DataFrame({
f'Channel_{i+1}': simulated_eeg_data[i, :],
})
dfs.append(df_channel)

# Concatenate individual dataframes along columns axis
df_pairwise = pd.concat(dfs, axis=1)

# Save to CSV
if folder_to_save_csv is not None:

folder_path = folder_to_save_csv
if not os.path.exists(folder_path):
os.makedirs(folder_path)

path_to_save_csv = os.path.join(folder_path, csv_file_name)

df_pairwise.to_csv(path_to_save_csv, index=False)

print(f"The simulated EEG data is saved as a csv file at {path_to_save_csv}")
print("=====================================================")

return simulated_eeg_data

0 comments on commit ffdac07

Please sign in to comment.