import os

import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm

import utils
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('pdf', fonttype=42)

def bar_plot_for_BabelNet_wordcount():
    SMALL_SIZE = 7  # was 2.5
    matplotlib.rc('font', size=SMALL_SIZE)
    plt.rcParams["figure.figsize"] = (8, 3)
    config = utils.get_config()
    """We want to plot the number of word entries from our collated lists in BabelNet"""
    """Step 1: For each language, load the .json files we get from passing our wordlists through BabelNet."""
    files = collect_files(config.directories.wordsyns)
    length_dict = {}
    """DEBUGGGGG:"""
    # files = files[0:5]
    for file in tqdm(files):
        lang_name = file.split('/')[-1].split('.')[0]
        data = utils.load(file)
        length = len(data)
        length_dict[lang_name] = length
    """Step 2: Sort by length for each item so the plot shows languages in decreasing order"""
    sorted_dict = dict(sorted(length_dict.items(), key=lambda item: item[1], reverse=True))
    "Make lists for plotting."
    languages = [key for key, value in sorted_dict.items()]
    lengths = [value for key, value in sorted_dict.items()]

    # create dataset
    # height = [3, 12, 5, 18, 45]
    # bars = ('A', 'B', 'C', 'D', 'E')
    height = lengths
    bars = languages
    # Choose the width of each bar and their positions
    # width = [0.1, 0.2, 3, 1.5, 0.3]
    # x_pos = [0, 0.3, 2, 4.5, 5.5]
    width = [0.2 for x in range(len(sorted_dict))]
    x_pos = list(np.linspace(start=0, stop=40, num=len(sorted_dict)))

    # Make the plot
    plt.bar(x_pos, height, width=width)

    # Create names on the x-axis
    plt.xticks(x_pos, bars, rotation=90, size=2.5)
    plt.yscale('log', basey=10)

    # Show graphic
    # plt.show()
    plt.savefig(os.path.join(config.directories.plots, "BabelNet_bar_plot.pdf"))

def get_plot_color_codes():
    # color_codes = {'ar': '#ff5733',
    #                'en': '#a4a01e',
    #                'es': '#88530b',
    #                'fi': '#1d58f4',
    #                'fr': '#000000',
    #                'he': '#1d58f4',
    #                'pl': '#ff5733',
    #                'ru': '#9a9ab6',
    #                'zh': '#a4a01e',
    #                }
    color_codes = {'ar': '#ff5733',
                   'en': '#a4a01e',
                   'es': '#41a41e',
                   'fi': '#53a588',
                   'fr': '#2a7158',
                   'he': '#17c9c1',
                   'pl': '#000000',
                   'ru': '#5a17c9',
                   'zh': '#8b6fb9',
                   'syn_overlap': '#ff5733',
                   'edge_diffs': '#a4a01e',
                   'perc_overlap_syn_per_word': '#41a41e',
                   }
    return color_codes

def get_plot_symbols():
    symbols = {'ar': 'v',
               'en': '^',
               'es': '<',
               'fi': '>',
               'fr': 'P',
               'he': 's',
               'pl': 'h',
               'ru': '*',
               'zh': 'X',
               'syn_overlap': 'X',
               'edge_diffs': 's',
               'perc_overlap_syn_per_word': '*',
               }
    return symbols

def get_plot_styles():
    styles = {'ar': '--',
              'en': '--',
              'es': '--',
              'fi': '--',
              'fr': '--',
              'he': '--',
              'pl': '--',
              'ru': '--',
              'zh': '--',
              'syn_overlap': '--',
              'edge_diffs': '--',
              'perc_overlap_syn_per_word': '--',
              }
    return styles

def get_plot_linewidths():
    styles = {'ar': 2,
              'en': 2,
              'es': 2,
              'fi': 2,
              'fr': 2,
              'he': 2,
              'pl': 2,
              'ru': 2,
              'zh': 2,
              'syn_overlap': 2,
              'edge_diffs': 2,
              'perc_overlap_syn_per_word': 2,
              }
    return styles

def lang_inventory_LSIM_plot(args):
    config = get_config()
    colex = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_all.pkl'))
    colex_5 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs5.pkl'))
    colex_10 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs10.pkl'))
    colex_20 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs20.pkl'))
    colex_50 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs50.pkl'))
    colex_100 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs100.pkl'))
    colex_200 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs200.pkl'))
    plot_results = {}

    # dummy_x = [0, 1, 2]
    # dummy_x = [5, 10, 50, 499]
    # dummy_x = [5, 10, 20, 50, 100, 200, 499]
    dummy_x = [0, 1, 2, 3, 4, 5, 6]
    # my_xticks = ['5', '10', '50', '499']
    my_xticks = ['5', '10', '20', '50', '100', '200', '499']
    plt.xticks(dummy_x, my_xticks)
    languages = args.languages.split("_")
    color_codes = get_plot_color_codes()
    linewidths = get_plot_linewidths()
    symbols = get_plot_symbols()
    styles = get_plot_styles()
    for lang in languages:
        plot_heights = [colex_5[lang]['spearman_rank_corr'],
                        colex_10[lang]['spearman_rank_corr'],
                        colex_20[lang]['spearman_rank_corr'],
                        colex_50[lang]['spearman_rank_corr'],
                        colex_100[lang]['spearman_rank_corr'],
                        colex_200[lang]['spearman_rank_corr'],
                        colex[lang]['spearman_rank_corr']]
        total_pairs = colex[lang]['num_pairs']
        # plt.plot(plot_heights,
        #          label=lang.upper())
        label = lang.upper() + " (" + str(total_pairs) + ")"
        plt.plot(dummy_x, plot_heights,
                 color_codes[lang],
                 linestyle=styles[lang],
                 marker=symbols[lang],
                 label=label,
                 linewidth=linewidths[lang])
    # plt.legend()
    # plt.legend(bbox_to_anchor=(0.185, 0.21), ncol=3, fancybox=False, shadow=False, framealpha=1.0)  # LEGEND IN PLOT
    plt.legend(bbox_to_anchor=(0.91, 1.21), ncol=3, fancybox=False, shadow=False, framealpha=1.0)
    # plt.yticks([0.45, 0.50, 0.55, 0.60, 0.65])
    plt.grid(axis='y')
    # plt.xscale('log', basex=10)
    plt.ylabel('Spearman rank correlation ρ')
    plt.xlabel('Number of Languages Used for Colexification Graph Construction')
    # plt.show()
    dump_path = os.path.join(config.directories.plots, 'langInventoryEffect.pdf')
    plt.savefig(dump_path, bbox_inches='tight', pad_inches=0.05)
    plt.close()

