import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import math
import matplotlib.pyplot as plt

def post_process(embed, mean_center=True, normalize=True):
    if normalize:
        embed = embed / np.linalg.norm(embed)
    if mean_center:
        embed = embed - np.mean(embed)
    if normalize:
        embed = embed / np.linalg.norm(embed)
    return embed

def get_PCA_(x, num_components):
    """Shape is batch, feat"""
    pca = PCA(n_components=num_components, svd_solver='full')
    new_x = pca.fit_transform(x)
    return new_x

def get_PCA(embeddings):
    """"""
    vectors = []
    for word, embed in embeddings.items():
        vectors.append(embed)
    """Let's replace the vectors with first N PCA components"""
    vectors = np.asarray(vectors)
    vectors = get_PCA_(vectors, num_components=vectors.shape[1])
    count = 0
    new_embeds = {}
    for word, embed in embeddings.items():
        new_embeds[word] = vectors[count]
        count += 1
    return new_embeds

def get_embeddings(args, unique_words, lang, embed_type, do_PCA=True):
    config = get_config()
    embed_dir = None
    if embed_type == 'SensEmBERT':
        embed_dir = os.path.join(config.directories.word_vectors, 'SensEmBERT', lang)
    if embed_type == 'fasttext':
        embed_dir = os.path.join(config.directories.word_vectors, 'fasttext', lang)
    if embed_type == 'BERT':
        embed_dir = os.path.join(config.directories.word_vectors, 'BERT', lang)
    if embed_type == 'cross_colex_binary':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_binary', lang)
    if embed_type == 'cross_colex_sum':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum', lang)
    if embed_type == 'cross_colex_pairwise_product':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_pairwise_product', lang)
    if embed_type == 'colex_sum':
        embed_dir = os.path.join(config.directories.word_vectors, 'colex_sum', lang)
    if embed_type == 'cross_colex_sum_9':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_9', lang)
    if embed_type == 'cross_colex_sum_50':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_50', lang)
    if embed_type == 'cross_colex_sum_senses_PCA':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_senses_PCA', lang)
    if embed_type == 'cross_colex_sum_senses_no_PCA':
        embed_dir = os.path.join(config.directories.word_vectors, 'cross_colex_sum_senses_no_PCA', lang)
    embedding_files = collect_files(embed_dir)
    if args.embed_type != "cross_colex_sum_senses_PCA" and args.embed_type != "cross_colex_sum_senses_no_PCA":
        embeddings = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embed = post_process(embed)
                embeddings[word] = embed
        if do_PCA:
            embeddings = get_PCA(embeddings)
        return embeddings
    else:
        embeddings = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embeddings[word] = embed
        return embeddings

def get_concat_embeddings(args, unique_words, lang, embed_type, do_PCA=True):
    config = get_config()
    methods = embed_type.split("~")
    embed_dirs = {}
    for method in methods:
        dir_ = os.path.join(config.directories.word_vectors, method, lang)
        embed_dirs[method] = dir_
    """Load the embeddings for each method ahead of time"""
    method_embeds = {}
    for method, directory in embed_dirs.items():
        embedding_files = collect_files(directory)
        method_embeds[method] = {}
        for file in embedding_files:
            word = file.split('/')[-1].split('.')[0]
            if word in unique_words:
                embed = load(file)
                embed = post_process(embed)
                method_embeds[method][word] = embed
    embeddings = {}
    for word in unique_words:
        counter = 0
        embedding = None
        for method, data_dict in method_embeds.items():
            new_vec = data_dict[word]
            if counter == 0:
                embedding = new_vec
            else:
                embedding = np.concatenate((embedding, new_vec))
            counter += 1
        embedding = post_process(embedding)
        embeddings[word] = embedding
    if do_PCA:
        embeddings = get_PCA(embeddings)
    return embeddings

def spearman_rank_correlation(LSIM_pairs, embeddings, rank_method):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    for key, dictionary in LSIM_pairs.items():
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            embedding1 = embeddings[word1]
            embedding2 = embeddings[word2]
            embed_similarity = 1 - cosine(embedding1, embedding2)
            dictionary['embed_similarity'] = embed_similarity
            multisimlex_list.append(dictionary)
            ground_truth_scores.append(dictionary['total_score'])
            our_scores.append(embed_similarity)
        except:
            no_embedding_count += 1
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    print(str(completed_pairs) + " / " + str(total_pairs) + " pairs had embeddings for both words.")
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
    spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    return spearman, completed_pairs

