import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

def post_process(embed, mean_center=False, normalize=True):
    # we had mean_center=True, but changed to False for latest experiments
    if normalize:
        embed = embed / np.linalg.norm(embed)
    if mean_center:
        embed = embed - np.mean(embed)
    if normalize:
        embed = embed / np.linalg.norm(embed)
    return embed

def get_PCA_(x, num_components, mean_center=True):
    """Shape is batch, feat"""
    if mean_center:
        mean = np.mean(x, axis=0)
        x = x - mean[None, :]
    # dum_mean = np.mean(x, axis=0)
    pca = PCA(n_components=num_components, svd_solver='full')
    new_x = pca.fit_transform(x)
    return new_x

def get_PCA(embeddings):
    """"""
    vectors = []
    for word, embed in embeddings.items():
        vectors.append(embed)
    """Let's replace the vectors with first N PCA components"""
    vectors = np.asarray(vectors)
    num_components = np.min((vectors.shape[1], vectors.shape[0]))
    # vectors = get_PCA_(vectors, num_components=vectors.shape[1])
    vectors = get_PCA_(vectors, num_components=num_components)
    count = 0
    new_embeds = {}
    for word, embed in embeddings.items():
        new_embeds[word] = vectors[count]
        count += 1
    return new_embeds

def update_keys(embs, lang, add=True):
    new_embs = {}
    for word, emb in embs.items():
        if add:
            new_word = word + "_" + lang
        else:
            new_word = word.split("_")[0]  # remove the language code
        new_embs[new_word] = emb
    return new_embs

def split_combined_emb_dict(embs, src, tgt):
    src_dict = {}
    tgt_dict = {}
    for word_key, emb in embs.items():
        word = word_key.split("_")[0]
        lang = word_key.split("_")[1]
        if lang == src:
            src_dict[word] = emb
        elif lang == tgt:
            tgt_dict[word] = emb
    return src_dict, tgt_dict

def get_embeddings(args, unique_words_src, unique_words_tgt, data_src, data_tgt, embed_type, do_PCA=True):
    config = get_config()
    # embed_dir = None
    if embed_type == 'fasttext':
        lang_pair1 = data_src + "_" + data_tgt
        lang_pair2 = data_tgt + "_" + data_src
        dict1_file = os.path.join(config.directories.vecmap, lang_pair1 + '.DICT')
        dict2_file = os.path.join(config.directories.vecmap, lang_pair2 + '.DICT')
        src_mapped_file = None
        tgt_mapped_file = None
        if os.path.exists(dict1_file):
            """Load the MAPPED file for each language and read in embeddings, flip if the data_src is different
               from the src we made when doing vecmap (one or the other)"""
            src_mapped_file = os.path.join(config.directories.vecmap, data_src + "_" + data_tgt + "_SRC_MAPPED.EMB")
            tgt_mapped_file = os.path.join(config.directories.vecmap, data_src + "_" + data_tgt + "_TGT_MAPPED.EMB")
            # src, tgt = data_src, data_tgt
            # unique_words_src, unique_words_tgt = unique_words_src, unique_words_tgt
        elif os.path.exists(dict2_file):
            """Load the MAPPED file for each language and read in embeddings"""
            src_mapped_file = os.path.join(config.directories.vecmap, data_tgt + "_" + data_src + "_TGT_MAPPED.EMB")
            tgt_mapped_file = os.path.join(config.directories.vecmap, data_tgt + "_" + data_src + "_SRC_MAPPED.EMB")
        src, tgt = data_src, data_tgt
        unique_words_src, unique_words_tgt = unique_words_src, unique_words_tgt
        src_embs = {}
        tgt_embs = {}
        for file in [src_mapped_file, tgt_mapped_file]:
            file1 = open(file, 'r')
            Lines = file1.readlines()
            for line in tqdm(Lines[1:]):  # skip header (first line)
                """"""
                pieces = line.split(" ")
                word = pieces[0]
                emb_values = pieces[1:]
                emb = np.asarray([float(x) for x in emb_values])
                """Remember to remove the '~' delimiter!!!"""
                # if '~' in word:
                #     stop = None
                word = word.replace("~", " ")
                if file == src_mapped_file and word in unique_words_src:
                    src_embs[word] = post_process(emb)
                elif file == tgt_mapped_file and word in unique_words_tgt:
                    tgt_embs[word] = post_process(emb)

    else:
    # if embed_type == 'cross_colex_sum':
        src = data_src
        tgt = data_tgt
        # src_embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum', data_src)
        # tgt_embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum', data_tgt)
        src_embed_dir = os.path.join(config.directories.word_vectors, args.embed_type, data_src)
        tgt_embed_dir = os.path.join(config.directories.word_vectors, args.embed_type, data_tgt)
        """Grab the embeddings from disk (remember, no adaptation needed for this approach)."""
        src_embedding_files = collect_files(src_embed_dir)
        tgt_embedding_files = collect_files(tgt_embed_dir)
        src_embs = {}
        tgt_embs = {}
        for file in src_embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words_src:
                embed = load(file)
                embed = post_process(embed)
                src_embs[word] = embed
        for file in tgt_embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words_tgt:
                embed = load(file)
                embed = post_process(embed)
                tgt_embs[word] = embed

    """Add the language code to the words in case there is overlap in spelling between two languages, then
               combine embedding dictionaries!"""
    src_embs = update_keys(src_embs, src, add=True)
    tgt_embs = update_keys(tgt_embs, tgt, add=True)
    src_embs.update(tgt_embs)
    embeddings = src_embs
    if do_PCA:
        embeddings = get_PCA(embeddings)
    """Remove the language code from words"""
    src_embs, tgt_embs = split_combined_emb_dict(embeddings, src=src, tgt=tgt)
    return src_embs, tgt_embs, src, tgt

