"""
   Big steps to code:
   (1) Get list of languages to check overlap with
   (2) Get all lemmas from each language and store as a list
   (3) Create syn2id and id2syn dictionaries
   (4) Build the graphs with different edge weight conditions
   """
import json
import os
import argparse
import matplotlib.pyplot as plt
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from easynmt import EasyNMT
from sys import getsizeof

def get_language_list(args):
    """"""
    files = collect_files(args.word2syn_dir)
    langs = []
    for file in files:
        lang_name = file.split('/')[-1].split('.')[0].lower()
        langs.append(lang_name)
    return langs

def get_syn2id_and_id2syn(args, languages):
    id2syn_path = os.path.join(args.exp_dir, 'id2syn.pkl')
    syn2id_path = os.path.join(args.exp_dir, 'syn2id.pkl')
    if not os.path.exists(id2syn_path) or not os.path.exists(syn2id_path):
        """Collect all synsets from each language dictionary by iterating through the word2syn files"""
        syns = []
        for lang in tqdm(languages):
            word2syn_file = os.path.join(args.word2syn_dir, lang.upper() + '.json')
            with open(word2syn_file) as json_file:
                word2syn = json.load(json_file)
            for lemma, synset_list in word2syn.items():
                for syn in synset_list:
                    syns.append(syn)
            syns = list(set(syns))
        print(len(syns))
        syns = sorted(syns)
        id2syn = {}
        syn2id = {}
        for i, syn in enumerate(syns):
            id2syn[i] = syn
            syn2id[syn] = i
        dump(id2syn, id2syn_path)
        dump(syn2id, syn2id_path)
    else:
        id2syn = load(id2syn_path)
        syn2id = load(syn2id_path)
    return syn2id, id2syn

def get_lemmas(args, languages, syn2id, id2syn):
    """We already have the original files pairing lemmas with their list of synsets,
       but we do this method so that we keep the structure consistent with wordnet version of code. Here,
       just load the original language json file, and replace the synset keys with their corresponding
       id."""
    lemma_path = os.path.join(args.exp_dir, 'lemmas.pkl')
    if not os.path.exists(lemma_path):
        lemmas = {}
        for lang in tqdm(languages):
            local_lemmas = {}
            word2syn_file = os.path.join(args.word2syn_dir, lang.upper() + '.json')
            with open(word2syn_file) as json_file:
                word2syn = json.load(json_file)
            for lemma, synset_list in word2syn.items():
                if lemma != "":  # for some reason that leaked through and has over 5000 synsets from BabelNet!!!
                    synIDs = [syn2id[x] for x in synset_list]
                    local_lemmas[lemma] = synIDs
                else:
                    print(len(synset_list))
            lemmas[lang] = local_lemmas
        dump(lemmas, lemma_path)
    else:
        lemmas = load(lemma_path)
    return lemmas

def get_all_edges_per_lang(args, lemmas):
    if not os.path.isdir(os.path.join(args.exp_dir, args.all_edges_dicts)):
        os.mkdir(os.path.join(args.exp_dir, args.all_edges_dicts))
    for lang, lem_list in lemmas.items():
        lang_edges = {}
        syn_counter = 0
        two_or_more_counter = 0
        dump_path = os.path.join(args.exp_dir, args.all_edges_dicts, lang + '.pkl')
        for lem, syns in tqdm(lem_list.items()):
            if len(syns) >= 2:
                """Get all pairwise combinations of synsets as edges"""
                edges = get_pairwise_edges(syns)
                for edge in edges:
                    """Actually we should just sort the edges based on number and always have that order!"""
                    edge = sorted(edge)
                    possibility1 = str(edge[0]) + '_' + str(edge[1])
                    if possibility1 in lang_edges:
                        lang_edges[possibility1] += 1
                    else:
                        lang_edges[possibility1] = 1
                two_or_more_counter += 1
            syn_counter += 1
        size_in_GB = getsizeof(lang_edges) / (1024**3)
        print(str(size_in_GB) + "GB with " + str(two_or_more_counter) + " syns with 2 or more...")
        utils.dump(lang_edges, dump_path)

