import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

def process_weird_spaces(word):
    """Dumb stuff to fix basic parsing stuff for multiple-word words"""
    split_word = word.split(" ")
    word_count = 0
    for chunk in split_word:
        if chunk != '':
            word_count += 1
    if word_count > 1:
        word = word.replace(" ", "~")
    else:
        word = word.replace(" ", "")
    return word

def main(args):
    config = get_config()
    os.makedirs(config.directories.vecmap, exist_ok=True)
    languages = args.languages.split("_")
    """Step 1: export supervised translation dictionaries for evaluation words"""
    num_lang_pairs = 0
    for i, src in enumerate(languages):
        for tgt in languages[i+1:]:
            translation_dictionary = {}
            """Load multisimlex for both languages"""
            multisimlex_src = get_multisimlex(src)[0]
            multisimlex_tgt = get_multisimlex(tgt)[0]
            assert len(multisimlex_tgt) == len(multisimlex_src)
            for key, src_word_pair in multisimlex_src.items():
                tgt_word_pair = multisimlex_tgt[key]  # I manually verified with en and es that they are already in order!!!
                """Replace spaces in the words with '~' delimiter for use with vecmap code"""
                # src_word1 = src_word_pair['word1'].replace(" ", "~")
                # src_word2 = src_word_pair['word2'].replace(" ", "~")
                # tgt_word1 = tgt_word_pair['word1'].replace(" ", "~")
                # tgt_word2 = tgt_word_pair['word2'].replace(" ", "~")
                src_word1 = process_weird_spaces(src_word_pair['word1'])
                src_word2 = process_weird_spaces(src_word_pair['word2'])
                tgt_word1 = process_weird_spaces(tgt_word_pair['word1'])
                tgt_word2 = process_weird_spaces(tgt_word_pair['word2'])
                assert " " not in src_word1 and " " not in src_word2 and " " not in tgt_word1 and " " not in tgt_word2
                if src_word1 not in translation_dictionary:
                    translation_dictionary[src_word1] = tgt_word1
                if src_word2 not in translation_dictionary:
                    translation_dictionary[src_word2] = tgt_word2
            """Save the dictionary as a SPACE-delimited text file with .DICT file ending."""
            lines = []
            for src_word, tgt_word in translation_dictionary.items():
                line = src_word + " " + tgt_word
                lines.append(line)
            dump_path = os.path.join(config.directories.vecmap, src + "_" + tgt + ".DICT")
            utils.write_file_from_list(lines, dump_path)
            print(src + " " + tgt)
            num_lang_pairs += 1
    print(num_lang_pairs)


    """Step 2: export text files with embeddings from saved .pkl files (only use fasttext for this!!!)"""
    for i, src in enumerate(languages):
        for tgt in languages[i+1:]:
            """Load all the embeddings for each language and write as a text file"""
            lang_prefix = src + "_" + tgt
            src_emb_dir = os.path.join(config.directories.word_vectors, args.embed_type, src)
            src_emb_files = collect_files(src_emb_dir)
            tgt_emb_dir = os.path.join(config.directories.word_vectors, args.embed_type, tgt)
            tgt_emb_files = collect_files(tgt_emb_dir)

            """Save embeddings as text files"""
            for fileset in [src_emb_files, tgt_emb_files]:
                lines = []
                """Need to add the header that says how many files there are and what dimension the embeddings are!!!"""
                dummy_emb = utils.load(fileset[0])
                header = str(len(fileset)) + " " + str(len(dummy_emb))  # fasttext embeddings are of dimension 300
                lines.append(header)
                for file in fileset:
                    emb = utils.load(file)
                    word = file.split("/")[-1].split(".")[0]
                    word = process_weird_spaces(word)
                    assert " " not in word
                    # """Dumb stuff to fix basic parsing stuff for multiple-word words"""
                    # split_word = word.split(" ")
                    # word_count = 0
                    # for chunk in split_word:
                    #     if chunk != '':
                    #         word_count += 1
                    # if word_count > 1:
                    #     word = word.replace(" ", "~")
                    # else:
                    #     word = word.replace(" ", "")
                    """Resume the regular code here."""
                    emb_as_text = ""
                    for x in emb:
                        emb_as_text += str(x) + " "
                    line = word + " " + emb_as_text
                    lines.append(line)
                suffix = None
                if fileset == src_emb_files:
                    suffix = "_SRC.EMB"
                elif fileset == tgt_emb_files:
                    suffix = "_TGT.EMB"
                dump_path = os.path.join(config.directories.vecmap, lang_prefix + suffix)
                utils.write_file_from_list(lines, dump_path)
            print(lang_prefix)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='ar_en_es_fi_fr_he_pl_ru_zh')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--embed_type', type=str, default='fasttext')  # binary and sum much better than pairwise product!!!
    args = parser.parse_args()
    main(args)