def get_concat_embeddings(args, unique_words, lang, embed_type, do_PCA=True):
    config = get_config()
    methods = embed_type.split("~")
    embed_dirs = {}
    for method in methods:
        dir_ = os.path.join(config.directories.word_vectors, method, lang)
        embed_dirs[method] = dir_
    """Load the embeddings for each method ahead of time"""
    method_embeds = {}
    for method, directory in embed_dirs.items():
        embedding_files = collect_files(directory)
        method_embeds[method] = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embed = post_process(embed)
                method_embeds[method][word] = embed
    embeddings = {}
    for word in unique_words:
        counter = 0
        embedding = None
        for method, data_dict in method_embeds.items():
            new_vec = data_dict[word]
            if counter == 0:
                embedding = new_vec
            else:
                embedding = np.concatenate((embedding, new_vec))
            counter += 1
        embedding = post_process(embedding)
        embeddings[word] = embedding
    if do_PCA:
        embeddings = get_PCA(embeddings)
    return embeddings

def spearman_rank_correlation(LSIM_pairs, src_embs, tgt_embs, rank_method):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    for key, dictionary in LSIM_pairs.items():
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            embedding1 = src_embs[word1]  # order is src, tgt!!!
            embedding2 = tgt_embs[word2]
            embed_similarity = 1 - cosine(embedding1, embedding2)
            dictionary['embed_similarity'] = embed_similarity
            multisimlex_list.append(dictionary)
            ground_truth_scores.append(dictionary['total_score'])
            our_scores.append(embed_similarity)
        except:
            no_embedding_count += 1
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    print(str(completed_pairs) + " / " + str(total_pairs) + " pairs had embeddings for both words.")
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
    spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    return spearman, completed_pairs, total_pairs

def spearman_rank_correlation_most_similar(LSIM_pairs, embeddings, rank_method):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    min_temp_embed_similarity = 10000
    for key, dictionary in tqdm(LSIM_pairs.items()):
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            embedding1 = embeddings[word1]
            embedding2 = embeddings[word2]
            """Check all pairwise similarities and pick the biggest one"""
            embed_similarity = 0
            for emb1 in embedding1:
                for emb2 in embedding2:
                    try:
                        temp_embed_similarity = 1 - cosine(emb1, emb2)
                        if temp_embed_similarity < min_temp_embed_similarity:
                            min_temp_embed_similarity = temp_embed_similarity
                        if temp_embed_similarity > embed_similarity:
                            embed_similarity = temp_embed_similarity
                    except:
                        """"""
            if embed_similarity > 0:
                dictionary['embed_similarity'] = embed_similarity
                multisimlex_list.append(dictionary)
                ground_truth_scores.append(dictionary['total_score'])
                our_scores.append(embed_similarity)
        except:
            no_embedding_count += 1
    # print(str(no_embedding_count) + ' words had bad embeddings...')
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
    spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    print("Minimum similarity between any vectors was " + str(min_temp_embed_similarity))
    return spearman, completed_pairs, total_pairs

