import numpy as np
import gensim
from gensim.models import KeyedVectors
import sys
import argparse
import json
from scipy import stats
from gensim import matutils


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--embedding', type=str, required=False)
    parser.add_argument('--output', type=str, required=True)
    parser.add_argument('--seed', type=int, default=1111)

    args = parser.parse_args()

    return args


def load_lmms(npz_vecs_path):
    lmms = {}
    loader = np.load(npz_vecs_path)
    labels = loader['labels'].tolist()
    vectors = loader['vectors']
    for label, vector in list(zip(labels, vectors)):
        lmms[label] = vector
    return lmms

def load_ares_txt(path):
    sense_vecs = {}
    with open(path, 'r') as sfile:
        for idx, line in enumerate(sfile):
            if idx == 0:
                continue
            splitLine = line.split(' ')
            label = splitLine[0]
            vec = np.array(splitLine[1:], dtype=float)
            dim = vec.shape[0]
            # print('self.dim', self.dim)
            sense_vecs[label] = vec
    return sense_vecs


def get_sk_lemma(sensekey):
    return sensekey.split('%')[0]


def word_assoc(w, A, B, emb):
    # print('w', w, 'A', A, 'B', B)
    return n_similarity(emb,[w],A) - n_similarity(emb,[w],B)


def diff_assoc(X, Y, A, B, emb):
    # print('_________X', X, '----------Y', Y, '----------A', A, '----------B', B)
    word_assoc_X = np.array(list(map(lambda x : word_assoc(x, A, B, emb), X)))
    word_assoc_Y = np.array(list(map(lambda y : word_assoc(y, A, B, emb), Y)))
    mean_diff = np.mean(word_assoc_X) - np.mean(word_assoc_Y)
    std = np.std(np.concatenate((word_assoc_X, word_assoc_Y), axis=0))
    return mean_diff / std


def random_choice(word_pairs, subset_size):
    return np.random.choice(word_pairs,
                            subset_size,
                            replace=False)


def get_bias_scores_mean_err(word_pairs, emb):
    # print('word_pairs', word_pairs)
    # emb = KeyedVectors.load(emb_load)
    subset_size_target = min(len(word_pairs['X']), len(word_pairs['Y'])) // 2
    subset_size_attr = min(len(word_pairs['A']), len(word_pairs['B'])) // 2
    bias_scores = [diff_assoc(
        random_choice(word_pairs['X'], subset_size_target),
        random_choice(word_pairs['Y'], subset_size_target),
        random_choice(word_pairs['A'], subset_size_attr),
        random_choice(word_pairs['B'], subset_size_attr),
        emb) for _ in range(5000)]
    return np.mean(bias_scores), stats.sem(bias_scores)


def n_similarity(emb, ws1, ws2):
        """
        Compute cosine similarity between two sets of words.

        Example::

          >>> trained_model.n_similarity(['sushi', 'shop'], ['japanese', 'restaurant'])
          0.61540466561049689

          >>> trained_model.n_similarity(['restaurant', 'japanese'], ['japanese', 'restaurant'])
          1.0000000000000004

          >>> trained_model.n_similarity(['sushi'], ['restaurant']) == trained_model.similarity('sushi', 'restaurant')
          True

        """
        v1 = [emb[word] for word in ws1]
        v2 = [emb[word] for word in ws2]
        return np.dot(matutils.unitvec(np.array(v1).mean(axis=0)), matutils.unitvec(np.array(v2).mean(axis=0)))


def load_annotated_senses(fn_path):
    word2sense_dict = {}
    with open(fn_path, 'r') as sfile:
        for line in sfile:
            line = line.strip('\n')
            splitLine = line.split(',')
            # print('splitLine', splitLine)
            word = splitLine[0]
            sense = splitLine[1]
            word2sense_dict[word] = sense
    return word2sense_dict


def run_test(config, sense_emb):
    word_pairs = {}
    emb = {}
    min_len = sys.maxsize
    word2sense = load_annotated_senses('data/word2sense.txt')
    word_list_temp = list(config['X']+config['Y']+config['A']+config['B'])
    for word_list_name, word_list in config.items():
        sense_id_list = sense_emb.keys()

        ### Using average method --------
        for word in word_list_temp:
            relevant_sks = []
            for sense in sense_id_list:
                if word == get_sk_lemma(sense):
                    relevant_sks.append(sense)

            if len(relevant_sks)==0:
                continue

            vec = np.mean(np.stack([sense_emb[i] for i in relevant_sks]), axis=0)
            emb[word] = vec
        ### --------------------
       

        ### Using annotated sense ---------
        # for word in word_list_temp:
        #     if word not in word2sense.keys():
        #         continue

        #     sense = word2sense[word]
        #     if sense not in sense_id_list:
        #         continue

        #     vec = np.array(sense_emb[sense])
        #     emb[word] = vec
        ### --------------------

        if word_list_name in ['X', 'Y', 'A', 'B']:
            word_list_filtered = list(filter(lambda x: x in emb and np.count_nonzero(emb[x]) > 0, word_list))
            word_pairs[word_list_name] = word_list_filtered
            if len(word_list_filtered) < 2:
                print('ERROR: Words from list {} not found in embedding\n {}'.\
                format(word_list_name, word_list))
                print('All word groups must contain at least two words')
                return None, None
    # print('emb', emb)

    # emb = KeyedVectors.load_word2vec_format(emb, binary=False)
    # emb_temp = model.wv.save("lmms_2348.wv")
    return get_bias_scores_mean_err(word_pairs, emb)


def eval_weat(sense_emb, output):
    config = json.load(open('data/weat.json'))
    with open(output, 'w') as fw:
        for name_of_test, test_config in config['tests'].items():
            print('name_of_test', name_of_test, 'test_config', test_config)
            mean, err = run_test(test_config, sense_emb)
            if mean is not None:
                mean = str(round(mean, 4))
                err = str(round(err, 4))
                fw.write(f'{name_of_test}\n')
                fw.write(f'Score: {mean}\n')
                fw.write(f'P-value: {err}\n')


def main(args):
    # if args.embedding.endswith('bin'):
    #     binary = True
    # else:
    #     binary = False
    # emb = KeyedVectors.load_word2vec_format(args.embedding, binary=binary)
    # sense_emb = load_lmms('data/lmms_2348.bert-large-cased.fasttext-commoncrawl.npz')
    sense_emb = load_lmms('../senseEmbeddings/external/lmms/lmms_1024.bert-large-cased.npz')
    # sense_emb = load_ares_txt("../senseEmbeddings/external/ares/ares_bert_large.txt")
    # print('lmms', lmms)

    eval_weat(sense_emb, args.output)


if __name__ == '__main__':
    args= parse_args()
    np.random.seed(args.seed)
    main(args)