def lang_inventory_LSIM_plot_version2(args):
    config = get_config()
    colex = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_all.pkl'))
    # colex_5 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs5.pkl'))
    colex_9 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs9.pkl'))
    colex_20 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs20.pkl'))
    colex_50 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs50.pkl'))
    colex_100 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs100.pkl'))
    colex_200 = utils.load(os.path.join(config.directories.results, 'LangInventory~colex_limited_langs200.pkl'))
    plot_results = {}

    # dummy_x = [0, 1, 2]
    # dummy_x = [5, 10, 50, 499]
    # dummy_x = [5, 10, 20, 50, 100, 200, 499]
    dummy_x = [0, 1, 2, 3, 4, 5]
    # my_xticks = ['5', '10', '50', '499']
    my_xticks = ['9', '20', '50', '100', '200', '499']
    plt.xticks(dummy_x, my_xticks)
    languages = args.languages.split("_")
    color_codes = get_plot_color_codes()
    linewidths = get_plot_linewidths()
    symbols = get_plot_symbols()
    styles = get_plot_styles()
    for lang in languages:
        plot_heights = [colex_9[lang]['spearman_rank_corr'],
                        colex_20[lang]['spearman_rank_corr'],
                        colex_50[lang]['spearman_rank_corr'],
                        colex_100[lang]['spearman_rank_corr'],
                        colex_200[lang]['spearman_rank_corr'],
                        colex[lang]['spearman_rank_corr']]
        total_pairs = colex[lang]['num_pairs']
        # plt.plot(plot_heights,
        #          label=lang.upper())
        label = lang.upper() + " (" + str(total_pairs) + ")"
        plt.plot(dummy_x, plot_heights,
                 color_codes[lang],
                 linestyle=styles[lang],
                 marker=symbols[lang],
                 label=label,
                 linewidth=linewidths[lang])
    # plt.legend()
    # plt.legend(bbox_to_anchor=(0.185, 0.21), ncol=3, fancybox=False, shadow=False, framealpha=1.0)  # LEGEND IN PLOT
    plt.legend(bbox_to_anchor=(0.91, 1.21), ncol=3, fancybox=False, shadow=False, framealpha=1.0)
    # plt.yticks([0.45, 0.50, 0.55, 0.60, 0.65])
    plt.grid(axis='y')
    # plt.xscale('log', basex=10)
    plt.ylabel('Spearman rank correlation ρ')
    plt.xlabel('Number of Languages Used for Colexification Graph Construction')
    # plt.show()
    dump_path = os.path.join(config.directories.plots, 'langInventoryEffect_version2.pdf')
    plt.savefig(dump_path, bbox_inches='tight', pad_inches=0.05)
    plt.close()

def lang_inventory_analysis_plot(args):
    config = get_config()
    results = utils.load(os.path.join(config.directories.results, 'LangInventoryAnalysis.pkl'))

    # dummy_x = [0, 1, 2]
    # dummy_x = [5, 10, 50, 499]
    dummy_x = [5, 10, 20, 50, 499]
    # my_xticks = ['5', '10', '50', '499']
    # plt.xticks(dummy_x, my_xticks)
    languages = args.languages.split("_")
    color_codes = get_plot_color_codes()
    linewidths = get_plot_linewidths()
    symbols = get_plot_symbols()
    styles = get_plot_styles()
    err_bars = {"syn_overlap": False, "edge_diffs": True, "perc_overlap_syn_per_word": False}
    labels = {"syn_overlap": "Synset Pair Overlap",
              "edge_diffs": "Mean Normalized Edge Weight Difference",
              "perc_overlap_syn_per_word": "Mean Synset Overlap Per Evaluation Word"}
    for data_type in ["syn_overlap", "edge_diffs", "perc_overlap_syn_per_word"]:
        plot_heights = None
        if data_type == "syn_overlap" or data_type == "perc_overlap_syn_per_word":
            plot_heights = [results[data_type]["five"],
                            results[data_type]["ten"],
                            results[data_type]["twenty"],
                            results[data_type]["fifty"],
                            results[data_type]["all"]]
        elif data_type == "edge_diffs":
            plot_heights = [results[data_type]["five"]["mean"],
                            results[data_type]["ten"]["mean"],
                            results[data_type]["twenty"]["mean"],
                            results[data_type]["fifty"]["mean"],
                            results[data_type]["all"]["mean"]]

        label = labels[data_type]
        if err_bars[data_type]:
            yerr = [results[data_type]["five"]["std"],
                    results[data_type]["ten"]["std"],
                    results[data_type]["twenty"]["std"],
                    results[data_type]["fifty"]["std"],
                    results[data_type]["all"]["std"]]
            plt.errorbar(dummy_x,
                         plot_heights,
                         yerr=yerr,
                         color=color_codes[data_type],
                         linestyle=styles[data_type],
                         marker=symbols[data_type],
                         label=label,
                         linewidth=linewidths[data_type])
        else:
            plt.plot(dummy_x, plot_heights,
                     color_codes[data_type],
                     linestyle=styles[data_type],
                     marker=symbols[data_type],
                     label=label,
                     linewidth=linewidths[data_type])
    # plt.legend()
    # plt.legend(bbox_to_anchor=(0.84, 1.15), ncol=1, fancybox=False, shadow=False, framealpha=1.0)
    plt.legend(bbox_to_anchor=(0.86, 1.21), ncol=1, fancybox=False, shadow=False, framealpha=1.0)
    # plt.yticks([0.45, 0.50, 0.55, 0.60, 0.65])
    plt.grid(axis='y')
    plt.xscale('log', basex=10)
    plt.ylabel('Fraction')
    plt.xlabel('Number of Languages Used for Similarity Graph Construction')
    # plt.show()
    dump_path = os.path.join(config.directories.plots, 'langInventoryAnalysis.pdf')
    plt.savefig(dump_path, bbox_inches='tight', pad_inches=0.05)
    plt.close()

