import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import SensEmBERT
import ARES

def get_SensEmBERT_embeddings(wordlist, lang):
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, "SensEmBERT", lang)
    os.makedirs(local_save_folder, exist_ok=True)
    if lang != 'en':
        word2syn = SensEmBERT.load_BabelNet_synsets_for_language(lang)
    else:
        word2syn = None
    sense_embeds = SensEmBERT.load_sense_embeddings(lang=lang)
    print("Getting word embeddings from SensEmBERT sense embeddings and writing to disk...")
    for word in tqdm(wordlist):
        embed = SensEmBERT.get_word_embedding(word, lang, sense_embeddings=sense_embeds, BabelNetSyns=word2syn)
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
    print("Done.")

def get_ARES_embeddings(wordlist, lang):
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, "ARES", lang)
    os.makedirs(local_save_folder, exist_ok=True)
    word2syn = ARES.load_BabelNet_synsets_for_language(lang)
    sense_embeds = ARES.load_sense_embeddings(lang=lang, syn_PCA=args.sense_PCA)
    print("Getting word embeddings from SensEmBERT sense embeddings and writing to disk...")
    for word in tqdm(wordlist):
        embed = ARES.get_word_embedding(word, lang, sense_embeddings=sense_embeds, BabelNetSyns=word2syn)
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
    print("Done.")

def get_fasttext_embeddings(unique_words, lang):
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, "fasttext", lang)
    os.makedirs(local_save_folder, exist_ok=True)

    print('Getting fasttext vectors...')
    if lang == 'en':
        fasttext.util.download_model('en', if_exists='ignore')
        ft = fasttext.load_model('cc.en.300.bin')
    elif lang == 'zh':
        fasttext.util.download_model('zh', if_exists='ignore')
        ft = fasttext.load_model('cc.zh.300.bin')
    elif lang == 'ar':
        fasttext.util.download_model('ar', if_exists='ignore')
        ft = fasttext.load_model('cc.ar.300.bin')
    elif lang == 'es':
        fasttext.util.download_model('es', if_exists='ignore')
        ft = fasttext.load_model('cc.es.300.bin')
    elif lang == 'he':
        fasttext.util.download_model('he', if_exists='ignore')
        ft = fasttext.load_model('cc.he.300.bin')
    elif lang == 'fi':
        fasttext.util.download_model('fi', if_exists='ignore')
        ft = fasttext.load_model('cc.fi.300.bin')
    elif lang == 'fr':
        fasttext.util.download_model('fr', if_exists='ignore')
        ft = fasttext.load_model('cc.fr.300.bin')
    elif lang == 'pl':
        fasttext.util.download_model('pl', if_exists='ignore')
        ft = fasttext.load_model('cc.pl.300.bin')
    elif lang == 'ru':
        fasttext.util.download_model('ru', if_exists='ignore')
        ft = fasttext.load_model('cc.ru.300.bin')

    for word in tqdm(unique_words):
        embed = ft.get_word_vector(word)
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
    print("Done.")

