import numpy as np

import nltk
# nltk.download("stopwords")
from nltk.corpus import stopwords
from contractions import fix

from sklearn.linear_model import RidgeCV
from sklearn.linear_model import LogisticRegressionCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVR

from scipy.stats import zscore
from sklearn.model_selection import train_test_split, cross_val_predict
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from tqdm import tqdm



############################# PREPROCESS DATA ###############################
def lemmatize_and_expand(ds, wnl):
    # lemmatize as nouns
    ds.data = list(map(
        lambda w: wnl.lemmatize(w, pos='n'), 
        ds.data))
    
    # expand contractions
    ds.data = list(map(
        lambda w: fix(w).split()[0] if not w in ["", " "] else w, 
        ds.data))
    ds.data = list(map(
        lambda w: w[:-2] if (len(w)>2 and w[-2:]=="'s") else w, 
        ds.data))
    ds.data = list(map(
    lambda w: w[:-3] if (len(w)>3 and w[-3:]=="'re") else w, 
    ds.data))

    
def rate_word(word, ratings, S):
    if word in S:
        return ratings[ratings["Word"]==word]["Conc.M"].values[0]
    else:
        return 3.0

    
def rate_ds(ds, ratings, S):
    ds.data = list(map(
        lambda w: rate_word(w, ratings, S),
        ds.data))
    

def remove_stopwords(wordseq_chunks, rateseq_chunks, storyseq_chunks):
    assert len(wordseq_chunks) == len(rateseq_chunks)
    bools = list(map(
            lambda row: [(True if w not in stopwords.words("english") else False) for w in row],
            wordseq_chunks))
    
    basic_wordseq_chunks = [wordseq_chunks[i][bools[i]] for i in range(len(wordseq_chunks))]
    basic_rateseq_chunks = [rateseq_chunks[i][bools[i]] for i in range(len(rateseq_chunks))]
    basic_storyseq_chunks = [storyseq_chunks[i][bools[i]] for i in range (len(storyseq_chunks))]
    return basic_wordseq_chunks, basic_rateseq_chunks, basic_storyseq_chunks



############################# MEAN-DOWNSAMPLE ###############################
def summarize(basic_storyseq_chunks, basic_rateseq_chunks, basic_wordseq_chunks, model_name):
    if model_name == "w2v":
        chunk_vectors = [(np.mean(entry, axis=0) if len(entry) else np.zeros(300)) for entry in basic_storyseq_chunks]
    elif model_name == "gpt2":
        chunk_vectors = [(np.mean(entry, axis=0) if len(entry) else np.zeros(768)) for entry in basic_storyseq_chunks]
    
    chunk_ratings = [(np.mean(entry, axis=0) if len(entry) else np.array([3])) for entry in basic_rateseq_chunks]
    chunk_words =   [np.array(" ".join(list(entry)) if len(entry) else " ") for entry in basic_wordseq_chunks]
    return chunk_vectors, chunk_ratings, chunk_words



############################# RESIDUALS ###############################
def residual_predict(X, Y, y_strict):
    X1, Y1 = zscore(X), zscore(Y)

    model = RidgeCV(alphas=[1000, 5000, 10000, 50000, 100000], cv=5)  
    Y_pred = cross_val_predict(model, X1, Y1, cv=5)
    
    residuals = Y_pred - Y1
    return logisticCV(residuals, y_strict, z=False)

def evaluate_set(w_set, ratings, S):
    n = len(w_set)
    n_conc = 0
    n_abs = 0
    r_avg = 0
    
    for chunk in w_set:
        chunk_words = chunk.split(" ")
        n_chunk = len(chunk_words) 
        
        r_chunk = 0
        for w in chunk_words:
            r_chunk += rate_word(w, ratings, S)
        r_chunk /= n_chunk
        
        if r_chunk > 3:
            n_conc += 1
        else:
            n_abs += 1
    r_avg /= n
        
    return n_conc/n, n_abs/n



############################# MODELS ###############################
def permutation_test(X, y, n_perm=100):
    true_score = logisticCV(X, y)["A"]
    
    scores = []
    r = 0  # number of permutations with score >= true_score
    for _ in tqdm(range(n_perm)):
        score = logisticCV(X, np.random.permutation(y))["A"]
        scores.append(score)
        if score >= true_score:
            r += 1
    p = (r+1)/(n_perm+1)    
    return true_score, scores, p


