Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't reproduce experimental results of BCI Competition IV dataset 2a for within classification #31

Open
realblack0 opened this issue Sep 22, 2021 · 9 comments

Comments

@realblack0
Copy link

Hello.

I am trying to reproduce experimental results of BCI Competition IV dataset 2a for within classification in the paper.
I reused EEGNet class but got mean accuracy 60~63. I expected around 68.

I read #7 and tried EEGNet-8,2, with kernLength = 32.
I did 4-fold blockwise cross-validation that splits training set into three equal contiguous partitions(96/96/96) and selects each one of three partitions as validation set while retaining test set(288). So there were three training for each subject.

I did preprocess using braindecode. And I also tried preprocessing uisng scipy, but got similar accuracy.
In my thought, I missed something in preprocessing.
Could you check the code below and help me find something I missed?
Or, it would be thankful if you share the preprocessing code you used.

Here is code using braindecode:

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import numpy as np
import pickle
import argparse
import os
import shutil

from arl_eegmodels.EEGModels import EEGNet

from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.preprocess import (
    exponential_moving_standardize, preprocess, Preprocessor)
from braindecode.datautil.windowers import create_windows_from_events

# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
K.set_image_data_format('channels_last')


###################
## Configuration ##
###################
# SYSTEM
parser = argparse.ArgumentParser()
parser.add_argument('--name', required=True)
parser.add_argument('--device', default=None)
parser.add_argument('--subject', type=int, required=True)
args = parser.parse_args()

if args.device:
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device # use specific gpu

name = args.name
subject_id = args.subject

def make_results_directory(name, subject_id, base="."):
    results_dir = f"{base}/results/{name}_subject{subject_id}"
    if os.path.exists(results_dir):
        print(f"'{results_dir}' already exists!")
        raise
    os.mkdir(results_dir)
    
    shutil.copy("main.py", results_dir)
    print(f"'{results_dir}' is created!")
    
make_results_directory(name, subject_id, base=".")

###################
#### Load Data ####
###################
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 40.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda x: x * 1e6),  # Convert from V to uV
    Preprocessor('resample', sfreq=128), # Added by me
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]
# Transform the data
preprocess(dataset, preprocessors)

trial_start_offset_seconds = 0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)


trial_stop_offset_seconds = -1.5 
trial_stop_offset_samples = int(trial_stop_offset_seconds * sfreq)


# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=trial_stop_offset_samples,
    preload=True,
)

splitted = windows_dataset.split('session')
train_set = splitted['session_T']
test_set = splitted['session_E']


X_Train = []
y_Train = []
for run in train_set.datasets:
    for X, y, _ in run:
        X_Train.append(X)
        y_Train.append(y)
X_Train = np.array(X_Train)
y_Train = np.array(y_Train)
        
X_test = []
y_test = []
for run in test_set.datasets:
    for X, y, _ in run:
        X_test.append(X)
        y_test.append(y)
X_test = np.array(X_test)
y_test = np.array(y_test)

print("X_Train shape:", X_Train.shape, "y_Train shape:", y_Train.shape)
print("X_test shape:", X_test.shape, "y_test shape:", y_test.shape)

fold1 = list(range(0,  96))
fold2 = list(range(96, 192))
fold3 = list(range(192,288))
train_val_split = [
    (fold2 + fold3, fold1),    
    (fold3 + fold1, fold2),
    (fold1 + fold2, fold3)
]
    
###########################
#### Cross Validation #####
###########################

fold_hist = []
for fold_step, (train_index, val_index) in enumerate(train_val_split):
    print(f"Start fold {fold_step}")
    print("Train", train_index)
    print("Val", val_index)
    X_train = X_Train[train_index]
    y_train = y_Train[train_index]
    X_val   = X_Train[val_index]
    y_val   = y_Train[val_index]
    
    ###################
    #### Modeling #####
    ###################
    model = EEGNet(nb_classes=4,
                   Chans=22,
                   Samples=256, # 2 seconds at 128 Hz
                   dropoutRate=0.5,
                   kernLength=32, # for SMR data
                   F1=8,
                   D=2,
                   F2=16,
                   norm_rate=0.25, # FC layer
                   dropoutType="Dropout")
    
    ####################
    #    Training      #
    ####################
    # set a valid path for your system to record model checkpoints
    checkpointer = ModelCheckpoint(filepath=f'results/{name}_subject{subject_id}/{name}_subject{subject_id}_fold{fold_step}.h5', verbose=1,
                                   save_best_only=True)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    train_hist = model.fit(X_train, y_train, 
             batch_size=64, epochs=500,
             verbose=2, validation_data=(X_val, y_val),
             callbacks=[checkpointer])
    
    # load optimal weights
    model.load_weights(f'results/{name}_subject{subject_id}/{name}_subject{subject_id}_fold{fold_step}.h5')
    test_hist = model.evaluate(X_test,  y_test, verbose=2)
    
    fold_hist.append({"train_hist":train_hist.history, "test_hist":test_hist})
    
    fold_best_val_acc = [np.max(_["train_hist"]["val_accuracy"]) for _ in fold_hist]