def check_overlap_words(args, unique_words, lang):
    config = get_config()
    methods = args.method_word_intersections.split("~")
    filtered_words = unique_words
    for method in methods:
        word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
        method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
        filtered_words = list(set(method_words).intersection(set(filtered_words)))
    return filtered_words

def filter_word_pairs_POS(word_pairs, args):
    new_word_pairs = {}
    unique_words = []
    for key, info in word_pairs.items():
        if info['POS_tag'] == args.POS_type:
            new_word_pairs[key] = info
            unique_words.append(info['word1'])
            unique_words.append(info['word2'])
    return new_word_pairs, unique_words

def main(args):
    """We use 'sum' option from now on, it's the best."""
    languages = args.languages
    languages = languages.split("_")
    results = {}
    oov_data = {}
    for i, src in enumerate(languages):
        for tgt in languages[i+1:]:
            original_src, original_tgt = src, tgt
            lang_pair = src + "_" + tgt
            word_pairs, unique_words_src, unique_words_tgt, data_src, data_tgt = get_crosslingual_multisimlex(lang_pair)
            lang_pair = data_src + "_" + data_tgt
            if word_pairs != None:
                if args.POS_type != '':
                    """Filter word_pairs by POS_type!"""
                    word_pairs, unique_words = filter_word_pairs_POS(word_pairs, args)
                unique_words_src = check_overlap_words(args, unique_words_src, data_src)
                unique_words_tgt = check_overlap_words(args, unique_words_tgt, data_tgt)
                # if src == original_tgt:
                #     src, tgt, unique_words_src, unique_words_tgt = tgt, src, unique_words_tgt, unique_words_src
                # print(len(word_pairs))
                src_embs, tgt_embs, new_src, new_tgt = get_embeddings(args, unique_words_src, unique_words_tgt, data_src=data_src, data_tgt=data_tgt, embed_type=args.embed_type, do_PCA=args.PCA)
                # if src == new_tgt:
                #     src, tgt, unique_words_src, unique_words_tgt = tgt, src, unique_words_tgt, unique_words_src
                print(len(src_embs))
                print(len(tgt_embs))
                """Compute spearman rank for each layer number in each method"""
                score, num_completed_pairs, num_total_pairs = spearman_rank_correlation(LSIM_pairs=word_pairs,
                                                                                        src_embs=src_embs,
                                                                                        tgt_embs=tgt_embs,
                                                                                        rank_method=args.rank_method)
                print(lang_pair + ": " + str(score))
                local_result = {"spearman_rank_corr": score, "num_pairs": num_completed_pairs}
                results[lang_pair] = local_result
                oov_data[lang_pair] = {"total_pairs": num_total_pairs, "completed_pairs": num_completed_pairs}
    if args.results_save_path != '':
        config = get_config()
        os.makedirs(config.directories.results, exist_ok=True)
        save_path = os.path.join(config.directories.results, args.results_save_path)
        if "OOV" not in args.results_save_path:
            dump(results, save_path)
        else:
            dump(oov_data, save_path)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='en_es_fi_fr_he_pl_ru_zh')  # en_es_fi_fr_he_pl_ru_zh, no Arabic unfortunately
    parser.add_argument('--embed_type', type=str, default='colex_all')
    parser.add_argument('--rank_method', type=str, default='average')
    parser.add_argument('--results_save_path', type=str, default='CROSSLINGUAL_OOV~colex_all.pkl')  # CROSSLINGUAL~fasttext.pkl, CROSSLINGUAL_OOV~fasttext.pkl
    parser.add_argument('--method_word_intersections', type=str, default='colex_all~fasttext')  # only words that have valid embeddings for all these methods
    parser.add_argument('--PCA', type=str2bool, default=True)
    parser.add_argument('--use_gpu', type=str2bool, default=True)
    parser.add_argument('--POS_type', type=str, default='')  # nouns, adjectives, verbs, adverbs
    args = parser.parse_args()
    main(args)


