Demo Codes

Simple Demo

This demo shows the basic functions of this toolbox, including:

  • How to initialize a dataset, and how to hook required functions.

  • How to initialize a recognition model.

  • How to get the training data from the dataset, and train the model.

  • How to get the testing data from the dataset, and test the model.

  • How to calculate the classification accuracy and the ITR.

Demo file: demo/simple_example.py

# -*- coding: utf-8 -*-

import sys
sys.path.append('..')
from SSVEPAnalysisToolbox.datasets import BenchmarkDataset
from SSVEPAnalysisToolbox.utils.benchmarkpreprocess import (
    preprocess, filterbank, suggested_ch, suggested_weights_filterbank
)
from SSVEPAnalysisToolbox.algorithms import (
    SCCA_qr, SCCA_canoncorr, ECCA, MSCCA, MsetCCA, MsetCCAwithR,
    TRCA, ETRCA, MSETRCA, MSCCA_and_MSETRCA, TRCAwithR, ETRCAwithR, SSCOR, ESSCOR,
    TDCA
)
from SSVEPAnalysisToolbox.evaluator import cal_acc,cal_itr

import time

num_subbands = 5

# Prepare dataset
dataset = BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)

# Prepare recognition model
weights_filterbank = suggested_weights_filterbank()
recog_model = ETRCAwithR(weights_filterbank = weights_filterbank)

# Set simulation parameters
ch_used = suggested_ch()
all_trials = [i for i in range(dataset.trial_num)]
harmonic_num = 5
tw = 1
sub_idx = 1
test_block_idx = 0
test_block_list, train_block_list = dataset.leave_one_block_out(block_idx = test_block_idx)

# Get training data and train the recognition model
ref_sig = dataset.get_ref_sig(tw, harmonic_num)
freqs = dataset.stim_info['freqs']
X_train, Y_train = dataset.get_data(sub_idx = sub_idx,
                                    blocks = train_block_list,
                                    trials = all_trials,
                                    channels = ch_used,
                                    sig_len = tw)
tic = time.time()
recog_model.fit(X=X_train, Y=Y_train, ref_sig=ref_sig, freqs=freqs) 
toc_train = time.time()-tic

# Get testing data and test the recognition model
X_test, Y_test = dataset.get_data(sub_idx = sub_idx,
                                    blocks = test_block_list,
                                    trials = all_trials,
                                    channels = ch_used,
                                    sig_len = tw)
tic = time.time()
pred_label, _ = recog_model.predict(X_test)
toc_test = time.time()-tic
toc_test_onetrial = toc_test/len(Y_test)

# Calculate performance
acc = cal_acc(Y_true = Y_test, Y_pred = pred_label)
itr = cal_itr(tw = tw, t_break = dataset.t_break, t_latency = dataset.default_t_latency, t_comp = toc_test_onetrial,
              N = len(freqs), acc = acc)
print("""
Simulation Information:
    Method Name: {:s}
    Dataset: {:s}
    Signal length: {:.3f} s
    Channel: {:s}
    Subject index: {:n}
    Testing block: {:s}
    Training block: {:s}
    Training time: {:.5f} s
    Total Testing time: {:.5f} s
    Testing time of single trial: {:.5f} s

Performance:
    Acc: {:.3f} %
    ITR: {:.3f} bits/min
""".format(recog_model.ID,
           dataset.ID,
           tw,
           str(ch_used),
           sub_idx,
           str(test_block_list),
           str(train_block_list),
           toc_train,
           toc_test,
           toc_test_onetrial,
           acc*100,
           itr))

Recognition Performance in Benchmark Dataset

The individual performance on the Benchmark Dataset with various signal lengths is verified in this demo. The classification accuracies, ITRs, training time, testing time, and confusion matrices are verified and stored in res/benchmarkdataset_res.mat.

This demo shows the following points:

  • How to use the Benchmark Dataset. When you first try this demo, the benchmark dataset will be downloaded in the folder 2016_Tsinghua_SSVEP_database.

  • How to create recognition models.

  • How to use the provided evaluator BaseEvaluator to verify recognition performance.