def rank_difference_vs_word_frequency(LSIM_pairs, embeddings, rank_method, word_frequencies):
    """Now we want to get the cosine distance between the embeddings for each pair"""
    multisimlex_list = []
    ground_truth_scores = []
    our_scores = []
    no_embedding_count = 0
    log_freq_means = []
    for key, dictionary in LSIM_pairs.items():
        try:
            word1 = dictionary['word1']
            word2 = dictionary['word2']
            """Get mean log word frequency"""
            word1_freq = word_frequencies[word1]
            word2_freq = word_frequencies[word2]
            freq_log_mean = np.mean([math.log10(word1_freq), math.log10(word2_freq)])
            # freq_log_mean = np.max((math.log10(word1_freq), math.log10(word2_freq)))

            embedding1 = embeddings[word1]
            embedding2 = embeddings[word2]
            embed_similarity = 1 - cosine(embedding1, embedding2)
            dictionary['embed_similarity'] = embed_similarity
            multisimlex_list.append(dictionary)
            ground_truth_scores.append(dictionary['total_score'])
            our_scores.append(embed_similarity)
            log_freq_means.append(freq_log_mean)
        except:
            no_embedding_count += 1
        num_annotators = len(dictionary['annotator_scores'])
    max_annotator_score = 6.0 * num_annotators
    total_pairs = len(LSIM_pairs)
    completed_pairs = total_pairs - no_embedding_count
    print(str(completed_pairs) + " / " + str(total_pairs) + " pairs had embeddings for both words.")
    ground_truth_scores = np.asarray(ground_truth_scores)
    our_scores = np.asarray(our_scores)
    ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=rank_method)
    our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)

    ground_truth_scaled_scores = ground_truth_scores / max_annotator_score  # scales from range 0 to 1
    our_scores_scaled_scores = (our_scores + 1.0) / 2.0  # scales from range 0 to 1
    rank_diff = list(np.squeeze(np.abs(ground_truth_ranks - our_ranks)))
    score_diff = list(np.squeeze(np.abs(ground_truth_scaled_scores - our_scores_scaled_scores)))
    # p = scipy.stats.pearsonr(rank_diff, log_freq_means)
    p = scipy.stats.pearsonr(score_diff, log_freq_means)
    # plt.scatter(log_freq_means, rank_diff)
    # plt.show()
    # spearman = scipy.stats.spearmanr(ground_truth_ranks, our_ranks).correlation
    return p, completed_pairs

def get_word_frequencies(lang, unique_words):
    config = get_config()
    """Load the file into a dictionary with word and frequency as key and value, respectively"""
    if lang == "zh":
        file_lang_name = 'zh_cn'
    else:
        file_lang_name = lang
    freq_file = os.path.join(config.directories.FrequencyWords, file_lang_name, file_lang_name + "_full.txt")
    file1 = open(freq_file, 'r')
    Lines = file1.readlines()
    freq_dict = {}
    max_value = -1
    for line in tqdm(Lines):
        line = line.replace('\n', '')
        data = line.split(" ")
        word = data[0]
        total_count = int(data[1])
        if total_count > max_value:
            max_value = total_count
        freq_dict[word] = total_count
    """Normalize the elements by dividing by the max occurence."""
    max_value = float(max_value)
    for key, value in tqdm(freq_dict.items()):
        value = float(value)
        value = value / max_value
        freq_dict[key] = value
    freq_words = {}
    for word in unique_words:
        if word in freq_dict:
            freq_words[word] = freq_dict[word]
    freq_words = dict(sorted(freq_words.items(), key=lambda item: item[1], reverse=True))  # sort by most frequent words
    new_freq_words = {}
    new_unique_words = unique_words
    count = 0
    # new_unique_words = []
    # for key, value in freq_words.items():
    #     if count > 800:
    #         new_freq_words[key] = value
    #         new_unique_words.append(key)
    #     count += 1
    # freq_words = new_freq_words
    return freq_words, new_unique_words

def check_overlap_words(args, unique_words, lang):
    config = get_config()
    methods = args.method_word_intersections.split("~")
    filtered_words = unique_words
    for method in methods:
        word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
        method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
        filtered_words = list(set(method_words).intersection(set(filtered_words)))
    return filtered_words

def main(args):
    """We use 'sum' option from now on, it's the best."""
    languages = args.languages
    languages = languages.split("_")
    results = {}

    for lang in languages:
        word_pairs, unique_words = get_multisimlex(lang)
        unique_words = check_overlap_words(args, unique_words, lang)
        """First grab the word frequency list for the language"""
        word_frequencies, unique_words = get_word_frequencies(lang, unique_words)
        # print(len(word_pairs))
        if "~" not in args.embed_type:  # not a concatenation type
            embeddings = get_embeddings(args, unique_words, lang=lang, embed_type=args.embed_type, do_PCA=args.PCA)
        else:
            embeddings = get_concat_embeddings(args, unique_words, lang=lang, embed_type=args.embed_type, do_PCA=args.PCA)
        print(len(embeddings))
        """Spearman code"""
        # score, num_completed_pairs = rank_difference_vs_word_frequency(LSIM_pairs=word_pairs,
        #                                                                embeddings=embeddings,
        #                                                                rank_method=args.rank_method,
        #                                                                word_frequencies=word_frequencies)
        score, num_completed_pairs = spearman_rank_correlation(LSIM_pairs=word_pairs,
                                                               embeddings=embeddings,
                                                               rank_method=args.rank_method)

        print(lang + ": " + str(score))
        local_result = {"spearman_rank_corr": score, "num_pairs": num_completed_pairs}
        results[lang] = local_result
    if args.results_save_path != '':
        config = get_config()
        os.makedirs(config.directories.results, exist_ok=True)
        save_path = os.path.join(config.directories.results, args.results_save_path)
        dump(results, save_path)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='ar_en_es_fi_fr_he_pl_ru_zh')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--embed_type', type=str, default='cross_colex_sum')  # binary and sum much better than pairwise product!!!
    parser.add_argument('--rank_method', type=str, default='average')
    parser.add_argument('--results_save_path', type=str, default='')
    parser.add_argument('--method_word_intersections', type=str, default='cross_colex_sum~fasttext~BERT')  # only words that have valid embeddings for all these methods
    parser.add_argument('--PCA', type=str2bool, default=True)
    parser.add_argument('--use_gpu', type=str2bool, default=True)
    args = parser.parse_args()
    main(args)


