"""
   Big steps to code:
   (1) Get list of languages to check overlap with
   (2) Get all lemmas from each language and store as a list
   (3) Create syn2id and id2syn dictionaries
   (4) Build the graphs with different edge weight conditions
   """
import json
import os
import argparse
import matplotlib.pyplot as plt
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from easynmt import EasyNMT
from sys import getsizeof
from tqdm import tqdm

def get_ignore_lemmas():
    """These lemmas have thousands of synsets and are removed for colexification processing."""
    ignore_lemmas = ["", " ", "список_астероидов", "աստերոիդների_ցանկ"]
    return ignore_lemmas

def get_languages_from_synset_text_files(args):
    lemma_savedir = os.path.join(args.exp_dir, args.lemma_savedir)
    if not os.path.isdir(lemma_savedir):
        os.mkdir(lemma_savedir)
    text_files = collect_files(args.synset_text_files_dir)
    languages = []
    # dump_filename = os.path.join(args.exp_dir, "all_lemmas_dict.pkl")
    # text_files = [os.path.join(args.synset_text_files_dir, '3000000.txt')]
    # text_files = [text_files[0]]  # DEBUGGGG!!!!!!
    dummy_counter = 0
    for file in text_files:
        """"""
        file1 = open(file, 'r')
        Lines = file1.readlines()
        for line in tqdm(Lines):
            line = line.replace("\n", "")
            line_info = line.split("~")  # we have to write some code because unfortunately some lemmas contain '~'
            lang = line_info[0]
            if lang not in languages:
                languages.append(lang)
            if len(languages) >= 500:
                break
        file1.close()
        if len(languages) >= 499:  # seems like there are only 499 languages, yep only 499
            break
    return languages

def get_concept_syns(args):
    if not os.path.exists(os.path.join(args.exp_dir, 'concept_syns.pkl')):
        files = collect_files(args.synset_types_files_dir)
        concept_syns = []
        total_syns = []
        for file in tqdm(files):
            file1 = open(file, 'r')
            Lines = file1.readlines()
            for line in Lines:
                line = line.replace("\n", "")
                syn = line.split("~")[0]
                syn_type = line.split("~")[1].lower()
                if syn_type == 'concept':
                    concept_syns.append(syn)
                total_syns.append(syn)
        print("Total synsets: " + str(len(total_syns)))
        print("Concept synsets: " + str(len(concept_syns)))
        dump(concept_syns, os.path.join(args.exp_dir, 'concept_syns.pkl'))
    else:
        concept_syns = load(os.path.join(args.exp_dir, 'concept_syns.pkl'))
    return concept_syns

def get_lemmasyns(args):
    """The server can't hold the entire dictionary in memory, so you must only save 10 languages at a time and
       load the files many times."""
    lemma_savedir = os.path.join(args.exp_dir, args.lemma_savedir)
    if not os.path.isdir(lemma_savedir):
        os.mkdir(lemma_savedir)
    text_files = collect_files(args.synset_text_files_dir)
    languages = get_languages_from_synset_text_files(args)
    """Check which languages have already been done."""
    completed_langs = [x.upper() for x in get_language_list(args)]
    languages = list(set(languages).difference(set(completed_langs)))
    language_chunks = utils.divide_index(len(languages), 50)
    for chunk in language_chunks:
        temp_languages = [languages[x] for x in chunk]
        lang_lamma_dict = {}
        # dump_filename = os.path.join(args.exp_dir, "all_lemmas_dict.pkl")
        # text_files = [os.path.join(args.synset_text_files_dir, '3000000.txt')]
        # text_files = [text_files[0]]  # DEBUGGGG!!!!!!
        dummy_counter = 0
        for file in text_files:
            """"""
            file1 = open(file, 'r')
            Lines = file1.readlines()
            for line in tqdm(Lines):
                line = line.replace("\n", "")
                line_info = line.split("~")  # we have to write some code because unfortunately some lemmas contain '~'
                lang = line_info[0]
                syn = line_info[1]
                # if len(line_info) > 3:
                #     stop = None
                lemma = "~".join(line_info[2:])
                lemma = lemma.lower()
                # lang, syn, lemma = line.split("~")
                if lemma != "" and lang in temp_languages:
                    if lang not in lang_lamma_dict:
                        lang_lamma_dict[lang] = {lemma: syn}
                    elif lemma not in lang_lamma_dict[lang]:
                        lang_lamma_dict[lang][lemma] = syn  # to save memory we keep list as "_" delimited string!!!
                    elif syn not in lang_lamma_dict[lang][lemma]:
                        lang_lamma_dict[lang][lemma] = lang_lamma_dict[lang][lemma] + "_" + syn
            file1.close()
            dummy_counter += 1
            # if dummy_counter == 5:
            #     """Let's dump the files to check the size isn't too much!!!"""
            #     for lang, lemmasyns in tqdm(lang_lamma_dict.items()):
            #         dump_path = os.path.join(args.exp_dir, args.lemma_savedir, lang + ".pkl")
            #         utils.dump(lemmasyns, dump_path)
            # # print(str(getsizeof(lang_lamma_dict)*1.0 / (1024**3)*1.0) + "GB")
            # print(getsizeof(lang_lamma_dict))
            # utils.dump(lang_lamma_dict, dump_filename)

        for lang, lemmasyns in tqdm(lang_lamma_dict.items()):
            dump_path = os.path.join(args.exp_dir, args.lemma_savedir, lang + ".pkl")
            utils.dump(lemmasyns, dump_path)

def get_language_list(args):
    """"""
    files = collect_files(os.path.join(args.exp_dir, args.lemma_savedir))
    langs = []
    for file in files:
        lang_name = file.split('/')[-1].split('.')[0].lower()
        langs.append(lang_name)
    return langs

