import sys
from math import log
from collections import defaultdict
from itertools import product, combinations
import pandas as pd
import umap
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import utils


def get_base_df(sentences):
    data = defaultdict(list)
    l = len(sentences)
    for i, j in combinations(range(l), 2):
        s1 = sentences[i]
        s2 = sentences[j]
        # for var_name in ['subj', 'object', 'copula', 'adjective', 'pred']:
        for var_name in ['subj', 'copula', 'adjective', 'pred']:
            data[f'Same{var_name.capitalize()}'].append(
                int(s1[var_name] == s2[var_name]))
        if s1['object'] in {'it', 'them'} and s1['object'] == s2['object']:
            obj_pair = 'SamePron'
        elif s1['object'] == s2['object']:
            obj_pair = 'SameNoun'
        else:
            obj_pair = 'Different'  # The baseline
        data['ObjPair'].append(obj_pair)

    return pd.DataFrame.from_dict(data)


reducer = umap.UMAP(metric='precomputed')

# Model names
mpnet = 'sentence-transformers/all-mpnet-base-v2'
distilro = 'sentence-transformers/all-distilroberta-v1'
bert_large = 'bert-large-uncased'

# Words for constructing sentences
# gerunds = ['continuing', 'abandoning', 'starting', 'completing']
gerunds = ['proposing', 'rejecting', 'praising', 'criticizing']
# objects = ['it', 'them', 'the project', 'the plan']
objects = ['him', 'me', 'the idea', 'the design']
copulas = ['is', 'was', 'will be', 'is going to be']
# adjectives = ['big', 'real', 'negligible', 'insignificant']
adjectives = ['huge', 'massive', 'small', 'unimportant']
# subjects = ['solution', 'mistake', 'failure', 'triumph']
subjects = ['decision', 'defeat', 'loss', 'improvement']

sentences = []
for ger, pron, cop, adj, noun in product(
    gerunds,
    objects,
    copulas,
    adjectives,
    subjects
):
    sentences.append({
        'text': f'{ger.capitalize()} {pron} {cop} a {adj} {noun}.',
        'subj': ger,
        'object': pron,
        'copula': cop,
        'adjective': adj,
        'pred': noun
    })
l = len(sentences)
print(l, 'sentences,', l * (l-1) // 2, 'sentence pairs')

text = [s['text'] for s in sentences]
embeddings_mpnet = utils.compute_embeddings(text, mpnet)
embeddings_distilro = utils.compute_embeddings(text, distilro)
_, embeddings_bert = utils.compute_embeddings_from_tokens(
    text, bert_large
)

cosine_mpnet = utils.self_cosine_sim(embeddings_mpnet)
cosine_distilro = utils.self_cosine_sim(embeddings_distilro)
cosine_bert = utils.self_cosine_sim(embeddings_bert)

if '--draw-plot' in sys.argv:
    # Convert similarities to distances for dimensionality reduction
    eps_1 = 1.00001
    mpnet_umap = reducer.fit_transform(eps_1 - cosine_mpnet)
    distilroberta_umap = reducer.fit_transform(eps_1 - cosine_distilro)
    bert_umap = reducer.fit_transform(eps_1 - cosine_bert)

    # Indices for colour maps
    gerund_idx = utils.list2dict(gerunds)
    pronoun_idx = utils.list2dict(objects)
    copula_idx = utils.list2dict(copulas)
    adjective_idx = utils.list2dict(adjectives)
    noun_idx = utils.list2dict(subjects)

    # Draw the embedding plots
    plt.figure(figsize=(12, 24))
    size = 7

    colourings = ['subj', 'object', 'copula', 'adjective', 'pred']
    embeddings = [mpnet_umap, distilroberta_umap, bert_umap]
    cmap = plt.get_cmap('Set1')
    for row, colouring in enumerate(colourings):
        for col, embedding in enumerate(embeddings):
            if colouring == 'subj':
                c = utils.get_colours(cmap, gerund_idx, colouring, sentences)
            elif colouring == 'object':
                c = utils.get_colours(cmap, pronoun_idx, colouring, sentences)
            elif colouring == 'copula':
                c = utils.get_colours(cmap, copula_idx, colouring, sentences)
            elif colouring == 'adjective':
                c = utils.get_colours(
                    cmap, adjective_idx, colouring, sentences)
            else:
                c = utils.get_colours(cmap, noun_idx, colouring, sentences)
            plt.subplot(len(colourings), len(embeddings),
                        len(embeddings) * row + col + 1)
            scatter = plt.scatter(
                embedding[:, 0],
                embedding[:, 1],
                c=c,
                s=size
            )
            # Creating legend with color box
            if col == 2:
                patches = []
                for k, v in [
                    gerund_idx, pronoun_idx, copula_idx, adjective_idx, noun_idx
                ][row].items():
                    patches.append(Patch(color=cmap(v), label=k))
                plt.legend(handles=patches, bbox_to_anchor=(
                    1.6, 1.025), title=colouring)

            plt.gca().set_aspect('equal', 'datalim')
            plt.title(['mpnet', 'distilroberta', 'bert'][col], fontsize=12)
    plt.tight_layout()
    plt.savefig('../img/gerund_subj.pdf')

if '--fit-models' in sys.argv:
    # Fit a model to predict similarities from sentence properties
    print('Preparing data frames...')
    base_df = pd.get_dummies(get_base_df(sentences), drop_first=True)
    mpnet_sim = utils.get_similarities(sentences, cosine_mpnet)
    distilro_sim = utils.get_similarities(sentences, cosine_distilro)
    bert_sim = utils.get_similarities(sentences, cosine_bert)

    print('Fitting the models...')
    lr = LinearRegression()
    utils.fit_model_and_print_X_y(lr, base_df, mpnet_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, distilro_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, bert_sim)
