import json
import os
import argparse
import matplotlib.pyplot as plt
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from easynmt import EasyNMT
from sys import getsizeof
from tqdm import tqdm

def get_languages_from_synset_text_files(args):
    lemma_savedir = os.path.join(args.exp_dir, args.lemma_savedir)
    if not os.path.isdir(lemma_savedir):
        os.mkdir(lemma_savedir)
    text_files = collect_files(args.synset_text_files_dir)
    languages = []
    # dump_filename = os.path.join(args.exp_dir, "all_lemmas_dict.pkl")
    # text_files = [os.path.join(args.synset_text_files_dir, '3000000.txt')]
    # text_files = [text_files[0]]  # DEBUGGGG!!!!!!
    dummy_counter = 0
    for file in text_files:
        """"""
        file1 = open(file, 'r')
        Lines = file1.readlines()
        for line in tqdm(Lines):
            line = line.replace("\n", "")
            line_info = line.split("~")  # we have to write some code because unfortunately some lemmas contain '~'
            lang = line_info[0]
            if lang not in languages:
                languages.append(lang)
            if len(languages) >= 500:
                break
        file1.close()
        if len(languages) >= 500:  # seems like there are only 499 languages, yep only 499
            break
    return languages

def get_lemmasyns(args):
    """The server can't hold the entire dictionary in memory, so you must only save 10 languages at a time and
       load the files many times."""
    lemma_savedir = os.path.join(args.exp_dir, args.lemma_savedir)
    if not os.path.isdir(lemma_savedir):
        os.mkdir(lemma_savedir)
    text_files = collect_files(args.synset_text_files_dir)
    languages = get_languages_from_synset_text_files(args)
    """Check which languages have already been done."""
    completed_langs = [x.upper() for x in get_language_list(args)]
    languages = list(set(languages).difference(set(completed_langs)))
    language_chunks = utils.divide_index(len(languages), 50)
    for chunk in language_chunks:
        temp_languages = [languages[x] for x in chunk]
        lang_lamma_dict = {}
        # dump_filename = os.path.join(args.exp_dir, "all_lemmas_dict.pkl")
        # text_files = [os.path.join(args.synset_text_files_dir, '3000000.txt')]
        # text_files = [text_files[0]]  # DEBUGGGG!!!!!!
        dummy_counter = 0
        for file in text_files:
            """"""
            file1 = open(file, 'r')
            Lines = file1.readlines()
            for line in tqdm(Lines):
                line = line.replace("\n", "")
                line_info = line.split("~")  # we have to write some code because unfortunately some lemmas contain '~'
                lang = line_info[0]
                syn = line_info[1]
                # if len(line_info) > 3:
                #     stop = None
                lemma = "~".join(line_info[2:])
                lemma = lemma.lower()
                # lang, syn, lemma = line.split("~")
                if lemma != "" and lang in temp_languages:
                    if lang not in lang_lamma_dict:
                        lang_lamma_dict[lang] = {lemma: syn}
                    elif lemma not in lang_lamma_dict[lang]:
                        lang_lamma_dict[lang][lemma] = syn  # to save memory we keep list as "_" delimited string!!!
                    elif syn not in lang_lamma_dict[lang][lemma]:
                        lang_lamma_dict[lang][lemma] = lang_lamma_dict[lang][lemma] + "_" + syn
            file1.close()
            dummy_counter += 1
            # if dummy_counter == 5:
            #     """Let's dump the files to check the size isn't too much!!!"""
            #     for lang, lemmasyns in tqdm(lang_lamma_dict.items()):
            #         dump_path = os.path.join(args.exp_dir, args.lemma_savedir, lang + ".pkl")
            #         utils.dump(lemmasyns, dump_path)
            # # print(str(getsizeof(lang_lamma_dict)*1.0 / (1024**3)*1.0) + "GB")
            # print(getsizeof(lang_lamma_dict))
            # utils.dump(lang_lamma_dict, dump_filename)

        for lang, lemmasyns in tqdm(lang_lamma_dict.items()):
            dump_path = os.path.join(args.exp_dir, args.lemma_savedir, lang + ".pkl")
            utils.dump(lemmasyns, dump_path)

def get_language_list(args):
    """"""
    files = collect_files(os.path.join(args.exp_dir, args.lemma_savedir))
    langs = []
    for file in files:
        lang_name = file.split('/')[-1].split('.')[0].lower()
        langs.append(lang_name)
    return langs

def get_syn2id_and_id2syn(args, languages):
    id2syn_path = os.path.join(args.exp_dir, 'id2syn.pkl')
    syn2id_path = os.path.join(args.exp_dir, 'syn2id.pkl')
    if not os.path.exists(id2syn_path) or not os.path.exists(syn2id_path):
        """Collect all synsets from each language dictionary by iterating through the word2syn files"""
        syns = []
        for lang in languages:
            word2syn_file = os.path.join(args.exp_dir, args.lemma_savedir, lang.upper() + '.pkl')
            word2syn = utils.load(word2syn_file)
            for lemma, synset_str in tqdm(word2syn.items()):
                synset_list = synset_str.split("_")
                for syn in synset_list:
                    syns.append(syn)
            syns = list(set(syns))
        print(len(syns))
        syns = sorted(syns)
        id2syn = {}
        syn2id = {}
        for i, syn in enumerate(syns):
            id2syn[i] = syn
            syn2id[syn] = i
        dump(id2syn, id2syn_path)
        dump(syn2id, syn2id_path)
    else:
        id2syn = load(id2syn_path)
        syn2id = load(syn2id_path)
    return syn2id, id2syn

def get_lemmas(args, languages, syn2id, id2syn):
    """We already have the original files pairing lemmas with their list of synsets,
       but we do this method so that we keep the structure consistent with wordnet version of code. Here,
       just load the original language json file, and replace the synset keys with their corresponding
       id."""
    lemma_folder = os.path.join(args.exp_dir, args.lemma_synIDdir)
    os.makedirs(lemma_folder, exist_ok=True)
    lemmas = {}
    for lang in tqdm(languages):
        lemma_path = os.path.join(lemma_folder, lang.upper() + ".pkl")
        if not os.path.exists(lemma_path):
            local_lemmas = {}
            word2syn_file = os.path.join(args.exp_dir, args.lemma_savedir, lang.upper() + '.pkl')
            word2syn = utils.load(word2syn_file)
            for lemma, synset_str in tqdm(word2syn.items()):
                synset_list = synset_str.split("_")
                if lemma != "":  # for some reason that leaked through and has over 5000 synsets from BabelNet!!!
                    synIDs = [syn2id[x] for x in synset_list]
                    local_lemmas[lemma] = synIDs
                else:
                    print(len(synset_list))
            lemmas[lang] = local_lemmas
            dump(local_lemmas, lemma_path)
        else:
            local_lemmas = load(lemma_path)
            lemmas[lang] = local_lemmas
    return lemmas

def filter_lang_edge_files(args, lang_edge_files):
    config = utils.get_config()
    new_lang_edge_files = lang_edge_files
    if args.graph_type == 'cross_colex_sum_9':
        languages = config.eval_languages
        new_lang_edge_files = []
        for file in lang_edge_files:
            lang = file.split('/')[-1].split('.')[0].lower()
            if lang in languages:
                new_lang_edge_files.append(file)
    elif args.graph_type == 'cross_colex_sum_50':
        edge_file_dict = {}
        for file in lang_edge_files:
            lang_name = file.split('/')[-1].split('.')[0]
            edge_file_dict[lang_name] = file
        length_dict = {}
        json_files = collect_files(config.directories.wordsyns)
        """DEBUGGGGG:"""
        # json_files = json_files[0:5]
        for file in tqdm(json_files):
            lang_name = file.split('/')[-1].split('.')[0]
            data = utils.load(file)
            length = len(data)
            datum = {'len': length, 'file': edge_file_dict[lang_name.lower()]}
            length_dict[lang_name] = datum
        """Step 2: Sort by length for each item so the plot shows languages in decreasing order"""
        sorted_dict = dict(sorted(length_dict.items(), key=lambda item: item[1]['len'], reverse=True))
        counter = 0
        new_lang_edge_files = []
        for lang, datum in sorted_dict.items():
            if counter < 50:
                new_lang_edge_files.append(datum['file'])
            counter += 1
    return new_lang_edge_files

