% Cognition and Natural Sensory Processing (CNSP) Workshop
% Example N - Banded Ridge Regression (forward TRF)
%
% This example script loads and preprocesses a publicly available dataset
% (you can use any of the dataset in the CNSP resources). Then, the script
% runs a typical forward TRF analysis.
%
% Note:
% This code was written with the assumption that all subjects were
% presented with the same set of stimuli. Hence, we use a single stimulus
% file (dataStim.mat) that applies to all subjects. This is compatible
% with scenarios with randomise presentation orders. In that case, the
% EEG/MEG trials should be sorted to match the single stimulus file.
% The original order is preserved in a specific CND variable. If distinct
% subjects were presented with different stimuli, it is necessary to
% include a stimulus file per participant.
%
% CNSP-Workshop 2022
% https://cnsp-workshop.github.io/website/index.html
% Author: Aaron Nidiffer
% Copyright 2021 - Giovanni Di Liberto
%                  Nathaniel Zuk
%                  Michael Crosse
%                  Aaron Nidiffer
%                  Giorgia Cantisani
%                  (see license file for details)
% Last update: 24 June 2022

clear all
close all
clc

% Add dependencies to path
addpath ../libs/cnsp_utils
addpath ../libs/cnsp_utils/cnd
addpath ../libs/mTRF-Toolbox_v2/mtrf
addpath ../libs/NoiseTools
addpath ../libs/eeglab
% initialize eeglab
eeglab

% Get rid of pesky EEGLAB garbage figure and variables
close all
clear ALLCOM ALLEEG CURRENTSET CURRENTSTUDY EEG globalvars LASTCOM PLUGINLIST STUDY

%% Parameters - Natural speech listening experiment
dataMainFolder = '../datasets/LalorNatSpeech/';
% dataMainFolder = '../datasets/LalorNatSpeechReverse/';
dataCNDSubfolder = 'dataCND/';

reRefType = 'Avg'; % or 'Mastoids'
bandpassFilterRange = [1,8]; % Hz (indicate 0 to avoid running the low-pass
% or high-pass filters or both)
% e.g., [0,8] will apply only a low-pass filter
% at 8 Hz
downFs = 64; % Hz. *** fs/downFs must be an integer value ***


Subs = 1:20; % all subjects
Subs = 12; % 5 10 exemplar subjects

if downFs < bandpassFilterRange(2)*2
    disp('Warning: Be careful. The low-pass filter should use a cut-off frequency smaller than downFs/2')
end

%% Preprocess EEG - Natural speech listening experiment
if 1
    for sub = Subs
        % Loading EEG data
        eegFilename = [dataMainFolder,dataCNDSubfolder,sprintf('dataSub%i.mat',sub)];
        fprintf('Loading EEG data: dataSub%i.mat',sub)
        load(eegFilename,'eeg')
        eeg = cndNewOp(eeg,'Load'); % Saving the processing pipeline in the eeg struct
        
        % Filtering - LPF (low-pass filter)
        if bandpassFilterRange(2) > 0
            hd = getLPFilt(eeg.fs,bandpassFilterRange(2));
            
            % A little coding trick - for loop vs cellfun
            if (0)
                % Filtering each trial/run with a for loop
                for ft = 1:length(eeg.data)
                    eeg.data{ft} = filtfilthd(hd,eeg.data{ft});
                end
            else
                % Filtering each trial/run with a cellfun statement
                eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
            end
            
            % Filtering external channels
            if isfield(eeg,'extChan')
                for extIdx = 1:length(eeg.extChan)
                    eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
                end
            end
            
            eeg = cndNewOp(eeg,'LPF');
        end
        
        % Downsampling EEG and external channels
        if downFs < eeg.fs
            eeg = cndDownsample(eeg,downFs);
        end
        
        % Filtering - HPF (high-pass filter)
        if bandpassFilterRange(1) > 0
            hd = getHPFilt(eeg.fs,bandpassFilterRange(1));
            
            % Filtering EEG data
            eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
            
            % Filtering external channels
            if isfield(eeg,'extChan')
                for extIdx = 1:length(eeg.extChan)
                    eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
                end
            end
            
            eeg = cndNewOp(eeg,'HPF');
        end
        
        % Replacing bad channels
        if isfield(eeg,'chanlocs')
            for ft = 1:length(eeg.data)
                eeg.data{ft} = removeBadChannels(eeg.data{ft}, eeg.chanlocs);
            end
        end
        
        % Re-referencing EEG data
        eeg = cndReref(eeg,reRefType);
        
        % Removing initial padding (specific to this dataset)
        if isfield(eeg,'paddingStartSample')
            for ft = 1:length(eeg.data)
                eeg.data{ft} = eeg.data{ft}(eeg.paddingStartSample,:);
                for extIdx = 1:length(eeg.extChan)
                    eeg.extChan{extIdx}.data = eeg.extChan{extIdx}.data{ft}(1+eeg.paddingStartSample,:);
                end
            end
        end
        
        % Saving preprocessed data
        eegPreFilename = [dataMainFolder,dataCNDSubfolder,sprintf('pre_dataSub%i.mat',sub)];
        fprintf('Saving preprocessed EEG data: pre_dataSub%i.mat',sub)
        save(eegPreFilename,'eeg')
    end
