# In this analysis, we check how the overlap in the participant
# set of a three-participant predicate compares to a situation
# where two participants are fixed and the third one is not from
# the original set.
# E.g., given a participant set 'cat' (0), 'dog' (1), 'rat' (2), we check
# if the sentence 'The cat introduced the dog to the rat' (012) is
# more similar to 'The dog introduced the rat to the cat' (120) than
# to 'The cat introduced the dog to the squirrel' (01X),
# 'The cat introduced the squirrel to the rat' (0X2), and
# 'The squirrel introduced the dog to the rat' (X12).
# We use two participant sets and two predicates.

import sys
from collections import defaultdict
from itertools import product, combinations, permutations
import pandas as pd
import umap
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from tqdm.auto import tqdm
import utils


def get_3_participant_coding(A, B, C, A_, B_, C_):
    '''
    We decompose the encoding of a sentence pair in
    two components: (1) the number of identical words
    in the same positions (possible values range from 0 to 3);
    (2) the size of the overlap in the participants
    (possible values range from 1 to 3).
    E.g., given an lhs sentence with participants ABC,
    some of the possible rhs sentences are ABC (3, 1), 
    ACB (1, 1), XCA (0, 0), ABX (2, 0), etc.
    '''
    same_pos = 0
    for p, q in zip([A, B, C], [A_, B_, C_]):
        if p == q:
            same_pos += 1
    overlap = len(set.intersection(
        set([A, B, C]),
        set([A_, B_, C_])
    ))
    return same_pos, overlap


def get_base_df(sentences):
    data = defaultdict(list)
    l = len(sentences)
    for i, j in tqdm(combinations(range(l), 2), total=l*(l-1)//2, desc='Preparing the model data frame'):
        s1 = sentences[i]
        s2 = sentences[j]
        data['Text'].append(f'{s1["text"]} vs. {s2["text"]}')
        for var_name in ['adv', 'pred']:
            data[f'Same{var_name.capitalize()}'].append(
                int(s1[var_name] == s2[var_name]))
        same_pos, overlap = get_3_participant_coding(
            s1['A'], s1['B'], s1['C'],
            s2['A'], s2['B'], s2['C']
        )
        data['SamePosCount'].append(same_pos)
        data['Overlap'].append(overlap)
    return pd.DataFrame.from_dict(data)


reducer = umap.UMAP(metric='precomputed')

# Model names
mpnet = 'sentence-transformers/all-mpnet-base-v2'
distilro = 'sentence-transformers/all-distilroberta-v1'
bert_large = 'bert-large-uncased'

# Words for constructing sentences
# nouns_basic = ['cat', 'dog', 'rat']
nouns_basic = ['horse', 'pig', 'donkey']
nouns_extra = ['elephant', 'bison', 'moose']
# verbs = ['shows', 'sells', 'describes']
verbs = ['gives', 'demonstrates', 'entrusts']
# adverbs = ['happily', 'quickly', 'secretly']
adverbs = ['suddenly', 'predictably', 'openly']

print('Preparing sentences...')
all_texts = set()
sentences = []
for X, verb, adverb in tqdm(product(
    nouns_extra,
    verbs,
    adverbs
), total=len(nouns_extra)*len(verbs)*len(adverbs)):
    for nouns_perm in permutations(nouns_basic):
        for pos in range(4):
            nouns_perm_copy = list(nouns_perm)
            if pos < 3:
                nouns_perm_copy[pos] = X
            A, B, C = nouns_perm_copy
            text = f'The {A} {adverb} {verb} the {B} to the {C}.'
            # If we do not use the extra noun, the basic permutation
            # stays the same, so we skip it.
            if text in all_texts:
                continue
            else:
                all_texts.add(text)
            sentences.append({
                'text': text,
                'A': A,
                'B': B,
                'C': C,
                'adv': adverb,
                'pred': verb
            })
l = len(sentences)
print(l, 'sentences,', l * (l-1) // 2, 'sentence pairs')

text = [s['text'] for s in sentences]
embeddings_mpnet = utils.compute_embeddings(text, mpnet)
embeddings_distilro = utils.compute_embeddings(text, distilro)
_, embeddings_bert = utils.compute_embeddings_from_tokens(
    text, bert_large
)

cosine_mpnet = utils.self_cosine_sim(embeddings_mpnet)
cosine_distilro = utils.self_cosine_sim(embeddings_distilro)
cosine_bert = utils.self_cosine_sim(embeddings_bert)

if '--draw-plot' in sys.argv:
    print('Applying UMAP...')
    eps_1 = 1.00001
    # Convert similarities to distances for dimensionality reduction
    mpnet_umap = reducer.fit_transform(eps_1 - cosine_mpnet)
    distilroberta_umap = reducer.fit_transform(eps_1 - cosine_distilro)
    bert_umap = reducer.fit_transform(eps_1 - cosine_bert)

    # Indices for colour maps
    noun_idx = utils.list2dict(nouns_basic + nouns_extra)
    adverb_idx = utils.list2dict(adverbs)
    verb_idx = utils.list2dict(verbs)

    # Draw the embedding plots
    print('Drawing the plot...')
    plt.figure(figsize=(12, 24))
    size = 7

    colourings = ['A', 'B', 'C', 'pred', 'adv']
    indices = [noun_idx, noun_idx, noun_idx, verb_idx, adverb_idx]
    colouring_index_mapping = dict(zip(colourings, indices))
    embeddings = [mpnet_umap, distilroberta_umap, bert_umap]
    cmap = plt.get_cmap('Set1')
    for row, colouring in enumerate(colourings):
        for col, embedding in enumerate(embeddings):
            c = utils.get_colours_new(
                cmap,
                colouring_index_mapping,
                colouring,
                sentences
            )
            plt.subplot(len(colourings), len(embeddings),
                        len(embeddings) * row + col + 1)
            scatter = plt.scatter(
                embedding[:, 0],
                embedding[:, 1],
                c=c,
                s=size
            )
            # Creating legend with color box
            if col == 2:
                patches = []
                for k, v in indices[row].items():
                    patches.append(Patch(color=cmap(v), label=k))
                plt.legend(handles=patches, bbox_to_anchor=(
                    1.6, 1.025), title=colouring)

            plt.gca().set_aspect('equal', 'datalim')
            plt.title(['mpnet', 'distilroberta', 'bert'][col], fontsize=12)
    plt.tight_layout()
    plt.savefig('../img/participant_set.pdf')

if '--fit-models' in sys.argv:
    # Fit a model to predict similarities from sentence properties
    print('Preparing data frames...')
    base_df = get_base_df(sentences)
    base_df.to_csv('participant_set_df.csv', index=False)
    del base_df['Text']
    # Normalisation
    base_df.Overlap = base_df.Overlap - base_df.Overlap.min()
    base_df.SamePosCount = base_df.SamePosCount - base_df.SamePosCount.min()
    mpnet_sim = utils.get_similarities(sentences, cosine_mpnet)
    distilro_sim = utils.get_similarities(sentences, cosine_distilro)
    bert_sim = utils.get_similarities(sentences, cosine_bert)

    mpnet_df = base_df.copy()
    mpnet_df['Similarity'] = mpnet_sim
    mpnet_df.to_csv('mpnet_participant_set_repl.csv', index=False)

    distilro_df = base_df.copy()
    distilro_df['Similarity'] = distilro_sim
    distilro_df.to_csv('distilro_participant_set_repl.csv', index=False)

    bert_df = base_df.copy()
    bert_df['Similarity'] = bert_sim
    bert_df.to_csv('bert_participant_set_repl.csv', index=False)

    # print('Fitting the models...')
    lr = LinearRegression()
    utils.fit_model_and_print_X_y(lr, base_df, mpnet_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, distilro_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, bert_sim)

    # % latex table generated in R 4.1.2 by xtable 1.8-4 package
    # % Thu Aug 04 17:00:32 2022
    # \begin{table}[ht]
    # \centering
    # \begin{tabular}{rrrrr}
    #   \hline
    #  & Estimate & Std. Error & t value & Pr($>$$|$t$|$) \\
    #   \hline
    # (Intercept) & 0.9992 & 0.0014 & 723.27 & 0.0000 \\
    #   SameAdv & 0.5741 & 0.0018 & 319.37 & 0.0000 \\
    #   SamePred & 0.5162 & 0.0018 & 287.19 & 0.0000 \\
    #   SamePosCount & 0.1265 & 0.0012 & 107.10 & 0.0000 \\
    #   OverlapResiduals & 0.4850 & 0.0014 & 343.57 & 0.0000 \\
    #    \hline
    # \end{tabular}
    # \end{table}
