Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSOC] implement example of state-space model for connectivity #100

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

jadrew43
Copy link

@jadrew43 jadrew43 commented Jun 14, 2022

PR Description

Google Summer of Code (2022) project

Closes #99

WIP: Linear Dynamic System (state-space model using EM algorithm to find autoregressive coefficients) to infer functional connectivity by interpreting autoregressive coefficients as connectivity strength. The model uses M/EEG data as input, and outputs time-varying autoregressive coefficients for source space labels.

Completed during GSoC

  • A user-friendly API that allows the user to work easily with MEG and/or EEG EEG data, following MNE-Python's local standards and conventions for usability as much as possible:
    • Most of the code complexity is hidden from the user in the backend for the simplest interface:
data_path = mne.datasets.sample.data_path()

raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60)
events = mne.find_events(raw, ...)

event_dict = {...}
epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict,
                    preload=True).pick_types(meg=True,eeg=True)


fwd_fname = sample_folder / '....'
fwd = mne.read_forward_solution(fwd_fname)

cov_fname = sample_folder / 'sample_audvis-cov.fif'
cov = mne.read_cov(cov_fname)

label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh']
labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label',
                          subject='sample') for label in label_names]

model = LDS(lam0=0, lam1=100)
model.add_subject('sample', condition, epochs, labels, fwd, cov)
model.fit(...)
model.fit()
At = model.A
assert At.shape == (len(labels), len(labels), len(epochs.times))
  • Preprocess it to a format meant to increase the SNR of the data
  • Downsample the data for faster processing.
  • Utilizes forward and covariance matrices in the API well as the labels for the regions of interest (ROIs) of the dataset
  • PCA is used to reduce the dimensionality of the data
  • model.fit using the Expectation Maximization algorithm to fit the autoregressive coefficients of the state-space model, mapping the sensor data to each ROI, and computing the connectivity strength between ROI pairs
  • Plotting of the time-varying coefficients in a matrix format to observe the strength of connection between each pair of ROIs
  • An example script showing how to use the function and interpret its outputs using the MNE-Python sample dataset
  • Basic unit tests have been written and partially incorporated

Check-out this link to see my weekly progress. All of the code in this PR is new to MNE-Python's repositories.

Todo after GSoC

  • Finish unit tests
  • Finish replacing redundant functionality with MNE-Python equivalents (e.g., scale_data with compute_whitener and dot products)
  • The current implementation uses autograd as a dependency, which is no longer actively developed (but still maintained). The code should be updated to incorporate JAX, which includes the features of autograd and is being actively developed (autograd_linalg should be replaced by scipy.linalg).

Merge checklist

Maintainer, please confirm the following before merging:

  • All comments resolved
  • This is not your own PR
  • All CIs are happy
  • PR title starts with [MRG]
  • whats_new.rst is updated
  • PR description includes phrase "closes <#issue-number>"

@adam2392 adam2392 changed the title implement example of state-space model for connectivity [GSOC] implement example of state-space model for connectivity Jun 17, 2022
examples/mne_util.py Outdated Show resolved Hide resolved
Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some minor suggestions

Lmk when you want us to activate the CI pipelines + do a more in-depth look at the code!

examples/mne_util.py Outdated Show resolved Hide resolved
Copy link
Member

@larsoner larsoner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good start, let's keep iterating here until we're happy with how the interface looks in the example!

examples/mne_util.py Outdated Show resolved Hide resolved
examples/mne_util.py Outdated Show resolved Hide resolved
examples/mne_util.py Outdated Show resolved Hide resolved
examples/mne_util.py Outdated Show resolved Hide resolved
examples/mne_util.py Outdated Show resolved Hide resolved
examples/mne_util.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
@cbrnr
Copy link

cbrnr commented Jun 19, 2022