end
%% Banded Ridge Regression

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODOs:
%  1 - Go through the code and fill out the gaps
%  2 - Construct new stimulus feature sets with envelope, spectrogram, word 
%  onset, and phonetic features
%  3 - Construct the  banded ridge grouping variable for the multivariate
%  features model
%  4 - calclulate the optimal lambda and extract the best model performance
%  for each banded ridge model.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% stats = mTRFcrossval(stim,resp,fs,dirTRF,tmin,tmax,lambdas);
% TRF hyperparameters - some may change in the code below
tminF =   -200;           % Gather some pre-stim weights for TRF visualization
tmaxF =   600;            % Long max for TRF visualization
tminR =   0;              % Reduced for EEG prediction
tmaxR =   300;            % Reduced for EEG prediction
lambdas = 10.^(-2:2:2);   % small set of lambdas (quick)
lambdas = 10.^(-4:2:8);   % larger set of lambdas (slower)
dirTRF =  1;              % 1=forward model (encoder/TRF), 2=backward model (decoder)

% Backward banded ridge regression only makes sense in certain cases.
% Be careful: backward models with many electrodes and large time-windows
% can require long computational times. So, we suggest reducing the
% dimensionality if you are just playing around with the code (e.g., select
% only few electrodes and/or reduce the TRF window).

% Loading Stimulus data
stimFilename = [dataMainFolder,dataCNDSubfolder,'dataStim.mat'];
disp(['Loading stimulus data: ','dataStim.mat'])
load(stimFilename,'stim')

% Downsampling stim if necessary
if downFs < stim.fs
    stim = cndDownsample(stim,downFs);
end

% Normalize stim data (preserving the ratio between features)
% For continuous signals e.g., speech envelope, eeg. Skips features
% with a lot of zeros (i.e., stick models)

% These CND funcitons only checks one featureset at a time, so make
% sure all featuresets have the same dimensionality before doing this
for ft = 1:size(stim.data,1)
    % First pluck out each feature
    stim_ = stim;
    stim_.data = stim.data(ft,:);
    
    % Then Normalize continuous features
    stim_ = cndNormalise(stim_);
    
    % Return back to stim matrix
    stim.data(ft,:) = stim_.data;
end
clear stim_

% Check for all stim featuresets to have same time dimensionality
dim = cell(size(stim.data)); dim(:) = deal({1});
lens = cellfun(@size,stim.data,dim);
if any(~all(lens==lens(1,:),1))
    warning(sprintf('Time Dimension on %i trials are inconsistent. \n',find(~all(lens==lens(1,:),1))))
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% TODO: Construct Stimulus featuresets with envelope (E), spectrogram (S),
% phonetic features (F) and Word Onsets (O). We need at least EO, SF, and
% SFO sets. Add them to the end of stim.data. i.e., stim.data(end+(1:2),:).
% For example, EO can be constructed for each trial, tr, by calling:
% stim.data{5,tr} = [stim.data{1,tr} stim.data{2,tr}];
% And don't forget to update the names structure for each new model:
% stim.names{5} = [stim.names{1} stim.names{2}];

% If stuck, I've provided a script:
constructStim;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