def get_syn2id_and_id2syn(args, languages, concept_syns):
    id2syn_path = os.path.join(args.exp_dir, 'id2syn.pkl')
    syn2id_path = os.path.join(args.exp_dir, 'syn2id.pkl')
    if not os.path.exists(id2syn_path) or not os.path.exists(syn2id_path):
        if not os.path.exists(os.path.join(args.exp_dir, "concept_syns.pkl")) or not args.filter_syns_by_concept:
            """Collect all synsets from each language dictionary by iterating through the word2syn files"""
            syns = []
            for lang in languages:
                word2syn_file = os.path.join(args.exp_dir, args.lemma_savedir, lang.upper() + '.pkl')
                word2syn = utils.load(word2syn_file)
                for lemma, synset_str in tqdm(word2syn.items()):
                    synset_list = synset_str.split("_")
                    for syn in synset_list:
                        if args.filter_syns_by_concept:
                            if syn in concept_syns:
                                syns.append(syn)
                        else:
                            syns.append(syn)
                syns = list(set(syns))
        else:
            syns = load(os.path.join(args.exp_dir, "concept_syns.pkl"))
        print(len(syns))
        syns = sorted(syns)
        id2syn = {}
        syn2id = {}
        for i, syn in tqdm(enumerate(syns)):
            id2syn[i] = syn
            syn2id[syn] = i
        dump(id2syn, id2syn_path)
        dump(syn2id, syn2id_path)
    else:
        id2syn = load(id2syn_path)
        syn2id = load(syn2id_path)
    return syn2id, id2syn

def get_lemmas(args, languages, syn2id, id2syn, concept_syns):
    """We already have the original files pairing lemmas with their list of synsets,
       but we do this method so that we keep the structure consistent with wordnet version of code. Here,
       just load the original language json file, and replace the synset keys with their corresponding
       id."""
    concept_syn_dict = {}  # convert list to dummy dictionary for really fast checking if item is in list!!! (insanely slow without this)
    for syn in concept_syns:
        concept_syn_dict[syn] = ""
    lemma_folder = os.path.join(args.exp_dir, args.lemma_synIDdir)
    os.makedirs(lemma_folder, exist_ok=True)
    lemmas = {}
    lem_counter = 0
    for lang in tqdm(languages):
        lemma_path = os.path.join(lemma_folder, lang.upper() + ".pkl")
        if not os.path.exists(lemma_path):
            local_lemmas = {}
            word2syn_file = os.path.join(args.exp_dir, args.lemma_savedir, lang.upper() + '.pkl')
            word2syn = utils.load(word2syn_file)
            lang_lemma_count = len(word2syn)
            for lemma, synset_str in tqdm(word2syn.items()):
                synset_list = synset_str.split("_")
                """Filter synset list for concepts if we choose that option"""
                if args.filter_syns_by_concept:
                    new_list = []
                    for syn_ in synset_list:
                        try:
                            dum = concept_syn_dict[syn_]
                            new_list.append(syn_)
                        except:
                            """"""
                        # if syn_ in concept_syns:
                        #     new_list.append(syn_)
                    synset_list = new_list
                if lemma != "" and len(synset_list) >= 1:  # for some reason empty string leaked through and has over 5000 synsets from BabelNet!!!
                    synIDs = [syn2id[x] for x in synset_list]
                    local_lemmas[lemma] = synIDs
                elif lemma == "":
                    print(len(synset_list))
            print("Total lemmas in " + lang + ": " + str(lang_lemma_count))
            print("Lemmas in " + lang + " with concept synsets: " + str(len(local_lemmas)))
            lemmas[lang] = local_lemmas
            dump(local_lemmas, lemma_path)
        else:
            local_lemmas = load(lemma_path)
            lemmas[lang] = local_lemmas
            lem_counter += 1
            print(lem_counter)
    return lemmas

def get_all_edges_per_lang(args, languages):
    if not os.path.isdir(os.path.join(args.exp_dir, args.all_edges_dicts)):
        os.mkdir(os.path.join(args.exp_dir, args.all_edges_dicts))
    lemma_folder = os.path.join(args.exp_dir, args.lemma_synIDdir)
    ignore_lemmas = get_ignore_lemmas()
    global_max_syns = -1
    for lang in languages:
        lemma_path = os.path.join(lemma_folder, lang.upper() + ".pkl")
        lem_list = load(lemma_path)
        lang_edges = {}
        syn_counter = 0
        two_or_more_counter = 0
        dump_path = os.path.join(args.exp_dir, args.all_edges_dicts, lang + '.pkl')
        max_num_syns = -1
        for lem, syns in tqdm(lem_list.items()):
            if len(syns) >= 2 and lem not in ignore_lemmas:
                if len(syns) > max_num_syns:
                    max_num_syns = len(syns)
                """Get all pairwise combinations of synsets as edges"""
                edges = get_pairwise_edges(syns)
                for edge in edges:
                    """Actually we should just sort the edges based on number and always have that order!"""
                    edge = sorted(edge)
                    possibility1 = str(edge[0]) + '_' + str(edge[1])
                    if possibility1 in lang_edges:
                        lang_edges[possibility1] += 1
                    else:
                        lang_edges[possibility1] = 1
                two_or_more_counter += 1
            syn_counter += 1
        size_in_GB = getsizeof(lang_edges) / (1024**3)
        print(str(size_in_GB) + "GB with " + str(two_or_more_counter) + " lemmas with 2 or more...")
        print(len(lang_edges))
        print(max_num_syns)
        if max_num_syns > global_max_syns:
            global_max_syns = max_num_syns
        utils.dump(lang_edges, dump_path)
    print("Maximum number of synsets in a lexeme for (filtered) data is " + str(global_max_syns) + ".")