fold_test_acc = [_["test_hist"][1] for _ in fold_hist]
fold_mean_test_acc = np.mean([_["test_hist"][1] for _ in fold_hist])

print("\n##############################################")
print("subejct", subject_id, "- 3fold best val  accuracy", fold_best_val_acc)
print("subejct", subject_id, "- 3fold      test accuracy", fold_test_acc)
print("subejct", subject_id, "- 3fold mean test accuract:", fold_mean_test_acc)

with open(f"results/{name}_subject{subject_id}/fold_hist.pkl", "wb") as f:
    pickle.dump(fold_hist, f)
@robintibor
Copy link
Contributor

robintibor commented Sep 22, 2021

@realblack0
Copy link
Author

realblack0 commented Sep 22, 2021

I just tried EEGNetv4 from braindecode, but the results was still same(mean accuracy 61). I think that the problem may lie in preprocessing part. Could you give me advice about preprocessing?

@okbalefthanded
Copy link

@realblack0 See the following the Colab notebook, the preprocessing is taken from this repo https://github.com/iis-eth-zurich/eeg-tcnet , it has basic operations of segmentation and normalization no filtering is applied, I've used it to test the equivalence of EEGNet implementations from Keras and Braindecode.

Results:
 - Keras: 
      -  Model : EEGNet Mean accuracy all dataset: 0.72126762548198 std. 0.08706506238035744
      -  Model : EEGNet Mean Kappa  all dataset: 0.6284418754376309 std. 0.11591307087949579
 - PyTorch (EEGNetV4) :
    - Model : EEGNETV4 Mean accuracy for all dataset: 0.7274140515382533 std. 0.06291401587087057
    - Model : EEGNETV4 Mean Kappa for all dataset: 0.6364804236516348 std. 0.08392478435633635

The notebook is self contained, with all instructions available from data download to evaluation. Here it is you can play with it:
https://colab.research.google.com/drive/1ANF8PwvtUPawTeQt4Uu4iwscpyhHBgvM?usp=sharing

@YoloEliwa
Copy link

@realblack0 See the following the Colab notebook, the preprocessing is taken from this repo https://github.com/iis-eth-zurich/eeg-tcnet , it has basic operations of segmentation and normalization no filtering is applied, I've used it to test the equivalence of EEGNet implementations from Keras and Braindecode.

Results:
 - Keras: 
      -  Model : EEGNet Mean accuracy all dataset: 0.72126762548198 std. 0.08706506238035744
      -  Model : EEGNet Mean Kappa  all dataset: 0.6284418754376309 std. 0.11591307087949579
 - PyTorch (EEGNetV4) :
    - Model : EEGNETV4 Mean accuracy for all dataset: 0.7274140515382533 std. 0.06291401587087057
    - Model : EEGNETV4 Mean Kappa for all dataset: 0.6364804236516348 std. 0.08392478435633635

The notebook is self contained, with all instructions available from data download to evaluation. Here it is you can play with it: https://colab.research.google.com/drive/1ANF8PwvtUPawTeQt4Uu4iwscpyhHBgvM?usp=sharing

Thanks for your reply,But i want to ask where has normalization in the Colab notebook, I just see segmentation.

@okbalefthanded
Copy link

@realblack0 See the following the Colab notebook, the preprocessing is taken from this repo https://github.com/iis-eth-zurich/eeg-tcnet , it has basic operations of segmentation and normalization no filtering is applied, I've used it to test the equivalence of EEGNet implementations from Keras and Braindecode.