for sub = Subs
    % Loading preprocessed EEG
    eegPreFilename = [dataMainFolder,dataCNDSubfolder,sprintf('pre_dataSub%i.mat',sub)];
    fprintf('Loading preprocessed EEG data: pre_dataSub%i.mat',sub)
    load(eegPreFilename,'eeg')
    
    % Downsample EEG if necessary
    if downFs < eeg.fs
        eeg = cndDownsample(eeg,downFs);
    end
    
    % Making sure that stim and neural data have the same length
    % The trial may end a few seconds after the end of the audio
    % e.g., the neural data may include the break between trials
    % It would be best to do this chunking at preprocessing, but let's
    % check here, just to be sure
    
    % These CND funcitons only checks one featureset at a time, so make
    % sure all featuresets have the same dimensionality before doing this
    for ft = 1:size(stim.data,1)
        % First pluck out each feature
        stim_ = stim;
        stim_.data = stim.data(ft,:);
        
        % Then standardize its length if necessary
        [stim_,eeg] = cndCheckStimNeural(stim_,eeg);
        
        % Return back to stim matrix
        stim.data(ft,:) = stim_.data;
    end
    clear stim_

    % Standardise neural data (preserving the ratio between channels)
    eeg = cndNormalise(eeg);
    
    % Now the TRFs
    % TRF crossvalidation
    
    % Determining optimal regularisation parameter for each feature separately.
    fprintf('\n**************************************\n')
    for ft = 1:size(stim.data,1)
        fprintf('Running mTRFcrossval for feature %s \n',stim.names{ft})
        stats = mTRFcrossval(stim.data(ft,:),eeg.data,eeg.fs,dirTRF,tminR,tmaxR,lambdas,'verbose',0);
        
        % Calculating optimal lambda. Display and store results
        [maxR,Ilambda] = max(squeeze(mean(mean(stats.r,1),3)));
        l_unbanded(sub,ft) = lambdas(Ilambda);
        fprintf('r = %5.3f, best lambda = %g \n',maxR,l_unbanded(sub,ft))
        r_mean_unbanded(sub,ft) = maxR;
        r_unbanded(:,sub,ft) = squeeze(mean(stats.r(:,Ilambda,:),1));
    end
    
    % Banded TRF crossvalidation
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: After inspecting the optimal lambda above, should we narrow or 
    % expand lambda range for any feature model?
    % lambdas = [];
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % Feel free to play around with tmin and tmax.
    % You could try an early (0-150) or late (150-300) window
    % tminR = [];
    % tmaxR = [];
    
    fprintf('\n**************************************\n')
    % EO MODEL - two univariate (1+1) features
    ft = 5;
    fprintf('Running mTRFcvbanded for feature %s \n',stim.names{ft})
    grouping = [1 2]; % Grouping parameter: 2 features, 2 bands
                      % Each element corresponding to a feature variate
                      % with its value specifying the regularization band.
    stats = mTRFcvbanded(stim.data(ft,:),eeg.data,eeg.fs,dirTRF,tminR,tmaxR,lambdas,grouping);
    plotBanded(stats,stim.names{ft})
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: calculate optimal lambda combinations and extract best predictions
    [maxR,Ilambda] = max(squeeze(mean(mean(stats.r,1),3)));
    l_banded{sub,ft} = stats.lambdas(Ilambda,:); 
    fprintf('r = %5.3f, best lambdas = %g, %g \n',maxR,stats.lambdas(Ilambda,:))
    r_mean_banded(sub,ft) = maxR;
    r_banded(:,sub,ft) = squeeze(mean(stats.r(:,Ilambda,:),1));
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    
    % SF MODEL - two multivariate (16+19) features
    ft = 6;
    fprintf('Running mTRFcvbanded for feature %s \n',stim.names{ft})
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: Construct the grouping parameter for SF model
    grouping = [ones(1,16) 2.*ones(1,19)];
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    stats = mTRFcvbanded(stim.data(ft,:),eeg.data,eeg.fs,dirTRF,tminR,tmaxR,lambdas,grouping);
    plotBanded(stats,stim.names{ft})
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: calculate optimal lambda combination and extract best predictions
    % Can modify from previous call to mTRFcvbanded()
    [maxR,Ilambda] = max(squeeze(mean(mean(stats.r,1),3)));
    l_banded{sub,ft} = stats.lambdas(Ilambda,:); 
    fprintf('r = %5.3f, best lambdas = %g, %g \n',maxR,stats.lambdas(Ilambda,:))
    r_mean_banded(sub,ft) = maxR;
    r_banded(:,sub,ft) = squeeze(mean(stats.r(:,Ilambda,:),1));
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    
    % SFO MODEL - two multivariate and one univariate (16+19+1) features
    ft = 7;
    fprintf('Running mTRFcvbanded for feature %s \n',stim.names{ft})
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: Construct the grouping parameter for SF model
    grouping = [ones(1,16) 2.*ones(1,19) 3];
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    stats = mTRFcvbanded(stim.data(ft,:),eeg.data,eeg.fs,dirTRF,tminR,tmaxR,lambdas,grouping);
    % plotBanded(stats,stim.names{ft}) % won't work. only supports 2 band visualization
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % TODO: calculate optimal lambda combination and extract best predictions
    % Can modify from previous call to mTRFcvbanded()
    [maxR,Ilambda] = max(squeeze(mean(mean(stats.r,1),3)));
    l_banded{sub,ft} = stats.lambdas(Ilambda,:); 
    fprintf('r = %5.3f, best lambdas = %g, %g, %g \n',maxR,stats.lambdas(Ilambda,:))
    r_mean_banded(sub,ft) = maxR;
    r_banded(:,sub,ft) = squeeze(mean(stats.r(:,Ilambda,:),1));
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
end

save('bandedRidge_ws.mat')