def build_cross_colex_binary_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Step 1: Load all edge files. Then check for cross-lingual colexification by collecting edges but only putting
       a 1 to indicate presence of the edge from one particular language. Then filter by edges that don't occur with
       any other language by removing edges with weight of 1."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    """Keep those edges with weights 2 or more"""
    cross_lingual_edges = {}
    max_occurances = 0
    weights = []
    unique_nodes = {}  # basically a list but we want it hashable for speed
    for edge, value in per_lang_binary_edges.items():
        # if value >= 2:  # old code where we allow edge if it occurs in two languages only
        if value >= args.cross_lingual_threshold:
            cross_lingual_edges[edge] = value
            weights.append(value)
            node1 = edge.split("_")[0]
            node2 = edge.split("_")[1]
            unique_nodes[node1] = ""
            unique_nodes[node2] = ""
            if value > max_occurances:
                max_occurances = value
    print(str(len(per_lang_binary_edges)) + " total edges.")
    print(str(len(cross_lingual_edges)) + " cross-lingual edges.")
    print(str(len(unique_nodes)) + " unique nodes in cross-lingual edges.")
    print("Maximum number of languages an edge occured in was " + str(max_occurances) + ".")
    # plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 200])
    # plt.show()
    return cross_lingual_edges

def build_cross_colex_presence_absence_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Build graph"""
    cross_lingual_edges = build_cross_colex_binary_graph(args)
    """Now make all edge weights 1"""
    for edge, weight in cross_lingual_edges.items():
        cross_lingual_edges[edge] = 1
    max_edge_weight = 0
    weights = []
    for edge, weight in cross_lingual_edges.items():
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in cross-colex presence/absence graph is " + str(max_edge_weight) + ".")
    # bins = np.linspace(start=5, stop=300, num=30)
    # # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 50, 150, 300]
    # plt.hist(weights, bins=bins)
    # plt.show()
    return cross_lingual_edges

def filter_lang_edge_files(args, lang_edge_files):
    config = utils.get_config()
    new_lang_edge_files = lang_edge_files
    if args.graph_type == 'cross_colex_sum_9':
        languages = config.eval_languages
        new_lang_edge_files = []
        for file in lang_edge_files:
            lang = file.split('/')[-1].split('.')[0].lower()
            if lang in languages:
                new_lang_edge_files.append(file)
    elif args.graph_type == 'cross_colex_sum_50':
        edge_file_dict = {}
        for file in lang_edge_files:
            lang_name = file.split('/')[-1].split('.')[0]
            edge_file_dict[lang_name] = file
        length_dict = {}
        json_files = collect_files(config.directories.wordsyns)
        """DEBUGGGGG:"""
        # json_files = json_files[0:5]
        for file in tqdm(json_files):
            lang_name = file.split('/')[-1].split('.')[0]
            data = utils.load(file)
            length = len(data)
            datum = {'len': length, 'file': edge_file_dict[lang_name.lower()]}
            length_dict[lang_name] = datum
        """Step 2: Sort by length for each item so the plot shows languages in decreasing order"""
        sorted_dict = dict(sorted(length_dict.items(), key=lambda item: item[1]['len'], reverse=True))
        counter = 0
        new_lang_edge_files = []
        for lang, datum in sorted_dict.items():
            if counter < 50:
                new_lang_edge_files.append(datum['file'])
            counter += 1
    return new_lang_edge_files

def build_cross_colex_sum_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Build graph"""
    cross_lingual_edges = build_cross_colex_binary_graph(args)
    """Now go back through the lang_files and collect the original weights for the cross-lingual edges"""
    full_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in cross_lingual_edges and lang_edges[edge] >= args.edge_threshold:
                if edge in full_edges:
                    full_edges[edge] += lang_edges[edge]
                else:
                    full_edges[edge] = lang_edges[edge]
    max_edge_weight = 0
    weights = []
    for edge, weight in full_edges.items():
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    num_remaining_edges = len(full_edges)
    print(str(num_remaining_edges) + " remaining edges after filtering insignificant edges.")
    print("Maximum edge weight in cross-colex sum graph is " + str(max_edge_weight) + ".")
    bins = np.linspace(start=5, stop=300, num=30)
    # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 50, 150, 300]
    plt.hist(weights, bins=bins)
    plt.show()
    return full_edges

def random_sample_langs(args, lang_edge_files):
    # random.shuffle(lang_edge_files)
    # lang_intervals = [100, 200, 300, 400, 500]
    # file_dicts = {}
    # for interval in lang_intervals:
    #     file_dicts[interval] = lang_edge_files[0:interval]
    # """Now let's check the number of lemmas for each interval and see if it increases approximately linearly"""
    # total_lemmas = 0
    # for counter, file in tqdm(enumerate(lang_edge_files)):
    #     lang = file.split("/")[-1].split(".")[0]
    #     lemma_synID_file = os.path.join(args.exp_dir, "lemmas_synID", lang.upper() + ".pkl")
    #     lemmas = load(lemma_synID_file)
    #     num_lemmas = len(lemmas)
    #     total_lemmas += num_lemmas
    #     if counter in lang_intervals:
    #         print(str(counter) + ": total lemmas is " + str(total_lemmas))
    """NO LONGER RANDOM, SORT FROM LARGEST TO SMALLEST IN TERMS OF LEXEMES!!!"""
    lang_edge_files_list_path = os.path.join(args.exp_dir, 'lang_limitation_edge_files_list.pkl')
    if not os.path.exists(lang_edge_files_list_path):
        random.shuffle(lang_edge_files)
        dump(lang_edge_files, lang_edge_files_list_path)
    else:
        lang_edge_files = load(lang_edge_files_list_path)
    # lang_intervals = [5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500]
    # lang_intervals = [100, 200]
    lang_intervals = [5, 10, 20, 50, 100, 200]
    # file_dicts = {}
    # for interval in lang_intervals:
    #     file_dicts[interval] = lang_edge_files[0:interval]
    """Now let's check the number of lemmas for each interval and see if it increases approximately linearly"""
    lang_edge_files_lexeme_counts_path = os.path.join(args.exp_dir, 'lang_limitation_edge_files_lexeme_counts.pkl')
    if not os.path.exists(lang_edge_files_lexeme_counts_path):
        total_lemmas = 0
        counter = 0
        lang_edge_files_lexeme_counts = {}
        for file in tqdm(lang_edge_files):
            counter += 1
            lang = file.split("/")[-1].split(".")[0]
            lemma_synID_file = os.path.join(args.exp_dir, "lemmas_synID", lang.upper() + ".pkl")
            lemmas = load(lemma_synID_file)
            remove_lemmas = get_ignore_lemmas()
            for lem in remove_lemmas:
                if lem in lemmas:
                    del lemmas[lem]
                    print("Deleted " + lem + " from lemmas...")
            num_lemmas = len(lemmas)
            total_lemmas += num_lemmas
            lang_edge_files_lexeme_counts[file] = num_lemmas
            if counter in lang_intervals:
                print(str(counter) + ": total lemmas is " + str(total_lemmas))
        print(str(counter) + ": total lemmas is " + str(total_lemmas))
        """Sort from largest to smallest in terms of lexemes!"""
        lang_edge_files_lexeme_counts = {k: v for k, v in sorted(lang_edge_files_lexeme_counts.items(), key=lambda item: item[1], reverse=True)}
        lang_edge_files_lexeme_counts_path = os.path.join(args.exp_dir, 'lang_limitation_edge_files_lexeme_counts.pkl')
        dump(lang_edge_files_lexeme_counts, lang_edge_files_lexeme_counts_path)
    else:
        lang_edge_files_lexeme_counts = load(lang_edge_files_lexeme_counts_path)
    lang_edge_files = []
    for key, value in lang_edge_files_lexeme_counts.items():
        lang_edge_files.append(key)
    file_dicts = {}
    for interval in lang_intervals:
        file_dicts[interval] = lang_edge_files[0:interval]
    return file_dicts