def LSIM_table(args):
    """"""
    config = get_config()
    # colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_colex_sum.pkl'))
    # cross_colex_sum_senses_PCA = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_senses_PCA.pkl'))
    # cross_colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM_fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM_BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM_ARES.pkl'))

    # colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_colex_sum_10.pkl'))
    # cross_colex_sum_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_10_maxsim.pkl'))
    # cross_colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_10.pkl'))
    # cross_colex_binary = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_binary_10.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM_fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM_BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM_ARES.pkl'))

    # colex = utils.load(os.path.join(config.directories.results, 'LSIM~colex_binary_5.pkl'))
    # cross_colex_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5_maxsim.pkl'))
    # cross_colex = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))

    # colex = utils.load(os.path.join(config.directories.results, 'LSIM~colex_binary_5.pkl'))
    # colex_filtered = utils.load(os.path.join(config.directories.results, 'LSIM~colex_filtered_binary_5.pkl'))
    # cross_colex_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5_maxsim.pkl'))
    # cross_colex = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5.pkl'))
    # cross_colex_presence_absence = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_presence_absence_5.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))

    colex_mono = utils.load(os.path.join(config.directories.results, 'LSIM~colex_mono.pkl'))
    colex_all = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all.pkl'))
    colex_all_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all_maxsim.pkl'))
    fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))

    """
    
        & AR & EN & ES & FI & FR & HE & PL & RU & ZH & Mean & Std.   \\ 
         & n & n & n & n & n & n & n & n & n & n & n \\ 
         \hline
       COLEX\textsubscript{mono}  & n & n & n & n & n & n & n & n & n & n & n  \\
       COLEX\textsubscript{sense}  & n & n & n & n & n & n & n & n & n & n & n  \\
       COLEX\textsubscript{cross}  & n & n & n & n & n & n & n & n & n & n & n  \\ \hline
       fastText  & n & n & n & n & n & n & n & n & n & n & n  \\
       BERT  & n & n & n & n & n & n & n & n & n & n & n  \\
       ARES  & n & n & n & n & n & n & n & n & n & n & n  \\ \hline
    
    """
    first_line = " & "
    for lang in config.eval_languages:
        first_line += lang.upper() + " & "
    first_line += " Mean & Std. \\\\ "

    pair_line = " "
    for lang in config.eval_languages:
        num_pairs = colex_mono[lang]['num_pairs']
        assert num_pairs == colex_all[lang]['num_pairs']
        assert num_pairs == colex_all_maxsim[lang]['num_pairs']
        # assert num_pairs == cross_colex_binary[lang]['num_pairs']
        assert num_pairs == fasttext[lang]['num_pairs']
        assert num_pairs == BERT[lang]['num_pairs']
        assert num_pairs == ARES[lang]['num_pairs']
        # assert num_pairs == colex_filtered[lang]['num_pairs']
        # assert num_pairs == cross_colex_presence_absence[lang]['num_pairs']
        pair_line += " & {\small(" + str(num_pairs) + ")} "
    pair_line += " & $\\uparrow$ & $\\downarrow$ \\\\ \hline"

    method_info_dict = {"fasttext": {"header": "fastText ", "variable": fasttext},
                        "BERT": {"header": "BERT ", "variable": BERT},
                        "ARES": {"header": "ARES ", "variable": ARES},
                        "colex_mono": {"header": "COLEX\\textsubscript{mono} ", "variable": colex_mono},
                        "colex_all_maxsim": {"header": "COLEX\\textsubscript{maxsim} ", "variable": colex_all_maxsim},
                        "colex_all": {"header": "COLEX\\textsubscript{cross} ", "variable": colex_all},
                        }

    data_lines = []
    for key, info in method_info_dict.items():
        line = info['header']
        line_values = []
        for lang in config.eval_languages:
            # "{:.2f}".format(float)
            line += " & " + "{:.2f}".format(info['variable'][lang]['spearman_rank_corr'])

            # line += " & " + str(round(info['variable'][lang]['spearman_rank_corr'], 2))
            line_values.append(info['variable'][lang]['spearman_rank_corr'])
        mean = np.mean(np.asarray(line_values))
        std = np.std(np.asarray(line_values))
        # line += " & " + str(round(mean, 2)) + " & " + str(round(std, 2))
        # line += " & " + "{:.2f}".format(mean) + " & " + "{:.2f}".format(std, 2)
        line += " & " + "{:.2f}".format(mean) + " & " + "{:.2f}".format(std)
        line += " \\\\ "
        data_lines.append(line)

    print(first_line)
    print(pair_line)
    for line in data_lines:
        print(line)