def synset_list_to_string_list(syns):
    strings = []
    for syn in syns:
        strings.append(syn._name)
    return strings

def get_pairwise_edges(overlap):
    """"""
    edges = []
    for i, node1 in enumerate(overlap):
        remaining_nodes = overlap[i + 1:]
        for node2 in remaining_nodes:
            edges.append([node1, node2])
    return edges

def collect_stats_on_graph(args):
    edgelist_path = os.path.join(args.exp_dir, args.edgelist_path)
    file1 = open(edgelist_path, 'r')
    Lines = file1.readlines()
    node_properties = {}
    edge_weights = []
    for line in tqdm(Lines):
        line = line.replace("\n", "")
        info = line.split(" ")
        node1 = int(info[0])
        node2 = int(info[1])
        weight = int(info[2])
        for node in [node1, node2]:
            if node in node_properties:
                node_properties[node]['weight'] += weight
                node_properties[node]['count'] += 1
            else:
                node_properties[node] = {'weight': 0, 'count': 0}
                node_properties[node]['weight'] = weight
                node_properties[node]['count'] = 1
        edge_weights.append(weight)
    node_weights = []
    node_counts = []
    total_num_nodes = 0
    total_matching_count_weight_nodes = 0
    for node, property_dict in tqdm(node_properties.items()):
        total_num_nodes += 1
        weight = property_dict['weight']
        count = property_dict['count']
        if weight == count:
            total_matching_count_weight_nodes += 1
        node_weights.append(property_dict['weight'])
        node_counts.append(property_dict['count'])
    print("Of the total " + str(total_num_nodes) + " nodes, " + str(total_matching_count_weight_nodes) + " nodes had same weight and count.")
    #bins = [0, 5, 10, 15, 30, 100, 200, 1000, 5000]
    bins = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90,
            100, 200, 300, 400, 500, 600, 700, 800, 900,
            1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000,
            10000]
    plt.hist(node_weights, bins=bins)
    plt.yscale('log')
    plt.xscale('log')
    plt.show()
    plt.hist(node_counts, bins=bins)
    plt.yscale('log')
    plt.xscale('log')
    plt.show()
    plt.hist(edge_weights, bins=bins)
    plt.yscale('log')
    plt.xscale('log')
    plt.show()
    stop = None

def load_edges(filepath, id2syn):
    file1 = open(filepath, 'r')
    Lines = file1.readlines()
    edges = {}
    for line in tqdm(Lines):
        line = line.replace("\n", "")
        info = line.split(" ")
        node1 = id2syn[int(info[0])]
        node2 = id2syn[int(info[1])]
        # nodes = sorted([node1, node2])
        edge = node1 + "_" + node2
        weight = int(info[2])
        edges[edge] = weight
    return edges

def get_syn2gloss(args):
    """"""
    config = get_config()
    files = collect_files(config.directories.BabelNetSynsetGlosses)
    syn2gloss = {}
    for file in tqdm(files):
        file1 = open(file, 'r')
        Lines = file1.readlines()
        for line in Lines:
            line = line.replace("\n", "")
            synset = line.split("~")[0]
            gloss = line.split("~")[1]
            syn2gloss[synset] = gloss
    return syn2gloss

def save_edges_with_glosses(args):
    """"""
    dump_path = os.path.join(args.exp_dir, args.edgelist_path + "_glosses")
    syn2gloss = get_syn2gloss(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, ["en"])
    edgelist_path = os.path.join(args.exp_dir, args.edgelist_path)
    file1 = open(edgelist_path, 'r')
    Lines = file1.readlines()
    edges_with_glosses = []
    for line in tqdm(Lines):
        line = line.replace("\n", "")
        info = line.split(" ")
        node1 = syn2gloss[id2syn[int(info[0])]]
        node2 = syn2gloss[id2syn[int(info[1])]]
        weight = int(info[2])
        edges_with_glosses.append({'node1': node1, 'node2': node2, 'weight': weight})
    edges_with_glosses = sorted(edges_with_glosses, key = lambda i: i['weight'], reverse=True)
    line = ""
    lines = []
    for dictionary in edges_with_glosses:
        line = str(dictionary['weight']) + "~" + dictionary['node1'] + "~" + dictionary['node2']
        lines.append(line)
    utils.write_file_from_list(lines, dump_path)

def compare_edges(graph1, graph2, id2syn1, id2syn2):
    graph1 = load_edges(graph1, id2syn1)
    graph2 = load_edges(graph2, id2syn2)
    graph_overlap = {}
    for edge, weight in tqdm(graph1.items()):
        if edge in graph2:
            overlap_weight = int(np.abs(graph2[edge] - weight))
            graph_overlap[edge] = overlap_weight
    print(len(graph1))
    print(len(graph2))
    print(len(graph_overlap))
    stop = None

def get_num_lexemes_per_lang(args):
    """Now let's check the number of lemmas for each interval and see if it increases approximately linearly"""
    lang_edge_files_lexeme_counts_path = os.path.join(args.exp_dir, 'lang_limitation_edge_files_lexeme_counts.pkl')
    lang_edge_files_lexeme_counts = load(lang_edge_files_lexeme_counts_path)  # Already have this file from random_sample_langs in colexAllBabelNet.py
    total_lexemes = 0
    for key, value in lang_edge_files_lexeme_counts.items():
        total_lexemes += value
    stop = None

def sorted_lang_list_by_num_lexemes(args):
    lang_edge_files_lexeme_counts_path = os.path.join(args.exp_dir, 'lang_limitation_edge_files_lexeme_counts.pkl')
    lang_edge_files_lexeme_counts = load(lang_edge_files_lexeme_counts_path)  # Already have this file from random_sample_langs in colexAllBabelNet.py
    """List is already sorted from before."""
    # lang_edge_files_lexeme_counts = sorted(lang_edge_files_lexeme_counts, reverse=True)
    lines = []
    for key, value in lang_edge_files_lexeme_counts.items():
        lang = key.split("/")[-1].split(".")[0]
        lines.append(lang.upper() + "," + str(value))
    dump_path = os.path.join(args.exp_dir, 'langs_sorted_by_lexeme_counts.txt')
    utils.write_file_from_list(lines, dump_path)