def build_cross_colex_sum_limited_langs_graph(args):
    """"""
    lang_edge_files_ = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_file_dicts = random_sample_langs(args, lang_edge_files_)
    # lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    original_edgelist_savename = args.edgelist_savename
    for level, files in lang_edge_file_dicts.items():
        args.edgelist_savename = original_edgelist_savename + "_" + str(level)
        """Build graph"""
        cross_lingual_edges = build_cross_colex_binary_graph(args)
        """Now go back through the lang_files and collect the original weights for the cross-lingual edges"""
        full_edges = {}
        lang_edge_files = lang_edge_file_dicts[level]
        for lang_file in tqdm(lang_edge_files):
            lang_edges = utils.load(lang_file)
            for edge in lang_edges:
                if edge in cross_lingual_edges and lang_edges[edge] >= args.edge_threshold:
                    if edge in full_edges:
                        full_edges[edge] += lang_edges[edge]
                    else:
                        full_edges[edge] = lang_edges[edge]
        max_edge_weight = 0
        weights = []
        for edge, weight in full_edges.items():
            weights.append(weight)
            if weight > max_edge_weight:
                max_edge_weight = weight
        num_remaining_edges = len(full_edges)
        print(str(num_remaining_edges) + " remaining edges after filtering insignificant edges.")
        print("Maximum edge weight in cross-colex sum graph is " + str(max_edge_weight) + ".")
        bins = np.linspace(start=5, stop=300, num=30)
        # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 50, 150, 300]
        plt.hist(weights, bins=bins)
        plt.show()
        write_edgelist_dict_to_edgelist(args, full_edges)
        # return full_edges

def build_cross_colex_binary_limited_langs_graph(args):
    """"""
    lang_edge_files_ = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_file_dicts = random_sample_langs(args, lang_edge_files_)
    # lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    original_edgelist_savename = args.edgelist_savename
    for level, files in lang_edge_file_dicts.items():
        args.edgelist_savename = original_edgelist_savename + "_" + str(level)
        """Build graph"""
        cross_lingual_edges = build_cross_colex_binary_graph(args)
        """Now go back through the lang_files and collect the BINARY weights for the cross-lingual edges
           (this is long code just because we keep the same outline as sum and need to limit the languages)."""
        full_edges = {}
        lang_edge_files = lang_edge_file_dicts[level]
        for lang_file in tqdm(lang_edge_files):
            lang_edges = utils.load(lang_file)
            for edge in lang_edges:
                if edge in cross_lingual_edges and lang_edges[edge] >= args.edge_threshold:
                    if edge in full_edges:
                        full_edges[edge] += 1
                    else:
                        full_edges[edge] = 1
        max_edge_weight = 0
        weights = []
        for edge, weight in full_edges.items():
            weights.append(weight)
            if weight > max_edge_weight:
                max_edge_weight = weight
        num_remaining_edges = len(full_edges)
        print(str(num_remaining_edges) + " remaining edges after filtering insignificant edges.")
        print("Maximum edge weight in cross-colex binary graph is " + str(max_edge_weight) + ".")
        bins = np.linspace(start=5, stop=300, num=30)
        # bins = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 50, 150, 300]
        plt.hist(weights, bins=bins)
        plt.show()
        write_edgelist_dict_to_edgelist(args, full_edges)
        # return full_edges

def build_cross_colex_pairwise_product_graph(args):
    languages = get_language_list(args)
    lang_pairs = get_pairwise_edges(languages)
    random.shuffle(lang_pairs)
    """Load all the language-wise graphs and store in dictionary"""
    lang_graphs = {}
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    print("Loading individual language graphs...")
    for file in tqdm(lang_edge_files):
        local_lang = file.split("/")[-1].split(".")[0]
        lang_graphs[local_lang] = utils.load(file)
    print("Done.")
    pairwise_product_graph = {}
    for pair in tqdm(lang_pairs):
        lang0_graph = lang_graphs[pair[0]]
        lang1_graph = lang_graphs[pair[1]]
        for edge, count0 in lang0_graph.items():
            if edge in lang1_graph:
                count1 = lang1_graph[edge]
                product = count0 * count1
                if edge in pairwise_product_graph:
                    pairwise_product_graph[edge] += product
                else:
                    pairwise_product_graph[edge] = product
    max_edge_weight = 0
    weights = []
    for edge, weight in pairwise_product_graph.items():
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in cross-colex pairwise product graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 1000, 2000])
    plt.show()
    return pairwise_product_graph

def build_colex_sum_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    edge_file = None
    for file in lang_edge_files:
        lang_name = file.split('/')[-1].split('.')[0]
        if lang_name == args.monolingual_lang:
            edge_file = file
            break
    """Build graph"""
    edges = {}
    lang_edges = utils.load(edge_file)
    edges = lang_edges
    # for edge, weight in tqdm(lang_edges.items()):
    #     if edge in edges:
    #         edges[edge] += weight
    #     else:
    #         edges[edge] = weight
    max_edge_weight = 0
    weights = []
    for edge, weight in tqdm(edges.items()):
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in colex sum graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 30, 50])
    plt.show()
    return edges

