Author : Alexandre Gramfort alexandre.gramfort@telecom-paristech.fr
License : BSD 3 clause
# add plot inline in the page
%matplotlib inline
import matplotlib.pyplot as plt
First, load the mne package:
import mne
We set the log-level to 'WARNING' so the output is less verbose
mne.set_log_level('WARNING')
Now we import the dataset. If you don't already have it, it will be downloaded automatically (but be patient approx. 2GB)
from mne.datasets import spm_face
data_path = spm_face.data_path()
raw_fname = data_path + '/MEG/spm/SPM_CTF_MEG_example_faces1_3D_raw.fif'
Read data from file:
raw = mne.io.Raw(raw_fname, preload=True)
print raw
Band pass the data between 1Hz and 45Hz
raw.filter(1, 45)
%matplotlib osx
fig = raw.plot()
First extract events:
%matplotlib inline
events = mne.find_events(raw, stim_channel='UPPT001', verbose=True)
Look at the design in a graphical way:
mne.viz.plot_events(events, raw.info['sfreq'], raw.first_samp);
Define epochs parameters:
event_id = {"faces": 1, "scrambled": 2}
tmin, tmax = -0.1, 0.5
# Set up pick list
picks = mne.pick_types(raw.info, meg=True, stim=True, eog=True,
ref_meg=False, exclude='bads')
# Read epochs
decim = 4 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=1.5e-12), decim=decim)
print epochs
Look at the ERF and contrast between left and rigth response
evoked_faces = epochs['faces'].average()
evoked_scrambled = epochs['scrambled'].average()
evoked_contrast = evoked_faces - evoked_scrambled
ylim = dict(mag=[-400., 400.])
fig = evoked_faces.plot(ylim=ylim)
fig = evoked_scrambled.plot(ylim=ylim)
fig = evoked_contrast.plot(ylim=ylim)
Plot some topographies
import numpy as np
times = np.linspace(-0.1, 0.3, 10)
fig = evoked_faces.plot_topomap(times=times, ch_type='mag', contours=0)
fig = evoked_scrambled.plot_topomap(times=times, ch_type='mag', contours=0)
fig = evoked_contrast.plot_topomap(times=times, ch_type='mag', contours=0)
To have a chance at 50% accuracy equalize epoch count in each condition
epochs_list = [epochs[k] for k in event_id]
mne.epochs.equalize_epoch_counts(epochs_list)
A classifier takes as input an x
and return y
(0 or 1). Here x will be the data at one or all time point(s) on all MEG sensors.
We work with all sensors jointly and try to find a discriminative pattern between 2 conditions to predict the class.
n_times = len(epochs.times)
# Take only the data channels (here the gradiometers)
data_picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
# Make arrays X and y such that :
# X is 3d with X.shape[0] is the total number of epochs to classify
# y is filled with integers coding for the class to predict
# We must have X.shape[0] equal to y.shape[0]
X = [e.get_data()[:, data_picks, :] for e in epochs_list]
y = [k * np.ones(len(this_X)) for k, this_X in enumerate(X)]
X = np.concatenate(X)
y = np.concatenate(y)
print X.shape, y.shape
print y
Now let's use an SVM to classify our MEG data
from sklearn.svm import SVC
from sklearn.cross_validation import cross_val_score, ShuffleSplit
# Define an SVM classifier (SVC) with a linear kernel
clf = SVC(C=1., kernel='linear')
Define a monte-carlo cross-validation generator (to reduce variance):
cv = ShuffleSplit(len(X), 10, test_size=0.2, random_state=42)
The goal is going to be to learn on 80% of the epochs and evaluate on the remaining 20% of trials if we can predict accurately.
X_2d = X.reshape(len(X), -1)
X_2d = X_2d / np.std(X_2d)
scores_full = cross_val_score(clf, X_2d, y, cv=cv, n_jobs=1)
print "Classification score: %s (std. %s)" % \
(np.mean(scores_full), np.std(scores_full))
It's also possible to run the same classifier at each time point to know when in time the conditions can be better classified:
scores = np.empty(n_times)
std_scores = np.empty(n_times)
from scipy.stats import zscore
X = zscore(X, axis=-1) # standardize features
for t, Xt in enumerate(X.T): # Run cross-validation
scores_t = cross_val_score(clf, Xt.T, y, cv=cv, n_jobs=1)
scores[t] = scores_t.mean()
std_scores[t] = scores_t.std()
A bit of rescaling
times = 1e3 * epochs.times # to have times in ms
scores *= 100 # make it percentage accuracy
std_scores *= 100
Now a bit of plotting
plt.plot(times, scores, label="Classif. score")
plt.axhline(50., color='k', linestyle='--', label="Chance level")
plt.axvline(0., color='r', label='stim onset')
plt.axhline(100. * np.mean(scores_full), color='g', label='Accuracy full epoch')
plt.legend()
hyp_limits = (scores - std_scores, scores + std_scores)
plt.fill_between(times, hyp_limits[0], y2=hyp_limits[1], color='b', alpha=0.5)
plt.xlabel('Times (ms)')
plt.ylabel('CV classification score (% correct)')
plt.ylim([30., 100.])
plt.title('Sensor space decoding')