% Cognition and Natural Sensory Processing (CNSP) Workshop
% Example 3 - CCA
%
% This example script loads and preprocesses a publicly available dataset
% (you can use any of the dataset in the CNSP resources). Then, the script
% runs a CCA analysis, evaluated with correlation in CC space as well as
% with a match-vs-mismatch classification score.
%
% Note:
% This code was written with the assumption that all subjects were
% presented with the same set of stimuli. Hence, we use a single stimulus
% file (dataStim.mat) that applies to all subjects. This is compatible
% with scenarios with randomise presentation orders. In that case, the
% EEG/MEG trials should be sorted to match the single stimulus file. 
% The original order is preserved in a specific CND variable. If distinct
% subjects were presented with different stimuli, it is necessary to
% include a stimulus file per participant.
%
% CNSP-Workshop 2022
% https://cnsp-workshop.github.io/website/index.html
% Author: Giovanni M. Di Liberto
% Copyright 2021 - Giovanni Di Liberto
%                  Nathaniel Zuk
%                  Michael Crosse
%                  Aaron Nidiffer
%                  Giorgia Cantisani
%                  (see license file for details)
% Last update: 27 June 2022

clear all
close all

addpath ../libs/cnsp_utils
addpath ../libs/cnsp_utils/cnd
addpath ../libs/mTRF-Toolbox_v2/mtrf
addpath ../libs/NoiseTools
addpath ../libs/eeglab
eeglab


%% Parameters - Natural speech listening experiment
dataMainFolder = '../datasets/LalorNatSpeech/';
% dataMainFolder = '../datasets/LalorNatSpeechReverse/';
dataCNDSubfolder = 'dataCND/';

reRefType = 'Avg'; % or 'Mastoids'
bandpassFilterRange = [0.1,16]; % Hz (indicate 0 to avoid running the low-pass
                          % or high-pass filters or both)
                          % e.g., [0,8] will apply only a low-pass filter
                          % at 8 Hz
downFs = 32; % Hz. *** fs/downFs must be an integer value ***
             % Note that CCA is slower than the mTRF. As such, we will need
             % a heavier downsampling

eegFilenames = dir([dataMainFolder,dataCNDSubfolder,'dataSub*.mat']);
nSubs = length(eegFilenames);

if downFs < bandpassFilterRange(2)*2
    disp('Warning: Be careful. The low-pass filter should use a cut-off frequency smaller than downFs/2')
end

%% Preprocess EEG - Natural speech listening experiment
% Same preprocessing as in examples 1 and 2
% This time, we downsample the data to 32 Hz
% Also, CCA has less tight constraints in terms of filtering than TRF
% analyses. As such, we can use wider frequency ranges or even no filters
% at all

for sub = 1:nSubs
    % Loading EEG data
    eegFilename = [dataMainFolder,dataCNDSubfolder,eegFilenames(sub).name];
    disp(['Loading EEG data: ',eegFilenames(sub).name])
    load(eegFilename,'eeg')
    eeg = cndNewOp(eeg,'Load'); % Saving the processing pipeline in the eeg struct

    % Filtering - LPF (low-pass filter)
    if bandpassFilterRange(2) > 0
        hd = getLPFilt(eeg.fs,bandpassFilterRange(2));
        
        % A little coding trick - for loop vs cellfun
        if (0)
            % Filtering each trial/run with a for loop
            for ii = 1:length(eeg.data)
                eeg.data{ii} = filtfilthd(hd,eeg.data{ii});
            end
        else
            % Filtering each trial/run with a cellfun statement
            eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
        end
        
        % Filtering external channels
        if isfield(eeg,'extChan')
            for extIdx = 1:length(eeg.extChan)
                eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
            end
        end
        
        eeg = cndNewOp(eeg,'LPF');
    end
    
    % Downsampling EEG and external channels
    if downFs < eeg.fs
        eeg = cndDownsample(eeg,downFs);
    end
    
    % Filtering - HPF (high-pass filter)
    if bandpassFilterRange(1) > 0 
        hd = getHPFilt(eeg.fs,bandpassFilterRange(1));
        
        % Filtering EEG data
        eeg.data = cellfun(@(x) filtfilthd(hd,x),eeg.data,'UniformOutput',false);
        
        % Filtering external channels
        if isfield(eeg,'extChan')
            for extIdx = 1:length(eeg.extChan)
                eeg.extChan{extIdx}.data = cellfun(@(x) filtfilthd(hd,x),eeg.extChan{extIdx}.data,'UniformOutput',false);
            end  
        end
        
        eeg = cndNewOp(eeg,'HPF');
    end
    
    % Replacing bad channels
    if isfield(eeg,'chanlocs')
        for tr = 1:length(eeg.data)
            eeg.data{tr} = removeBadChannels(eeg.data{tr}, eeg.chanlocs);
        end
    end
    
    % Re-referencing EEG data
    eeg = cndReref(eeg,reRefType);
    
    % Removing initial padding (specific to this dataset)
    if isfield(eeg,'paddingStartSample')
        for tr = 1:length(eeg.data)
            eeg.data{tr} = eeg.data{tr}(eeg.paddingStartSample,:);
            for extIdx = 1:length(eeg.extChan)
                eeg.extChan{extIdx}.data = eeg.extChan{extIdx}.data{tr}(1+eeg.paddingStartSample,:);
            end
        end
    end
    
    % Saving preprocessed data
    eegPreFilename = [dataMainFolder,dataCNDSubfolder,'preCCA_',eegFilenames(sub).name];
    disp(['Saving preprocessed EEG data: pre_',eegFilenames(sub).name])
    save(eegPreFilename,'eeg')