def build_colex_binary_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    edge_file = None
    for file in lang_edge_files:
        lang_name = file.split('/')[-1].split('.')[0]
        if lang_name == args.monolingual_lang:
            edge_file = file
            break
    """Build graph"""
    lang_edges = utils.load(edge_file)
    edges = lang_edges
    for edge, weight in tqdm(lang_edges.items()):
        edges[edge] = 1

    max_edge_weight = 0
    weights = []
    for edge, weight in tqdm(edges.items()):
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Maximum edge weight in colex binary graph is " + str(max_edge_weight) + ".")
    plt.hist(weights, bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 30, 50])
    plt.show()
    return edges

def build_colex_filtered_binary_graph(args):
    """Get cross-lingual edges first"""
    cross_lingual_edges = build_cross_colex_binary_graph(args)
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    edge_file = None
    for file in lang_edge_files:
        lang_name = file.split('/')[-1].split('.')[0]
        if lang_name == args.monolingual_lang:
            edge_file = file
            break
    """Build graph"""
    lang_edges = utils.load(edge_file)
    edges = {}
    for edge, weight in tqdm(lang_edges.items()):
        if edge in cross_lingual_edges:
            edges[edge] = 1

    max_edge_weight = 0
    weights = []
    for edge, weight in tqdm(edges.items()):
        weights.append(weight)
        if weight > max_edge_weight:
            max_edge_weight = weight
    print("Number of edges in full colexification graph for " + args.monolingual_lang + ": " + str(len(lang_edges)))
    print("Number of edges in filtered colexification graph for " + args.monolingual_lang + ": " + str(len(edges)))
    print("Maximum edge weight in colex binary graph is " + str(max_edge_weight) + ".")
    # plt.hist(weights, bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 30, 50])
    # plt.show()
    return edges

def build_colex_binary_all_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    # """Depending on configuration, we may use subset of languages."""
    # lang_edge_files = filter_lang_edge_files(args, lang_edge_files)
    """Step 1: Load all edge files. Then DO NOT check for cross-lingual colexification. DO NOT filter by edges
     that don't occur with any other language.  WE KEEP ALL EDGES ACROSS ALL LANGUAGES."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    """Keep those edges with weights 2 or more"""
    cross_lingual_edges = {}
    max_occurances = 0
    weights = []
    unique_nodes = {}  # basically a list but we want it hashable for speed
    for edge, value in per_lang_binary_edges.items():
        # if value >= 2:  # old code where we allow edge if it occurs in two languages only
        if value >= args.cross_lingual_threshold:
            cross_lingual_edges[edge] = value
            weights.append(value)
            node1 = edge.split("_")[0]
            node2 = edge.split("_")[1]
            unique_nodes[node1] = ""
            unique_nodes[node2] = ""
            if value > max_occurances:
                max_occurances = value
    print(str(len(per_lang_binary_edges)) + " total edges.")
    print(str(len(cross_lingual_edges)) + " cross-lingual edges.")
    print(str(len(unique_nodes)) + " unique nodes in cross-lingual edges.")
    print("Maximum number of languages an edge occured in was " + str(max_occurances) + ".")
    # plt.hist(weights, bins=[0, 5, 10, 15, 30, 100, 200])
    # plt.show()
    # return cross_lingual_edges
    return per_lang_binary_edges  # KEEP ALL EDGES!!!

def build_colex_all_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Step 1: Load all edge files. Then DO NOT check for cross-lingual colexification. DO NOT filter by edges
     that don't occur with any other language.  WE KEEP ALL EDGES ACROSS ALL LANGUAGES. Set ALL weights to 1."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    print(str(len(per_lang_binary_edges)) + " total edges.")
    """Set all weights to 1."""
    for edge, weight in tqdm(per_lang_binary_edges.items()):
        per_lang_binary_edges[edge] = 1
    return per_lang_binary_edges

def build_colex_mono_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    edge_file = None
    for file in lang_edge_files:
        lang_name = file.split('/')[-1].split('.')[0]
        if lang_name == args.monolingual_lang:
            edge_file = file
            break
    """Build graph"""
    edges = utils.load(edge_file)
    for edge, weight in edges.items():
        edges[edge] = 1
    return edges

def build_colex_limited_langs_graph(args):
    """"""
    lang_edge_files_ = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_file_dicts = random_sample_langs(args, lang_edge_files_)
    original_edgelist_savename = args.edgelist_savename
    for level, files in lang_edge_file_dicts.items():
        args.edgelist_savename = original_edgelist_savename + "_" + str(level)
        per_lang_binary_edges = {}
        for lang_file in tqdm(files):
            lang_edges = utils.load(lang_file)
            for edge in lang_edges:
                if edge in per_lang_binary_edges:
                    per_lang_binary_edges[edge] += 1
                else:
                    per_lang_binary_edges[edge] = 1
        print(str(len(per_lang_binary_edges)) + " total edges.")
        """Set all weights to 1."""
        for edge, weight in tqdm(per_lang_binary_edges.items()):
            per_lang_binary_edges[edge] = 1
        full_edges = per_lang_binary_edges
        write_edgelist_dict_to_edgelist(args, full_edges)
        # return full_edges

def random_sample_langs_version2(args, lang_edge_files):
    """NO LONGER RANDOM, SORT FROM LARGEST TO SMALLEST IN TERMS OF LEXEMES!!!"""
    config = get_config()
    # BabelNet_stats_file = utils.load(os.path.join(config.directories.results, 'BabelNet_data_stats_all_langs.pkl'))
    """Collect the number of synset pairs (edges) in each language"""
    all_edges_files = utils.collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    languages = []
    edge_file_dict = {}
    for edge_file in all_edges_files:
        lang = edge_file.split("/")[-1].split(".")[0]
        languages.append(lang)
        edge_file_dict[lang] = edge_file
    num_edges = {}
    for lang, edge_file in tqdm(edge_file_dict.items()):
        G = utils.load(edge_file)
        num_edges[lang] = len(G)
    num_edges = {k: v for k, v in sorted(num_edges.items(), key=lambda item: item[1], reverse=True)}
    eval_languages = config.eval_languages
    for eval_lang in eval_languages:
        del num_edges[eval_lang]
    lang_intervals = [9, 20, 50, 100, 200]
    """First put the 9 evaluation languages, then sort the remaining ones by number of colexified synset pairs."""
    lang_edge_files = []
    for eval_lang in eval_languages:
        lang_edge_files.append(edge_file_dict[eval_lang])
    for lang, n_edges in num_edges.items():
        lang_edge_files.append(edge_file_dict[lang])

    file_dicts = {}
    for interval in lang_intervals:
        file_dicts[interval] = lang_edge_files[0:interval]
    return file_dicts

def build_colex_limited_langs_graph_version2(args):
    """Sort by eval languages first, then by number of colexified synset pairs in language"""
    lang_edge_files_ = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Depending on configuration, we may use subset of languages."""
    lang_edge_file_dicts = random_sample_langs_version2(args, lang_edge_files_)
    original_edgelist_savename = args.edgelist_savename
    for level, files in lang_edge_file_dicts.items():
        args.edgelist_savename = original_edgelist_savename + "_" + str(level)
        per_lang_binary_edges = {}
        for lang_file in tqdm(files):
            lang_edges = utils.load(lang_file)
            for edge in lang_edges:
                if edge in per_lang_binary_edges:
                    per_lang_binary_edges[edge] += 1
                else:
                    per_lang_binary_edges[edge] = 1
        print(str(len(per_lang_binary_edges)) + " total edges.")
        """Set all weights to 1."""
        for edge, weight in tqdm(per_lang_binary_edges.items()):
            per_lang_binary_edges[edge] = 1
        full_edges = per_lang_binary_edges
        write_edgelist_dict_to_edgelist(args, full_edges)
        # return full_edges