def lang_inventory_graph_analysis(args):
    five_langs = "colex_from_AllBabelNet_Concepts/cross_colex_binary_limited_langs.edgelist_5"
    ten_langs = "colex_from_AllBabelNet_Concepts/cross_colex_binary_limited_langs.edgelist_10"
    twenty_langs = "colex_from_AllBabelNet_Concepts/cross_colex_binary_limited_langs.edgelist_20"
    fifty_langs = "colex_from_AllBabelNet_Concepts/cross_colex_binary_limited_langs.edgelist_50"
    all_langs = "colex_from_AllBabelNet_Concepts/cross_colex_binary.edgelist"

    maxsim_embs = {"five": "word_vectors/cross_colex_binary_limited_langs5_5_maxsim",
                   "ten": "word_vectors/cross_colex_binary_limited_langs10_5_maxsim",
                   "twenty": "word_vectors/cross_colex_binary_limited_langs20_5_maxsim",
                   "fifty": "word_vectors/cross_colex_binary_limited_langs50_5_maxsim",
                   "all": "word_vectors/cross_colex_binary_5_maxsim"}

    """We also want to check relative percentage of synsets per word wrt the all_lang scenario,
       so we use the maxsim embedding .pkl files, because they store each synset embedding separately."""
    languages = args.languages
    languages = languages.split("_")
    per_word_synset_overlap = {}
    for lang in tqdm(languages):
        per_word_synset_overlap[lang] = {}
        for method, directory in maxsim_embs.items():
            method_embs = collect_files(os.path.join(directory, lang))
            all_langs_embs = collect_files(os.path.join(maxsim_embs["all"], lang))
            synset_overlap_percentages = []
            for file in all_langs_embs:
                word = file.split("/")[-1].split(".")[0]
                method_word_file = os.path.join(directory, lang, word + ".pkl")
                if os.path.exists(method_word_file):
                    """Compare synset overlap"""
                    method_syns = utils.load(method_word_file)
                    all_langs_syns = utils.load(file)
                    overlap_frac = float(len(method_syns)) / float(len(all_langs_syns))
                else:
                    overlap_frac = 0.0
                synset_overlap_percentages.append(overlap_frac)
            mean_frac = np.mean(np.asarray(synset_overlap_percentages))
            per_word_synset_overlap[lang][method] = mean_frac
    """Now average over languages for each method"""
    perc_overlap_method_scores = {}
    for method, directory in maxsim_embs.items():
        lang_scores = []
        for lang in languages:
            lang_scores.append(per_word_synset_overlap[lang][method])
        method_score = np.mean(np.asarray(lang_scores))
        print(method + ": " + str(method_score))
        perc_overlap_method_scores[method] = method_score




    dividing_factor = {"five": 5.0, "ten": 10, "twenty": 20, "fifty": 50.0, "all": 499.0}
    file_dict = {"five": five_langs, "ten": ten_langs, "twenty": twenty_langs, "fifty": fifty_langs, "all": all_langs}

    """Collect relative edge weights first"""
    # for file in [five_langs, fifty_langs]:
    setting_graphs = {}
    for setting in ["five", "ten", "twenty", "fifty", "all"]:
        file = file_dict[setting]
        divider = dividing_factor[setting]
        """Load the edges"""
        file1 = open(file, 'r')
        Lines = file1.readlines()
        graph = {}
        for i, line in tqdm(enumerate(Lines)):
            pieces = line.split(" ")
            edge = pieces[0] + "_" + pieces[1]
            relative_weight = float(pieces[2]) /  divider
            graph[edge] = relative_weight
        setting_graphs[setting] = graph
    stop = None
    """First show plot of synset overlap percentage with all_langs"""
    synset_overlap_percentage = {}
    for setting in ["five", "ten", "twenty", "fifty", "all"]:
        perc = len(setting_graphs[setting]) / len(setting_graphs["all"])
        synset_overlap_percentage[setting] = perc
        print(perc)
    """Next, collect difference values in relative edge weights and report mean and standard deviation"""
    mean_std_edge_diffs = {}
    for setting in ["five", "ten", "twenty", "fifty", "all"]:
        diffs = []
        tmp_graph = setting_graphs[setting]
        for edge, relative_weight in tmp_graph.items():
            diff_val = np.abs(relative_weight - setting_graphs["all"][edge])
            diffs.append(diff_val)
        diffs = np.asarray(diffs)
        mean = np.mean(diffs)
        std = np.std(diffs)
        mean_std_edge_diffs[setting] = {"mean": mean, "std": std}
        print("Mean: " + str(mean) + ", Std:" + str(std))
    results = {"syn_overlap": synset_overlap_percentage,
               "edge_diffs": mean_std_edge_diffs,
               "perc_overlap_syn_per_word": perc_overlap_method_scores}
    """Save results"""
    if args.results_save_path != '':
        config = get_config()
        os.makedirs(config.directories.results, exist_ok=True)
        save_path = os.path.join(config.directories.results, args.results_save_path)
        dump(results, save_path)

def build_cross_colex_binary_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Step 1: Load all edge files. Then check for cross-lingual colexification by collecting edges but only putting
       a 1 to indicate presence of the edge from one particular language. Then filter by edges that don't occur with
       any other language by removing edges with weight of 1."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    """Keep those edges with weights 2 or more"""
    cross_lingual_edges = {}
    max_occurances = 0
    weights = []
    unique_nodes = {}  # basically a list but we want it hashable for speed
    for edge, value in per_lang_binary_edges.items():
        # if value >= 2:  # old code where we allow edge if it occurs in two languages only
        if value >= args.cross_lingual_threshold:
            cross_lingual_edges[edge] = value
            weights.append(value)
            node1 = edge.split("_")[0]
            node2 = edge.split("_")[1]
            unique_nodes[node1] = ""
            unique_nodes[node2] = ""
            if value > max_occurances:
                max_occurances = value
    print(str(len(per_lang_binary_edges)) + " total edges.")
    print(str(len(cross_lingual_edges)) + " cross-lingual edges.")
    print(str(len(unique_nodes)) + " unique nodes in cross-lingual edges.")
    print("Maximum number of languages an edge occured in was " + str(max_occurances) + ".")
    # plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 200])
    # plt.show()
    return cross_lingual_edges

def cross_lang_edges_per_lang(args):
    """"""
    binary_graph = build_cross_colex_binary_graph(args)
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Step 1: Load all edge files. Then check for cross-lingual colexification by collecting edges but only putting
       a 1 to indicate presence of the edge from one particular language. Then filter by edges that don't occur with
       any other language by removing edges with weight of 1."""
    per_lang_cross_ling_edge_count = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        cl_edge_count = 0
        for edge in lang_edges:
            if edge in binary_graph:
                cl_edge_count += 1
        per_lang_cross_ling_edge_count[lang_file] = cl_edge_count
    per_lang_cross_ling_edge_count = {k: v for k, v in sorted(per_lang_cross_ling_edge_count.items(), key=lambda item: item[1], reverse=True)}
    values = []
    for key, value in per_lang_cross_ling_edge_count.items():
        values.append(value)
    max_val = values[0]
    min_val = values[-1]
    print("Max number of cross-lingual edges in a language: " + str(max_val))
    print("Min number of cross-lingual edges in a language: " + str(min_val))
    bins = [1000, 10000, 100000, 1000000, 5000000]
    plt.hist(values, bins=bins)
    plt.yscale('log')
    plt.xscale('log')
    plt.ylabel("Number of languages")
    plt.xlabel("Number of cross-lingual synset pairs")
    plt.savefig("cross_lingual_synset_pairs.pdf")

def check_overlap_words(args, unique_words, lang):
    config = get_config()
    methods = args.method_word_intersections.split("~")
    missing_word_dict = {}
    for method in methods:
        word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
        method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
        if method != "BabelNet":
            """"""
            if lang == "fr" and method == "BERT":
                for word in unique_words:
                    if word not in method_words:
                        stop = None
            # word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
            # method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
            # missing_words = list(set(unique_words).difference(set(method_words)))
            # missing_word_dict[method] = missing_words
        else:
            word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
            with open(word2syn_file) as json_file:
                word2syn = json.load(json_file)

            # lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
            # for file in lang_edge_files:
            #     lang_name = file.split('/')[-1].split('.')[0]
            #     if lang_name == lang:
            #         edge_file = file
            #         break
            # lemma_synID_file = os.path.join(args.exp_dir, "lemmas_synID", lang.upper() + ".pkl")
            # lemmas = load(lemma_synID_file)
            method_words = []
            for word in unique_words:
                try:
                    # word = word.replace(" ", "_")
                    dum = word2syn[word]
                    method_words.append(word)
                except:
                    """"""

        missing_words = list(set(unique_words).difference(set(method_words)))
        missing_word_dict[method] = missing_words
    return missing_word_dict

def check_overlap_words_num_synsets(args, unique_words, lang):
    config = get_config()
    methods = args.method_word_intersections.split("~")
    missing_word_dict = {}
    word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
    with open(word2syn_file) as json_file:
        word2syn = json.load(json_file)
    """First get the BabelNet words for the language"""
    BabelNet_words = []
    for word in unique_words:
        try:
            # word = word.replace(" ", "_")
            dum = word2syn[word]
            BabelNet_words.append(word)
        except:
            """"""
    """Want to first guarantee that all missing words are monosemous!!!"""
    for method in methods:
        word_files = utils.collect_files(os.path.join(config.directories.word_vectors, method, lang))
        method_words = [x.split("/")[-1].split(".")[0] for x in word_files]
        monosemous_eval_words = []
        if method != "BabelNet":
            """"""
            for word in BabelNet_words:
                syns = word2syn[word]
                if len(syns) == 1:
                    monosemous_eval_words.append(word)
            stop = None

        else:
            method_words = BabelNet_words

        missing_words = list(set(unique_words).difference(set(method_words)))
        greater_than_1_syn_count = 0
        """The reason there are ANY syns with counts greater than 1 that are missing is because we use LEXEMES to
           construct the graphs while for the evaluation words we query BabelNet. It appears BabelNet doesn't map
           words to lexemes OR the ProNE embedding method doesn't produce embeddings for each node. Need to
           investigate this further!!!"""
        for word in missing_words:
            if word in BabelNet_words:
                syns = word2syn[word]
                greater_than_1_syn_count += len(syns)
        missing_word_dict[method] = missing_words
    return missing_word_dict