Tip

  • This demo uses gen_trials_onedataset_individual_diffsiglen to generate evaluation trials used for BaseEvaluator. These trials are used to evaluate indivudal performance on various signal lengths. If your target is not to evaluate these performance, you can follow this function to prepare your own evaluation trials.

  • This demo uses cal_performance_onedataset_individual_diffsiglen and cal_confusionmatrix_onedataset_individual_diffsiglen to calculate recognition performance (accuracies and ITRs) and confusion matrices. These two functions are also used to calculate individual performance on various signal lengths. For your own evaluation trials, you can follow these two functions to evaluate your own performance.

  • You can adjust the threading number by changing n_jobs in evaluator.run. Higher number requires the computer with higher performance. The current demo may occupy mora than 24 hours. You may reduce the number of models or the number of signal lengths to condense the running time.

  • ITRs are related to computational time. Different implementations may lead to different computational time. You may check the recorded testing time to get know the time used for ITR computation. We are also try to optimize implementations to reduce the computational time. For example, the sCCA implemented based on the QR decomposition is faster than the sCCA implemented based on the conventional canonical correlation with the same performance as shown in “Plot Recognition Performance” demo.

Demo file: demo/benchmarkdataset.py

# -*- coding: utf-8 -*-

import sys
sys.path.append('..')
from SSVEPAnalysisToolbox.datasets import BenchmarkDataset
from SSVEPAnalysisToolbox.utils.benchmarkpreprocess import (
    preprocess, filterbank, suggested_ch, suggested_weights_filterbank
)
from SSVEPAnalysisToolbox.algorithms import (
    SCCA_qr, SCCA_canoncorr, ECCA, MSCCA, MsetCCA, MsetCCAwithR,
    TRCA, ETRCA, MSETRCA, MSCCA_and_MSETRCA, TRCAwithR, ETRCAwithR, SSCOR, ESSCOR,
    TDCA
)
from SSVEPAnalysisToolbox.evaluator import (
    BaseEvaluator,
    gen_trials_onedataset_individual_diffsiglen,
    cal_performance_onedataset_individual_diffsiglen, 
    cal_confusionmatrix_onedataset_individual_diffsiglen
)
from SSVEPAnalysisToolbox.utils.io import savedata

import numpy as np

# Prepare dataset
dataset = BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)
ch_used = suggested_ch()
all_trials = [i for i in range(dataset.trial_num)]
harmonic_num = 5
dataset_container = [
                        dataset
                    ]


# Prepare train and test trials
tw_seq = [i/100 for i in range(25,100+5,5)]
trial_container = gen_trials_onedataset_individual_diffsiglen(dataset_idx = 0,
                                                             tw_seq = tw_seq,
                                                             dataset_container = dataset_container,
                                                             harmonic_num = harmonic_num,
                                                             trials = all_trials,
                                                             ch_used = ch_used,
                                                             t_latency = None,
                                                             shuffle = False)


# Prepare models
weights_filterbank = suggested_weights_filterbank()
model_container = [
                   SCCA_qr(weights_filterbank = weights_filterbank),
                   SCCA_canoncorr(weights_filterbank = weights_filterbank),
                   MsetCCA(weights_filterbank = weights_filterbank),
                   MsetCCAwithR(weights_filterbank = weights_filterbank),
                   ECCA(weights_filterbank = weights_filterbank),
                   MSCCA(n_neighbor = 12, weights_filterbank = weights_filterbank),
                   SSCOR(weights_filterbank = weights_filterbank),
                   ESSCOR(weights_filterbank = weights_filterbank),
                   TRCA(weights_filterbank = weights_filterbank),
                   TRCAwithR(weights_filterbank = weights_filterbank),
                   ETRCA(weights_filterbank = weights_filterbank),
                   ETRCAwithR(weights_filterbank = weights_filterbank),
                   MSETRCA(n_neighbor = 2, weights_filterbank = weights_filterbank),
                   MSCCA_and_MSETRCA(n_neighbor_mscca = 12, n_neighber_msetrca = 2, weights_filterbank = weights_filterbank),
                   TDCA(n_component = 8, weights_filterbank = weights_filterbank, n_delay = 6)
                  ]

# Evaluate models
evaluator = BaseEvaluator(dataset_container = dataset_container,
                          model_container = model_container,
                          trial_container = trial_container,
                          save_model = False,
                          disp_processbar = True)

evaluator.run(n_jobs = 10,
              eval_train = False)
evaluator_file = 'res/benchmarkdataset_evaluator.pkl'
evaluator.save(evaluator_file)