def build_colex_lang_count_graph(args):
    """"""
    lang_edge_files = collect_files(os.path.join(args.exp_dir, args.all_edges_dicts))
    """Step 1: Load all edge files. Then DO NOT check for cross-lingual colexification. DO NOT filter by edges
     that don't occur with any other language.  WE KEEP ALL EDGES ACROSS ALL LANGUAGES. Set ALL weights to 1."""
    per_lang_binary_edges = {}
    for lang_file in tqdm(lang_edge_files):
        lang_edges = utils.load(lang_file)
        for edge in lang_edges:
            if edge in per_lang_binary_edges:
                per_lang_binary_edges[edge] += 1
            else:
                per_lang_binary_edges[edge] = 1
    print(str(len(per_lang_binary_edges)) + " total edges.")
    # """Set all weights to 1."""
    # for edge, weight in tqdm(per_lang_binary_edges.items()):
    #     per_lang_binary_edges[edge] = 1
    return per_lang_binary_edges

def write_edgelist_dict_to_edgelist(args, edge_dict):
    edgelist = []
    dump_path = os.path.join(args.exp_dir, args.edgelist_savename)
    for edgekey, weight in edge_dict.items():
        line = str(edgekey.split('_')[0]) + ' ' + str(edgekey.split('_')[1]) + ' ' + str(weight)
        edgelist.append(line)
    utils.write_file_from_list(edgelist, dump_path)

def synset_list_to_string_list(syns):
    strings = []
    for syn in syns:
        strings.append(syn._name)
    return strings

def get_pairwise_edges(overlap):
    """"""
    edges = []
    for i, node1 in enumerate(overlap):
        remaining_nodes = overlap[i + 1:]
        for node2 in remaining_nodes:
            edges.append([node1, node2])
    return edges

