from itertools import combinations
from math import ceil, log
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm

ss = StandardScaler()


# Model names
mpnet = 'sentence-transformers/all-mpnet-base-v2'
distilro = 'sentence-transformers/all-distilroberta-v1'
bert_large = 'bert-large-uncased'


def compute_embeddings(sentences, model_name, batch_size=8):
    '''
    Compute sentence embeddings using the API from
    sentence_transformers.
    '''
    model = SentenceTransformer(model_name)
    model.cuda()
    embeddings_batches = []
    batch_size = batch_size
    n_steps = ceil(len(sentences) / batch_size)
    for step in tqdm(range(n_steps), desc=model_name):
        i = step * batch_size
        embeddings_batches.append(
            model.encode(sentences[i: i + batch_size]))
    return np.vstack(embeddings_batches)


def compute_embeddings_from_tokens(sentences,
                                   model_name,
                                   include_special_tokens=True,
                                   batch_size=8):
    '''
    Compute sentence embeddings by averaging token embeddings
    from the final layer of a BERT-like encoder.
    '''
    tokeniser = AutoTokenizer.from_pretrained(
        model_name, add_prefix_space='roberta' in model_name)
    model = AutoModel.from_pretrained(model_name)
    model.eval()
    model.cuda()
    cls_embeddings = []
    averaged_embeddings = []
    n_steps = ceil(len(sentences) / batch_size)
    for step in tqdm(range(n_steps), desc=model_name):
        i = step * batch_size
        batch = [el.split() for el in sentences[i: i + batch_size]]
        tokenisation = tokeniser(batch, truncation=True, padding=True,
                                 return_tensors='pt', is_split_into_words=True)
        with torch.no_grad():
            model_outputs = model(
                **{k: v.cuda() for k, v in tokenisation.items()}).last_hidden_state
        cls_embeddings.append(model_outputs[:, 0])
        averaged_embeddings_batch = []
        for j in range(len(batch)):
            token_embeddings_sentence = []
            for k, word_id in enumerate(tokenisation.word_ids(batch_index=j)):
                if (word_id is not None) or (include_special_tokens and tokenisation.input_ids[j][k] != 0):
                    token_embeddings_sentence.append(model_outputs[j, k])
            averaged_embeddings_batch.append(
                torch.vstack(token_embeddings_sentence).mean(dim=0))
        averaged_embeddings.append(torch.vstack(averaged_embeddings_batch))
    return (
        torch.vstack(cls_embeddings).cpu().numpy(),
        torch.vstack(averaged_embeddings).cpu().numpy()
    )


def self_cosine_sim(df):
    return cosine_similarity(df, df)


def fit_model_and_print(lr, data, n_skipped_columns, n_perm=0):
    X = data.iloc[:, n_skipped_columns:]
    y = data.Similarity
    fit = lr.fit(X, y)
    original_coefs = fit.coef_
    max_param_name_len = max(map(len, ['Intercept'] + list(X.columns)))
    col_width = max_param_name_len + 2
    print('Intercept'.ljust(col_width), round(fit.intercept_, 2))
    if n_perm > 0:
        perm_coef = np.zeros((n_perm, len(original_coefs)))
        p_values = np.zeros(len(original_coefs))
        y_perm = np.copy(y)
        for perm in tqdm(range(n_perm), desc='Computing p-values', leave=False):
            np.random.shuffle(y_perm)
            perm_fit = lr.fit(X, y_perm)
            perm_coef[perm, :] = perm_fit.coef_
        for coef_idx in range(len(original_coefs)):
            test_coefs = perm_coef[:, coef_idx]
            true_coef_abs = np.abs(original_coefs[coef_idx])
            test_coefs_abs = np.sort(np.abs(test_coefs))
            n_bigger = 0
            for idx in range(len(test_coefs_abs)-1, -1, -1):
                if test_coefs_abs[idx] > true_coef_abs:
                    n_bigger += 1
                else:
                    break
            p_values[coef_idx] = n_bigger / n_perm
    if n_perm > 0:
        for colname, coef, p in zip(X.columns, original_coefs, p_values):
            print(colname.ljust(col_width), round(coef, 2), round(p, 2))
    else:
        for colname, coef in zip(X.columns, original_coefs):
            print(colname.ljust(col_width), round(coef, 2))


def fit_model_and_print_X_y(lr, X, y, n_perm=0):
    fit = lr.fit(X, y)
    original_coefs = fit.coef_
    max_param_name_len = max(map(len, ['Intercept'] + list(X.columns)))
    col_width = max_param_name_len + 2
    print('Intercept'.ljust(col_width), round(fit.intercept_, 2))
    if n_perm > 0:
        perm_coef = np.zeros((n_perm, len(original_coefs)))
        p_values = np.zeros(len(original_coefs))
        y_perm = np.copy(y)
        for perm in tqdm(range(n_perm), desc='Computing p-values', leave=False):
            np.random.shuffle(y_perm)
            perm_fit = lr.fit(X, y_perm)
            perm_coef[perm, :] = perm_fit.coef_
        for coef_idx in range(len(original_coefs)):
            test_coefs = perm_coef[:, coef_idx]
            true_coef_abs = np.abs(original_coefs[coef_idx])
            test_coefs_abs = np.sort(np.abs(test_coefs))
            n_bigger = 0
            for idx in range(len(test_coefs_abs)-1, -1, -1):
                if test_coefs_abs[idx] > true_coef_abs:
                    n_bigger += 1
                else:
                    break
            p_values[coef_idx] = n_bigger / n_perm
    if n_perm > 0:
        for colname, coef, p in zip(X.columns, original_coefs, p_values):
            print(colname.ljust(col_width), round(coef, 2), round(p, 2))
    else:
        for colname, coef in zip(X.columns, original_coefs):
            print(colname.ljust(col_width), round(coef, 2))
        print('R-squared', round(lr.score(X, y), 2))


def list2dict(arr):
    return {el: i for i, el in enumerate(arr)}


def get_colours(cmap, idx_dict, key, sentences):
    return [
        cmap(idx_dict[s[key]])
        for s in sentences
    ]


def get_colours_new(cmap, colouring_to_dict_mapping, key, sentences):
    # The key for the value in the colour dict
    # and the key for the needed dict are the same.
    idx_dict = colouring_to_dict_mapping[key]
    return [
        cmap(idx_dict[s[key]])
        for s in sentences
    ]


def get_subj_obj_coding(subj1, obj1, subj2, obj2):
    if subj1 == subj2:
        if obj1 == obj2:
            return 'AB'
        else:
            return 'A0'
    elif obj1 == obj2:
        return '0B'
    elif subj2 == obj1:
        if obj2 == subj1:
            return 'BA'
        else:
            return 'B0'
    elif obj2 == subj1:
        return '0A'
    # This value should be the first one alphabetically,
    # so that the dummy encoder will make it the baseline.
    else:
        return '00'


def normalise_cosine(cosine_similarity):
    # Map to [0, 1].
    result = (cosine_similarity + 1) / 2
    # Map to the whole number line.
    return log(result / (1-result))


def get_similarities(sentences, similarities):
    l = len(sentences)
    result = []
    for i, j in tqdm(combinations(range(l), 2), total=l*(l-1)//2,
                     desc='Flattening the similarities'):
        # result.append(normalise_cosine(similarities[i, j]))
        result.append(similarities[i, j])
    # Perform standard scaling and return
    return ss.fit_transform(
        np.array(result).reshape(-1, 1)
    ).flatten()
