import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

def post_process(embed, mean_center=False, normalize=True):
    # we had mean_center=True, but changed to False for latest experiments
    if normalize:
        embed = embed / np.linalg.norm(embed)
    if mean_center:
        embed = embed - np.mean(embed)
    if normalize:
        embed = embed / np.linalg.norm(embed)
    return embed

def get_PCA_(x, num_components, mean_center=True):
    """Shape is batch, feat"""
    if mean_center:
        mean = np.mean(x, axis=0)
        x = x - mean[None, :]
    # dum_mean = np.mean(x, axis=0)
    pca = PCA(n_components=num_components, svd_solver='full')
    new_x = pca.fit_transform(x)
    return new_x

def get_PCA(embeddings):
    """"""
    vectors = []
    for word, embed in embeddings.items():
        vectors.append(embed)
    """Let's replace the vectors with first N PCA components"""
    vectors = np.asarray(vectors)
    num_components = np.min((vectors.shape[1], vectors.shape[0]))
    # vectors = get_PCA_(vectors, num_components=vectors.shape[1])
    vectors = get_PCA_(vectors, num_components=num_components)
    count = 0
    new_embeds = {}
    for word, embed in embeddings.items():
        new_embeds[word] = vectors[count]
        count += 1
    return new_embeds

def get_embeddings(args, unique_words, lang, embed_type, do_PCA=True):
    config = get_config()
    # embed_dir = None
    embed_dir = os.path.join(config.directories.word_vectors, embed_type, lang)
    # if embed_type == 'SensEmBERT':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'SensEmBERT', lang)
    # if embed_type == 'ARES':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'ARES', lang)
    # if embed_type == 'fasttext':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'fasttext', lang)
    # if embed_type == 'BERT':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'BERT', lang)
    # if embed_type == 'cross_colex_binary':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_binary', lang)
    # if embed_type == 'cross_colex_sum':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum', lang)
    # if embed_type == 'cross_colex_pairwise_product':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_pairwise_product', lang)
    # if embed_type == 'colex_sum':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'colex_sum', lang)
    # if embed_type == 'cross_colex_sum_9':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_9', lang)
    # if embed_type == 'cross_colex_sum_50':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_50', lang)
    # if embed_type == 'cross_colex_sum_senses_PCA':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_senses_PCA', lang)
    # if embed_type == 'cross_colex_sum_senses_no_PCA':
    #     embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_senses_no_PCA', lang)
    embedding_files = collect_files(embed_dir)
    # if args.embed_type != "cross_colex_sum_senses_PCA" and args.embed_type != "cross_colex_sum_senses_no_PCA":
    if 'maxsim' not in args.embed_type:
        embeddings = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embed = post_process(embed)
                embeddings[word] = embed
        if do_PCA:
            embeddings = get_PCA(embeddings)
        return embeddings
    else:
        embeddings = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embeddings[word] = embed
        return embeddings

def get_concat_embeddings(args, unique_words, lang, embed_type, do_PCA=True):
    config = get_config()
    methods = embed_type.split("~")
    embed_dirs = {}
    for method in methods:
        dir_ = os.path.join(config.directories.word_vectors, method, lang)
        embed_dirs[method] = dir_
    """Load the embeddings for each method ahead of time"""
    method_embeds = {}
    for method, directory in embed_dirs.items():
        embedding_files = collect_files(directory)
        method_embeds[method] = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embed = post_process(embed)
                method_embeds[method][word] = embed
    embeddings = {}
    for word in unique_words:
        counter = 0
        embedding = None
        for method, data_dict in method_embeds.items():
            new_vec = data_dict[word]
            if counter == 0:
                embedding = new_vec
            else:
                embedding = np.concatenate((embedding, new_vec))
            counter += 1
        embedding = post_process(embedding)
        embeddings[word] = embedding
    if do_PCA:
        embeddings = get_PCA(embeddings)
    return embeddings

