Example 4: How to plot confusion matrix¶

This example shows how to directly use imshow in matplotlib to plot the confusion matrix generated by the build-in function. You also can use heatmap 1 in seaborn or plot_confusion_matrix 2 in sklearn to plot the confusion matrix.

1

Plot confusion matrices using seaborn

2

Plot confusion matrices using sklearn

You can find the related code in demo/plot_confusion_matrix.py or demo/plot_confusion_matrix.ipynb.

In the 2nd example, we already generated the confusion matrices and stored them in res/benchmarkdataset_res.mat. So, firstly, we need to reload these confusion matrices. In this example, we only consider the results of 0.5s signal length.

from SSVEPAnalysisToolbox.utils.io import loaddata
data_file = 'res/benchmarkdataset_res.mat'
data = loaddata(data_file, 'mat')
confusion_matrix = data["confusion_matrix"]
method_ID = data["method_ID"]
tw_seq = data["tw_seq"]

import numpy as np
target_time = 0.5
signal_len_idx = int(np.where(np.array(tw_seq)==target_time)[0])

We also want to display the confusion matrix from low stimulus frequency to high stimulus frequency. Therefore, we also need the stimulus frequency information. Because such information can be found in dataset, we recreate the dataset and read the information.

from SSVEPAnalysisToolbox.datasets import BenchmarkDataset
dataset = BenchmarkDataset(path = '2016_Tsinghua_SSVEP_database')
freqs = dataset.stim_info['freqs']
sort_idx = list(np.argsort(freqs))

In the 2nd example, we evaluate multiple methods’ performance. This example only consider the 1st method, i.e., eTRCA method implemented by the eigen decomposition. We can get the corresponding confusion matrix and plot it.

method_idx = 0
confusion_matrix_plot = confusion_matrix[method_idx, :, signal_len_idx, :, :]
confusion_matrix_plot = np.sum(confusion_matrix_plot, axis = 0)
confusion_matrix_plot = confusion_matrix_plot[sort_idx,:]
confusion_matrix_plot = confusion_matrix_plot[:,sort_idx]
N, _ = confusion_matrix_plot.shape
min_v = 0
max_v = np.amax(np.reshape(confusion_matrix_plot - np.diag(np.diag(confusion_matrix_plot)),(-1)))

import matplotlib.pyplot as plt
import matplotlib.patches as pach

fig = plt.figure()
ax = fig.add_axes([0,0,1,1])

im = ax.imshow(confusion_matrix_plot,
                interpolation = 'none',
                origin = 'upper',
                vmin = min_v,
                vmax = max_v,
                cmap='winter')

for n in range(N):
    ax.add_patch(
        pach.Rectangle(xy=(n-0.5, n-0.5), width=1, height=1, facecolor='white')
    )
for i in range(N):
    for j in range(N):
        if i==j:
            text_color = 'black'
        else:
            text_color = 'white'
        ax.text(i,j,"{:n}".format(int(confusion_matrix_plot[j,i])),
            fontsize=5,
            horizontalalignment='center',
            verticalalignment='center',
            color=text_color)
ax.figure.colorbar(im, ax=ax)
ax.set_xticks(list(range(N)))
ax.set_yticks(list(range(N)))
ax.spines[:].set_visible(False)
ax.grid(which="minor", color="black", linestyle='-', linewidth=10)
ax.tick_params(top=True, bottom=False,
                labeltop=True, labelbottom=False)
ax.tick_params(which="minor", bottom=False, left=False)
ax.tick_params(axis='x',labelsize=5)
ax.tick_params(axis='y',labelsize=5)
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')
../_images/confusion_matrix_sCCA(qr)_T0.5.jpg

Finally, we can save this figure.

save_path = 'res/confusion_matrix_sCCA(qr)_T{:n}.jpg'.format(tw_seq[signal_len_idx])
fig.savefig(save_path,
            bbox_inches='tight', dpi=300)