Hello! Just wanted to say hi to include myself in the loop 😄. Could someone quickly explain what the aim of this PR is? Is it to add VAR model based connectivity estimates? State-space sounds like it's implemented as a Kalman filter. At the risk of sounding repetitive, but could we look at https://github.com/scot-dev/scot to see if anything could be re-used here? I've implemented least squares VAR estimation (optionally with regularization) to compute several popular (directed) connectivity measures (for a list see https://scot-dev.github.io/scot-doc/api/scot/scot.html#module-scot.connectivity).

@jadrew43
Copy link
Author

Hi @cbrnr you are correct that this PR aims to implement a Kalman filter using an AR model to measure connectivity. Reviewing SCoT is still on my to-do list, thanks for the reminder.

@larsoner
Copy link
Member

larsoner commented Jul 5, 2022

From a quick chat with Jordan, here is what we fleshed out a bit based on my suggestion for the public API:

Internal implementation sketch and public API
# Internal code

class MEGLDS:

    def __init__(self, ...):
        ...
        self._subject_data = dict()

    def add_subject(subject, forward, cov, ...):
        self._subject_data[subject] = dict()
        self._subject_data[subject]['G'] = something(forward, ...)
        self._subject_data[subject]['C'] = something_else(cov, ...)

    def fit(self):
        Gs = np.array([val['G'] for val in self._subject_data.values()])
        ...


# User API should only have:
# - subjects_dir, then per-subject:
# - subject
# - Forward
# - Covariance
# - Epochs
# - list of Label

data_path = mne.datasets.sample.data_path()
subjects_dir = data_path / 'subjects'
model = MEGLDS(lambda0, lambda1, ...)
forward = mne.read_forward_solution(...)
cov = mne.read_covariance(...)
labels = list()
label_names = ('Aud.lh', 'Aud.rh', 'Visual.lh', 'Visual.rh')
for name in label_names:
    labels.append(mne.read_label(subjects_dir / 'sample' / 'labels' / name))
model.add_subject('sample', subjects_dir=subjects_dir, labels=labels,
                  forward=forward, cov=cov)
...
model.fit()
At = model.A
assert At.shape == (len(labels), len(labels), len(epochs.times))

EDIT: I resolved the conversations above where I talk about this since I think it's general enough to discuss in this main thread rather than inline

@jadrew43
Copy link
Author

jadrew43 commented Jul 7, 2022

examples/state_space_connectivity.py is functioning properly on my machine with the output depicted below. I think for a single subject, these results look good. Along the diagonal, values are close to 1, as expected for a computation similar to an autocorrelation. For the condition auditory/left there seems to be a connection from Aud-lh to Aud-rh as seen by the non-zero values in graph [0,1]. I expect measurements to be less noisy when running for a large number of subjects. Please run CI checks.

x-axis: time (seconds); y-axis: connectivity strength (autoregressive coefficients)

image

@jadrew43
Copy link
Author

jadrew43 commented Jul 7, 2022

I am at a research conference for the next 7 days. When I return I have the following to-do:

  • work to replace scale_sensor_data with mne.make_ad_hoc_cov() and compute_whitener()
  • test for different conditions; see how results change
  • use different dataset with multiple subjects for sanity check

Looking forward to your feedback!

@jadrew43 jadrew43 marked this pull request as ready for review July 7, 2022 21:46
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/megssm/mne_util.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/state_space_connectivity.py Outdated Show resolved Hide resolved
examples/megssm/models.py Outdated Show resolved Hide resolved
examples/megssm/models.py Outdated Show resolved Hide resolved
examples/megssm/models.py Outdated Show resolved Hide resolved
self._roi_cov_0 = roi_cov_0

@property
def log_sigsq_lst(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this stuff looks like it should be private to me. That also means we don't need a @property or a @whatever.setter for each of them. We just use private attributes directly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I should be able to remove all of these @property and @some.setter lines right? Does that mean there needs to be additional checks on variable format somewhere? I don't think any of these have any checks as they are written here. I'd also like to talk about which vars and functions to make private. In my mind, all of the vars and functions in models.py should be private.


# retrieve forward and sensor covariance
fwd_src_snsr = fwd['sol']['data'].copy()
snsr_cov = cov.data.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... probably here you do

epochs = epochs.copy().pick('data', exclude='bads')
snsr_cov = pick_channels_cov(cov, epochs.ch_names, ordered=True).data
fwd = convert_forward_solution(fwd, force_fixed=True)
fwd_src_snsr = pick_channels_forward(fwd, epochs.ch_names, ordered=True)['sol']['data']
data = epochs.get_data()
info = epochs.info
del cov, fwd, epochs

now you are guaranteed that data, fwd_src_snsr, and snsr_cov all have the same set of channels in the same order. Then to rescale them all you can just do:

rescale_cov = make_ad_hoc_cov(epochs.info)
scaler = compute_whitener(info, rescale_cov)
del rescale_cov
fwd_src_snsr = scaler @ fwd_src_snsr
snsr_cov = scaler @ snsr_cov @ scaler.T
data = scaler @ epochs.get_data()  # @ nicely knows to operate over the last two dims of (epochs, chs, time)

or something similar. Basically you let the MNE pick_channels_* make sure all channels are present and ordered properly, then make use of MNE functions to do the rescaling.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously you said to play around with these functions to see if it would produce the same outputs as the scale_sensor_data function. The original function was using a scale of 1 for each of the channel types so nothing was changed. This scaling does alter the data, primarily in a significant increase in the number of principal components (PCA run immediately after scaling function). Also, the output of the model is much noisier using @ scaler, I think as expected with the significant increase in PCs. Let's talk to make sure this is working as expected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output using mne functions for scaling seems much noisier.
image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to no scaling (or previous function with scale of 1 for each channel type).
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original function was using a scale of 1 for each of the channel types so nothing was changed.

This is the equivalent of using mne.make_ad_hoc_cov(..., std=1.). So if you want this mode you can get it easily with a very small change. Can you verify this looks the same as the old code?

This scaling does alter the data, primarily in a significant increase in the number of principal components (PCA run immediately after scaling function).

This is to be expected. If you do no scaling and have MEG and EEG data, your EEG data will be ~6 orders of magnitude larger or so, so you will end up only looking at EEG data.

Also, the output of the model is much noisier using @ scaler, I think as expected with the significant increase in PCs. Let's talk to make sure this is working as expected.

What's not clear to me: is it also noisier (and noisier in the same way) when using the old code path but enabling scaling?

In other words: is this a problem with the new way of scaling, or is this a problem that has always existed with scaling the data, regardless of scaling method?

We can chat about this today

@jadrew43
Copy link
Author

I am currently working to incorporate an old dataset to see if these scripts produce the expected results.

@jadrew43
Copy link
Author

Hey @adam2392 one of the CI errors is due to autograd: ModuleNotFoundError: No module named 'autograd'. I'd really like to move to the next step of the proposal and work to integrate jax (to replace autograd) later on in order to keep pace with the summer milestones. Is it alright if I install autograd in the dependencies for mnedev (assuming that will fix the error) for now and work to integrate jax later on?

@adam2392
Copy link
Member

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.

Perhaps we can start another sep issue to track migration to jax later on?

@jadrew43
Copy link
Author

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.

Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

@adam2392
Copy link
Member

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.
Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt
and then just add a comment e.g.

...
# TODO: we will replace this with Jax
autograd

Then when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run import mne_connectivity if they don't have autograd/jax installed. For example, this function inside MNE-Python needs pyqt, but imports it within the function so mne still works if the user only has numpy/scipy installed.

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

Yeah I'll create the GH issue.

@jadrew43
Copy link
Author

Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.
Perhaps we can start another sep issue to track migration to jax later on?

@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?

You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt and then just add a comment e.g.

...
# TODO: we will replace this with Jax
autograd

Then when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run import mne_connectivity if they don't have autograd/jax installed. For example, this function inside MNE-Python needs pyqt, but imports it within the function so mne still works if the user only has numpy/scipy installed.

Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?

Yeah I'll create the GH issue.

Ok I think that's complete. Thanks!

@jadrew43
Copy link
Author

jadrew43 commented Aug 4, 2022

Update: I got the code to work for an old dataset, however the results are not what I expected. I am currently working to get the sample data (much smaller than the old dataset) to work on my original model. This output will be my new ground truth and I can work iteratively to make sure each edit I make to the original model to conform it the MNE-Python API standards produces the same output. Lesson learned - have a simple ground truth to work with from the beginning :)

@jadrew43
Copy link
Author

jadrew43 commented Aug 8, 2022

Here is the output of my original model using the sample dataset for the auditory/left condition for the 12 ROIs from a different dataset. My next step is to get this same output in the mnedev environment. Then I will work to use only the 4 ROIs commonly used with the sample dataset. Then piece by piece I will recreate the API.
image

@larsoner
Copy link
Member

larsoner commented Aug 9, 2022

Excellent!

A good next step is to make sure all random seeds can be set such that if you run this again you get the exact same output (to numerical precision at least). Then you can save the At result to disk now and compare to it each time you replace some piece of code

@jadrew43
Copy link
Author

jadrew43 commented Aug 11, 2022

I have changed the PCA method from being based on the rank of the matrix to a method based on explaining 99% of the variance. This method allows the fitting of the model to run much faster as it produces 147 principal components vs 360 components produced from the rank method. The model output is noticeably but not extremely different. My next step is to perform the processing steps using the 4 labels provided in the sample dataset, which should reduce processing time even further. All processing was completed within the mnedev conda environment.
image

@jadrew43
Copy link
Author

jadrew43 commented Aug 11, 2022

Processing completed with 4 labels from sample.
image

@jadrew43
Copy link
Author

Bootstrapping of epochs, and PCA of epochs.get_data() and forward matrices completed in API. Model fitting completed in command line model. Output from API (LDS) compared to command line model (MEGLDS) are not identical but extremely similar.
image
image

@larsoner
Copy link
Member

Nice!

It would be good to know what the differences are that make them not identical, but really if this version of the API works on our UW data as well (maybe even just for one subject?) then I'd say you could use this as the "ground truth" for correctness of additional changes!

@jadrew43
Copy link
Author

@larsoner Can you look at megssm/mne_util.py L114. Am I using the scaler correctly? Because the results do not agree with the original _scale_sensor_data (L140). Thanks.

I'll get to the CI checks first thing Monday!

def fwd_src_roi(self, val):
self._fwd_src_roi = val

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jadrew43 this is line 114 and it's not related to scaling. There is some stuff below that is. Can you confirm by commenting on the correct lines/lines in the "Files" tab of this PR? This not being the correct line makes me wonder if you forgot a push, and I don't want to look prematurely.

The best way to help me check would be to make it very easy to run and check the results. For example, add a commented-out (to make CIs happy) # np.testing.assert_allclose(data, data_mne, rtol=..., atol=...) statement that fails (when you uncomment it) where you scale data the old way and data_mne the new way. Then I need a minimal script (hopefully runs in just a few seconds, and on sample data for example) I can run on this branch that will fail when I uncomment the assert_allclose line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated files are in state_space/, I haven't done anything in examples/ in weeks so I went ahead and deleted the files I had in there. Running state_space_connectivity.py takes about 90 seconds, adding _mne_scale_sensor_data() adds about 60 seconds. Happy to chat about building tests that take much less time. The commented testing line is L228 in mne_util.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't think this is quite a minimal reproducible bit of code.

When I tried to run the code, it failed because I needed autograd. I installed that, then it complained that I needed autograd_linalg. I can't find this on pip or conda, so I tried swapping in from scipy.linalg import solve_triangular.

Next I tried to run the script and got:

FileNotFoundError: [Errno 2] No such file or directory: '/home/larsoner/mne_data/MNE-sample-data/MEG/sample/labels/AUD-lh.label'

This makes sense, usually labels are in the SUBJECTS_DIR. But even then there aren't auditory labels in sample. So let's change it to something that should work I think:

regexp = '^(G_temp_sup-Plan_tempo.*|Pole_occipital)'
labels = mne.read_labels_from_annot(
    'sample', 'aparc.a2009s', regexp=regexp, subjects_dir=subjects_dir)
label_names = [label.name for label in labels]
assert len(label_names) == 4

Then I got a bit farther, until I hit:

AttributeError: python: undefined symbol: mkl_get_max_threads

So I then disabled all the thread-setting stuff, and then get to:

  File "/home/larsoner/python/scipy/scipy/linalg/_basic.py", line 336, in solve_triangular
    raise ValueError('expected square matrix')
ValueError: expected square matrix

So my swap of scipy.linalg is not correct.

I don't think I can proceed until I can install autograd_linalg somehow, how do you recommend doing this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try this:
pip install --src deps -e git+ssh://git@github.com/nfoti/autograd_linalg.git@master#egg=autograd_linalg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit for a solve_triangular that allows the code to run on my machine. I can look into the issue now!

@adam2392
Copy link
Member

adam2392 commented Oct 3, 2022

Hi @jadrew43 and @larsoner any help needed to review code / look at prelim results here? Feel free to lmk where I can help!

@larsoner
Copy link
Member

larsoner commented Oct 3, 2022

IIRC code is still WIP / needs to be systematically converted to MNE conventions, but some bugs have been found along the way which is good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gathering feedback on API Proposal for state-space model
5 participants