# Calculate performance
acc_store, itr_store = cal_performance_onedataset_individual_diffsiglen(evaluator = evaluator,
                                                                         dataset_idx = 0,
                                                                         tw_seq = tw_seq,
                                                                         train_or_test = 'test')
confusion_matrix = cal_confusionmatrix_onedataset_individual_diffsiglen(evaluator = evaluator,
                                                                        dataset_idx = 0,
                                                                        tw_seq = tw_seq,
                                                                        train_or_test = 'test')                                                                       

# Calculate training time and testing time
train_time = np.zeros((len(model_container), len(evaluator.performance_container)))
test_time = np.zeros((len(model_container), len(evaluator.performance_container)))
for trial_idx, performance_trial in enumerate(evaluator.performance_container):
    for method_idx, method_performance in enumerate(performance_trial):
        train_time[method_idx, trial_idx] = sum(method_performance.train_time)
        test_time[method_idx, trial_idx] = sum(method_performance.test_time_test)
train_time = train_time.T
test_time = test_time.T
            
# Save results
data = {"acc_store": acc_store,
        "itr_store": itr_store,
        "train_time": train_time,
        "test_time": test_time,
        "confusion_matrix": confusion_matrix,
        "tw_seq":tw_seq,
        "method_ID": [model.ID for model in model_container]}
data_file = 'res/benchmarkdataset_res.mat'
savedata(data_file, data, 'mat')






Plot Recognition Performance

This demo uses bar graph with error bars and shadow lines to plot classification accuracies, ITRs, training time, and testing time. Before running this demo, please run above two demos to obtain res/benchmarkdataset_res.mat and res/betadataset_res.mat files.

This demo shows the following points:

  • How to use provided bar_plot_with_errorbar to plot the bar grapth with error bars.

  • How to use provided shadowline_plot to plot shadow lines.

Demo file: demo/plot_performance.py

# -*- coding: utf-8 -*-

import sys
import os
sys.path.append('..')
from SSVEPAnalysisToolbox.utils.io import loaddata
from SSVEPAnalysisToolbox.evaluator import (
    bar_plot_with_errorbar, shadowline_plot, close_fig
)
from SSVEPAnalysisToolbox.utils.wearablepreprocess import subj_idx_highperformance

import numpy as np

data_file_list = ['res/benchmarkdataset_res.mat',
                  'res/betadataset_res.mat',
                  'res/nakanishidataset_res.mat',
                  'res/eldbetadataset_res.mat',
                  'res/openbmidataset_res.mat',
                  'res/wearable_dry_res.mat',
                  'res/wearable_wet_res.mat']
sub_title = ['benchmark',
             'beta',
             'nakanishi',
             'eldbeta',
             'openbmi',
             'wearable_dry',
             'wearable_wet']


