import sys
from collections import defaultdict
from itertools import product, combinations
import pandas as pd
import umap
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from tqdm.auto import tqdm
import utils


def get_base_df(sentences):
    data = defaultdict(list)
    l = len(sentences)
    for i, j in tqdm(combinations(range(l), 2), total=l*(l-1)//2, desc='Preparing the model data frame'):
        s1 = sentences[i]
        s2 = sentences[j]
        for var_name in ['modifier', 'adv', 'pred']:
            data[f'Same{var_name.capitalize()}'].append(
                int(s1[var_name] == s2[var_name]))
        data['SubjObj'].append(utils.get_subj_obj_coding(
            s1['subj'], s1['obj'],
            s2['subj'], s2['obj']
        ))
    return pd.DataFrame.from_dict(data)


reducer = umap.UMAP(metric='precomputed')

# Model names
mpnet = 'sentence-transformers/all-mpnet-base-v2'
distilro = 'sentence-transformers/all-distilroberta-v1'
bert_large = 'bert-large-uncased'

# Words for constructing sentences
articles = ['the', 'a']
nouns = ['cat', 'dog', 'rat', 'giraffe', 'wombat', 'hippo']
np_modifiers = [
    'with big shiny eyes',
    'that my brother saw yesterday',
    'whose photo was in the papers',
    'worth a great deal of money'
]
adverbs = ['quickly', 'slowly']
verbs = ['sees', 'chases', 'draws', 'meets', 'remembers', 'pokes']
puncts = ['.', '!']

sentences = []
for noun1, noun2 in product(nouns, nouns):
    if noun1 == noun2:
        continue
    # for det1, det2 in product(articles, articles):
    # for modifier, adverb, verb, punct in product(
    for modifier, adverb, verb in product(
        np_modifiers,
        adverbs,
        verbs
        # puncts
    ):
        sentences.append({
            # 'text': f'{det1.capitalize()} {noun1} {modifier} {adverb} {verb} {det2} {noun2}{punct}',
            'text': f'The {noun1} {modifier} {adverb} {verb} the {noun2}.',
            # 'det1': det1,
            # 'det2': det2,
            'subj': noun1,
            'obj': noun2,
            'modifier': modifier,
            'adv': adverb,
            'pred': verb
            # 'punct': punct
        })
l = len(sentences)
print(l, 'sentences,', l * (l-1) // 2, 'sentence pairs')

text = [s['text'] for s in sentences]
embeddings_mpnet = utils.compute_embeddings(text, mpnet)
embeddings_distilro = utils.compute_embeddings(text, distilro)
_, embeddings_bert = utils.compute_embeddings_from_tokens(
    text, bert_large
)

cosine_mpnet = utils.self_cosine_sim(embeddings_mpnet)
cosine_distilro = utils.self_cosine_sim(embeddings_distilro)
cosine_bert = utils.self_cosine_sim(embeddings_bert)

if '--draw-plot' in sys.argv:
    print('Applying UMAP...')
    eps_1 = 1.00001
    # Convert similarities to distances for dimensionality reduction
    mpnet_umap = reducer.fit_transform(eps_1 - cosine_mpnet)
    distilroberta_umap = reducer.fit_transform(eps_1 - cosine_distilro)
    bert_umap = reducer.fit_transform(eps_1 - cosine_bert)

    # Indices for colour maps
    noun_idx = utils.list2dict(nouns)
    modifier_idx = utils.list2dict(np_modifiers)
    adverb_idx = utils.list2dict(adverbs)
    verb_idx = utils.list2dict(verbs)

    # Draw the embedding plots
    print('Drawing the plot...')
    plt.figure(figsize=(12, 24))
    size = 7

    colourings = ['subj', 'obj', 'modifier', 'adv', 'pred']
    indices = [noun_idx, noun_idx, modifier_idx, adverb_idx, verb_idx]
    colouring_index_mapping = dict(zip(colourings, indices))
    embeddings = [mpnet_umap, distilroberta_umap, bert_umap]
    cmap = plt.get_cmap('Set1')
    for row, colouring in enumerate(colourings):
        for col, embedding in enumerate(embeddings):
            c = utils.get_colours_new(
                cmap,
                colouring_index_mapping,
                colouring,
                sentences
            )
            plt.subplot(len(colourings), len(embeddings),
                        len(embeddings) * row + col + 1)
            scatter = plt.scatter(
                embedding[:, 0],
                embedding[:, 1],
                c=c,
                s=size
            )
            # Creating legend with color box
            if col == 2:
                patches = []
                for k, v in indices[row].items():
                    patches.append(Patch(color=cmap(v), label=k))
                plt.legend(handles=patches, bbox_to_anchor=(
                    1.6, 1.025), title=colouring)

            plt.gca().set_aspect('equal', 'datalim')
            plt.title(['mpnet', 'distilroberta', 'bert'][col], fontsize=12)
    plt.tight_layout()
    plt.savefig('../img/np_modifier.pdf')

if '--fit-models' in sys.argv:
    # Fit a model to predict similarities from sentence properties
    print('Preparing data frames...')
    base_df = pd.get_dummies(get_base_df(sentences), drop_first=True)
    mpnet_sim = utils.get_similarities(sentences, cosine_mpnet)
    distilro_sim = utils.get_similarities(sentences, cosine_distilro)
    bert_sim = utils.get_similarities(sentences, cosine_bert)

    print('Fitting the models...')
    lr = LinearRegression()
    utils.fit_model_and_print_X_y(lr, base_df, mpnet_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, distilro_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, bert_sim)