def get_OOV_words_table(args):
    languages = args.languages
    languages = languages.split("_")
    language_stats = {}
    for lang in languages:
        language_stats[lang] = {}
        word_pairs, unique_words = get_multisimlex(lang)
        missing_word_dict = check_overlap_words(args, unique_words, lang)
        """For ARES you used the old wordsyns files (where you had queried BabelNet), and you included the eval
           words in the query, so BabelNet has some stuff to resolve words better than just exact matching
           because ARES has less missing words than COLEX which uses the direct dictionary look-up so for
           COLEX the string must exactly match. For now, just double check that ARES missing words are an subset
           of COLEX missing words, then we can just report missing BabelNet words using our approach."""
        # ARES_leftovers = []
        # for word in missing_word_dict["ARES"]:
        #     if word not in missing_word_dict["cross_colex_binary_5"]:
        #         ARES_leftovers.append(word)
        language_stats[lang] = missing_word_dict
        line_text = lang.upper() + ": "
        for method, missing_wordlist in missing_word_dict.items():
            line_text += method + ": " + str(len(missing_wordlist)) + " "
            language_stats[lang][method] = len(missing_wordlist)
        print(line_text)

    """Print the latex table."""
    method_order = ["fasttext", "BERT", "BabelNet", "ARES", "colex_mono",
                    "colex_all_maxsim", "colex_all"]
    method_text = {"fasttext": "fastText", "BERT": "BERT",
                   "ARES": "ARES", "colex_mono": "COLEX\\textsubscript{mono}",
                   "colex_all_maxsim": "COLEX\\textsubscript{maxsim}",
                   "colex_all": "COLEX\\textsubscript{cross}",
                   "BabelNet": "BabelNet"}
    first_line = ""
    for lang in languages:
        first_line += " & " + lang.upper()
    first_line += " \\\\"
    print(first_line)
    for method in method_order:
        line = method_text[method]
        for lang in languages:
            line += " & " + str(language_stats[lang][method])
        line += " \\\\"
        print(line)

def get_OOV_words_table_checking_numbers(args):
    languages = args.languages
    languages = languages.split("_")
    language_stats = {}
    for lang in languages:
        language_stats[lang] = {}
        word_pairs, unique_words = get_multisimlex(lang)
        missing_word_dict = check_overlap_words_num_synsets(args, unique_words, lang)
        """For ARES you used the old wordsyns files (where you had queried BabelNet), and you included the eval
           words in the query, so BabelNet has some stuff to resolve words better than just exact matching
           because ARES has less missing words than COLEX which uses the direct dictionary look-up so for
           COLEX the string must exactly match. For now, just double check that ARES missing words are an subset
           of COLEX missing words, then we can just report missing BabelNet words using our approach."""
        # ARES_leftovers = []
        # for word in missing_word_dict["ARES"]:
        #     if word not in missing_word_dict["cross_colex_binary_5"]:
        #         ARES_leftovers.append(word)
        language_stats[lang] = missing_word_dict
        line_text = lang.upper() + ": "
        for method, missing_wordlist in missing_word_dict.items():
            line_text += method + ": " + str(len(missing_wordlist)) + " "
            language_stats[lang][method] = len(missing_wordlist)
        print(line_text)

    """Print the latex table."""
    method_order = ["fasttext", "BERT", "BabelNet", "ARES", "colex_mono",
                    "colex_all_maxsim", "colex_all"]
    method_text = {"fasttext": "fastText", "BERT": "BERT",
                   "ARES": "ARES", "colex_mono": "COLEX\\textsubscript{mono}",
                   "colex_all_maxsim": "COLEX\\textsubscript{maxsim}",
                   "colex_all": "COLEX\\textsubscript{cross}",
                   "BabelNet": "BabelNet"}
    first_line = ""
    for lang in languages:
        first_line += " & " + lang.upper()
    first_line += " \\\\"
    print(first_line)
    for method in method_order:
        line = method_text[method]
        for lang in languages:
            line += " & " + str(language_stats[lang][method])
        line += " \\\\"
        print(line)

def simscore_lists(args):
    """"""
    delim = args.simscore_delimiter
    config = get_config()
    save_dir = config.directories.simscore
    os.makedirs(save_dir, exist_ok=True)
    files = {"fasttext": "results/SIMSCORE~LSIM~fasttext.pkl",
             "BERT": "results/SIMSCORE~LSIM~BERT.pkl",
             "ARES": "results/SIMSCORE~LSIM~ARES.pkl",
             "colex_all": "results/SIMSCORE~LSIM~colex_all.pkl",
             "colex_all_maxsim": "results/SIMSCORE~LSIM~colex_all_maxsim.pkl",
             "colex_mono": "results/SIMSCORE~LSIM~colex_mono.pkl",
             "C+F+B": "results/SIMSCORE~LSIM~colex_all~fasttext~BERT.pkl",
             "C+F": "results/SIMSCORE~LSIM~colex_all~fasttext.pkl"}

    languages = args.languages
    languages = languages.split("_")
    for method, file in files.items():
        scores = utils.load(file)
        for lang in languages:
            """First get the maximum and minimum values so we can normalize everything"""
            # min_ground_truth = 1000.0
            # max_ground_truth = -1000.0
            min_ground_truth = 0.0
            range_ground_truth = 6.0
            # min_our_score = 1000
            # max_our_score = -1000
            min_our_score = -1.0
            range_our_score = 2.0
            # for word_pair, pair_score in scores[lang].items():
            #     gt_score = pair_score["ground_truth"]
            #     our_score = pair_score["our_score"]
                # if gt_score > max_ground_truth:
                #     max_ground_truth = gt_score
                # elif gt_score < min_ground_truth:
                #     min_ground_truth = gt_score
                # if our_score > max_our_score:
                #     max_our_score = our_score
                # elif our_score < min_our_score:
                #     min_our_score = our_score
            """Now adjust the scores"""
            gt_raw_values = []
            our_raw_values = []
            for word_pair, pair_score in scores[lang].items():
                """Save rankings as well"""
                gt_raw_values.append(pair_score["ground_truth"])
                our_raw_values.append(pair_score["our_score"])
                pair_score["ground_truth"] = (pair_score["ground_truth"] - min_ground_truth) / range_ground_truth
                pair_score["our_score"] = (pair_score["our_score"] - min_our_score) / range_our_score
                pair_score["score_diff"] = np.abs(pair_score["ground_truth"] - pair_score["our_score"])
            """Get ranking info"""
            ground_truth_scores = np.asarray(gt_raw_values)
            our_scores = np.asarray(our_raw_values)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=args.rank_method)
            our_ranks = scipy.stats.rankdata(our_scores, method=args.rank_method)
            count = 0
            for word_pair, pair_score in scores[lang].items():
                pair_score["ground_truth_rank"] = ground_truth_ranks[count]
                pair_score["our_score_rank"] = our_ranks[count]
                pair_score["rank_diff"] = np.abs(pair_score["ground_truth_rank"] - pair_score["our_score_rank"])
                count += 1
            scores[lang] = {k: v for k, v in sorted(scores[lang].items(), key=lambda item: item[1]["rank_diff"])}
            save_path = os.path.join(save_dir, method + "_" + lang + ".csv")
            line_elements = ["word1", "word2", "rank_diff", "ground_truth_rank",
                             "our_score_rank", "score_diff", "ground_truth", "our_score"]
            first_line = "word1,word2,rank_diff,gt_rank,our_rank,score_diff,gt_score,our_score"
            lines = [first_line]
            for word_pair, pair_score in scores[lang].items():
                word1 = word_pair.split("_")[0]
                word2 = word_pair.split("_")[1]
                line = word1 + delim + word2
                for elem in line_elements[2:]:
                    line = line + delim + str(round(pair_score[elem], 2))
                lines.append(line)
            utils.write_file_from_list(lines, save_path)

        stop = None