for dataset_idx, data_file in enumerate(data_file_list):
    if not os.path.isfile(data_file):
        print("'{:s}' does not exist.".format(data_file))
        continue
    data = loaddata(data_file, 'mat')
    acc_store = data["acc_store"]
    itr_store = data["itr_store"]
    train_time = data["train_time"]
    test_time = data["test_time"]
    tw_seq = data["tw_seq"]
    method_ID = data["method_ID"]
    method_ID = [name.strip() for name in method_ID]

    # Plot training time and testing time
    fig, _ = bar_plot_with_errorbar(train_time,
                                    x_label = 'Methods',
                                    y_label = 'Training time (s)',
                                    x_ticks = method_ID,
                                    grid = True,
                                    figsize=[6.4*2, 4.8])
    fig.savefig('res/{:s}_traintime_bar.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)
    

    fig, _ = bar_plot_with_errorbar(test_time,
                                    x_label = 'Methods',
                                    y_label = 'Testing time (s)',
                                    x_ticks = method_ID,
                                    grid = True,
                                    figsize=[6.4*2, 4.8])
    fig.savefig('res/{:s}_testtime_bar.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)


    # Plot Performance of bar plots
    fig, _ = bar_plot_with_errorbar(acc_store,
                                    x_label = 'Signal Length (s)',
                                    y_label = 'Acc',
                                    x_ticks = tw_seq,
                                    legend = method_ID,
                                    errorbar_type = '95ci',
                                    grid = True,
                                    ylim = [0, 1],
                                    figsize=[6.4*3, 4.8])
    fig.savefig('res/{:s}_acc_bar.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)

    fig, _ = bar_plot_with_errorbar(itr_store,
                                    x_label = 'Signal Length (s)',
                                    y_label = 'ITR (bits/min)',
                                    x_ticks = tw_seq,
                                    legend = method_ID,
                                    errorbar_type = '95ci',
                                    grid = True,
                                    figsize=[6.4*3, 4.8])
    fig.savefig('res/{:s}_itr_bar.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)

    # Plot Performance of shadow lines
    fig, _ = shadowline_plot(tw_seq,
                            acc_store,
                            'x-',
                            x_label = 'Signal Length (s)',
                            y_label = 'Acc',
                            legend = method_ID,
                            errorbar_type = '95ci',
                            grid = True,
                            ylim = [0, 1],
                            figsize=[6.4*3, 4.8])
    fig.savefig('res/{:s}_acc_shadowline.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)

    fig, _ = shadowline_plot(tw_seq,
                            itr_store,
                            'x-',
                            x_label = 'Signal Length (s)',
                            y_label = 'ITR (bits/min)',
                            legend = method_ID,
                            errorbar_type = '95ci',
                            grid = True,
                            figsize=[6.4*3, 4.8])
    fig.savefig('res/{:s}_itr_shadowline.jpg'.format(sub_title[dataset_idx]), bbox_inches='tight', dpi=300)
    close_fig(fig)

Generated graphs are stored in demo/res. One example of results is shown below.

  • Classification accuracies of the Benchmark Dataset:

    _images/benchmark_acc_bar.jpg

Plot Confusion Matrices

This demo provides an example of plotting confusion matrices. This demo directly uses imshow in matplotlib to plot confusion matrices. You also can use heatmap 1 in seaborn or plot_confusion_matrix 2 in sklearn to plot these confusion matrices. This demo only shows confusion matrices at 0.5-s signal length, which is controlled by target_time in the demo file. Moreover, all subjects’ confusion matrices are summed together.

1

Plot confusion matrices using seaborn

2

Plot confusion matrices using sklearn

Demo file: demo/plot_confusion_matrix.py

# -*- coding: utf-8 -*-

import sys
sys.path.append('..')
from SSVEPAnalysisToolbox.datasets import BenchmarkDataset
from SSVEPAnalysisToolbox.datasets import BETADataset
from SSVEPAnalysisToolbox.utils.io import loaddata
import matplotlib.pyplot as plt
import matplotlib.patches as pach
import os

import numpy as np

data_file_list = ['res/benchmarkdataset_res.mat',
                  'res/betadataset_res.mat']
save_folder = ['benchmarkdataset_confusionmatrix',
               'beta_confusionmatrix']
dataset_container = [
                        BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database'),
                        BETADataset(path = '2020_BETA_SSVEP_database_update')
                    ]
target_time = 0.5

for dataset_idx, data_file in enumerate(data_file_list):
    if not os.path.isfile(data_file):
        print("'{:s}' does not exist.".format(data_file))
        continue
    data = loaddata(data_file, 'mat')
    confusion_matrix = data["confusion_matrix"]
    method_ID = data["method_ID"]
    tw_seq = data["tw_seq"]
    freqs = dataset_container[dataset_idx].stim_info['freqs']
    sort_idx = list(np.argsort(freqs))

    signal_len_idx = int(np.where(np.array(tw_seq)==target_time)[0])

    # for signal_len_idx in range(len(tw_seq)):
    for method_idx, method in enumerate(method_ID):
        confusion_matrix_plot = confusion_matrix[method_idx, :, signal_len_idx, :, :]
        confusion_matrix_plot = np.sum(confusion_matrix_plot, axis = 0)
        confusion_matrix_plot = confusion_matrix_plot[sort_idx,:]
        confusion_matrix_plot = confusion_matrix_plot[:,sort_idx]
        N, _ = confusion_matrix_plot.shape
        min_v = 0
        max_v = np.amax(np.reshape(confusion_matrix_plot - np.diag(np.diag(confusion_matrix_plot)),(-1)))

        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])

        im = ax.imshow(confusion_matrix_plot,
                        interpolation = 'none',
                        origin = 'upper',
                        vmin = min_v,
                        vmax = max_v,
                        cmap='winter')

        for n in range(N):
            ax.add_patch(
                pach.Rectangle(xy=(n-0.5, n-0.5), width=1, height=1, facecolor='white')
            )
        for i in range(N):
            for j in range(N):
                if i==j:
                    text_color = 'black'
                else:
                    text_color = 'white'
                ax.text(i,j,"{:n}".format(int(confusion_matrix_plot[j,i])),
                    fontsize=5,
                    horizontalalignment='center',
                    verticalalignment='center',
                    color=text_color)
        ax.figure.colorbar(im, ax=ax)
        ax.set_xticks(list(range(N)))
        ax.set_yticks(list(range(N)))
        ax.spines[:].set_visible(False)
        ax.grid(which="minor", color="black", linestyle='-', linewidth=10)
        ax.tick_params(top=True, bottom=False,
                        labeltop=True, labelbottom=False)
        ax.tick_params(which="minor", bottom=False, left=False)
        ax.tick_params(axis='x',labelsize=5)
        ax.tick_params(axis='y',labelsize=5)
        ax.set_ylabel('True Label')
        ax.set_xlabel('Predicted Label')

        save_path = 'res/{:s}/{:s}_T{:n}.jpg'.format(save_folder[dataset_idx],
                                                        method_ID[method_idx].strip(),
                                                        tw_seq[signal_len_idx])
        desertation_dir = os.path.dirname(save_path)
        if not os.path.exists(desertation_dir):
            os.makedirs(desertation_dir)
        fig.savefig(save_path, 
                    bbox_inches='tight', dpi=300)
        plt.close(fig)

Generated graphs are stored in demo/res/benchmarkdataset_confusionmatrix and demo/res/beta_confusionmatrix. One example of results is shown below.

  • eCCA (0.5s) in Benchmark Dataset

    _images/eCCA_T0.5.jpg

Calculate SNR

This demo shows how to calculate SNRs and how to plot distributions of SNRs.

Demo file: demo/plot_snr.py

# -*- coding: utf-8 -*-

import sys
sys.path.append('..')

from SSVEPAnalysisToolbox.datasets import (
    BenchmarkDataset, BETADataset, ELDBETADataset, NakanishiDataset, openBMIDataset,
    WearableDataset_wet, WearableDataset_dry
)

from SSVEPAnalysisToolbox.evaluator import (
    hist, close_fig, gen_colors
)
from SSVEPAnalysisToolbox.utils.io import savedata
from SSVEPAnalysisToolbox.utils.algsupport import nextpow2

snr_list = []
legend = []
dataset_no = 0
harmonic_num = 5
sig_len = 1
# filterbank index is 0

for snr_type in ['fft','sine']:

    # Benchmark dataset
    from SSVEPAnalysisToolbox.utils.benchmarkpreprocess import (
        preprocess, filterbank, suggested_ch
    )
    dataset = BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(filterbank)
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) 
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # BETA datset
    dataset = BETADataset(path = '2020_BETA_SSVEP_database_update')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(filterbank)
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # eldBETA dataset
    dataset = ELDBETADataset(path = 'eldBETA_database')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(filterbank)
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # Nakanishi dataset
    from SSVEPAnalysisToolbox.utils.nakanishipreprocess import (
        preprocess, filterbank, suggested_ch
    )
    dataset = NakanishiDataset(path = 'Nakanishi_2015')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(filterbank)
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # openBMI dataset
    from SSVEPAnalysisToolbox.utils.openbmipreprocess import (
        preprocess, filterbank, suggested_ch, ref_sig_fun
    )
    dataset = openBMIDataset(path = 'openBMI')
    downsample_srate = 100
    dataset.regist_preprocess(lambda dataself, X: preprocess(dataself, X, downsample_srate))
    dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, downsample_srate))
    dataset.regist_ref_sig_fun(lambda dataself, sig_len, N, phases: ref_sig_fun(dataself, sig_len, N, phases, downsample_srate))
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            srate = downsample_srate,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*downsample_srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # Wearable dataset
    from SSVEPAnalysisToolbox.utils.wearablepreprocess import (
        preprocess, filterbank, suggested_ch
    )
    dataset = WearableDataset_wet(path = 'Wearable')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, 5))
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    dataset = WearableDataset_dry(path = 'Wearable')
    dataset.regist_preprocess(preprocess)
    dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, 5))
    print("{:s}: ".format(dataset.ID))
    if snr_type == 'sine':
        snr = dataset.get_snr(type = 'sine', ch_used_recog=suggested_ch(), display_progress = True)
        snr_list.append(snr)
    else:
        snr = dataset.get_snr(Nh = harmonic_num, display_progress = True, 
                            sig_len = sig_len,
                            remove_break = False, remove_pre_and_latency = False,
                            NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
        snr_list.append(snr[:,:,:,suggested_ch()])
    legend.append(dataset.ID)
    dataset_no += 1

    # Store results
    data = {"snr_list": snr_list,
            "legend": legend}
    data_file = 'res/snr_' + snr_type + '.mat'
    savedata(data_file, data, 'mat')

    # plot histogram of SNR
    if snr_type == 'sine':
        hist_bins = list(range(-100,0+1))
        hist_range = (-100, 0)
    else:
        hist_bins = list(range(-30,0+1))
        hist_range = (-30, 0)
    color = gen_colors(dataset_no)
    fig, ax = hist(snr_list, bins = hist_bins, range = hist_range, density = True,
                color = color, alpha = 0.3, fit_line = True, line_points = 1000,
                x_label = 'SNR (dB)',
                y_label = 'Probability',
                grid = True,
                legend = legend)
    fig.savefig('res/SNR_' + snr_type + '.jpg', bbox_inches='tight', dpi=300)
    close_fig(fig)

Generated graphs are stored in demo/res. One example of results is shown below.

_images/SNR_fft.jpg

Calculate Phase

This demo shows how to calculate phases and how to plot distributions of phases.

Demo file: demo/plot_phase.py

# -*- coding: utf-8 -*-

import sys
sys.path.append('..')

from SSVEPAnalysisToolbox.datasets import (
    BenchmarkDataset, BETADataset, ELDBETADataset, NakanishiDataset, openBMIDataset,
    WearableDataset_wet, WearableDataset_dry
)
from SSVEPAnalysisToolbox.utils.benchmarkpreprocess import (
    preprocess, filterbank, suggested_ch
)

from SSVEPAnalysisToolbox.evaluator import (
    polar_phase_shadow, close_fig, gen_colors
)
from SSVEPAnalysisToolbox.utils.io import savedata
from SSVEPAnalysisToolbox.utils.algsupport import nextpow2

phase_list = []
legend = []
dataset_no = 0
sig_len = 1

# Benchmark dataset
dataset = BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# BETA datset
dataset = BETADataset(path = '2020_BETA_SSVEP_database_update')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# eldBETA dataset
dataset = ELDBETADataset(path = 'eldBETA_database')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# Nakanishi dataset
from SSVEPAnalysisToolbox.utils.nakanishipreprocess import (
    preprocess, filterbank, suggested_ch
)
dataset = NakanishiDataset(path = 'Nakanishi_2015')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(filterbank)
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# openBMI dataset
from SSVEPAnalysisToolbox.utils.openbmipreprocess import (
    preprocess, filterbank, suggested_ch
)
dataset = openBMIDataset(path = 'openBMI')
downsample_srate = 100
dataset.regist_preprocess(lambda dataself, X: preprocess(dataself, X, downsample_srate))
dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, downsample_srate))
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# Wearable dataset
from SSVEPAnalysisToolbox.utils.wearablepreprocess import (
    preprocess, filterbank, suggested_ch
)
dataset = WearableDataset_wet(path = 'Wearable')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, 5))
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

dataset = WearableDataset_dry(path = 'Wearable')
dataset.regist_preprocess(preprocess)
dataset.regist_filterbank(lambda dataself, X: filterbank(dataself, X, 5))
print("{:s}: ".format(dataset.ID))
snr = dataset.get_phase(display_progress = True, 
                      sig_len = sig_len,
                      remove_break = False, remove_pre_and_latency = False, remove_target_phase = True,
                      NFFT = 2 ** nextpow2(10*dataset.srate)) # filterbank index is 0
phase_list.append(snr[:,:,:,suggested_ch()])
legend.append(dataset.ID)
dataset_no += 1

# Store results
data = {"phase_list": phase_list,
        "legend": legend}
data_file = 'res/phase.mat'
savedata(data_file, data, 'mat')

# plot histogram of SNR
color = gen_colors(dataset_no)
fig, ax = polar_phase_shadow(phase_list,
                            color = color,
                            grid = True,
                            legend = legend,
                            errorbar_type = 'std')
fig.savefig('res/phase.jpg', bbox_inches='tight', dpi=300)
close_fig(fig)

Generated graph is stored in demo/res, and shown below.

_images/phase.jpg