def spearman_rank_correlation(LSIM_pairs, embeddings, rank_method):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    saved_similarity_scores = {}
    for key, dictionary in LSIM_pairs.items():
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            embedding1 = embeddings[word1]
            embedding2 = embeddings[word2]
            embed_similarity = 1 - cosine(embedding1, embedding2)
            dictionary['embed_similarity'] = embed_similarity
            multisimlex_list.append(dictionary)
            ground_truth_scores.append(dictionary['total_score'])
            our_scores.append(embed_similarity)
            save_key = word1 + "_" + word2
            saved_similarity_scores[save_key] = {"ground_truth": dictionary['total_score'],
                                                 "our_score": embed_similarity}
        except:
            no_embedding_count += 1
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    print(str(completed_pairs) + " / " + str(total_pairs) + " pairs had embeddings for both words.")
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
    spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    return spearman, completed_pairs, saved_similarity_scores

def spearman_rank_correlation_most_similar(LSIM_pairs, embeddings, rank_method):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    saved_similarity_scores = {}
    min_temp_embed_similarity = 10000
    for key, dictionary in tqdm(LSIM_pairs.items()):
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            embedding1 = embeddings[word1]
            embedding2 = embeddings[word2]
            """Check all pairwise similarities and pick the biggest one"""
            embed_similarity = -10000
            valid_embed = False
            for emb1 in embedding1:
                for emb2 in embedding2:
                    try:
                        temp_embed_similarity = 1 - cosine(emb1, emb2)
                        valid_embed = True
                        if temp_embed_similarity < min_temp_embed_similarity:
                            min_temp_embed_similarity = temp_embed_similarity
                        if temp_embed_similarity > embed_similarity:
                            embed_similarity = temp_embed_similarity
                    except:
                        """"""
                        stop = None
            # if embed_similarity > 0:
            if valid_embed:
                dictionary['embed_similarity'] = embed_similarity
                multisimlex_list.append(dictionary)
                ground_truth_scores.append(dictionary['total_score'])
                our_scores.append(embed_similarity)
                save_key = word1 + "_" + word2
                saved_similarity_scores[save_key] = {"ground_truth": dictionary['total_score'],
                                                     "our_score": embed_similarity}
        except:
            no_embedding_count += 1
    # print(str(no_embedding_count) + ' words had bad embeddings...')
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
    spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    print("Minimum similarity between any vectors was " + str(min_temp_embed_similarity))
    dummy_our_ranks = list(set(our_ranks))

    return spearman, completed_pairs, saved_similarity_scores

def check_overlap_words(args, unique_words, lang):
    config = get_config()
    methods = args.method_word_intersections.split("~")
    filtered_words = unique_words
    for method in methods:
        word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
        method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
        filtered_words = list(set(method_words).intersection(set(filtered_words)))
    return filtered_words

def filter_word_pairs_POS(word_pairs, args):
    new_word_pairs = {}
    unique_words = []
    for key, info in word_pairs.items():
        if info['POS_tag'] == args.POS_type:
            new_word_pairs[key] = info
            unique_words.append(info['word1'])
            unique_words.append(info['word2'])
    return new_word_pairs, unique_words