def get_word_embeddings(args, wordlist, lang, embed_filepath, weight_type='freq'):
    # args = edict({'exp_dir': 'colex_from_BabelNet_combined',
    #               'word2syn_dir': 'word2syns_by_lang_combined'})
    languages = get_language_list(args)
    concept_syns = get_concept_syns(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages, concept_syns)
    languages = [lang]
    # lemmas = get_lemmas(args, languages, syn2id, id2syn, concept_syns)

    """Using wordsyns instead of lemmas due to better coverage (BabelNet query is better than strict string matching)"""
    config = get_config()
    word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
    with open(word2syn_file) as json_file:
        word2syn = json.load(json_file)

    file1 = open(embed_filepath, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    for i, line in tqdm(enumerate(Lines)):
        if i > 0:
            pieces = line.split(' ')
            id = int(pieces[0])
            # synID = id2syn[id]
            embedding = np.asarray([float(x) for x in pieces[1:]])
            syn_embs[id] = embedding
    """For each word, we want to grab the syn embeddings and inversely weight them"""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            # synsets = lemmas[lang][word]
            synsets = word2syn[word]
            # synsets = [syn2id[x] for x in synsets]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []  # reset to something
            # if word == 'monarchy' or word == 'roof' or word == 'emporer':
            #     stop = None
            for i, syn in enumerate(synsets):
                """Not all senses have embeddings after node2vec training (too infrequent)"""
                if weight_type == 'inverse':
                    try:
                        # temp_emb = syn_embs[syn]
                        syn_id = syn2id[syn]
                        temp_emb = syn_embs[syn_id]
                        if len(word_emb) == 0:
                        # if i == 0:
                            word_emb = temp_emb
                        else:
                            # word_emb += (1/(i+1)**0.0005) * temp_emb
                            word_emb += temp_emb
                    except:
                        """Skip this sense"""

            if len(word_emb) >= 1:
                """Normalize the embedding"""
                word_emb = word_emb / np.linalg.norm(word_emb)
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    # print(len(word_embs))
    # print(len(wordlist))
    return word_embs

def get_word_embeddings_nodevectors(args, wordlist, lang, embed_filepath, weight_type='freq', model_type="ProNE"):
    # args = edict({'exp_dir': 'colex_from_BabelNet_combined',
    #               'word2syn_dir': 'word2syns_by_lang_combined'})
    import nodevectors
    node_model = None
    if model_type == 'ProNE':
        node_model = nodevectors.ProNE.load(embed_filepath)
    elif model_type == 'GGVec':
        node_model = nodevectors.GGVec.load(embed_filepath)
    elif model_type == 'Node2Vec':
        node_model = nodevectors.Node2Vec.load(embed_filepath)
    # dum = prone_model.predict(285941)

    languages = get_language_list(args)
    concept_syns = get_concept_syns(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages, concept_syns)
    languages = [lang]
    # lemmas = get_lemmas(args, languages, syn2id, id2syn, concept_syns)
    # file1 = open(embed_filepath, 'r')
    # Lines = file1.readlines()

    """Using wordsyns instead of lemmas due to better coverage (BabelNet query is better than strict string matching)"""
    config = get_config()
    word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
    with open(word2syn_file) as json_file:
        word2syn = json.load(json_file)

    """For each word, we want to grab the syn embeddings and inversely weight them"""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            # synsets = lemmas[lang][word]
            synsets = word2syn[word]
            # synsets = [syn2id[x] for x in synsets]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []  # reset to something
            # if word == 'monarchy' or word == 'roof' or word == 'emporer':
            #     stop = None
            for i, syn in enumerate(synsets):
                """Not all words are in BabelNet."""
                if weight_type == 'inverse':
                    try:
                        syn_id = syn2id[syn]
                        temp_emb = node_model.predict(syn_id)
                        # temp_emb = syn_embs[syn]
                        if len(word_emb) == 0:
                        # if i == 0:
                            word_emb = temp_emb
                        else:
                            # word_emb += (1/(i+1)**0.0005) * temp_emb
                            word_emb += temp_emb
                    except:
                        """Skip this sense"""

            if len(word_emb) >= 1:
                """Normalize the embedding"""
                word_emb = word_emb / np.linalg.norm(word_emb)
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    # print(len(word_embs))
    # print(len(wordlist))
    return word_embs

def get_word_embeddings_synset_nodevectors(args, wordlist, lang, embed_filepath, weight_type='freq', syn_PCA=False, model_type="ProNE", original_graph=""):
    import nodevectors
    node_model = None
    print("Loading nodevectors model...")
    if model_type == 'ProNE':
        node_model = nodevectors.ProNE.load(embed_filepath)
    elif model_type == 'GGVec':
        node_model = nodevectors.GGVec.load(embed_filepath)
    elif model_type == 'Node2Vec':
        node_model = nodevectors.Node2Vec.load(embed_filepath)
    print("Done.")
    languages = get_language_list(args)
    concept_syns = get_concept_syns(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages, concept_syns)

    """Using wordsyns instead of lemmas due to better coverage (BabelNet query is better than strict string matching)"""
    config = get_config()
    word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
    with open(word2syn_file) as json_file:
        word2syn = json.load(json_file)

    """Need to load the original graph and get all nodes, then collect all the synset embeddings from the ProNE model."""
    file1 = open(original_graph, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    nodes = {}
    print("Extracting node embeddings from nodevectors model so we can perform PCA...")
    for i, line in tqdm(enumerate(Lines)):
        line = line.replace("\n", "")
        pieces = line.split(" ")
        node1 = int(pieces[0])
        node2 = int(pieces[1])
        nodes[node1] = ""
        nodes[node2] = ""
    for node, _ in nodes.items():
        embedding = node_model.predict(node)
        syn_embs[node] = embedding
    print("Done.")

    if syn_PCA:
        from LSIM import get_PCA
        print('Getting PCA on sense embeddings...')
        syn_embs = get_PCA(syn_embs)
        print('Done.')

    """For each word, we want to grab the syn embeddings and hold onto all of them."""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            # synsets = lemmas[lang][word]
            synsets = word2syn[word]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []
            for i, syn in enumerate(synsets):
                """Not all senses have embeddings after node2vec training (too infrequent)"""
                if weight_type == 'inverse':
                    try:
                        syn_id = syn2id[syn]
                        temp_emb = syn_embs[syn_id]
                        # temp_emb = node_model.predict(syn_id)
                        # temp_emb = syn_embs[syn]
                        """Keep list of all the sense embeddings"""
                        word_emb.append(temp_emb)
                    except:
                        """Skip this sense"""
            if len(word_emb) >= 1:
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    return word_embs

def get_word_embeddings_synset(args, wordlist, lang, embed_filepath, weight_type='freq', syn_PCA=False):
    languages = get_language_list(args)
    concept_syns = get_concept_syns(args)
    syn2id, id2syn = get_syn2id_and_id2syn(args, languages, concept_syns)
    # syn2id, id2syn = get_syn2id_and_id2syn(args, languages)
    # lemmas = get_lemmas(args, languages, syn2id, id2syn)
    # lemmas = get_lemmas(args, [lang], syn2id, id2syn, concept_syns)

    """Using wordsyns instead of lemmas due to better coverage (BabelNet query is better than strict string matching)"""
    config = get_config()
    word2syn_file = os.path.join(config.directories.wordsyns, lang.upper() + ".json")
    with open(word2syn_file) as json_file:
        word2syn = json.load(json_file)


    file1 = open(embed_filepath, 'r')
    Lines = file1.readlines()
    syn_embs = {}
    for i, line in tqdm(enumerate(Lines)):
        if i > 0:
            pieces = line.split(' ')
            id = int(pieces[0])
            # synID = id2syn[id]
            embedding = np.asarray([float(x) for x in pieces[1:]])
            syn_embs[id] = embedding
    if syn_PCA:
        from LSIM import get_PCA
        print('Getting PCA on sense embeddings...')
        syn_embs = get_PCA(syn_embs)
        print('Done.')

    """For each word, we want to grab the syn embeddings and hold onto all of them."""
    word_embs = {}
    for word in tqdm(wordlist):
        try:
            # synsets = lemmas[lang][word]
            synsets = word2syn[word]
            word_exists = True
        except:
            word_exists = False
            synsets = None
        """Get the embedding"""
        if word_exists:
            word_emb = []
            for i, syn in enumerate(synsets):
                """Not all senses have embeddings after node2vec training (too infrequent)"""
                if weight_type == 'inverse':
                    try:
                        syn_id = syn2id[syn]
                        temp_emb = syn_embs[syn_id]
                        # temp_emb = syn_embs[syn]
                        """Keep list of all the sense embeddings"""
                        word_emb.append(temp_emb)
                    except:
                        """Skip this sense"""
            if len(word_emb) >= 1:
                word_embs[word] = word_emb
            else:
                word_embs[word] = None
        else:
            word_embs[word] = None
    return word_embs

def main(args):
    """"""
    # dum = utils.load("colex_from_AllBabelNet/lemmas/ES.pkl")
    #dum_str = dum["hola"]
    #dum_str2 = dum["como"]
    # max_syns = 0
    # max_char = ""
    # lengths = []
    # for key, synstr in tqdm(dum.items()):
    #     num_syns = len(synstr.split("_"))
    #     lengths.append(num_syns)
    #     if num_syns > max_syns:
    #         max_syns = num_syns
    #         max_char = key
    # lengths = sorted(lengths, reverse=True)
    # print(lengths[0:30])
    # print(max_syns)
    # print(max_char)
    if not os.path.exists(args.exp_dir):
        os.mkdir(args.exp_dir)
    if args.get_lemmasyns:
        """Step 0: Create lemmasyns from BabelNet text files"""
        get_lemmasyns(args)
    else:
        if args.compute_edges:
            concept_syns = get_concept_syns(args)
            """ Step 1: Get list of languages to check overlap with"""
            languages = get_language_list(args)
            # languages = ['es']  # DEBUGGGG!!!
            """ Step 2: Create syn2id and id2syn dictionaries"""
            syn2id, id2syn = get_syn2id_and_id2syn(args, languages, concept_syns)

            if args.specific_languages != "":
                languages = args.specific_languages
                languages = languages.split("_")

            """ Step 3: Get all lemmas from each language and store as a dictionary with their synset ids"""
            if args.get_lemmas:
                lemmas = get_lemmas(args, languages, syn2id, id2syn, concept_syns)
            """ Step 4: Get edges"""
            get_all_edges_per_lang(args, languages)
        else:
            # build_cross_colex_binary_graph(args)
            if args.graph_type == 'cross_colex_binary':
                all_edges = build_cross_colex_binary_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_presence_absence':
                all_edges = build_cross_colex_presence_absence_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_sum':
                all_edges = build_cross_colex_sum_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_pairwise_product':
                all_edges = build_cross_colex_pairwise_product_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_sum_9':
                all_edges = build_cross_colex_sum_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_sum_50':
                all_edges = build_cross_colex_sum_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_sum':
                all_edges = build_colex_sum_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_sum_limited_langs':
                all_edges = build_cross_colex_sum_limited_langs_graph(args)
                # We do all five thresholds at once so write the edgelists in that function.
                # write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'cross_colex_binary_limited_langs':
                all_edges = build_cross_colex_binary_limited_langs_graph(args)
            elif args.graph_type == 'colex_binary':
                all_edges = build_colex_binary_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_binary_all':
                all_edges = build_colex_binary_all_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_filtered_binary':
                all_edges = build_colex_filtered_binary_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)

            elif args.graph_type == 'colex_all':
                all_edges = build_colex_all_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_mono':
                all_edges = build_colex_mono_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_limited_langs':
                build_colex_limited_langs_graph(args)
                # write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_lang_count':
                all_edges = build_colex_lang_count_graph(args)
                write_edgelist_dict_to_edgelist(args, all_edges)
            elif args.graph_type == 'colex_limited_langs_version2':
                build_colex_limited_langs_graph_version2(args)





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to evaluate on LSIM task')
    parser.add_argument('--exp_dir', type=str, default='colex_from_AllBabelNet_Concepts')
    parser.add_argument('--specific_languages', type=str, default='')  # just for fixing problems with specific langs, ru_hy had issues
    parser.add_argument('--cross_lingual_threshold', type=int, default=2)  # edge must occur in at least this many languages to be cross-lingual
    parser.add_argument('--filter_syns_by_concept', type=str2bool, default=True)  # only keep synsets that are concepts
    parser.add_argument('--synset_text_files_dir', type=str, default='synset_text_files')
    parser.add_argument('--synset_types_files_dir', type=str, default='BabelNet_Synset_Types')
    parser.add_argument('--lemma_savedir', type=str, default='lemmas')
    parser.add_argument('--lemma_synIDdir', type=str, default='lemmas_synID')
    parser.add_argument('--word2syn_dir', type=str, default='wordsyns/word2syns_by_lang')
    parser.add_argument('--all_edges_dicts', type=str, default='all_edges')
    parser.add_argument('--get_lemmasyns', type=str2bool, default=False)  # get lemmasyns first (takes a while)
    parser.add_argument('--get_lemmas', type=str2bool, default=False)  # get lemmas (False after doing it once)
    parser.add_argument('--compute_edges', type=str2bool, default=False)  # set to False if you've run that part already, very fast now
    parser.add_argument('--edgelist_savename', type=str, default='colex_limited_langs.edgelist')
    parser.add_argument('--edge_type', type=str, default='all')  # all, binary
    parser.add_argument('--graph_type', type=str, default='colex_limited_langs_version2')  # cross_colex_binary, cross_colex_sum, cross_colex_pairwise_product, colex_sum, colex_binary, colex_binary_all, cross_colex_presence_absence, colex_filtered_binary
    parser.add_argument('--monolingual_lang', type=str, default='zh')  # we only do the plain colex graph on one language
    parser.add_argument('--num_task', type=int, default=16)
    parser.add_argument('--edge_threshold', type=int, default=1)  # when collecting edges, remove edges with weight less than this (too many edges)
    parser.add_argument('--DEBUG_COUNT', type=int, default=500000000000000000)  # set low to run the code quickly to check everything works
    args = parser.parse_args()
    main(args)