import os
import argparse

import matplotlib.pyplot as plt

import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from nltk.stem import WordNetLemmatizer
import math


def filter_word_pairs_POS(word_pairs, args):
    new_word_pairs = {}
    unique_words = []
    for key, info in word_pairs.items():
        if info['POS_tag'] == args.POS_type:
            new_word_pairs[key] = info
            unique_words.append(info['word1'])
            unique_words.append(info['word2'])
    return new_word_pairs, unique_words

def is_a_hierarchy(unique_words):
    """"""
    lemmatizer = WordNetLemmatizer()
    is_a_lengths = {}
    missed_words = []
    for word in tqdm(unique_words):
        old_word = word
        word = lemmatizer.lemmatize(word)
        # if word != old_word:
        #     stop = None
        synsets = wn.synsets(word, pos=wn.NOUN)
        if len(synsets) >= 1:
            syn_path_lens = []
            for syn in synsets:
                paths = syn.hypernym_paths()
                path_lengths = [len(path) for path in paths]
                distance_to_root = min(path_lengths)
                syn_path_lens.append(distance_to_root)
            mean_len = np.mean(np.asarray(syn_path_lens))
            if mean_len == 1.0:
                stop = None
            is_a_lengths[word] = mean_len
        else:
            missed_words.append(word)
    return is_a_lengths

def main(args):

    """We use 'sum' option from now on, it's the best."""
    languages = args.languages
    languages = languages.split("_")
    results = {}
    for lang in languages:
        word_pairs, unique_words = get_multisimlex(lang)
        if args.POS_type != '':
            """Filter word_pairs by POS_type!"""
            word_pairs, unique_words = filter_word_pairs_POS(word_pairs, args)
        is_a_lengths = is_a_hierarchy(unique_words)
        sorted_is_a_lengths = {k: v for k, v in sorted(is_a_lengths.items(), key=lambda item: item[1])}
        lengths = [v for k, v in sorted_is_a_lengths.items()]
        min_value = int(math.floor(min(lengths)))
        max_value = int(math.ceil(max(lengths)))
        num_bins = max_value - min_value + 1
        bins = np.linspace(start=min_value, stop=max_value, num=num_bins)
        plt.hist(lengths, bins=bins)
        plt.show()
        min_value = 5
        # for word in unique_words:
        #     synsets = wn.synsets(word)
        #     syn_path_lens = []
        #     for syn in synsets:
        #         path_lengths = [len(path) for path in syn.hypernym_paths()]
        #         distance_to_root = min(path_lengths)
        #         syn_path_lens.append(distance_to_root)
        #     mean_len = np.mean(np.asarray(syn_path_lens))
        #     is_a_lengths[word] = mean_len
        # stop = None

    if args.results_save_path != '':
        config = get_config()
        os.makedirs(config.directories.results, exist_ok=True)
        save_path = os.path.join(config.directories.results, args.results_save_path)
        dump(results, save_path)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to check is-a hierarchy of LSIM evaluation words')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='en')  # ar_en_es_fi_fr_he_pl_ru_zh
    parser.add_argument('--rank_method', type=str, default='average')
    parser.add_argument('--results_save_path', type=str, default='')  # LSIM_cross_colex_sum_10.pkl
    parser.add_argument('--POS_type', type=str, default='')  # nouns, adjectives, verbs, adverbs
    args = parser.parse_args()
    main(args)


