from random import shuffle, seed
from collections import defaultdict
from itertools import product, combinations
import pandas as pd
import umap
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from tqdm.auto import tqdm
import utils


def get_base_df(sentences):
    data = defaultdict(list)
    for i, j in combinations(range(len(sentences)), 2):
        s1 = sentences[i]
        s2 = sentences[j]
        # data['VerbOverlap'].append(
        #     len(set.intersection(set(s1['verbs']), set(s2['verbs'])))
        # )
        # data['NounOverlap'].append(
        #     len(set.intersection(set(s1['nouns']), set(s2['nouns'])))
        # )
        # data['VerbNounPairOverlap'].append(
        #     len(set.intersection(set(s1['so_pairs']), set(s2['so_pairs'])))
        # )
        data['V1Same'].append(int(s1['verbs'][0] == s2['verbs'][0]))
        data['V2Same'].append(int(s1['verbs'][1] == s2['verbs'][1]))
        data['V3Same'].append(int(s1['verbs'][2] == s2['verbs'][2]))
        data['N1Same'].append(int(s1['nouns'][0] == s2['nouns'][0]))
        data['N2Same'].append(int(s1['nouns'][1] == s2['nouns'][1]))
        data['N3Same'].append(int(s1['nouns'][2] == s2['nouns'][2]))
    return pd.DataFrame.from_dict(data)


seed(42)
reducer = umap.UMAP(metric='precomputed')

nouns_animals = [
    # 'cat', 'dog', 'rat', 'giraffe', 'wombat', 'hippo'
    'mouse', 'fox', 'horse', 'kangaroo', 'bison', 'elephant'
]
verbs_transitive = [
    # 'sees', 'chases', 'draws', 'meets', 'remembers', 'pokes'
    'hears', 'pursues', 'imagines', 'recognizes', 'touches', 'finds'
]

sentences = []
for (n1, n2, n3), (v1, v2, v3) in product(
    combinations(nouns_animals, 3),
    combinations(verbs_transitive, 3)
):
    # Each combination of nouns and verbs is unique.
    # Randomise the order of nouns and verbs to alleviate
    # possible order biases.
    ns = [n1, n2, n3]
    shuffle(ns)
    n1, n2, n3 = ns
    vs = [v1, v2, v3]
    shuffle(vs)
    v1, v2, v3 = vs
    sentences.append({
        'text': f'The man {v1} the {n1}, {v2} the {n2}, and {v3} the {n3}.',
        'verbs': [v1, v2, v3],
        'nouns': [n1, n2, n3],
        'so_pairs': [
            (v1, n1),
            (v2, n2),
            (v3, n3)
        ]
    })

l = len(sentences)
print(l, 'sentences,', l * (l-1) // 2, 'sentence pairs')

text = [s['text'] for s in sentences]
embeddings_mpnet = utils.compute_embeddings(text, utils.mpnet)
embeddings_distilro = utils.compute_embeddings(text, utils.distilro)
_, embeddings_bert = utils.compute_embeddings_from_tokens(
    text, utils.bert_large
)

cosine_mpnet = utils.self_cosine_sim(embeddings_mpnet)
cosine_distilro = utils.self_cosine_sim(embeddings_distilro)
cosine_bert = utils.self_cosine_sim(embeddings_bert)

# Fit a model to predict similarities from sentence properties
print('Preparing data frames...')
base_df = get_base_df(sentences)
base_df.to_csv('../csv/coordinated_vp_predictor_df.csv', index=False)

mpnet_sim = utils.get_similarities(sentences, cosine_mpnet)
distilro_sim = utils.get_similarities(sentences, cosine_distilro)
bert_sim = utils.get_similarities(sentences, cosine_bert)
pd.DataFrame({
    'mpnet': mpnet_sim,
    'distilro': distilro_sim,
    'bert': bert_sim
}).to_csv('../csv/coordinated_vp_sim_df.csv', index=False)

print('Fitting the models...')
# It is awkward to residualise this using Python,
# so additional analysis was done in R.
# del base_df['VerbNounPairOverlap']
lr = LinearRegression()
utils.fit_model_and_print_X_y(lr, base_df, mpnet_sim)
print()
utils.fit_model_and_print_X_y(lr, base_df, distilro_sim)
print()
utils.fit_model_and_print_X_y(lr, base_df, bert_sim)