def LSIM_table_version2(args):
    """"""
    config = get_config()
    # colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_colex_sum.pkl'))
    # cross_colex_sum_senses_PCA = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_senses_PCA.pkl'))
    # cross_colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM_fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM_BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM_ARES.pkl'))

    # colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_colex_sum_10.pkl'))
    # cross_colex_sum_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_10_maxsim.pkl'))
    # cross_colex_sum = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_sum_10.pkl'))
    # cross_colex_binary = utils.load(os.path.join(config.directories.results, 'LSIM_cross_colex_binary_10.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM_fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM_BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM_ARES.pkl'))

    # colex = utils.load(os.path.join(config.directories.results, 'LSIM~colex_binary_5.pkl'))
    # cross_colex_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5_maxsim.pkl'))
    # cross_colex = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))

    # colex = utils.load(os.path.join(config.directories.results, 'LSIM~colex_binary_5.pkl'))
    # colex_filtered = utils.load(os.path.join(config.directories.results, 'LSIM~colex_filtered_binary_5.pkl'))
    # cross_colex_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5_maxsim.pkl'))
    # cross_colex = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_binary_5.pkl'))
    # cross_colex_presence_absence = utils.load(os.path.join(config.directories.results, 'LSIM~cross_colex_presence_absence_5.pkl'))
    # fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    # BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    # ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))

    colex_mono = utils.load(os.path.join(config.directories.results, 'LSIM~colex_mono.pkl'))
    colex_all = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all.pkl'))
    colex_all_maxsim = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all_maxsim.pkl'))
    fasttext = utils.load(os.path.join(config.directories.results, 'LSIM~fasttext.pkl'))
    BERT = utils.load(os.path.join(config.directories.results, 'LSIM~BERT.pkl'))
    ARES = utils.load(os.path.join(config.directories.results, 'LSIM~ARES.pkl'))
    CF_fusion = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all~fasttext.pkl'))
    CB_fusion = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all~BERT.pkl'))
    BF_fusion = utils.load(os.path.join(config.directories.results, 'LSIM~BERT~fasttext.pkl'))
    CBF_fusion = utils.load(os.path.join(config.directories.results, 'LSIM~colex_all~fasttext~BERT.pkl'))

    """

        & AR & EN & ES & FI & FR & HE & PL & RU & ZH & Mean & Std.   \\ 
         & n & n & n & n & n & n & n & n & n & n & n \\ 
         \hline
       COLEX\textsubscript{mono}  & n & n & n & n & n & n & n & n & n & n & n  \\
       COLEX\textsubscript{sense}  & n & n & n & n & n & n & n & n & n & n & n  \\
       COLEX\textsubscript{cross}  & n & n & n & n & n & n & n & n & n & n & n  \\ \hline
       fastText  & n & n & n & n & n & n & n & n & n & n & n  \\
       BERT  & n & n & n & n & n & n & n & n & n & n & n  \\
       ARES  & n & n & n & n & n & n & n & n & n & n & n  \\ \hline

    """
    first_line = " & "
    for lang in config.eval_languages:
        first_line += lang.upper() + " & "
    first_line += " Mean & Std. \\\\ "

    pair_line = " "
    for lang in config.eval_languages:
        num_pairs = colex_mono[lang]['num_pairs']
        assert num_pairs == colex_all[lang]['num_pairs']
        assert num_pairs == colex_all_maxsim[lang]['num_pairs']
        # assert num_pairs == cross_colex_binary[lang]['num_pairs']
        assert num_pairs == fasttext[lang]['num_pairs']
        assert num_pairs == BERT[lang]['num_pairs']
        assert num_pairs == ARES[lang]['num_pairs']
        assert num_pairs == CB_fusion[lang]['num_pairs']
        assert num_pairs == CBF_fusion[lang]['num_pairs']
        assert num_pairs == BF_fusion[lang]['num_pairs']
        assert num_pairs == CF_fusion[lang]['num_pairs']
        # assert num_pairs == colex_filtered[lang]['num_pairs']
        # assert num_pairs == cross_colex_presence_absence[lang]['num_pairs']
        pair_line += " & {\small(" + str(num_pairs) + ")} "
    pair_line += " & $\\uparrow$ & $\\downarrow$ \\\\ \hline"

    method_info_dict = {"fasttext": {"header": "fastText ", "variable": fasttext},
                        "BERT": {"header": "BERT ", "variable": BERT},
                        "ARES": {"header": "ARES ", "variable": ARES},
                        "colex_mono": {"header": "COLEX\\textsubscript{mono} ", "variable": colex_mono},
                        "colex_all_maxsim": {"header": "COLEX\\textsubscript{maxsim} ", "variable": colex_all_maxsim},
                        "colex_all": {"header": "COLEX\\textsubscript{cross} ", "variable": colex_all},
                        "CF_fusion": {"header": "C+F ", "variable": CF_fusion},
                        "CB_fusion": {"header": "C+B ", "variable": CB_fusion},
                        "BF_fusion": {"header": "B+F ", "variable": BF_fusion},
                        "CBF_fusion": {"header": "C+F+B ", "variable": CBF_fusion},
                        }

    data_lines = []
    for key, info in method_info_dict.items():
        line = info['header']
        line_values = []
        for lang in config.eval_languages:
            # "{:.2f}".format(float)
            line += " & " + "{:.2f}".format(info['variable'][lang]['spearman_rank_corr'])

            # line += " & " + str(round(info['variable'][lang]['spearman_rank_corr'], 2))
            line_values.append(info['variable'][lang]['spearman_rank_corr'])
        mean = np.mean(np.asarray(line_values))
        std = np.std(np.asarray(line_values))
        # line += " & " + str(round(mean, 2)) + " & " + str(round(std, 2))
        # line += " & " + "{:.2f}".format(mean) + " & " + "{:.2f}".format(std, 2)
        line += " & " + "{:.2f}".format(mean) + " & " + "{:.2f}".format(std)
        line += " \\\\ "
        data_lines.append(line)

    print(first_line)
    print(pair_line)
    for line in data_lines:
        print(line)

def Concat_LSIM_table(args):
    """"""
    config = get_config()
    C = utils.load(os.path.join(config.directories.results, 'CONCAT~colex_all.pkl'))
    B = utils.load(os.path.join(config.directories.results, 'CONCAT~BERT.pkl'))
    F = utils.load(os.path.join(config.directories.results, 'CONCAT~fasttext.pkl'))
    CF = utils.load(os.path.join(config.directories.results, 'CONCAT~colex_all~fasttext.pkl'))
    CB = utils.load(os.path.join(config.directories.results, 'CONCAT~colex_all~BERT.pkl'))
    BF = utils.load(os.path.join(config.directories.results, 'CONCAT~BERT~fasttext.pkl'))
    CFB = utils.load(os.path.join(config.directories.results, 'CONCAT~colex_all~fasttext~BERT.pkl'))

    first_line = " "
    for lang in config.eval_languages:
        first_line += " & " + lang.upper()
    first_line += " & Mean & Std. \\\\ "
    # first_line += " \\\\ "

    pair_line = " "
    for lang in config.eval_languages:
        num_pairs = C[lang]['num_pairs']
        assert num_pairs == B[lang]['num_pairs']
        assert num_pairs == F[lang]['num_pairs']
        assert num_pairs == CB[lang]['num_pairs']
        assert num_pairs == CF[lang]['num_pairs']
        assert num_pairs == BF[lang]['num_pairs']
        assert num_pairs == CFB[lang]['num_pairs']
        pair_line += " & {\small(" + str(num_pairs) + ")} "
    pair_line += " & $\\uparrow$ & $\\downarrow$ \\\\ \hline"
    # pair_line += " \\\\ \hline"

    method_info_dict = {"C": {"header": "C ", "variable": C},
                        "F": {"header": "F ", "variable": F},
                        "B": {"header": "B ", "variable": B},
                        "CF": {"header": "C+F ", "variable": CF},
                        "CB": {"header": "C+B ", "variable": CB},
                        "BF": {"header": "B+F ", "variable": BF},
                        "CFB": {"header": "C+F+B ", "variable": CFB},
                        }

    data_lines = []
    for key, info in method_info_dict.items():
        line = info['header']
        line_values = []
        for lang in config.eval_languages:
            # "{:.2f}".format(float)
            line += " & " + "{:.2f}".format(info['variable'][lang]['spearman_rank_corr'])

            # line += " & " + str(round(info['variable'][lang]['spearman_rank_corr'], 2))
            line_values.append(info['variable'][lang]['spearman_rank_corr'])
        mean = np.mean(np.asarray(line_values))
        std = np.std(np.asarray(line_values))
        # line += " & " + str(round(mean, 2)) + " & " + str(round(std, 2))
        line += " & " + "{:.2f}".format(mean) + " & " + "{:.2f}".format(std)
        line += " \\\\ "
        data_lines.append(line)

    print(first_line)
    print(pair_line)
    for line in data_lines:
        print(line)