end

%% Canonical Correlation Analysis - Step 3

% TODOs:
%  1 - Go through the code and fill out the gaps
%  2 - What is the difference between 'dili_ccaDataPrep_shifts' and 
%      'dili_ccaDataPrep'? Compare their impact on the result
%  3 - 'windowSize' and 'ncomps' only affect the match-vs-mismatch
%      evaluation score. The model is unchanged. Things are different for
%      nPCS, which strongly impacts the result. A small number of selected
%      PCs may reduce the risk for overfitting as well as making the
%      computations faster. However, it may also mean that useful
%      information is lost. Try changing those parameters and study their
%      impact.
          
% Stim parameters (which stimulus feature we are selecting)
stimIdx = 1; % 1: env; 2: word onset

% Loading Stimulus data
stimFilename = [dataMainFolder,dataCNDSubfolder,'dataStim.mat'];
load(stimFilename,'stim')

% CCA parameters
tmin = -1000; % ms - search window
tmax = 1000;  % ms
shifts = (floor(tmin/1000*stim.fs):4:ceil(tmax/1000*stim.fs));

% Preprocessing parameters
nPCS = 1;
% TODO TODO TODO: Try changing the number of PCs to keep for stimulus and EEG
% PCs to keep when preprocessing the stim and neural data
% (Same for stim and neural data here)

% Downsampling stim if necessary
if downFs < stim.fs, stim = cndDownsample(stim,downFs); end

% Run CCA for each subject
figure;
clear rAll rCCallSub accMMall
for sub = 1:nSubs
    % Loading preprocessed EEG
    eegPreFilename = [dataMainFolder,dataCNDSubfolder,'preCCA_',eegFilenames(sub).name];
    disp('Loading preprocessed EEG data')
    load(eegPreFilename,'eeg')
    
    % Downsampling eeg if necessary
    if downFs < eeg.fs, eeg = cndDownsample(eeg,downFs); end
    
    % CCA input variables and parameters:
    % TODO TODO TODO: Try changing the two parameters that follow. They
    % impact the match-vs-mismatch score
    windowSize = 30; % This is the window (in seconds) where the match-vs-mismatch
                    % score is calculated
    ncomp = 1; % Number of components for calculating the match-vs-mismatch score
    
    % Stim - time-lags and dimensionality reduction
    tminModel = 0; % ms - start time-lags window
    tmaxModel = 500;  % ms - end time-lag window
    shiftsModel = ((floor(tminModel/1000*stim.fs):1:ceil(tmaxModel/1000*stim.fs)))+1;

    xx = stim.data(stimIdx,:);
    xx = dili_ccaDataPrep_shifts(xx,nPCS,shiftsModel); % shifts of the envelope + PCA
%     xx = dili_ccaDataPrep(xx,nPCS); % modulations of the envelope + PCA


    % EEG - time-lags and dimensionality reduction
    yy = eeg.data;        
    % TODO TODO TODO: Try the two options that follow. Either
    % 'dili_ccaDataPrep_shifts' or 'dili_ccaDataPrep'. What is the
    % difference?
%     yy = dili_ccaDataPrep(yy,nPCS); % modulations of the EEG + PCA
    yy = dili_ccaDataPrep_shifts(yy,nPCS,shiftsModel); % shifts of the EEG + PCA
        
    % cca crossval, match-vs-mismatch version (mm)
    [AA,BB,RR,~,accMM] = ...
        nt_cca_crossvalidate_mm(xx,yy,shifts,windowSize*stim.fs,ncomp); 
    
    rAll = mean(RR,3); % RR: nPCs x shifts x trials

    % Storing all CCA correlations for the first CC
    rCCallSub(:,sub) = rAll(1,:)';

    % Storing MM classification (match-vs-mismatch)
    accMMall(:,sub) = accMM;
    
    disp(['rAll: ',num2str(rAll(1,:))])
    
    % Re-plot the results after adding the results from this subject
    subplot(1,2,1)
    plot(-shifts*stim.fs,rCCallSub,'.:','MarkerSize',20)
    xlabel('Shift')
    ylabel('CCA Correlations')
    title('CC_1 Correlations')
    run prepExport.m
    
    subplot(1,2,2)
    plot(-shifts*stim.fs,accMMall,'.:','MarkerSize',20)
    xlabel('Shift')
    ylabel('MM classification')
    title('Match-vs-mismatch')
    run prepExport.m
    
    drawnow
end