def simscore_lists_synset_graph_analysis(args):
    """"""
    # max_gloss_len = 10000
    max_gloss_len = 120
    import colexAllBabelNet
    delim = args.simscore_delimiter
    config = get_config()
    save_dir = config.directories.simscore_synsets
    os.makedirs(save_dir, exist_ok=True)
    # files = {"fasttext": "results/SIMSCORE~LSIM~fasttext.pkl",
    #          "BERT": "results/SIMSCORE~LSIM~BERT.pkl",
    #          "ARES": "results/SIMSCORE~LSIM~ARES.pkl",
    #          "colex_all": "results/SIMSCORE~LSIM~colex_all.pkl",
    #          "colex_all_maxsim": "results/SIMSCORE~LSIM~colex_all_maxsim.pkl",
    #          "colex_mono": "results/SIMSCORE~LSIM~colex_mono.pkl"}
    """Only need to analyze the monolingual and all graphs!!!"""
    # files = {"colex_all": "results/SIMSCORE~LSIM~colex_all.pkl",
    #          "colex_mono": "results/SIMSCORE~LSIM~colex_mono.pkl"}
    files = {"colex_mono": "results/SIMSCORE~LSIM~colex_mono.pkl",
             "colex_all": "results/SIMSCORE~LSIM~colex_all.pkl"}
    graphs = {"colex_all": os.path.join(args.exp_dir, "colex_all.edgelist"),
             "colex_mono": os.path.join(args.exp_dir, "colex_LANG.edgelist")}
    syn2gloss = get_syn2gloss(args)
    languages = args.languages
    languages = languages.split("_")
    concept_syns = colexAllBabelNet.get_concept_syns(args)
    syn2id, id2syn = colexAllBabelNet.get_syn2id_and_id2syn(args, languages, concept_syns)
    for method, file in files.items():
        scores = utils.load(file)
        for lang in languages:
            # concept_syns = colexAllBabelNet.get_concept_syns(args)
            # syn2id, id2syn = colexAllBabelNet.get_syn2id_and_id2syn(args, languages, concept_syns)
            # syn2gloss = get_syn2gloss(args)


            """Using wordsyns"""
            word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
            with open(word2syn_file) as json_file:
                word2syn = json.load(json_file)

            """Need to load the graph and collect edge counts for each node"""
            edgelist_path = None
            if method == "colex_all":
                edgelist_path = graphs[method]
            elif method == "colex_mono":
                edgelist_path = graphs[method]
                edgelist_path = edgelist_path.replace("LANG", lang)
            edgelist = utils.read_edgelist(edgelist_path)
            node_counts = {}
            for edge, weight in tqdm(edgelist.items()):
                node1 = edge.split("_")[0]
                node2 = edge.split("_")[1]
                """Add the edge to both nodes' connections in the graph"""
                if node1 in node_counts:
                    node_counts[node1] += 1
                else:
                    node_counts[node1] = 1
                if node2 in node_counts:
                    node_counts[node2] += 1
                else:
                    node_counts[node2] = 1
            # node_counts = {}
            # for node, connections_dict in G.items():
            #     node_counts[node] = len(connections_dict)
            # del G

            """First get the maximum and minimum values so we can normalize everything"""
            min_ground_truth = 0.0
            range_ground_truth = 6.0
            min_our_score = -1.0
            range_our_score = 2.0
            """Now adjust the scores"""
            gt_raw_values = []
            our_raw_values = []
            #word_syn_data = []
            for word_pair, pair_score in scores[lang].items():
                """Get word2syn information for each word"""
                word1 = word_pair.split("_")[0]
                word2 = word_pair.split("_")[1]

                word1_syns = word2syn[word1]
                word2_syns = word2syn[word2]
                word1_syns_new = []
                for x in word1_syns:
                    try:
                        ID = syn2id[x]
                        connection_count = node_counts[str(ID)]
                        syn_gloss = syn2gloss[id2syn[ID]]
                        word1_syns_new.append([syn_gloss, connection_count])
                    except:
                        """"""
                        #word1_syns_new.append({"": ""})
                word2_syns_new = []
                for x in word2_syns:
                    try:
                        ID = syn2id[x]
                        connection_count = node_counts[str(ID)]
                        syn_gloss = syn2gloss[id2syn[ID]]
                        word2_syns_new.append([syn_gloss, connection_count])
                    except:
                        """"""
                        #word2_syns_new.append({"": ""})
                # word_syn_data.append({"word1": word1_syns_new, "word2": word2_syns_new})
                """Save rankings as well"""
                gt_raw_values.append(pair_score["ground_truth"])
                our_raw_values.append(pair_score["our_score"])
                pair_score["ground_truth"] = (pair_score["ground_truth"] - min_ground_truth) / range_ground_truth
                pair_score["our_score"] = (pair_score["our_score"] - min_our_score) / range_our_score
                pair_score["score_diff"] = np.abs(pair_score["ground_truth"] - pair_score["our_score"])
                pair_score["word_syn_data"] = {"word1": word1_syns_new, "word2": word2_syns_new}
            """Get ranking info"""
            ground_truth_scores = np.asarray(gt_raw_values)
            our_scores = np.asarray(our_raw_values)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth_scores, method=args.rank_method)
            our_ranks = scipy.stats.rankdata(our_scores, method=args.rank_method)
            count = 0
            for word_pair, pair_score in scores[lang].items():
                pair_score["ground_truth_rank"] = ground_truth_ranks[count]
                pair_score["our_score_rank"] = our_ranks[count]
                pair_score["rank_diff"] = np.abs(pair_score["ground_truth_rank"] - pair_score["our_score_rank"])
                count += 1
            scores[lang] = {k: v for k, v in sorted(scores[lang].items(), key=lambda item: item[1]["rank_diff"])}
            save_path = os.path.join(save_dir, method + "_" + lang + ".csv")
            line_elements = ["word1", "word2", "rank_diff", "ground_truth_rank",
                             "our_score_rank", "score_diff", "ground_truth", "our_score"]
            # first_line = "word1,word2,rank_diff,gt_rank,our_rank,score_diff,gt_score,our_score"
            lines = []
            count = 0
            for word_pair, pair_score in scores[lang].items():
                first_line = "word1,word2,rank_diff,gt_rank,our_rank,score_diff,gt_score,our_score"
                lines.append(first_line)
                word1 = word_pair.split("_")[0]
                word2 = word_pair.split("_")[1]
                line = word1 + delim + word2
                for elem in line_elements[2:]:
                    line = line + delim + str(round(pair_score[elem], 2))
                lines.append(line)
                word1_syn_data = pair_score["word_syn_data"]["word1"]
                word2_syn_data = pair_score["word_syn_data"]["word2"]
                # word1_syn_data = word_syn_data[count]["word1"]
                # word2_syn_data = word_syn_data[count]["word2"]
                """Now add lines for the synset connection info"""
                second_line = "word1_gloss,word1_synset_connections,word2_gloss,word2_synset_connections"
                lines.append(second_line)
                for dum_index in range(max(len(word1_syn_data), len(word2_syn_data))):
                    try:
                        word1_gloss = word1_syn_data[dum_index][0]
                        word1_syn_count = word1_syn_data[dum_index][1]
                    except:
                        word1_gloss = ""
                        word1_syn_count = ""
                    try:
                        word2_gloss = word2_syn_data[dum_index][0]
                        word2_syn_count = word2_syn_data[dum_index][1]
                    except:
                        word2_gloss = ""
                        word2_syn_count = ""
                    word1_gloss = word1_gloss.replace(",", "")
                    word2_gloss = word2_gloss.replace(",", "")
                    new_line = word1_gloss[0:max_gloss_len] + delim + str(word1_syn_count) + delim + \
                               word2_gloss[0:max_gloss_len] + delim + str(word2_syn_count)
                    lines.append(new_line)
                count += 1

            utils.write_file_from_list(lines, save_path)

        stop = None