def POS_LSIM_table(args):
    """"""
    config = get_config()
    methods = ['fasttext', 'BERT', 'ARES', 'colex_sum', 'cross_colex_sum_senses_PCA', 'cross_colex_sum']
    pos_types = ['nouns', 'adjectives', 'verbs', 'adverbs']
    data_dict = {}
    for method in methods:
        data_dict[method] = {}
        for pos_type in pos_types:
            value = utils.load(os.path.join(config.directories.results, 'POS' + pos_type + '_' + method + '.pkl'))
            data_dict[method][pos_type] = value

    """
      & Noun & Verb & Adj. & Adv. \\ \hline
       fastText  & n & n & n & n \\
       BERT & n & n & n & n \\ \hline
       ARES & n & n & n & n \\ \hline
       COLEX\textsubscript{mono}  & n & n & n & n \\
       COLEX\textsubscript{sense}  & n & n & n & n \\ 
       COLEX\textsubscript{cross}  & n & n & n & n \\
    """
    method_info_dict = {"fasttext": {"header": "fastText "},
                        "BERT": {"header": "BERT "},
                        "ARES": {"header": "ARES "},
                        "colex_sum": {"header": "COLEX\\textsubscript{mono} "},
                        "cross_colex_sum_senses_PCA": {"header": "COLEX\\textsubscript{sense} "},
                        "cross_colex_sum": {"header": "COLEX\\textsubscript{cross} "},
                        }


    first_line = "  & Noun & Verb & Adj. & Adv. \\\\ "
    lines = []
    mean_pairs_dict = {}
    for method in methods:
        tmp_line = method_info_dict[method]['header']
        for pos_type in pos_types:
            tmp_data = data_dict[method][pos_type]
            num_pairs_list = []
            spearman_rank_list = []
            for lang, info in tmp_data.items():
                spearman_rank_list.append(info['spearman_rank_corr'])
                num_pairs_list.append(info['num_pairs'])
            mean_spearman = np.mean(np.asarray(spearman_rank_list))
            mean_pairs = np.mean(np.asarray(num_pairs_list))
            if pos_type in mean_pairs_dict:
                assert mean_pairs == mean_pairs_dict[pos_type]
            else:
                mean_pairs_dict[pos_type] = mean_pairs
            tmp_line += " & " + "{:.2f}".format(mean_spearman)
        tmp_line += " \\\\ "
        lines.append(tmp_line)

    pair_line = " "
    for pos_type, value in mean_pairs_dict.items():
        pair_line += " & {\small(" + "{:.1f}".format(value) + ")} "
    pair_line += " \\\\ \hline"
    print(first_line)
    print(pair_line)
    for line in lines:
        print(line)

def CrossLingual_LSIM(args):
    """
    EN  & - & n & n & n & n & n & n & n  \\
       ES  & n & - & n & n & n & n & n & n  \\
       FI  & n & n & - & n & n & n & n & n  \\
       FR  & n & n & n & - & n & n & n & n  \\
       HE  & n & n & n & n & - & n & n & n  \\
       PL  & n & n & n & n & n & - & n & n  \\
       RU & n & n & n & n & n & n & - & n  \\
       ZH  & n & n & n & n & n & n & n & - \\
    """

    config = get_config()
    top_method = 'fasttext'
    bottom_method = 'colex_all'
    data_dict = {}
    for method in [top_method, bottom_method]:
        data_dict[method] = {}
        value = utils.load(os.path.join(config.directories.results, 'CROSSLINGUAL~' + method + '.pkl'))
        data_dict[method] = value

    method_info_dict = {"fasttext": {"header": "fastText "},
                        "BERT": {"header": "BERT "},
                        "colex_sum": {"header": "COLEX\\textsubscript{mono} "},
                        "cross_colex_sum_senses_PCA": {"header": "COLEX\\textsubscript{sense} "},
                        "cross_colex_sum": {"header": "COLEX\\textsubscript{cross} "},
                        }

    lines = []
    mean_pairs_dict = {}
    languages = args.languages.split("_")
    bottom_score = None
    top_score = None
    for row_lang in languages:
        """We print lines by row"""
        done_with_bottom = False
        line = row_lang.upper() + " "
        for col_lang in languages:
            if row_lang == col_lang:
                done_with_bottom = True
                line += " & - "
            else:
                lang_pair1 = row_lang + "_" + col_lang
                lang_pair2 = col_lang + "_" + row_lang

                if lang_pair1 in data_dict[bottom_method]:
                    bottom_score = data_dict[bottom_method][lang_pair1]['spearman_rank_corr']
                elif lang_pair2 in data_dict[bottom_method]:
                    bottom_score = data_dict[bottom_method][lang_pair2]['spearman_rank_corr']
                if lang_pair1 in data_dict[top_method]:
                    top_score = data_dict[top_method][lang_pair1]['spearman_rank_corr']
                elif lang_pair2 in data_dict[top_method]:
                    top_score = data_dict[top_method][lang_pair2]['spearman_rank_corr']
                bold = False
                if top_score != None and bottom_score != None:
                    if not done_with_bottom:
                        score = bottom_score
                        if round(bottom_score, 2) >= round(top_score, 2):
                            bold = True
                        # if lang_pair1 in data_dict[bottom_method]:
                        #     score = data_dict[bottom_method][lang_pair1]['spearman_rank_corr']
                        # elif lang_pair2 in data_dict[bottom_method]:
                        #     score = data_dict[bottom_method][lang_pair2]['spearman_rank_corr']
                    else:
                        score = top_score
                        if round(top_score, 2) >= round(bottom_score, 2):
                            bold = True
                        # if lang_pair1 in data_dict[top_method]:
                        #     score = data_dict[top_method][lang_pair1]['spearman_rank_corr']
                        # elif lang_pair2 in data_dict[top_method]:
                        #     score = data_dict[top_method][lang_pair2]['spearman_rank_corr']
                    if bold and args.bold_best_CLSIM:
                        line += " & \\textbf{" + "{:.2f}".format(score) + "}"
                    else:
                        line += " & " + "{:.2f}".format(score)
        line = line + " \\\\ "
        lines.append(line)
    for line in lines:
        print(line)

