%   CNSP Workshop 2022: Encoding and decoding models, introduction to 
%   multivariate analysis for first-time users.
% 
%   This tutorial loads a natural speech EEG dataset (Di Liberto et al., 
%   2015; Broderick et al., 2018) and demonstrates how to train and test a
%   decoding model and how to perform a single-lag analysis. The tutorial 
%   uses the mTRF-Toolbox (Crosse et al., 2015; Crosse et al., 2021) to 
%   train and test the decoder to evaluate its ability to reconstruct
%   stimuli to new neural responses.
%
%   Dependencies:
%      CNSP utils: https://cnspworkshop.net/resources.html
%      EEGLAB: https://sccn.ucsd.edu/eeglab/index.php
%      mTRF-Toolbox: https://github.com/mickcrosse/mTRF-Toolbox

%   References:
%      [1] Di Liberto GM, O'Sullivan JA, Lalor EC (2015) Low-frequency 
%          cortical entrainment to speech reflects phoneme-level
%          processing. Curr Biol 25(19):2457-2465.
%      [2] Broderick MP, Anderson AJ, Di Liberto GM, Crosse MJ, Lalor EC 
%          (2018) Electrophysiological Correlates of Semantic Dissimilarity 
%          Reflect the Comprehension of Natural, Narrative Speech. Curr 
%          Biol 28(5):1–7.
%      [3] Crosse MC, Zuk NJ, Di Liberto GM, Nidiffer A, Molholm S, Lalor 
%          EC (2021) Linear Modeling of Neurophysiological Responses to 
%          Speech and Other Continuous Stimuli: Methodological 
%          Considerations for Applied Research. Front Neurosci, 15:705621.
%      [4] Crosse MC, Di Liberto GM, Bednar A, Lalor EC (2016) The
%          multivariate temporal response function (mTRF) toolbox: a MATLAB
%          toolbox for relating neural signals to continuous stimuli. Front
%          Hum Neurosci 10:604.

%   CNSP Workshop 2022
%   Resources: https://cnspworkshop.net/resources.html
%   Author: Michael Crosse <crossemj@tcd.ie>

%% Part 2: Decoding models and single-lag analysis

%% A.1. Data ingestion

close all;
clear; clc;

% a. Set main path
cd('C:\Users\mickc\Dropbox (Personal)\MATLAB\Workshops\CNSP 2022')

% b. Add other directories to path
addpath tutorials\TRFtutorial
addpath libs\cnsp_utils
addpath libs\cnsp_utils\cnd
addpath libs\eeglab
addpath libs\mTRF-Toolbox_v2\mtrf
addpath datasets\LalorNatSpeech

% c. Load data
disp('Loading data...')
load('.\datasets\LalorNatSpeech\Stimuli\dataStim_64Hz.mat','stim');
load('.\datasets\LalorNatSpeech\dataCND\dataSub10.mat','eeg');

% d. Run EEGlab
eeglab
close

%% A.2. Data preprocessing

% a. Set up highpass filter
highpass_cutoff = 1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);

% b. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);

% c. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);

% d. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);

% e. Downsample EEG data to 64 Hz
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);

% f. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
    for i = 1:numel(eeg.data)
        eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
    end
end

% g. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Mastoids');

% h. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(:));
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);

% i. Crop EEG to match stim length
[stim,eeg] = cndCheckStimNeural(stim,eeg);

%% B.1. Speech envelope visualization

% a. Load auidio
[audio,fs] = audioread('audio.wav'); 

% b. Compute speech envelope at 128 Hz
envelope = mTRFenvelope(audio,fs,64,1,1);

% c. Compute speech envelope with compression
envelope_comp = mTRFenvelope(audio,fs,64,1,0.3);

% d. Plot envelope
figure(1)
hold on
plot((1:length(audio))/fs,audio)
plot((1:length(envelope))/64,envelope,'LineWidth',2)
plot((1:length(envelope))/64,envelope_comp,'LineWidth',2)
legend('Audio (44.1 kHz)','Env (64 Hz)','Env^0^.^3 (64 Hz)')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
xlim([0.75,4])

%% B.2. Cross-validation

% a. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = stim.data(1,:); 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = stim.data(1,test_trials);
eeg_test = eeg.data(test_trials);

% b. Model hyperparameters
Dir = -1;
tmin = 0;
tmax = 250;
lambda_idx = -2:2:8;
lambda_vals = 10.^lambda_idx;
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Plot CV accuracy
figure(14)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r),std(cv.r)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lambda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% e. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err),std(cv.err)/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lambda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% B.3. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
Bmodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

% c. Plot decoder weights
lim = max(max(abs(Bmodel.w(:,7:14))));
figure(15)
subplot(2,2,1)
topoplot(Bmodel.w(:,7),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Bmodel.t(7)),' ms'])
subplot(2,2,2)
topoplot(Bmodel.w(:,9),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Bmodel.t(9)),' ms'])
subplot(2,2,3)
topoplot(Bmodel.w(:,11),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Bmodel.t(11)),' ms'])
subplot(2,2,4)
topoplot(Bmodel.w(:,14),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Bmodel.t(14)),' ms'])

%% B.4. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,Bmodel,'zeropad',0);

% b. Plot reconstruction
figure(14)
subplot(2,2,3)
plot((1:length(stim_test{1}))/eeg.fs,stim_test{1},'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}*1.5,'linewidth',1.5), hold off
xlim([0,10])
title('Reconstruction')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r)), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% C.1. Single-lag stimulus reconstruction

% a. Model hyperparameters
Dir = -1;
tmin = -250; 
tmax = 500;

% b. Run single-lag cross-validation
[stats,t] = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'type','single','zeropad',0);

% c. Compute mean and variance
macc = squeeze(mean(stats.r))'; vacc = squeeze(var(stats.r))';
merr = squeeze(mean(stats.err))'; verr = squeeze(var(stats.err))';

% d. Compute variance bound
num_folds = numel(stim_train);
xacc = [-fliplr(t),-t]; yacc = [fliplr(macc-sqrt(vacc/num_folds)),macc+sqrt(vacc/num_folds)];
xerr = [-fliplr(t),-t]; yerr = [fliplr(merr-sqrt(verr/num_folds)),merr+sqrt(verr/num_folds)];

% e. Plot accuracy
figure(16)
subplot(1,2,1), h = fill(xacc,yacc,'b','edgecolor','none'); hold on
set(h,'facealpha',0.2), xlim([tmin,tmax]), axis square, grid on
plot(-fliplr(t),fliplr(macc),'linewidth',2), hold off
title('Reconstruction Accuracy'), xlabel('Time lag (ms)'), ylabel('Correlation')

% f. Plot error
subplot(1,2,2)
h = fill(xerr,yerr,'b','edgecolor','none'); hold on
set(h,'facealpha',0.2), xlim([tmin,tmax]), axis square, grid on
plot(-fliplr(t),fliplr(merr),'linewidth',2), hold off
title('Reconstruction Error'), xlabel('Time lag (ms)'), ylabel('MSE')