def build_cross_colex_binary_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Step 1: Load all edge files. Then check for cross-lingual colexification by collecting edges but only putting
       a 1 to indicate presence of the edge from one particular language. Then filter by edges that don't occur with
       any other language by removing edges with weight of 1."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    """Keep those edges with weights 2 or more"""
    cross_lingual_edges = {}
    max_occurances = 0
    weights = []
    for edge, value in per_lang_binary_edges.items():
        if value >= 2:
            cross_lingual_edges[edge] = value
            weights.append(value)
            if value > max_occurances:
                max_occurances = value
    print(str(len(per_lang_binary_edges)) + " total edges.")
    print(str(len(cross_lingual_edges)) + " cross-lingual edges.")
    print("Maximum number of languages an edge occured in was " + str(max_occurances) + ".")
    # plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 200])
    # plt.show()
    return cross_lingual_edges

def filter_lang_edge_files(args, lang_edge_files):
    config = utils.get_config()
    new_lang_edge_files = lang_edge_files
    if args.graph_type == 'cross_colex_sum_9':
        languages = config.eval_languages
        new_lang_edge_files = []
        for file in lang_edge_files:
            lang = file.split('/')[-1].split('.')[0].lower()
            if lang in languages:
                new_lang_edge_files.append(file)
    elif args.graph_type == 'cross_colex_sum_50':
        edge_file_dict = {}
        for file in lang_edge_files:
            lang_name = file.split('/')[-1].split('.')[0]
            edge_file_dict[lang_name] = file
        length_dict = {}
        json_files = collect_files(config.directories.wordsyns)
        """DEBUGGGGG:"""
        # json_files = json_files[0:5]
        for file in tqdm(json_files):
            lang_name = file.split('/')[-1].split('.')[0]
            data = utils.load(file)
            length = len(data)
            datum = {'len': length, 'file': edge_file_dict[lang_name.lower()]}
            length_dict[lang_name] = datum
        """Step 2: Sort by length for each item so the plot shows languages in decreasing order"""
        sorted_dict = dict(sorted(length_dict.items(), key=lambda item: item[1]['len'], reverse=True))
        counter = 0
        new_lang_edge_files = []
        for lang, datum in sorted_dict.items():
            if counter < 50:
                new_lang_edge_files.append(datum['file'])
            counter += 1
    return new_lang_edge_files

def build_cross_colex_sum_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Build graph"""
    cross_lingual_edges = build_cross_colex_binary_graph(args)
    """Now go back through the lang_files and collect the original weights for the cross-lingual edges"""
    full_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in cross_lingual_edges:
                if edge in full_edges:
                    full_edges[edge] += lang_edges[edge]
                else:
                    full_edges[edge] = lang_edges[edge]
    max_edge_weight = 0
    weights = []
    for edge, weight in full_edges.items():
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in cross-colex sum graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 1000, 2000])
    plt.show()
    return full_edges

def build_cross_colex_pairwise_product_graph(args):
    languages = get_language_list(args)
    lang_pairs = get_pairwise_edges(languages)
    random.shuffle(lang_pairs)
    """Load all the language-wise graphs and store in dictionary"""
    lang_graphs = {}
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    print("Loading individual language graphs...")
    for file in tqdm(lang_edge_files):
        local_lang = file.split("/")[-1].split(".")[0]
        lang_graphs[local_lang] = utils.load(file)
    print("Done.")
    pairwise_product_graph = {}
    for pair in tqdm(lang_pairs):
        lang0_graph = lang_graphs[pair[0]]
        lang1_graph = lang_graphs[pair[1]]
        for edge, count0 in lang0_graph.items():
            if edge in lang1_graph:
                count1 = lang1_graph[edge]
                product = count0 * count1
                if edge in pairwise_product_graph:
                    pairwise_product_graph[edge] += product
                else:
                    pairwise_product_graph[edge] = product
    max_edge_weight = 0
    weights = []
    for edge, weight in pairwise_product_graph.items():
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in cross-colex pairwise product graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 1000, 2000])
    plt.show()
    return pairwise_product_graph

def build_colex_sum_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    edge_file = None
    for file in lang_edge_files:
        lang_name = file.split('/')[-1].split('.')[0]
        if lang_name == args.monolingual_lang:
            edge_file = file
            break
    """Build graph"""
    edges = {}
    lang_edges = utils.load(edge_file)
    edges = lang_edges
    # for edge, weight in tqdm(lang_edges.items()):
    #     if edge in edges:
    #         edges[edge] += weight
    #     else:
    #         edges[edge] = weight
    max_edge_weight = 0
    weights = []
    for edge, weight in tqdm(edges.items()):
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in colex sum graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 5, 10, 15, 30, 60, 100])
    plt.show()
    return edges

def write_edgelist_dict_to_edgelist(args, edge_dict):
    edgelist = []
    dump_path = os.path.join(args.exp_dir, args.edgelist_savename)
    for edgekey, weight in edge_dict.items():
        line = str(edgekey.split('_')[0]) + ' ' + str(edgekey.split('_')[1]) + ' ' + str(weight)
        edgelist.append(line)
    utils.write_file_from_list(edgelist, dump_path)

def synset_list_to_string_list(syns):
    strings = []
    for syn in syns:
        strings.append(syn._name)
    return strings

def get_pairwise_edges(overlap):
    """"""
    edges = []
    for i, node1 in enumerate(overlap):
        remaining_nodes = overlap[i + 1:]
        for node2 in remaining_nodes:
            edges.append([node1, node2])
    return edges