def CrossLingual_LSIM_OOV_table(args):
    """
    EN  & - & n & n & n & n & n & n & n  \\
       ES  & n & - & n & n & n & n & n & n  \\
       FI  & n & n & - & n & n & n & n & n  \\
       FR  & n & n & n & - & n & n & n & n  \\
       HE  & n & n & n & n & - & n & n & n  \\
       PL  & n & n & n & n & n & - & n & n  \\
       RU & n & n & n & n & n & n & - & n  \\
       ZH  & n & n & n & n & n & n & n & - \\
    """

    config = get_config()
    top_method = 'fasttext'
    bottom_method = 'colex_all'
    data_dict = {}
    for method in [top_method, bottom_method]:
        data_dict[method] = {}
        value = utils.load(os.path.join(config.directories.results, 'CROSSLINGUAL_OOV~' + method + '.pkl'))
        data_dict[method] = value

    method_info_dict = {"fasttext": {"header": "fastText "},
                        "BERT": {"header": "BERT "},
                        "colex_sum": {"header": "COLEX\\textsubscript{mono} "},
                        "cross_colex_sum_senses_PCA": {"header": "COLEX\\textsubscript{sense} "},
                        "cross_colex_sum": {"header": "COLEX\\textsubscript{cross} "},
                        }

    lines = []
    mean_pairs_dict = {}
    languages = args.languages.split("_")
    bottom_score_complete = None
    top_score_complete = None
    bottom_score_total = None
    top_score_total = None
    for row_lang in languages:
        """We print lines by row"""
        done_with_bottom = False
        line = row_lang.upper() + " "
        for col_lang in languages:
            if row_lang == col_lang:
                done_with_bottom = True
                line += " & - "
            else:
                lang_pair1 = row_lang + "_" + col_lang
                lang_pair2 = col_lang + "_" + row_lang

                if lang_pair1 in data_dict[bottom_method]:
                    bottom_score_complete = data_dict[bottom_method][lang_pair1]['completed_pairs']
                    bottom_score_total = data_dict[bottom_method][lang_pair1]['total_pairs']
                elif lang_pair2 in data_dict[bottom_method]:
                    bottom_score_complete = data_dict[bottom_method][lang_pair2]['completed_pairs']
                    bottom_score_total = data_dict[bottom_method][lang_pair2]['total_pairs']
                if lang_pair1 in data_dict[top_method]:
                    top_score_complete = data_dict[top_method][lang_pair1]['completed_pairs']
                    top_score_total = data_dict[top_method][lang_pair1]['total_pairs']
                elif lang_pair2 in data_dict[top_method]:
                    top_score_complete = data_dict[top_method][lang_pair2]['completed_pairs']
                    top_score_total = data_dict[top_method][lang_pair2]['total_pairs']
                bold = False
                assert bottom_score_complete == top_score_complete  # We're checking OOV, they must be the evaluated on same set!
                assert bottom_score_total == top_score_total  # We're checking OOV, they must be the evaluated on same set!
                if top_score_complete != None and bottom_score_complete != None:
                    if not done_with_bottom:
                        completed_pairs = bottom_score_complete
                        total_pairs = bottom_score_total
                    else:
                        completed_pairs = top_score_complete
                        total_pairs = top_score_total
                    # line += " & " + "{:.2f}".format(score)
                    line += " & " + "$\\frac{" + str(completed_pairs) + "}{" + str(total_pairs) + "}$"
        line = line + " \\\\ \\hline"
        lines.append(line)
    for line in lines:
        print(line)