def residual_permutations(X, Y, y_strict, n_perm=100):    
    true_score = residual_predict(X, Y, y_strict)["A"]
    
    scores = []
    r = 0  # number of permutations with score <= true_score
    for i in range(n_perm):
        print(f"Iteration {i+1}...")
        score = residual_predict(np.random.permutation(X), Y, y_strict)["A"]
        scores.append(score)
        if score <= true_score:
            r += 1
    p = (r+1)/(n_perm+1)    
    return true_score, np.mean(scores), p
    

def logisticCV(X, y, val_result=False, get_misclassified=False, words=None, z=True, get_coeffs=False, return_set=True):
    if z:
        X1, y1 = zscore(X), y
    else:
        X1, y1 = X, y
    
    if len(y1.shape)==2 and y1.shape[1]==1:
        y1 = np.ravel(y1)
    
    if not get_misclassified:
        words = np.zeros(y.shape)
    
    
    logreg = LogisticRegression(max_iter=10000)
    model = GridSearchCV(logreg, 
                         param_grid={"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]},
                         cv=5,
                         scoring="accuracy",
                         n_jobs=-1)
    
#     model = LogisticRegressionCV(Cs=[0.001, 0.01, 0.1, 1, 10, 100, 1000],
#                                  cv=5,
#                                  scoring="accuracy"
#                                  max_iter=10000,
#                                  verbose=0)
    X_train, X_test, y_train, y_test, _, words_test = train_test_split(X1, y1, words, test_size=0.25, shuffle=True, random_state=45)
    
    model.fit(X_train, y_train)
    
    if val_result:
        return model.best_score_
    
    else:
        y_pred = model.predict(X_test)

        if get_misclassified:
            misclassified_idx = (y_pred != y_test)
            misclassified_trs = words_test[misclassified_idx]
            if return_set:
                return set(misclassified_trs), misclassified_idx, y_pred
            else:
                return misclassified_trs, misclassified_idx, y_pred                
                
        results = {
          "A": accuracy_score(y_test, y_pred),
          "P": precision_score(y_test, y_pred),
          "R": recall_score(y_test, y_pred),  
          "F1": f1_score(y_test, y_pred)
        }

        if get_coeffs:
        	return results, model.best_estimator_.coef_
        else:
        	return results


def ridgeCV(X, y, return_predictions=False):
    X1, y1 = zscore(X), zscore(y)
    
    if len(y1.shape)==2 and y1.shape[1]==1:
        y1 = np.ravel(y1)
    
    model = RidgeCV(alphas=[1000, 5000, 10000, 50000, 100000], cv=5)
    X_train, X_test, y_train, y_test = train_test_split(X1, y1, test_size=0.25, shuffle=True, random_state=800)
    
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    if return_predictions:
        return y_test, y_pred
    else:
        print("Pearson:", pearsonr(y_pred, y_test)[0], 
              "Spearman:", spearmanr(y_pred, y_test)[0], '\n')
        
        
############################# VOXEL SELECTION ###############################
def explainable_variance(Y, do_zscore=True, bias_correction=True):
    """ Computes the explainable variance across repetition of voxel responses.
    Explainable variance is the amount of variance in a voxel's response that can be explained
    by the mean response across several repetitions. Repetitions are recorded while the voxel
    is exposed to the same stimulus several times.

    Parameters
    ----------
    Y : np.ndarray (nrepeats, nTRs, nvoxels)
        Repeated time course for each voxel. Each voxel and repeat is nTRs long.
        Repeats should be zscored across time samples.

    do_zscore : bool
        z-score the data across time. Only set to False
        if Y across time is already z-scored. Default is True.
    bias_correction : bool
        Bias correction for the number of repetitions

    Returns
    -------
    ev : np.array (nvoxels, 1)
        Explainable variance per voxel

    References
    ----------
    Schoppe et al. 2016, Hsu et al. 2004

    Compare to https://github.com/gallantlab/tikreg/blob/master/tikreg/utils.py

    """

    if do_zscore:
        Y = zscore(Y, axis=1)

    res = Y - Y.mean(axis=0)  # mean across reps
    res_var = np.mean(res.var(axis=1), axis=0)
    ev = 1 - res_var

    if bias_correction:
        ev = ev - ((1 - ev) / np.float((Y.shape[0] - 1)))

    return ev

        
        