Results:
 - Keras: 
      -  Model : EEGNet Mean accuracy all dataset: 0.72126762548198 std. 0.08706506238035744
      -  Model : EEGNet Mean Kappa  all dataset: 0.6284418754376309 std. 0.11591307087949579
 - PyTorch (EEGNetV4) :
    - Model : EEGNETV4 Mean accuracy for all dataset: 0.7274140515382533 std. 0.06291401587087057
    - Model : EEGNETV4 Mean Kappa for all dataset: 0.6364804236516348 std. 0.08392478435633635

The notebook is self contained, with all instructions available from data download to evaluation. Here it is you can play with it: https://colab.research.google.com/drive/1ANF8PwvtUPawTeQt4Uu4iwscpyhHBgvM?usp=sharing

Thanks for your reply,But i want to ask where has normalization in the Colab notebook, I just see segmentation.

the normalization is applied with these lines :

    for j in range(22):
        scaler = StandardScaler()
        scaler.fit(X_train[:,0,j,:])
        X_train[:,0,j,:] = scaler.transform(X_train[:,0,j,:])
        X_test[:,0,j,:] = scaler.transform(X_test[:,0,j,:])

@martinwimpff
Copy link

@okbalefthanded: I had the same issue as @realblack0 and scaling the input resolved the issue for EEGNet.

Further I tried to reproduce the ShallowNet results with the original model from braindecode (pytorch/skorch).
The major issue here was regularization. Adding a kernel_constraint to the first (or the first two) layers did not help. However adding a kernel_constraint to the final layer improved the performance by roughly 10% acc over all subjects and resolved the issue.

@robintibor: Maybe this Conv2dWithConstraint should be added to the original ShallowNet Implementation?

@okbalefthanded
Copy link

@martinwimpff Indeed the kernel constraint has a significant effect on the results, for my PyTorch implementation of EEGModels code , I re-implemented the kernel constraint for both the Conv2D and Linear layers as it is implemented in Keras.
The MaxNorm function will be:

import torch

def MaxNorm(tensor, max_value, axis=0):
    eps = 1e-7
    norms = torch.sqrt(torch.sum(torch.square(tensor), axis=axis, keepdims=True))
    desired = torch.clip(norms, 0, max_value)
    return tensor * (desired / (norms + eps))    

@wwwyz02
Copy link

wwwyz02 commented Dec 15, 2023

事实上,内核约束对结果有重大影响,对于我的 EEGModels 代码的 PyTorch 实现,我重新实现了 Conv2D 和 Linear 层的内核约束,因为它是在 Keras 中实现的。 MaxNorm 函数为:

import torch

def MaxNorm(tensor, max_value, axis=0):
    eps = 1e-7
    norms = torch.sqrt(torch.sum(torch.square(tensor), axis=axis, keepdims=True))
    desired = torch.clip(norms, 0, max_value)
    return tensor * (desired / (norms + eps))    

Hello, I would like to replicate EEGnet using PyTorch and validate it on the BCI competition iV2a dataset. I've encountered an issue with low accuracy. Could you please provide guidance on how to set the max_value for MaxNorm? Additionally, can you share the steps you followed for preprocessing the dataset? Have you normalized the data? Could you also share your code?
@okbalefthanded

@okbalefthanded
Copy link

事实上,内核约束对结果有重大影响,对于我的 EEGModels 代码的 PyTorch 实现,我重新实现了 Conv2D 和 Linear 层的内核约束,因为它是在 Keras 中实现的。 MaxNorm 函数为:

import torch

def MaxNorm(tensor, max_value, axis=0):
    eps = 1e-7
    norms = torch.sqrt(torch.sum(torch.square(tensor), axis=axis, keepdims=True))
    desired = torch.clip(norms, 0, max_value)
    return tensor * (desired / (norms + eps))    

Hello, I would like to replicate EEGnet using PyTorch and validate it on the BCI competition iV2a dataset. I've encountered an issue with low accuracy. Could you please provide guidance on how to set the max_value for MaxNorm? Additionally, can you share the steps you followed for preprocessing the dataset? Have you normalized the data? Could you also share your code? @okbalefthanded

The default values as the Keras/TF implementation will produce similar values with the PyTorch version, check your pre-processing operations first, they are the most crucial parts. Yes, we do normalize the data before training.
For easier comparison I suggest to run this colab notebook where I the Keras implementation of both EEGNet and EEG-TCNET is reproduced:
https://github.com/okbalefthanded/eeg-tcnet/blob/master/eeg_tcnet_colab.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants