%   CNSP Workshop 2022: Encoding and decoding models, introduction to 
%   multivariate analysis for first-time users.
% 
%   This tutorial loads a natural speech EEG dataset (Di Liberto et al., 
%   2015; Broderick et al., 2018) and demonstrates how to extract 
%   multivariate temporal response functions (mTRFs) to different features 
%   of natural speech, including the envelope, spectrogram, phonetic 
%   features and semantic dissimilarity. The tutorial uses the mTRF-Toolbox 
%   (Crosse et al., 2015; Crosse et al., 2021) to train and test the 
%   corresponding mTRF models to evaluate their ability to predict neural 
%   responses to new stimuli.
%
%   Dependencies:
%      CNSP utils: https://cnspworkshop.net/resources.html
%      EEGLAB: https://sccn.ucsd.edu/eeglab/index.php
%      mTRF-Toolbox: https://github.com/mickcrosse/mTRF-Toolbox

%   References:
%      [1] Di Liberto GM, O'Sullivan JA, Lalor EC (2015) Low-frequency 
%          cortical entrainment to speech reflects phoneme-level
%          processing. Curr Biol 25(19):2457-2465.
%      [2] Broderick MP, Anderson AJ, Di Liberto GM, Crosse MJ, Lalor EC 
%          (2018) Electrophysiological Correlates of Semantic Dissimilarity 
%          Reflect the Comprehension of Natural, Narrative Speech. Curr 
%          Biol 28(5):17.
%      [3] Crosse MC, Zuk NJ, Di Liberto GM, Nidiffer A, Molholm S, Lalor 
%          EC (2021) Linear Modeling of Neurophysiological Responses to 
%          Speech and Other Continuous Stimuli: Methodological 
%          Considerations for Applied Research. Front Neurosci, 15:705621.
%      [4] Crosse MC, Di Liberto GM, Bednar A, Lalor EC (2016) The
%          multivariate temporal response function (mTRF) toolbox: a MATLAB
%          toolbox for relating neural signals to continuous stimuli. Front
%          Hum Neurosci 10:604.

%   CNSP Workshop 2022
%   Resources: https://cnspworkshop.net/resources.html
%   Author: Michael Crosse <crossemj@tcd.ie>

%% Part 1: Encoding models and multivariate analyses

%% A.1. Data ingestion

close all;
clear; clc;

% a. Set main path
cd('C:\Users\mickc\Dropbox (Personal)\MATLAB\Workshops\CNSP 2022')

% b. Add other directories to path
addpath tutorials\TRFtutorial
addpath libs\cnsp_utils
addpath libs\cnsp_utils\cnd
addpath libs\eeglab
addpath libs\mTRF-Toolbox_v2\mtrf
addpath datasets\LalorNatSpeech

% c. Load data
disp('Loading data...')
load('.\datasets\LalorNatSpeech\Stimuli\dataStim_64Hz.mat','stim');
load('.\datasets\LalorNatSpeech\dataCND\dataSub10.mat','eeg');

% d. Run EEGlab
eeglab
close

%% A.2. Data preprocessing

% a. Set up highpass filter
highpass_cutoff = 1;
highpass_order = 3;
hd_hpf = getHPFilt(eeg.fs,highpass_cutoff,highpass_order);

% b. Set up lowpass filter
lowpass_cutoff = 8;
lowpass_order = 3;
hd_lpf = getLPFilt(eeg.fs,lowpass_cutoff,lowpass_order);

% c. Filter EEG recording channels
disp('Filtering recording channels...')
eeg.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.data,'UniformOutput',false);
eeg.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.data,'UniformOutput',false);

% d. Filter EEG external channels
disp('Filtering external channels...')
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_hpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);
eeg.extChan{1,1}.data = cellfun(@(x) filtfilthd(hd_lpf,x),eeg.extChan{1,1}.data,'UniformOutput',false);

% e. Downsample EEG data to 64 Hz
fs_new = 64;
disp('Downsampling data...')
eeg = cndDownsample(eeg,fs_new);

% f. Interpolate bad channels
disp('Interpolating bad channels...')
if isfield(eeg,'chanlocs')
    for i = 1:numel(eeg.data)
        eeg.data{i} = removeBadChannels(eeg.data{i},eeg.chanlocs);
    end
end

% g. Re-reference EEG data
disp('Re-referencing EEG data...')
eeg = cndReref(eeg,'Mastoids');

% h. Normalize EEG data
disp('Normalizing data...')
eeg_data_mat = cell2mat(eeg.data');
eeg_std = std(eeg_data_mat(:));
eeg.data = cellfun(@(x) x/eeg_std,eeg.data,'UniformOutput',false);

% i. Crop EEG to match stim length
[stim,eeg] = cndCheckStimNeural(stim,eeg);

%% B.1. Speech envelope visualization

% a. Load auidio
[audio,fs] = audioread('audio.wav'); 

% b. Compute speech envelope at 128 Hz
envelope = mTRFenvelope(audio,fs,64,1,1);

% c. Compute speech envelope with compression
envelope_comp = mTRFenvelope(audio,fs,64,1,0.3);

% d. Plot envelope
figure(1)
hold on
plot((1:length(audio))/fs,audio)
plot((1:length(envelope))/64,envelope,'LineWidth',2)
plot((1:length(envelope))/64,envelope_comp,'LineWidth',2)
legend('Audio (44.1 kHz)','Env (64 Hz)','Env^0^.^3 (64 Hz)')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
xlim([0.75,4])

%% B.2. Cross-validation

% a. Apply compression to envelope
envelopes = cellfun(@(x) x.^0.3,stim.data(1,:),'UniformOutput',false);

% b. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = envelopes; 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = envelopes(test_trials);
eeg_test = eeg.data(test_trials);

% c. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 350;
lamda_idx = -4:2:10;
lambda_vals = 10.^lamda_idx;
nlambda = numel(lambda_vals);

% d. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% e. Plot CV accuracy
figure(2)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% f. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% B.3. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
Emodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

% c. Plot TRF weights
lim = max(max(abs(Emodel.w(:,13:18,:)),[],3),[],2);
figure(3)
subplot(2,2,1)
plot(Emodel.t,squeeze(Emodel.w))
xlim([-50,300])
title('Temporal Response Function (TRF)')
xlabel('Time lag (ms)')
ylabel('Amplitude (a.u.)')
subplot(2,2,2)
plot(Emodel.t,std(Emodel.w,[],3))
xlim([-50,300])
title('Global Field Power (GFP)')
xlabel('Time lag (ms)')
subplot(2,2,3)
topoplot(Emodel.w(:,13,:),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Emodel.t(13)),' ms'])
subplot(2,2,4)
topoplot(Emodel.w(:,18,:),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Emodel.t(18)),' ms'])

%% B.4. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,Emodel,'zeropad',0);

% b. Plot prediction
figure(2)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan)*5,'linewidth',1.5), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% C.1. Spectrogram visualization

% a. Get spectrogram
spectrogram = stim.data{2,1};

% b. Apply compression to spectrogram
spectrogram_comp = spectrogram.^0.3;

% c. Plot signals
figure(4)
hold on
subplot(3,1,1)
plot((1:length(audio))/fs,audio)
xlim([0.75,4])
title('Audio (44.1 kHz)')
ylabel('Amplitude (a.u.)')
subplot(3,1,2)
imagesc((1:length(spectrogram))/64,1:16,spectrogram')
freqs = [335,560,900,1400,2050,3050,4500,8000];
set(gca,'YDir','normal','ytick',2:2:16,'yticklabel',freqs)
title('Spectrogram (64 Hz)')
ylabel('Frequency (Hz)')
xlim([0.75,4])
subplot(3,1,3)
imagesc((1:length(spectrogram_comp))/64,1:16,spectrogram_comp')
set(gca,'YDir','normal','ytick',2:2:16,'yticklabel',freqs)
title('Spectrogram^0^.^3 (64 Hz)')
xlabel('Time (s)')
ylabel('Frequency (Hz)')
xlim([0.75,4])

%% C.2. Cross-validation

% a. Apply compression to spectrogram
spectrograms = cellfun(@(x) x.^0.3,stim.data(2,:),'UniformOutput',false);

% b. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = spectrograms; 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = spectrograms(test_trials);
eeg_test = eeg.data(test_trials);

% c. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 350;
lamda_idx = -4:2:10;
lambda_vals = 10.^lamda_idx;
nlambda = numel(lambda_vals);

% d. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% e. Plot CV accuracy
figure(5)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% f. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% C.3. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
Smodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

% c. Plot STRF weights
lim = max(max(abs(Smodel.w(1,13:18,:)),[],3),[],2);
figure(6)
subplot(2,2,1)
imagesc(Smodel.t,1:16,squeeze(Smodel.w(:,:,chan)))
set(gca,'YDir','normal','ytick',2:2:16,'yticklabel',freqs)
xlim([-50,300])
title({'Spectrotemporal Response';'Function (STRF) at Fz'})
xlabel('Time lag (ms)')
ylabel('Frequency (Hz)')
subplot(2,2,2)
imagesc(Smodel.t,1:16,std(Smodel.w,[],3))
set(gca,'YDir','normal','ytick',2:2:16,'yticklabel',freqs)
xlim([-50,300])
title('Global Field Power (GFP)')
xlabel('Time lag (ms)')
subplot(2,2,3)
topoplot(squeeze(Smodel.w(1,13,:)),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Smodel.t(13)),' ms (250 Hz)'])
subplot(2,2,4)
topoplot(squeeze(Smodel.w(1,18,:)),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Smodel.t(18)),' ms (250 Hz)'])

% d. Plot broadband TRF weights


%% C.4. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,Smodel,'zeropad',0);

% b. Plot prediction
figure(5)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan)*5,'linewidth',1.5), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% D.1. Phonetic feature visualization

% a. Get phonetic features
phon_feat = stim.data{3,1};

% c. Plot signals
figure(7)
hold on
subplot(2,1,1)
plot((1:length(audio))/fs,audio)
xlim([0.75,4])
title('Audio (44.1 kHz)')
ylabel('Amplitude (a.u.)')
subplot(2,1,2)
imagesc((1:length(phon_feat))/64,1:19,phon_feat')
set(gca,'YDir','normal')
title('Phonetic Features (64 Hz)')
xlabel('Time (s)')
ylabel('Phonetic feature')
features = {'Glottal','Lingua-velar','Lingua-palatal','Lingua-alveolar',...
    'Lingua-dental','Labio-dental','Bilabial','Diphtong','Back','Central',...
    'Front','Voiceless','Voiced','Glide','Liquid','Nasal','Affricate',...
    'Fricative','Plosive'};
set(gca,'ytick',1:19,'yticklabel',features)
xlim([0.75,4])

%% D.2. Cross-validation

% a. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = stim.data(3,:); 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = stim.data(3,test_trials);
eeg_test = eeg.data(test_trials);

% b. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 350;
lamda_idx = -4:2:10;
lambda_vals = 10.^lamda_idx;
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Plot CV accuracy
figure(8)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% e. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% D.3. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
Pmodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',1);

% c. Plot TRF weights
lim = max(max(abs(Pmodel.w(1,13:18,:)),[],3),[],2);
feats = {'Glo','Lin-v','Lin-p','Lin-a','Lin-d','Lab-d','Bi','Dip','Ba','Cen',...
    'Fro','Vls','Voi','Gli','Liq','Nas','Aff','Fri','Plo'};
figure(9)
subplot(2,2,1)
imagesc(Pmodel.t,1:19,squeeze(Pmodel.w(:,:,chan)))
set(gca,'YDir','normal','ytick',1:19,'yticklabel',feats)
xlim([-50,300])
title({'Spectrotemporal Response';'Function (STRF) at Fz'})
xlabel('Time lag (ms)')
ylabel('Frequency (Hz)')
subplot(2,2,2)
imagesc(Pmodel.t,1:19,std(Pmodel.w,[],3))
set(gca,'YDir','normal','ytick',1:19,'yticklabel',feats)
xlim([-50,300])
title('Global Field Power (GFP)')
xlabel('Time lag (ms)')
subplot(2,2,3)
topoplot(squeeze(Pmodel.w(1,13,:)),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Pmodel.t(13)),' ms (Glottal)'])
subplot(2,2,4)
topoplot(squeeze(Pmodel.w(1,18,:)),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(Pmodel.t(18)),' ms (Glottal)'])

%% D.4. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,Pmodel,'zeropad',0);

% b. Plot prediction
figure(8)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan)*5,'linewidth',1.5), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% D.5. Cross-validation for combined model

% a. Concatenate spectrogram and phonetic features
spec_phon = cell(1,length(stim.data));
for i = 1:length(stim.data)
    spec_phon{i} = [spectrograms{i},stim.data{3,i}];
end

% a. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = spec_phon; 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = spec_phon(test_trials);
eeg_test = eeg.data(test_trials);

% b. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 350;
lamda_idx = -4:2:10;
lambda_vals = 10.^lamda_idx;
nlambda = numel(lambda_vals);

% c. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% d. Plot CV accuracy
figure(10)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% e. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% D.6. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
SPmodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

%% D.7. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,SPmodel,'zeropad',0);

% b. Plot prediction
figure(10)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan)*5,'linewidth',1.5), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on

%% E.1. Semantic vectors visualization

% a. Get semantic vectors
sem_vec = stim.data{4,1};

% b. Plot envelope
figure(11)
subplot(2,1,1)
plot((1:length(audio))/fs,audio)
title('Audio (44.1 kHz)')
ylabel('Amplitude (a.u.)')
xlim([0.75,4])
subplot(2,1,2)
plot((1:length(sem_vec))/stim.fs,sem_vec)
title('Semantic Dissimilarity (64 Hz)')
ylabel('1 - correlation')
xlabel('Time (s)')
xlim([0.75,4])

%% E.2. Cross-validation

% b. Define training and test sets
test_trials = 10:13; % 20% of data
stim_train = stim.data(4,:); 
eeg_train = eeg.data;
stim_train(test_trials) = [];
eeg_train(test_trials) = [];
stim_test = stim.data(4,test_trials);
eeg_test = eeg.data(test_trials);

% c. Model hyperparameters
chan = 85; % Fz
Dir = 1;
tmin = -100;
tmax = 550;
lamda_idx = -4:2:10;
lambda_vals = 10.^lamda_idx;
nlambda = numel(lambda_vals);

% d. Run fast cross-validation
disp('Running cross-validation...')
cv = mTRFcrossval(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda_vals,...
    'zeropad',0,'fast',1);

% e. Plot CV accuracy
figure(12)
subplot(2,2,1)
errorbar(1:nlambda,mean(cv.r(:,:,chan)),std(cv.r(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Accuracy')
xlabel('Regularization (1\times10^\lambda)')
ylabel('Correlation')
axis square, grid on

% f. Plot CV error
subplot(2,2,2)
errorbar(1:nlambda,mean(cv.err(:,:,chan)),std(cv.err(:,:,chan))/sqrt(numel(stim_train)),'linewidth',2)
set(gca,'xtick',1:nlambda,'xticklabel',lamda_idx), xlim([0,nlambda+1])
title('CV Error')
xlabel('Regularization (1\times10^\lambda)')
ylabel('MSE')
axis square, grid on

%% E.3. Model training

% a. Get optimal hyperparameters
[rmax,idx] = max(mean(cv.r(:,:,chan)));
lambda = lambda_vals(idx);

% b. Train model
disp('Training model...')
SDmodel = mTRFtrain(stim_train,eeg_train,eeg.fs,Dir,tmin,tmax,lambda,...
    'zeropad',0);

% c. Plot TRF weights
lim = max(max(abs(SDmodel.w(:,13:18,:)),[],3),[],2);
figure(13)
subplot(2,2,1)
plot(SDmodel.t,squeeze(SDmodel.w))
xlim([-50,500])
title('Temporal Response Function (TRF)')
xlabel('Time lag (ms)')
ylabel('Amplitude (a.u.)')
subplot(2,2,2)
plot(SDmodel.t,std(SDmodel.w,[],3))
xlim([-50,500])
title('Global Field Power (GFP)')
xlabel('Time lag (ms)')
subplot(2,3,4)
topoplot(SDmodel.w(:,13,:),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(SDmodel.t(13)),' ms'])
subplot(2,3,5)
topoplot(SDmodel.w(:,21,:),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(SDmodel.t(21)),' ms'])
subplot(2,3,6)
topoplot(SDmodel.w(:,34,:),eeg.chanlocs,'maplimits',[-lim,lim],'whitebk','on')
title([num2str(SDmodel.t(34)),' ms'])

%% E.4. Model testing

% a. Test model
disp('Testing model...')
[pred,test] = mTRFpredict(stim_test,eeg_test,SDmodel,'zeropad',0);

% b. Plot prediction
figure(12)
subplot(2,2,3)
plot((1:length(eeg_test{1}))/eeg.fs,eeg_test{1}(:,chan),'linewidth',1.5), hold on
plot((1:length(pred{1}))/eeg.fs,pred{1}(:,chan)*10,'linewidth',1.5), hold off
xlim([0,10])
title('Prediciton')
xlabel('Time (s)')
ylabel('Amplitude (a.u.)')
axis square, grid on
legend('Orig','Pred')

% c. Plot test correlation
subplot(2,2,4)
bar(1,rmax), hold on
bar(2,mean(test.r(chan))), hold off
xlim([0,3])
set(gca,'xtick',1:2,'xticklabel',{'Val.','Test'})
title('Model Performance')
xlabel('Dataset')
ylabel('Correlation')
axis square, grid on