def get_word_embeddings(args, wordlist, lang, embed_filepath, weight_type='freq'):
    # args = edict({'exp_dir': 'colex_from_BabelNet_combined',
    #               'word2syn_dir': 'word2syns_by_lang_combined'})
    languages = get_language_list(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages)
    lemmas = get_lemmas(args, languages, syn2id, id2syn)
    file1 = open(embed_filepath, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    for i, line in tqdm(enumerate(Lines)):
        if i > 0:
            pieces = line.split(' ')
            id = int(pieces[0])
            # synID = id2syn[id]
            embedding = np.asarray([float(x) for x in pieces[1:]])
            syn_embs[id] = embedding
    """For each word, we want to grab the syn embeddings and inversely weight them"""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            synsets = lemmas[lang][word]
            # synsets = [syn2id[x] for x in synsets]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []  # reset to something
            # if word == 'monarchy' or word == 'roof' or word == 'emporer':
            #     stop = None
            for i, syn in enumerate(synsets):
                """Not all senses have embeddings after node2vec training (too infrequent)"""
                if weight_type == 'inverse':
                    try:
                        temp_emb = syn_embs[syn]
                        if len(word_emb) == 0:
                        # if i == 0:
                            word_emb = temp_emb
                        else:
                            word_emb += (1/(i+1)**0.0005) * temp_emb
                    except:
                        """Skip this sense"""

            if len(word_emb) >= 1:
                """Normalize the embedding"""
                word_emb = word_emb / np.linalg.norm(word_emb)
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    # print(len(word_embs))
    # print(len(wordlist))
    return word_embs

def get_word_embeddings_sense(args, wordlist, lang, embed_filepath, weight_type='freq', syn_PCA=False):
    languages = get_language_list(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages)
    lemmas = get_lemmas(args, languages, syn2id, id2syn)
    file1 = open(embed_filepath, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    for i, line in tqdm(enumerate(Lines)):
        if i > 0:
            pieces = line.split(' ')
            id = int(pieces[0])
            # synID = id2syn[id]
            embedding = np.asarray([float(x) for x in pieces[1:]])
            syn_embs[id] = embedding
    if syn_PCA:
        from LSIM import get_PCA
        print('Getting PCA on sense embeddings...')
        syn_embs = get_PCA(syn_embs)
        print('Done.')

    """For each word, we want to grab the syn embeddings and inversely weight them"""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            synsets = lemmas[lang][word]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []
            for i, syn in enumerate(synsets):
                """Not all senses have embeddings after node2vec training (too infrequent)"""
                if weight_type == 'inverse':
                    try:
                        temp_emb = syn_embs[syn]
                        """Keep list of all the sense embeddings"""
                        word_emb.append(temp_emb)
                    except:
                        """Skip this sense"""
            if len(word_emb) >= 1:
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    return word_embs

def main(args):
    """"""
    if not os.path.exists(args.exp_dir):
        os.mkdir(args.exp_dir)
    """ Step 1: Get list of languages to check overlap with"""
    languages = get_language_list(args)
    """ Step 2: Create syn2id and id2syn dictionaries"""
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages)
    """ Step 3: Get all lemmas from each language and store as a dictionary with their synset ids"""
    lemmas = get_lemmas(args, languages, syn2id, id2syn)

    """ Step 4: Get edges"""
    if args.compute_edges:
        get_all_edges_per_lang(args, lemmas)
    else:
        if args.graph_type == 'cross_colex_binary':
            all_edges = build_cross_colex_binary_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)
        elif args.graph_type == 'cross_colex_sum':
            all_edges = build_cross_colex_sum_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)
        elif args.graph_type == 'cross_colex_pairwise_product':
            all_edges = build_cross_colex_pairwise_product_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)
        elif args.graph_type == 'cross_colex_sum_9':
            all_edges = build_cross_colex_sum_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)
        elif args.graph_type == 'cross_colex_sum_50':
            all_edges = build_cross_colex_sum_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)
        elif args.graph_type == 'colex_sum':
            all_edges = build_colex_sum_graph(args)
            write_edgelist_dict_to_edgelist(args, all_edges)





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--exp_dir', type=str, default='colex_from_BabelNet')
    parser.add_argument('--word2syn_dir', type=str, default='wordsyns/word2syns_by_lang')
    parser.add_argument('--all_edges_dicts', type=str, default='all_edges')
    parser.add_argument('--compute_edges', type=str2bool, default=False)  # set to false if you've run that part already, very fast now
    parser.add_argument('--edgelist_savename', type=str, default='colex_sum_zh.edgelist')
    parser.add_argument('--edge_type', type=str, default='all')  # all, binary
    parser.add_argument('--graph_type', type=str, default='colex_sum')  # cross_colex_binary, cross_colex_sum, cross_colex_pairwise_product, colex_sum
    parser.add_argument('--monolingual_lang', type=str, default='zh')  # we only do the plain colex graph on one language
    parser.add_argument('--num_task', type=int, default=16)
    parser.add_argument('--SAVE_DISK_SPACE', type=utils.str2bool, default=False)  # when collecting edges, remove edges with value 1 to save disk space
    parser.add_argument('--DEBUG_COUNT', type=int, default=500000000000000000)  # set low to run the code quickly to check everything works
    args = parser.parse_args()
    main(args)