def LSIM_correlation_rank_maps(args):
    rank_method = args.rank_method
    config = get_config()
    simscore_data = {}
    simscore_data["colex_all"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all.pkl"))
    simscore_data["colex_mono"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_mono.pkl"))
    simscore_data["fasttext"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~fasttext.pkl"))
    #simscore_data["colex_maxsim"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all_maxsim.pkl"))
    #simscore_data["ARES"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~ARES.pkl"))
    #simscore_data["BERT"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~BERT.pkl"))
    simscore_data["Fusion"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all~fasttext.pkl"))
    """Generate the colex_all and colex_mono rank plots"""
    points = {}
    full_max_val = -1
    max_gt_ranks_per_method_per_lang = {}
    max_our_ranks_per_method_per_lang = {}
    # num_gt_ranks = {}
    for model_type, simdata in simscore_data.items():
        points[model_type] = []
        max_gt_rank = -1
        max_our_rank = -1
        max_gt_rank_per_lang = {}
        max_our_rank_per_lang = {}
        for lang, lang_dict in simdata.items():
            ground_truth = []
            our_scores = []
            for word_pair, data_dict in lang_dict.items():
                ground_truth.append(data_dict["ground_truth"])
                our_scores.append(data_dict["our_score"])
            ground_truth = np.asarray(ground_truth)
            our_scores = np.asarray(our_scores)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth, method=rank_method)
            temp_max_gt_rank = np.max(ground_truth_ranks)
            if temp_max_gt_rank > max_gt_rank:
                max_gt_rank = temp_max_gt_rank
            our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
            temp_max_our_rank = np.max(our_ranks)
            if temp_max_our_rank > max_our_rank:
                max_our_rank = temp_max_our_rank
            lang_max_gt = np.max(ground_truth_ranks)
            max_gt_rank_per_lang[lang] = lang_max_gt
            lang_max_ours = np.max(our_ranks)
            max_our_rank_per_lang[lang] = lang_max_ours
            for data_point in zip(ground_truth_ranks, our_ranks):
                points[model_type].append(data_point)
        max_gt_ranks_per_method_per_lang[model_type] = max_gt_rank_per_lang
        max_our_ranks_per_method_per_lang[model_type] = max_our_rank_per_lang
        print("Max ground truth rank for " + model_type + " is " + str(max_gt_rank))
        print("Max our rank for " + model_type + " is " + str(max_our_rank))
        all_points = points[model_type]
        all_points = np.asarray(all_points)

        x = all_points[:, 0]
        y = all_points[:, 1]
        point_sizes = np.ones_like(x)*3.0

        # heatmap, xedges, yedges = np.histogram2d(x, y, bins=15, range=[[0, 1822], [0, 1822]])
        # heatmap, xedges, yedges = np.histogram2d(x, y, bins=15, range=[[0, max_gt_rank], [0, max_our_rank]])
        """Want the plot to be square, pick max of max_our_rank and max_gt_rank"""
        bounder = max(max_gt_rank, max_our_rank)
        heatmap, xedges, yedges = np.histogram2d(x, y, bins=15, range=[[0, bounder], [0, bounder]])
        max_val = np.max(heatmap)
        if max_val > full_max_val:
            full_max_val = max_val
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

        plt.clf()
        plt.imshow(heatmap.T, extent=extent, origin='lower')
        # plt.set_cmap("jet")
        plt.set_cmap("hot")
        # plt.clim(0, 750)
        plt.clim(0, 435)  # full_max_val is 432 across 4 methods in paper, OLD--->full_max_val is 410 across all methods, 385 across colex_all and colex_mono only
        # plt.clim(0, 250)
        plt.colorbar()
        plt.xlabel("Gold Standard Rank")
        plt.ylabel("Computed Rank")

        # plt.show()
        if model_type == "colex_all":
            savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_cross.pdf")
        elif model_type == "colex_mono":
            savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_mono.pdf")
        elif model_type == "colex_maxsim":
            savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_maxsim.pdf")
        elif model_type == "fasttext":
            savepath = os.path.join(config.directories.plots, "RankPlot_fasttext.pdf")
        elif model_type == "BERT":
            savepath = os.path.join(config.directories.plots, "RankPlot_BERT.pdf")
        elif model_type == "ARES":
            savepath = os.path.join(config.directories.plots, "RankPlot_ARES.pdf")
        elif model_type == "Fusion":
            savepath = os.path.join(config.directories.plots, "RankPlot_C+F_Fusion.pdf")
        plt.savefig(savepath, bbox_inches='tight', pad_inches=0.05)

        # plt.scatter(x, y, point_sizes)
        # plt.show()
        stop = None
    print(full_max_val)

def LSIM_correlation_rank_maps_individual_langs(args):
    rank_method = args.rank_method
    config = get_config()
    simscore_data = {}
    simscore_data["colex_all"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all.pkl"))
    simscore_data["colex_mono"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_mono.pkl"))
    # simscore_data["fasttext"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~fasttext.pkl"))
    # simscore_data["colex_maxsim"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all_maxsim.pkl"))
    # simscore_data["ARES"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~ARES.pkl"))
    # simscore_data["BERT"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~BERT.pkl"))
    simscore_data["Fusion"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all~fasttext.pkl"))
    """Generate the colex_all and colex_mono rank plots"""
    points = {}
    full_percentage_max = -1
    for model_type, simdata in simscore_data.items():
        points[model_type] = []
        for lang, lang_dict in simdata.items():
            ground_truth = []
            our_scores = []
            for word_pair, data_dict in lang_dict.items():
                ground_truth.append(data_dict["ground_truth"])
                our_scores.append(data_dict["our_score"])
            ground_truth = np.asarray(ground_truth)
            our_scores = np.asarray(our_scores)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth, method=rank_method)
            our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
            for data_point in zip(ground_truth_ranks, our_ranks):
                points[model_type].append(data_point)
            all_points = points[model_type]
            all_points = np.asarray(all_points)

            x = all_points[:, 0]
            y = all_points[:, 1]
            point_sizes = np.ones_like(x)*3.0

            max_gt_rank = np.max(ground_truth_ranks)
            max_our_rank = np.max(our_ranks)
            """Want the plot to be square, pick max of max_our_rank and max_gt_rank"""
            bounder = max(max_gt_rank, max_our_rank)
            heatmap, xedges, yedges = np.histogram2d(x, y, bins=10, range=[[0, bounder], [0, bounder]])

            # heatmap, xedges, yedges = np.histogram2d(x, y, bins=10, range=[[0, 1822], [0, 1822]])
            """Normalize heatmap"""
            heatmap = heatmap / np.sum(heatmap)
            heatmap = 100 * heatmap  # we do percentages for the plot
            percent_max = np.max(heatmap)
            if percent_max > full_percentage_max:
                full_percentage_max = percent_max
            extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

            plt.clf()
            plt.imshow(heatmap.T, extent=extent, origin='lower')
            # plt.set_cmap("jet")
            plt.set_cmap("hot")
            # plt.clim(0, 750)
            plt.clim(0, 9.3)  # full_percentage_max in 9.3 for colex_all, colex_mono, and Fusion # full_percentage_max is 6.485 for colex_all and colex_mono only, OLD ---> full_percentage_max is 6.523 for colex_all and colex_mono only, 8.7 for all methods
            # plt.clim(0, 250)
            plt.colorbar()
            plt.xlabel("Gold Standard Rank")
            plt.ylabel("Computed Rank")

            # plt.show()
            if model_type == "colex_all":
                savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_cross_" + lang.upper() + ".pdf")
            elif model_type == "colex_mono":
                savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_mono_" + lang.upper() + ".pdf")
            elif model_type == "colex_maxsim":
                savepath = os.path.join(config.directories.plots, "RankPlot_COLEX_maxsim_" + lang.upper() + ".pdf")
            elif model_type == "fasttext":
                savepath = os.path.join(config.directories.plots, "RankPlot_fasttext_" + lang.upper() + ".pdf")
            elif model_type == "BERT":
                savepath = os.path.join(config.directories.plots, "RankPlot_BERT_" + lang.upper() + ".pdf")
            elif model_type == "ARES":
                savepath = os.path.join(config.directories.plots, "RankPlot_ARES_" + lang.upper() + ".pdf")
            elif model_type == "Fusion":
                savepath = os.path.join(config.directories.plots, "RankPlot_C+F_Fusion_" + lang.upper() + ".pdf")
            plt.savefig(savepath, bbox_inches='tight', pad_inches=0.05)

            # plt.scatter(x, y, point_sizes)
            # plt.show()
            stop = None
    print(full_percentage_max)

def LSIM_dissimilar_words_similarly_ranked(args):
    delim = ","
    save_dir = "dissimilar_words_similarly_ranked"
    os.makedirs(save_dir, exist_ok=True)
    rank_method = args.rank_method
    config = get_config()
    simscore_data = {}
    simscore_data["colex_all"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all.pkl"))
    simscore_data["colex_mono"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_mono.pkl"))
    simscore_data["colex_maxsim"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all_maxsim.pkl"))
    simscore_data["fasttext"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~fasttext.pkl"))
    simscore_data["C+F_Fusion"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all~fasttext.pkl"))
    """Generate the colex_all and colex_mono rank plots"""
    points = {}
    full_percentage_max = -1
    rank_threshold = 200
    for model_type, simdata in simscore_data.items():
        points[model_type] = []
        for lang, lang_dict in simdata.items():
            ground_truth = []
            our_scores = []
            word_pairs = []
            points_in_lang = {}
            for word_pair, data_dict in lang_dict.items():
                ground_truth.append(data_dict["ground_truth"])
                our_scores.append(data_dict["our_score"])
                word_pairs.append(word_pair)
            ground_truth = np.asarray(ground_truth)
            our_scores = np.asarray(our_scores)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth, method=rank_method)
            our_ranks = scipy.stats.rankdata(our_scores, method=rank_method)
            for data_point in zip(ground_truth_ranks, our_ranks, word_pairs):
                if data_point[0] <= rank_threshold and data_point[1] <= rank_threshold:
                    points[model_type].append(data_point)
                    points_in_lang[data_point[2]] = data_point
            lines = []
            lines.append("word pair,ground truth rank,our rank")
            for word_pair, data_point in points_in_lang.items():
                line = word_pair + delim + str(data_point[0]) + delim + str(data_point[1])
                lines.append(line)
            save_path = os.path.join(save_dir, model_type + "_" + lang + ".csv")
            utils.write_file_from_list(lines, save_path)
            stop = None

def FP_FN_TP_TN_table_per_language(args):
    """"""
    config = get_config()
    languages = args.languages
    languages = languages.split("_")
    simscore_data = {}
    simscore_data["colex_mono"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_mono.pkl"))
    simscore_data["colex_all"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all.pkl"))
    simscore_data["C+F_Fusion"] = utils.load(os.path.join(config.directories.results, "SIMSCORE~LSIM~colex_all~fasttext.pkl"))
    fraction_threshold = 0.25  # top quarter for positive labels, bottom quarter for negative labels
    all_data = {}
    for model_type, simdata in simscore_data.items():
        all_data[model_type] = {}
        for lang, lang_dict in simdata.items():
            ground_truth = []
            our_scores = []
            for word_pair, data_dict in lang_dict.items():
                ground_truth.append(data_dict["ground_truth"])
                our_scores.append(data_dict["our_score"])
            ground_truth = np.asarray(ground_truth)
            our_scores = np.asarray(our_scores)
            ground_truth_ranks = scipy.stats.rankdata(ground_truth, method=args.rank_method)
            # temp_max_gt_rank = np.max(ground_truth_ranks)
            # if temp_max_gt_rank > max_gt_rank:
            #     max_gt_rank = temp_max_gt_rank
            our_ranks = scipy.stats.rankdata(our_scores, method=args.rank_method)
            # temp_max_our_rank = np.max(our_ranks)
            # if temp_max_our_rank > max_our_rank:
            #     max_our_rank = temp_max_our_rank
            lang_max_gt = np.max(ground_truth_ranks)
            lang_max_ours = np.max(our_ranks)
            max_rank = max(lang_max_gt, lang_max_ours)
            """Now collect FP, FN, TP, TN for given data and rank thresholds."""
            rank_positive_threshold = (1-fraction_threshold) * max_rank
            rank_negative_threshold = fraction_threshold * max_rank
            TP = []
            FP = []
            TN = []
            FN = []
            for data_point in zip(ground_truth_ranks, our_ranks):
                gt = data_point[0]
                ours = data_point[1]
                if gt >= rank_positive_threshold and ours >= rank_positive_threshold:
                    TP.append(data_point)
                elif gt >= rank_positive_threshold and ours <= rank_negative_threshold:
                    FN.append(data_point)
                elif gt <= rank_negative_threshold and ours <= rank_negative_threshold:
                    TN.append(data_point)
                elif gt <= rank_negative_threshold and ours >= rank_positive_threshold:
                    FP.append(data_point)

            all_data[model_type][lang] = {"TP": TP, "FN": FN, "TN": TN, "FP": FP}
            stop = None
    # first_line = " & "
    # for lang in languages:
    #     first_line += lang.upper() + " & "
    # first_line += "\\\\"
    lines = []
    for model_type, lang_stats in all_data.items():
        first_line = ""
        if model_type == "colex_all":
            first_line = "COLEX\\textsubscript{cross} "
        elif model_type == "colex_mono":
            first_line = "COLEX\\textsubscript{mono} "
        elif model_type == "C+F_Fusion":
            first_line = "C+F Fusion "
        for lang in languages:
            first_line += " & " + lang.upper()
        first_line += " \\\\"
        lines.append(first_line)

        for label_type in ["TP", "FN", "TN", "FP"]:
            line = label_type + " "
            for lang in languages:
                local_stats = lang_stats[lang]
                value = int(len(local_stats[label_type]))
                line = line + " & " + str(value)
                stop = None
            line += " \\\\"
            lines.append(line)
    for line in lines:
        print(line)

def main(args):
    """"""
    config = utils.get_config()
    if not os.path.exists(config.directories.plots):
        os.mkdir(config.directories.plots)
    if args.plot_type == 'BabelNet_words_bar_plot':
        bar_plot_for_BabelNet_wordcount()
    elif args.plot_type == 'lang_inventory_LSIM':
        lang_inventory_LSIM_plot(args)
    elif args.plot_type == 'lang_inventory_LSIM_version2':
        lang_inventory_LSIM_plot_version2(args)
    elif args.plot_type == 'lang_inventory_analysis':
        lang_inventory_analysis_plot(args)
    elif args.plot_type == 'LSIM':
        LSIM_table(args)
    elif args.plot_type == 'LSIM_version2':
        LSIM_table_version2(args)
    elif args.plot_type == 'Concat_LSIM':
        Concat_LSIM_table(args)
    elif args.plot_type == 'POS_LSIM':
        POS_LSIM_table(args)
    elif args.plot_type == 'CrossLingual_LSIM':
        CrossLingual_LSIM(args)
    elif args.plot_type == 'CrossLingual_LSIM_OOV_table':
        CrossLingual_LSIM_OOV_table(args)
    elif args.plot_type == 'LSIM_correlation_rank_maps':
        LSIM_correlation_rank_maps(args)
    elif args.plot_type == 'LSIM_correlation_rank_maps_individual_langs':
        LSIM_correlation_rank_maps_individual_langs(args)
    elif args.plot_type == 'LSIM_dissimilar_words_similarly_ranked':
        LSIM_dissimilar_words_similarly_ranked(args)
    elif args.plot_type == 'FP_FN_TP_TN_table_per_language':
        FP_FN_TP_TN_table_per_language(args)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to pass for plotting or printing tables for the paper')
    parser.add_argument('--plot_type', type=str, default='LSIM_correlation_rank_maps_individual_langs')
    parser.add_argument('--languages', type=str, default='ar_en_es_fi_fr_he_pl_ru_zh')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--bold_best_CLSIM', type=utils.str2bool, default=True)
    # parser.add_argument('--embed_type', type=str, default='BERT')  # binary and sum much better than pairwise product!!!
    parser.add_argument('--rank_method', type=str, default='average')
    # parser.add_argument('--PCA', type=str2bool, default=True)
    # parser.add_argument('--use_gpu', type=str2bool, default=True)
    args = parser.parse_args()
    main(args)