def get_cross_colex_embeddings(wordlist, lang, args):
    # import colex
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    # args = edict({'exp_dir': 'colex_from_BabelNet',
    #               'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    # embeddings = colexAllBabelNet.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath,
    #                                        weight_type='inverse')
    embeddings = colexAllBabelNet.get_word_embeddings_nodevectors(args, wordlist, lang=lang,
                                                                  embed_filepath=embed_filepath,
                                                                  weight_type='inverse',
                                                                  model_type=args.NODE_MODEL_TYPE)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from cross_colex_sum synset embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_cross_colex_binary_embeddings(wordlist, lang):
    import colex
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, "cross_colex_binary", lang)
    args = edict({'exp_dir': 'colex_from_BabelNet',
                  'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    embed_filepath = os.path.join(args.exp_dir, 'cross_colex_binary.emb')
    print(len(wordlist))
    embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from cross_colex_binary sense embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_cross_colex_sum_embeddings(wordlist, lang, args):
    # import colex
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    # args = edict({'exp_dir': 'colex_from_BabelNet',
    #               'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    embeddings = colexAllBabelNet.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath,
                                           weight_type='inverse')
    # embeddings = colexAllBabelNet.get_word_embeddings_nodevectors(args, wordlist, lang=lang,
    #                                                               embed_filepath=embed_filepath,
    #                                                               weight_type='inverse',
    #                                                               model_type=args.NODE_MODEL_TYPE)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from cross_colex_sum synset embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_cross_colex_pairwise_product_embeddings(wordlist, lang):
    import colex
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, "cross_colex_pairwise_product", lang)
    args = edict({'exp_dir': 'colex_from_BabelNet',
                  'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    embed_filepath = os.path.join(args.exp_dir, 'cross_colex_pairwise_product.emb')
    print(len(wordlist))
    embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from cross_colex_pairwise_product sense embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_colex_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    # args = edict({'exp_dir': 'colex_from_BabelNet',
    #               'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    """We have to load a separate embedding file for each language for this approach"""
    assert "LANG" in args.embed_file
    embed_filename = args.embed_file.replace("LANG", lang)
    embed_filepath = os.path.join(args.exp_dir, embed_filename)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    embeddings = colexAllBabelNet.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath,
                                                      weight_type='inverse')
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from colex_sum sense embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_colex_sum_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    # args = edict({'exp_dir': 'colex_from_BabelNet',
    #               'word2syn_dir': 'wordsyns/word2syns_by_lang'})
    """We have to load a separate embedding file for each language for this approach"""
    embed_filename = args.embed_file.replace("LANG", lang)
    embed_filepath = os.path.join(args.exp_dir, embed_filename)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath, weight_type='inverse')
    embeddings = colexAllBabelNet.get_word_embeddings(args, wordlist, lang=lang, embed_filepath=embed_filepath,
                                                      weight_type='inverse')
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from colex_sum sense embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_cross_colex_maxsim_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings_sense(args, wordlist, lang=lang,
    #                                        embed_filepath=embed_filepath, weight_type='inverse', syn_PCA=args.sense_PCA)
    embeddings = colexAllBabelNet.get_word_embeddings_synset(args, wordlist, lang=lang, embed_filepath=embed_filepath,
                                                      weight_type='inverse', syn_PCA=args.sense_PCA)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Saving cross_colex_sum_synset embeddings to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_cross_colex_sum_maxsim_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    # embeddings = colex.get_word_embeddings_sense(args, wordlist, lang=lang,
    #                                        embed_filepath=embed_filepath, weight_type='inverse', syn_PCA=args.sense_PCA)
    embeddings = colexAllBabelNet.get_word_embeddings_synset(args, wordlist, lang=lang, embed_filepath=embed_filepath,
                                                      weight_type='inverse')
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Saving cross_colex_sum_synset embeddings to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_colex_all_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    embeddings = colexAllBabelNet.get_word_embeddings_nodevectors(args, wordlist, lang=lang,
                                                                  embed_filepath=embed_filepath,
                                                                  weight_type='inverse',
                                                                  model_type=args.NODE_MODEL_TYPE)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from cross_all synset embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_colex_mono_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    """We have to load a separate embedding file for each language for this approach"""
    assert "LANG" in args.embed_file
    embed_filename = args.embed_file.replace("LANG", lang)
    embed_filepath = os.path.join(args.exp_dir, embed_filename)
    print(len(wordlist))
    embeddings = colexAllBabelNet.get_word_embeddings_nodevectors(args, wordlist, lang=lang,
                                                                  embed_filepath=embed_filepath,
                                                                  weight_type='inverse',
                                                                  model_type=args.NODE_MODEL_TYPE)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Getting word embeddings from colex_sum sense embeddings and writing to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def get_colex_all_maxsim_embeddings(wordlist, lang, args):
    import colexAllBabelNet
    config = get_config()
    local_save_folder = os.path.join(config.directories.word_vectors, args.local_save_folder, lang)
    embed_filepath = os.path.join(args.exp_dir, args.embed_file)
    print(len(wordlist))
    embeddings = colexAllBabelNet.get_word_embeddings_synset_nodevectors(args, wordlist, lang=lang,
                                                                         embed_filepath=embed_filepath,
                                                                         weight_type='inverse',
                                                                         model_type=args.NODE_MODEL_TYPE,
                                                                         syn_PCA=args.sense_PCA,
                                                                         original_graph=args.original_graph)
    print(len(embeddings))
    os.makedirs(local_save_folder, exist_ok=True)
    print("Saving cross_colex_sum_synset embeddings to disk...")
    num_words_with_embeds = 0
    for word in tqdm(wordlist):
        embed = embeddings[word]
        if not embed is None:
            dump_path = os.path.join(local_save_folder, word + '.pkl')
            dump(embed, dump_path)
            num_words_with_embeds += 1
    print(num_words_with_embeds)
    print("Done.")

def main(args):
    # args = edict(vars(args))
    languages = args.languages
    languages = languages.split("_")
    config = get_config()
    if not os.path.isdir(config.directories.word_vectors):
        os.mkdir(config.directories.word_vectors)
    """First, get the set of words to grab vectors for based on the options"""
    for lang in languages:
        unique_words = None
        if args.eval_word_type == 'LSIM':
            word_pairs, unique_words = get_multisimlex(lang)
        """Next, get the vectors using the appropriate method"""
        if args.embed_type == 'SensEmBERT':
            get_SensEmBERT_embeddings(unique_words, lang)
        if args.embed_type == 'ARES':
            get_ARES_embeddings(unique_words, lang)
        elif args.embed_type == 'fasttext':
            get_fasttext_embeddings(unique_words, lang)
        elif args.embed_type == 'cross_colex_binary':
            get_cross_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'cross_colex_presence_absence':
            get_cross_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'cross_colex_sum':
            get_cross_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'cross_colex_pairwise_product':
            get_cross_colex_pairwise_product_embeddings(unique_words, lang)
        elif args.embed_type == 'colex_sum':
            get_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'cross_colex_sum_maxsim':
            get_cross_colex_maxsim_embeddings(unique_words, lang, args)
        elif args.embed_type == 'cross_colex_binary_maxsim':
            get_cross_colex_maxsim_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_binary':
            get_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_filtered_binary':
            get_colex_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_binary_all':
            get_cross_colex_embeddings(unique_words, lang, args)

        elif args.embed_type == 'colex_all':
            get_colex_all_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_limited_langs':
            get_colex_all_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_mono':
            get_colex_mono_embeddings(unique_words, lang, args)
        elif args.embed_type == 'colex_all_maxsim':
            get_colex_all_maxsim_embeddings(unique_words, lang, args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to get word embedding vectors')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='en_ar_es_fi_fr_he_pl_ru_zh')  # en_ar_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--embed_type', type=str, default='colex_limited_langs')
    parser.add_argument('--embed_file', type=str, default='colex_limited_langs200.zip')  # mono is colex_LANG.zip, maxsim is colex_all.zip
    parser.add_argument('--local_save_folder', type=str, default='colex_limited_langs200')
    parser.add_argument('--original_graph', type=str, default='colex_from_AllBabelNet_Concepts/colex_all.edgelist')  # Needed for maxsim with ProNE model for PCA!!!
    parser.add_argument('--NODE_MODEL_TYPE', type=str, default='ProNE')
    parser.add_argument('--exp_dir', type=str, default='colex_from_AllBabelNet_Concepts')
    parser.add_argument('--word2syn_dir', type=str, default='wordsyns/word2syns_by_lang')
    parser.add_argument('--sense_PCA', type=utils.str2bool, default=True)
    parser.add_argument('--synset_text_files_dir', type=str, default='synset_text_files')
    parser.add_argument('--synset_types_files_dir', type=str, default='BabelNet_Synset_Types')
    parser.add_argument('--lemma_savedir', type=str, default='lemmas')
    parser.add_argument('--lemma_synIDdir', type=str, default='lemmas_synID')
    args = parser.parse_args()
    main(args)