def umap_plot(args):
    """"""
    import umap
    import umap.plot
    embed_filepath = os.path.join(args.exp_dir, args.edgelist_path)
    file1 = open(embed_filepath, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    emb_list = []
    for i, line in tqdm(enumerate(Lines[0:1000000])): # DEBUG!!!
    # for i, line in tqdm(enumerate(Lines)):
        if i > 0:
            pieces = line.split(' ')
            id = int(pieces[0])
            # synID = id2syn[id]
            embedding = np.asarray([float(x) for x in pieces[1:]])
            syn_embs[id] = embedding
            emb_list.append(embedding)
    """We have the list of embeddings"""
    # mapper = umap.UMAP().fit(emb_list)
    # p = umap.plot.points(mapper)
    # umap.plot.show(p)
    print("Starting UMAP...")
    mapper = umap.UMAP().fit(emb_list)
    # p = umap.plot.diagnostic(mapper, diagnostic_type='pca')
    p = umap.plot.diagnostic(mapper, diagnostic_type='vq')
    print("Done.")
    umap.plot.show(p)
    # mapper = umap.UMAP().fit(emb_list)
    # umap.plot.points(mapper)
    # plt.show()
    # umap.plot.diagnostic(mapper, diagnostic_type='pca')
    stop = None

def synset_pairs_limited_langs(args):
    """"""
    graphs = {5: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_5"),
              10: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_10"),
              20: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_20"),
              50: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_50"),
              100: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_100"),
              200: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_200"),
              499: os.path.join(args.exp_dir, "colex_all.edgelist"),}
    num_edges = {}
    for num_langs, edgelist_path in graphs.items():
        G = utils.read_edgelist(edgelist_path)
        num_edges[num_langs] = float(len(G))
    for num_langs, n_edges in num_edges.items():
        frac = n_edges / num_edges[499]
        print(str(num_langs) + ", " + str(round(100*frac, 2)) + "\%")

def synset_pairs_limited_langs_version2(args):
    """"""
    graphs = {# 5: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_5"),
              9: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_9"),
              20: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_20"),
              50: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_50"),
              100: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_100"),
              200: os.path.join(args.exp_dir, "colex_limited_langs.edgelist_200"),
              499: os.path.join(args.exp_dir, "colex_all.edgelist"),}
    num_edges = {}
    for num_langs, edgelist_path in graphs.items():
        G = utils.read_edgelist(edgelist_path)
        num_edges[num_langs] = float(len(G))
    for num_langs, n_edges in num_edges.items():
        frac = n_edges / num_edges[499]
        print(str(num_langs) + ", " + str(round(100*frac, 2)) + "\%")

def synset_pairs_vs_performance_mono(args):
    """"""
    config = get_config()
    """Now grab the LSIM results"""
    results = utils.load(os.path.join(config.directories.results, "LSIM~colex_mono.pkl"))
    graphs = {}
    languages = args.languages
    languages = languages.split("_")
    for lang in languages:
        graph_path = os.path.join(args.exp_dir, "colex_" + lang + ".edgelist")
        graphs[lang] = graph_path
    num_edges = {}
    for lang, edgelist_path in graphs.items():
        # print(lang)
        G = utils.read_edgelist(edgelist_path)
        num_edges[lang] = float(len(G))
    value_list_results = []
    value_list_num_syn_pairs = []
    for lang, num_pairs in num_edges.items():
        value_list_num_syn_pairs.append(num_pairs)
        value_list_results.append(results[lang]["spearman_rank_corr"])
        print(lang + " - num_pairs:" + str(num_pairs) + ", result: " + str(results[lang]["spearman_rank_corr"]))
    value_list_num_syn_pairs = np.asarray(value_list_num_syn_pairs)
    value_list_results = np.asarray(value_list_results)
    corr = scipy.stats.pearsonr(value_list_results, value_list_num_syn_pairs)
    print(corr)
    # for num_langs, n_edges in num_edges.items():
    #     frac = n_edges / num_edges[499]
    #     print(str(num_langs) + ", " + str(round(100*frac, 2)) + "\%")

def mean_polysemy_LSIM_words(args):
    config = get_config()
    languages = args.languages
    languages = languages.split("_")
    import LSIM
    """Using wordsyns"""
    for lang in languages:
        word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
        with open(word2syn_file) as json_file:
            word2syn = json.load(json_file)
        word_pairs, unique_words = get_multisimlex(lang)
        unique_words = LSIM.check_overlap_words(args, unique_words, lang)
        polysemy_counts = []
        for word in unique_words:
            syns = word2syn[word]
            polysemy_counts.append(len(syns))
        mean_polysemy = np.mean(np.asarray(polysemy_counts))
        print(lang + ": " + str(mean_polysemy))

def get_BabelNet_data_stats(args):
    """"""
    from colexAllBabelNet import get_ignore_lemmas
    config = get_config()
    languages = args.languages
    # languages = languages + "_hy"
    languages = languages.split("_")
    """Using wordsyns"""
    num_lexemes = {}
    num_synsets = {}
    num_synsets_list = {}
    mean_syns_per_lexeme = {}
    max_syns_per_lexeme = {}
    num_monosemous_lexemes_list = {}
    num_polysemous_lexemes_list = {}
    mean_syns_per_polysemous_lexeme_list = {}
    results_save_path = os.path.join(config.directories.results, "BabelNet_data_stats.pkl")
    if not os.path.exists(results_save_path):
        for lang in languages:
            max_char = "EMPTY"
            max_syns_per_lex = -1
            num_monosemous_lexemes = 0
            num_polysemous_lexemes = 0
            num_syns_polysemous_lexemes_values = []
            lexemes_path = os.path.join(args.exp_dir, args.lemma_synIDdir, lang.upper() + ".pkl")
            lexemes = utils.load(lexemes_path)
            remove_lems = get_ignore_lemmas()
            for lem in remove_lems:
                if lem in lexemes:
                    del lexemes[lem]
                    print("Removed " + lem + " from lemmas...")
            num_lexemes[lang] = len(lexemes)
            unique_syn_dict = {}
            num_syns = []
            for lex, syns in tqdm(lexemes.items()):
                num_syns.append(len(syns))
                if len(syns) == 1:
                    num_monosemous_lexemes += 1
                else:
                    num_syns_polysemous_lexemes_values.append(len(syns))
                    num_polysemous_lexemes += 1
                if len(syns) > max_syns_per_lex:
                    max_syns_per_lex = len(syns)
                    max_char = lex
                for syn in syns:
                    unique_syn_dict[syn] = ""
            # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100]
            # plt.hist(num_syns, bins=bins)
            # plt.show()
            max_syns_per_lexeme[lang] = {"max_val": max_syns_per_lex, "max_char": max_char}
            mean_syns_per_lexeme[lang] = np.mean(np.asarray(num_syns))
            num_synsets[lang] = len(unique_syn_dict)
            num_syns_sorted = sorted(num_syns, reverse=True)
            num_synsets_list[lang] = num_syns_sorted
            num_monosemous_lexemes_list[lang] = num_monosemous_lexemes
            num_polysemous_lexemes_list[lang] = num_polysemous_lexemes
            mean_syns_per_polysemous_lexeme_list[lang] = np.mean(np.asarray(num_syns_polysemous_lexemes_values))
            stop = None
        BabelNet_stats = {"num_lexemes": num_lexemes, "mean_syns_per_lexeme": mean_syns_per_lexeme,
                          "max_syns_per_lexeme": max_syns_per_lexeme, "num_synsets": num_synsets,
                          "num_syns_full_sorted_list": num_synsets_list,
                          "num_monosemous_lexemes": num_monosemous_lexemes_list,
                          "num_polysemous_lexemes": num_polysemous_lexemes_list,
                          "mean_syns_per_polysemous_lexeme": mean_syns_per_polysemous_lexeme_list}
        utils.dump(BabelNet_stats, results_save_path)
    else:
        BabelNet_stats = utils.load(results_save_path)
    stop = None
    # languages.remove("hy")
    graphs = {}
    for lang in languages:
        graph_path = os.path.join(args.exp_dir, "colex_" + lang + ".edgelist")
        graphs[lang] = graph_path
    num_edges = {}
    for lang, edgelist_path in graphs.items():
        # print(lang)
        G = utils.read_edgelist(edgelist_path)
        num_edges[lang] = float(len(G))
    lines = []
    data_values = {"num_lexemes": BabelNet_stats["num_lexemes"],
                   "num_synsets": BabelNet_stats["num_synsets"],
                   "num_synset_pairs": num_edges,
                   "mean_syns_per_lexeme": BabelNet_stats["mean_syns_per_lexeme"],
                   "num_monosemous_lexemes": BabelNet_stats["num_monosemous_lexemes"],
                   "num_polysemous_lexemes": BabelNet_stats["num_polysemous_lexemes"],
                   "mean_syns_per_polysemous_lexeme": BabelNet_stats["mean_syns_per_polysemous_lexeme"]}
    table_labels = {"num_lexemes": "Vocabulary size (\# lexemes)",
                   "num_synsets": "\# Synsets",
                   "num_synset_pairs": "\# Colexified Synset Pairs",
                   "mean_syns_per_lexeme": "Mean \# synsets per lexeme",
                   "num_monosemous_lexemes": "\# Monosemous lexemes",
                   "num_polysemous_lexemes": "\# Polysemous lexemes",
                   "mean_syns_per_polysemous_lexeme": "Mean \# synsets per polysemous lexeme"}
    data_types = {"num_lexemes": "int",
                  "num_synsets": "int",
                  "num_synset_pairs": "int",
                  "mean_syns_per_lexeme": "float",
                  "num_monosemous_lexemes": "int",
                  "num_polysemous_lexemes": "int",
                  "mean_syns_per_polysemous_lexeme": "float"}
    rows = ["num_lexemes", "num_synsets", "num_synset_pairs",
            "mean_syns_per_lexeme", "num_polysemous_lexemes", "mean_syns_per_polysemous_lexeme"]
    first_line = " "
    for lang in languages:
        first_line = first_line + " & " + lang.upper()
    first_line += "\\\\"
    print(first_line)
    for row in rows:
        line = table_labels[row]
        for lang in languages:
            """"""
            value = data_values[row][lang]
            if data_types[row] == "int":
                value = int(value)
            elif data_types[row] == "float":
                value = round(value, 2)
            line = line + " & " + str(value)
        line += " \\\\"
        print(line)

def get_BabelNet_data_stats_all_langs(args):
    """"""
    from colexAllBabelNet import get_ignore_lemmas
    config = get_config()
    files = collect_files(os.path.join(args.exp_dir, args.lemma_savedir))
    languages = []
    for file in files:
        lang_name = file.split('/')[-1].split('.')[0].lower()
        languages.append(lang_name)
    """Using wordsyns"""
    num_lexemes = {}
    num_synsets = {}
    num_synsets_list = {}
    mean_syns_per_lexeme = {}
    max_syns_per_lexeme = {}
    num_monosemous_lexemes_list = {}
    num_polysemous_lexemes_list = {}
    mean_syns_per_polysemous_lexeme_list = {}
    results_save_path = os.path.join(config.directories.results, "BabelNet_data_stats_all_langs.pkl")
    if not os.path.exists(results_save_path):
        for lang in languages:
            max_char = "EMPTY"
            max_syns_per_lex = -1
            num_monosemous_lexemes = 0
            num_polysemous_lexemes = 0
            num_syns_polysemous_lexemes_values = []
            lexemes_path = os.path.join(args.exp_dir, args.lemma_synIDdir, lang.upper() + ".pkl")
            lexemes = utils.load(lexemes_path)
            remove_lems = get_ignore_lemmas()
            for lem in remove_lems:
                if lem in lexemes:
                    del lexemes[lem]
                    print("Removed " + lem + " from lemmas...")
            num_lexemes[lang] = len(lexemes)
            unique_syn_dict = {}
            num_syns = []
            for lex, syns in tqdm(lexemes.items()):
                num_syns.append(len(syns))
                if len(syns) == 1:
                    num_monosemous_lexemes += 1
                else:
                    num_syns_polysemous_lexemes_values.append(len(syns))
                    num_polysemous_lexemes += 1
                if len(syns) > max_syns_per_lex:
                    max_syns_per_lex = len(syns)
                    max_char = lex
                for syn in syns:
                    unique_syn_dict[syn] = ""
            # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100]
            # plt.hist(num_syns, bins=bins)
            # plt.show()
            max_syns_per_lexeme[lang] = {"max_val": max_syns_per_lex, "max_char": max_char}
            mean_syns_per_lexeme[lang] = np.mean(np.asarray(num_syns))
            num_synsets[lang] = len(unique_syn_dict)
            num_syns_sorted = sorted(num_syns, reverse=True)
            num_synsets_list[lang] = num_syns_sorted
            num_monosemous_lexemes_list[lang] = num_monosemous_lexemes
            num_polysemous_lexemes_list[lang] = num_polysemous_lexemes
            mean_syns_per_polysemous_lexeme_list[lang] = np.mean(np.asarray(num_syns_polysemous_lexemes_values))
            stop = None
        BabelNet_stats = {"num_lexemes": num_lexemes, "mean_syns_per_lexeme": mean_syns_per_lexeme,
                          "max_syns_per_lexeme": max_syns_per_lexeme, "num_synsets": num_synsets,
                          "num_syns_full_sorted_list": num_synsets_list,
                          "num_monosemous_lexemes": num_monosemous_lexemes_list,
                          "num_polysemous_lexemes": num_polysemous_lexemes_list,
                          "mean_syns_per_polysemous_lexeme": mean_syns_per_polysemous_lexeme_list}
        utils.dump(BabelNet_stats, results_save_path)
    else:
        BabelNet_stats = utils.load(results_save_path)
    stop = None
    # languages.remove("hy")
    graphs = {}
    for lang in tqdm(languages):
        graph_path = os.path.join(args.exp_dir, "all_edges", lang + ".pkl")
        graphs[lang] = graph_path
    num_edges = {}
    for lang, graph_path in tqdm(graphs.items()):
        # print(lang)
        G = utils.load(graph_path)
        num_edges[lang] = float(len(G))
    lines = []
    data_values = {"num_lexemes": BabelNet_stats["num_lexemes"],
                   "num_synsets": BabelNet_stats["num_synsets"],
                   "num_synset_pairs": num_edges,
                   "mean_syns_per_lexeme": BabelNet_stats["mean_syns_per_lexeme"],
                   "num_monosemous_lexemes": BabelNet_stats["num_monosemous_lexemes"],
                   "num_polysemous_lexemes": BabelNet_stats["num_polysemous_lexemes"],
                   "mean_syns_per_polysemous_lexeme": BabelNet_stats["mean_syns_per_polysemous_lexeme"]}
    table_labels = {"num_lexemes": "Vocabulary size (\# lexemes)",
                   "num_synsets": "\# Synsets",
                   "num_synset_pairs": "\# Colexified Synset Pairs",
                   "mean_syns_per_lexeme": "Mean \# synsets per lexeme",
                   "num_monosemous_lexemes": "\# Monosemous lexemes",
                   "num_polysemous_lexemes": "\# Polysemous lexemes",
                   "mean_syns_per_polysemous_lexeme": "Mean \# synsets per polysemous lexeme"}
    data_types = {"num_lexemes": "int",
                  "num_synsets": "int",
                  "num_synset_pairs": "int",
                  "mean_syns_per_lexeme": "float",
                  "num_monosemous_lexemes": "int",
                  "num_polysemous_lexemes": "int",
                  "mean_syns_per_polysemous_lexeme": "float"}
    rows = ["num_lexemes", "num_synsets", "num_synset_pairs",
            "mean_syns_per_lexeme", "num_polysemous_lexemes", "mean_syns_per_polysemous_lexeme"]

    total_num_lexemes = 0.0
    total_num_synsets = 0.0
    # total_num_synset_pairs = 0.0
    total_num_monosemous_lexemes = 0.0
    total_num_polysemous_lexemes = 0.0
    """Need to get these values before the means!"""
    for lang in languages:
        """"""
        total_num_lexemes += data_values["num_lexemes"][lang]
        total_num_synsets += data_values["num_synsets"][lang]
        # total_num_synset_pairs += data_values["num_synset_pairs"][lang]
        total_num_monosemous_lexemes += data_values["num_monosemous_lexemes"][lang]
        total_num_polysemous_lexemes += data_values["num_polysemous_lexemes"][lang]
    stop = None
    """To get the mean values, we have to weight by the respective number of lexemes or polysemous lexemes"""
    total_mean_syns_per_lexeme = 0.0
    total_mean_syns_per_polysemous_lexeme = 0.0
    for lang in languages:
        total_mean_syns_per_lexeme = total_mean_syns_per_lexeme \
                                     + (data_values["num_lexemes"][lang]/total_num_lexemes)*data_values["mean_syns_per_lexeme"][lang]
        total_mean_syns_per_polysemous_lexeme = total_mean_syns_per_polysemous_lexeme \
                                     + (data_values["num_polysemous_lexemes"][lang] / total_num_polysemous_lexemes) * \
                                     data_values["mean_syns_per_polysemous_lexeme"][lang]
    all_graph = utils.read_edgelist(os.path.join(args.exp_dir, "colex_all.edgelist"))
    total_num_synset_pairs = len(all_graph)
    print("Total number of lexemes: " + str(total_num_lexemes))
    print("Total number of synsets: " + str(total_num_synsets))  # This is wrong!!! You already got the corrent number of 7.2M somewhere else
    print("Total number of synset pairs: " + str(total_num_synset_pairs))
    print("Mean number of synsets per lexeme: " + str(total_mean_syns_per_lexeme))
    # print("Total number of monosemous lexemes: " + str(total_num_monosemous_lexemes))
    print("Total number of polysemous lexemes: " + str(total_num_polysemous_lexemes))
    print("Mean number of synsets per polysemous lexeme: " + str(total_mean_syns_per_polysemous_lexeme))

def get_num_sillage_syns(args):
    """"""
    synsets_FR = utils.load(os.path.join(args.exp_dir, args.lemma_synIDdir, "FR.pkl"))
    synsets_EN = utils.load(os.path.join(args.exp_dir, args.lemma_synIDdir, "EN.pkl"))
    syns_FR = synsets_FR["sillage"]
    if "sillage" in synsets_EN:
        syns_EN = synsets_EN["sillage"]
    else:
        syns_EN = []
    print(syns_FR)
    print(syns_EN)
    print(str(len(syns_FR)) + " synsets for \"sillage\" in FR.")
    print(str(len(syns_EN)) + " synsets for \"sillage\" in EN.")

def save_top_most_colexified_synset_pairs(args):
    edges_path = os.path.join(args.exp_dir, args.edgelist_path)
    save_path = edges_path + "_top100000"
    file1 = open(edges_path, 'r')

    Lines = file1.readlines()
    keep_lines = Lines[0:100000]
    utils.write_file_from_list(keep_lines, save_path)

# def compare_node_count_to_ProNE_embedding_count(args):
#     """"""
#     import nodevectors
#     edgelist = read_edgelist(os.path.join(args.exp_dir, args.edgelist_path))
#     """Collect nodes"""
#     nodes = {}
#     for edge, weight in edgelist.items():
#         node1 = edge.split("_")[0]
#         node2 = edge.split("_")[1]
#         nodes[node1] = ""
#         nodes[node2] = ""
#     node_model = nodevectors.ProNE.load(embed_filepath)
#     stop = None

def main(args):
    """"""
    # graph1_path = "colex_from_AllBabelNet_Concepts/cross_colex_sum.edgelist"
    # id2syn1 = load("colex_from_AllBabelNet_Concepts/id2syn.pkl")
    # graph2_path = "colex_from_BabelNet/cross_colex_sum.edgelist"
    # id2syn2 = load("colex_from_BabelNet/id2syn.pkl")
    # compare_edges(graph1_path, graph2_path, id2syn1, id2syn2)

    # dum = utils.load("colex_from_AllBabelNet/lemmas/ES.pkl")
    #dum_str = dum["hola"]
    #dum_str2 = dum["como"]
    # max_syns = 0
    # max_char = ""
    # lengths = []
    # for key, synstr in tqdm(dum.items()):
    #     num_syns = len(synstr.split("_"))
    #     lengths.append(num_syns)
    #     if num_syns > max_syns:
    #         max_syns = num_syns
    #         max_char = key
    # lengths = sorted(lengths, reverse=True)
    # print(lengths[0:30])
    # print(max_syns)
    # print(max_char)
    if not os.path.exists(args.exp_dir):
        os.mkdir(args.exp_dir)

    if args.stats_type == 'collect_stats':
        collect_stats_on_graph(args)
    elif args.stats_type == 'save_edges_with_glosses':
        save_edges_with_glosses(args)
    elif args.stats_type == 'num_lexemes_per_lang':
        get_num_lexemes_per_lang(args)
    elif args.stats_type == 'lang_inventory_graph_analysis':
        lang_inventory_graph_analysis(args)
    elif args.stats_type == 'cross_ling_edges_per_lang':
        cross_lang_edges_per_lang(args)
    elif args.stats_type == 'OOV_words':
        get_OOV_words_table(args)
    elif args.stats_type == 'OOV_words_checking_numbers':
        get_OOV_words_table_checking_numbers(args)
    elif args.stats_type == 'SIMSCORE_lists':
        simscore_lists(args)
    elif args.stats_type == "UMAP":
        umap_plot(args)
    elif args.stats_type == 'SIMSCORE_lists_synset_graph_analysis':
        simscore_lists_synset_graph_analysis(args)
    elif args.stats_type == 'synset_pairs_limited_langs':
        synset_pairs_limited_langs(args)
    elif args.stats_type == 'synset_pairs_limited_langs_version2':
        synset_pairs_limited_langs_version2(args)
    elif args.stats_type == 'synset_pairs_vs_performance_mono':
        synset_pairs_vs_performance_mono(args)
    elif args.stats_type == 'mean_polysemy_LSIM_words':
        mean_polysemy_LSIM_words(args)
    elif args.stats_type == 'BabelNet_data_stats':
        get_BabelNet_data_stats(args)
    elif args.stats_type == 'BabelNet_data_stats_all_langs':
        get_BabelNet_data_stats_all_langs(args)
    elif args.stats_type == "sorted_lang_list_by_num_lexemes":
        sorted_lang_list_by_num_lexemes(args)
    elif args.stats_type == "sillage_syns":
        get_num_sillage_syns(args)
    # elif args.stats_type == "compare_node_count_to_ProNE_embedding_count":
    #     compare_node_count_to_ProNE_embedding_count(args)
    elif args.stats_type == "most_colexified_synset_pairs":
        save_top_most_colexified_synset_pairs(args)





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--exp_dir', type=str, default='colex_from_AllBabelNet_Concepts')
    parser.add_argument('--synset_text_files_dir', type=str, default='synset_text_files')
    parser.add_argument('--lemma_savedir', type=str, default='lemmas')
    parser.add_argument('--lemma_synIDdir', type=str, default='lemmas_synID')
    parser.add_argument('--word2syn_dir', type=str, default='wordsyns/word2syns_by_lang')
    parser.add_argument('--all_edges_dicts', type=str, default='all_edges')
    parser.add_argument('--get_lemmasyns', type=str2bool, default=False)  # get lemmasyns first (takes a while)
    parser.add_argument('--edgelist_path', type=str, default='colex_lang_count.edgelist_glosses')
    parser.add_argument('--embed_filepath', type=str, default='')  # colex_he.zip
    parser.add_argument('--stats_type', type=str, default='synset_pairs_limited_langs_version2')  # collect_stats, save_edges_with_glosses, num_lexemes_per_lang, lang_inventory_graph_analysis, cross_ling_edges_per_lang, OOV_words
    parser.add_argument('--edge_type', type=str, default='all')  # all, binary
    parser.add_argument('--results_save_path', type=str, default='')  # LangInventoryAnalysis.pkl
    parser.add_argument('--graph_type', type=str, default='cross_colex_binary')  # cross_colex_binary, cross_colex_sum, cross_colex_pairwise_product, colex_sum
    parser.add_argument('--monolingual_lang', type=str, default='en')  # we only do the plain colex graph on one language
    parser.add_argument('--num_task', type=int, default=16)
    parser.add_argument('--edge_threshold', type=int, default=1)  # when collecting edges, remove edges with weight less than this (too many edges)
    parser.add_argument('--DEBUG_COUNT', type=int, default=500000000000000000)  # set low to run the code quickly to check everything works
    parser.add_argument('--cross_lingual_threshold', type=int, default=2)  # edge must occur in at least this many languages to be cross-lingual
    parser.add_argument('--method_word_intersections', type=str, default='colex_all~fasttext~BERT~colex_all_maxsim~colex_mono~ARES~BabelNet')  # colex_all~fasttext~BERT~colex_all_maxsim~colex_mono~ARES~BabelNet
    parser.add_argument('--languages', type=str, default='ar_en_es_fi_fr_he_pl_ru_zh')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--rank_method', type=str, default='average')
    parser.add_argument('--simscore_delimiter', type=str, default=',')
    args = parser.parse_args()
    main(args)