def main(args):
    # args = edict(vars(args))

    # file1 = open("/ws/ifp-54_2/USER/USER/Fall_2021/Fa21CLASS/colex_from_BabelNet_combined/colex_from_BabelNet_combined.edgelist", 'r')
    # Lines1 = file1.readlines()
    # old_pairwise = {}
    # for line in tqdm(Lines1):
    #     info = line.split(" ")
    #     edge = info[0] + "_" + info[1]
    #     weight = info[2]
    #     old_pairwise[edge] = weight
    #
    # file2 = open(
    #     "/ws/ifp-54_2/USER/USER/Fall_2021/NAACL2022/colex_from_BabelNet/cross_colex_pairwise_product.edgelist",
    #     'r')
    # Lines2 = file2.readlines()
    # new_pairwise = {}
    # for line in tqdm(Lines2):
    #     info = line.split(" ")
    #     edge = info[0] + "_" + info[1]
    #     weight = info[2]
    #     new_pairwise[edge] = weight
    # """We don't have the same syn2id and id2syn!!! So we can't directly compare"""
    # for edge, weight in tqdm(old_pairwise.items()):
    #     # if edge in new_pairwise:
    #     new_weight = new_pairwise[edge]
    #     assert new_weight == weight
    #     stop = None

    """We use 'sum' option from now on, it's the best."""
    languages = args.languages
    languages = languages.split("_")
    results = {}
    all_sim_scores = {}
    for lang in languages:
        word_pairs, unique_words = get_multisimlex(lang)
        if args.POS_type != '':
            """Filter word_pairs by POS_type!"""
            word_pairs, unique_words = filter_word_pairs_POS(word_pairs, args)
        unique_words = check_overlap_words(args, unique_words, lang)
        # print(len(word_pairs))
        if "~" not in args.embed_type:  # not a concatenation type
            embeddings = get_embeddings(args, unique_words, lang=lang, embed_type=args.embed_type, do_PCA=args.PCA)
        else:
            embeddings = get_concat_embeddings(args, unique_words, lang=lang, embed_type=args.embed_type, do_PCA=args.PCA)
        print(len(embeddings))
        """Compute spearman rank for each layer number in each method"""
        # if args.embed_type != "cross_colex_sum_senses_PCA" and args.embed_type != "cross_colex_sum_senses_no_PCA":
        if 'maxsim' not in args.embed_type:
            score, num_completed_pairs, sim_scores = spearman_rank_correlation(LSIM_pairs=word_pairs,
                                                                               embeddings=embeddings,
                                                                               rank_method=args.rank_method)
        else:
            score, num_completed_pairs, sim_scores = spearman_rank_correlation_most_similar(LSIM_pairs=word_pairs,
                                                                                embeddings=embeddings,
                                                                                rank_method=args.rank_method)
        print(lang + ": " + str(score))
        local_result = {"spearman_rank_corr": score, "num_pairs": num_completed_pairs}
        results[lang] = local_result
        all_sim_scores[lang] = sim_scores
    if args.results_save_path != '':
        config = get_config()
        os.makedirs(config.directories.results, exist_ok=True)
        save_path = os.path.join(config.directories.results, args.results_save_path)
        if "SIMSCORE" not in args.results_save_path:
            dump(results, save_path)
        else:
            dump(all_sim_scores, save_path)





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='ar_en_es_fi_fr_he_pl_ru_zh')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--embed_type', type=str, default='colex_limited_langs200')  # binary and sum much better than pairwise product!!!
    parser.add_argument('--rank_method', type=str, default='average')
    parser.add_argument('--results_save_path', type=str, default='LangInventory~colex_limited_langs200.pkl')  # LSIM_cross_colex_binary_5.pkl
    parser.add_argument('--method_word_intersections', type=str, default='colex_all~fasttext~BERT~ARES~colex_all_maxsim~colex_mono')  # colex_all~fasttext~BERT~ARES~colex_all_maxsim~colex_mono # only words that have valid embeddings for all these methods, # fasttext~BERT~ARES~colex_sum_10~cross_colex_binary_10~cross_colex_sum_10~cross_colex_sum_10_maxsim OLD--># cross_colex_sum~fasttext~BERT~cross_colex_sum_senses_PCA~colex_sum~ARES, lang overlap --> cross_colex_binary_5~cross_colex_binary_limited_langs5_5~cross_colex_binary_limited_langs10_5~cross_colex_binary_limited_langs20_5~cross_colex_binary_limited_langs50_5, # now colex_all~fasttext~BERT~ARES~colex_all_maxsim~colex_mono, # colex_all~colex_limited_langs5~colex_limited_langs10~colex_limited_langs20~colex_limited_langs50
    parser.add_argument('--PCA', type=str2bool, default=True)
    parser.add_argument('--use_gpu', type=str2bool, default=True)
    parser.add_argument('--POS_type', type=str, default='')  # nouns, adjectives, verbs, adverbs
    args = parser.parse_args()
    main(args)


