# In this analysis, we check what components of a simple
# transitive sentence most influence its representation.

import sys
from collections import defaultdict
from itertools import product, combinations
import pandas as pd
import umap
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from tqdm.auto import tqdm
import utils


def get_base_df(sentences):
    data = defaultdict(list)
    l = len(sentences)
    for i, j in tqdm(combinations(range(l), 2), total=l * (l-1) // 2):
        s1 = sentences[i]
        s2 = sentences[j]
        for param in ['det', 'adv', 'pred', 'punct', 'subj']:
            data[f'Same{param.capitalize()}'].append(
                int(s1[param] == s2[param]))
        # same_subject = s1['subj'] == s2['subj']
        # if s1['subj'] in ['he', 'she', 'it'] and same_subject:
        #     data['SubjCoding'].append('SamePronoun')
        # elif same_subject:
        #     data['SubjCoding'].append('SameNoun')
        # else:
        #     data['SubjCoding'].append('Different')
    return pd.DataFrame.from_dict(data)


reducer = umap.UMAP(metric='precomputed')

# Model names
mpnet = 'sentence-transformers/all-mpnet-base-v2'
distilro = 'sentence-transformers/all-distilroberta-v1'
bert_large = 'bert-large-uncased'

# Words for constructing sentences
articles = ['the', 'a']
nouns_animals = ['wolf', 'bear']
nouns_humans = ['fruit', 'vegetable']
nouns_celestial_bodies = ['building', 'car']
nouns_nature = ['lightning', 'wave']
# pronouns = ['he', 'she', 'it']
subjects = (
    nouns_animals + nouns_humans +
    nouns_celestial_bodies + nouns_nature  # + pronouns
)
adverbs = ['suddenly', 'predictably']
verbs_intransitive = ['stabilizes', 'bursts']
verbs_labile = ['grows', 'shrinks']
verbs = verbs_intransitive + verbs_labile
puncts = ['.', '!']

print('Preparing sentences...')
sentences = []
for det, subj, adverb, verb, punct in product(
    articles,
    subjects,
    adverbs,
    verbs_intransitive + verbs_labile,
    puncts
):
    sentences.append({
        'text': f'{det.capitalize()} {subj} {adverb} {verb}{punct}',
        'det': det,
        'subj': subj,
        'adv': adverb,
        'pred': verb,
        'punct': punct
    })
l = len(sentences)
print(l, 'sentences,', l * (l-1) // 2, 'sentence pairs')

text = [s['text'] for s in sentences]
embeddings_mpnet = utils.compute_embeddings(text, mpnet)
embeddings_distilro = utils.compute_embeddings(text, distilro)
_, embeddings_bert = utils.compute_embeddings_from_tokens(
    text, bert_large
)

cosine_mpnet = utils.self_cosine_sim(embeddings_mpnet)
cosine_distilro = utils.self_cosine_sim(embeddings_distilro)
cosine_bert = utils.self_cosine_sim(embeddings_bert)

if '--draw-plot' in sys.argv:
    print('Applying UMAP...')
    eps_1 = 1.00001
    # Convert similarities to distances for dimensionality reduction
    mpnet_umap = reducer.fit_transform(eps_1 - cosine_mpnet)
    distilroberta_umap = reducer.fit_transform(eps_1 - cosine_distilro)
    bert_umap = reducer.fit_transform(eps_1 - cosine_bert)

    # Indices for colour maps
    det_idx = utils.list2dict(articles)
    noun_idx = utils.list2dict(subjects)
    adverb_idx = utils.list2dict(adverbs)
    verb_idx = utils.list2dict(verbs)
    punct_idx = utils.list2dict(puncts)

    print('Drawing the plot...')
    plt.figure(figsize=(6, 8))
    size = 7

    # colourings = ['det', 'subj', 'adv', 'pred', 'punct']
    # indices = [det_idx, noun_idx, adverb_idx, verb_idx, punct_idx]
    colourings = ['subj', 'pred']
    indices = [noun_idx, verb_idx]
    colouring_index_mapping = dict(zip(colourings, indices))
    # embeddings = [mpnet_umap, distilroberta_umap, bert_umap]
    embeddings = [mpnet_umap, bert_umap]
    cmap = plt.get_cmap('tab20')
    for row, colouring in enumerate(colourings):
        for col, embedding in enumerate(embeddings):
            c = utils.get_colours_new(
                cmap,
                colouring_index_mapping,
                colouring,
                sentences
            )
            plt.subplot(len(colourings), len(embeddings),
                        len(embeddings) * row + col + 1)
            scatter = plt.scatter(
                embedding[:, 0],
                embedding[:, 1],
                c=c,
                s=size
            )
            # Creating legend with color box
            if col == len(embeddings) - 1:
                patches = []
                for k, v in indices[row].items():
                    patches.append(Patch(color=cmap(v), label=k))
                plt.legend(handles=patches, bbox_to_anchor=(
                    1.6, 1.025), title=colouring)

            # plt.gca().set_aspect('equal', 'datalim')
            # plt.title(['mpnet', 'distilroberta', 'bert'][col], fontsize=8)
            plt.title(['mpnet', 'bert'][col], fontsize=10)
    plt.tight_layout()
    plt.savefig('../img/simple_intransitives.pdf')

if '--fit-models' in sys.argv:
    # Fit a model to predict similarities from sentence properties
    print('Preparing data frames...')
    base_df = pd.get_dummies(get_base_df(sentences), drop_first=True)
    mpnet_sim = utils.get_similarities(sentences, cosine_mpnet)
    distilro_sim = utils.get_similarities(sentences, cosine_distilro)
    bert_sim = utils.get_similarities(sentences, cosine_bert)

    print('Fitting the models...')
    lr = LinearRegression()
    utils.fit_model_and_print_X_y(lr, base_df, mpnet_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, distilro_sim)
    print()
    utils.fit_model_and_print_X_y(lr, base_df, bert_sim)
