import os
import argparse
import utils
import nltk
from nltk import ConcordanceIndex
from nltk.tokenize import wordpunct_tokenize
from tqdm import tqdm
from utils import *
from easydict import EasyDict as edict
import easydict
from nltk.tokenize import RegexpTokenizer
import random
from datasets import load_dataset
from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents
import multiprocessing
import concurrent.futures
import numpy as np
import re
import fasttext.util
from scipy.spatial.distance import cosine
import scipy
import joblib
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from tokenizers.normalizers import Lowercase, NFD, StripAccents
from tokenizers import normalizers
from transformers import BertModel, BertConfig, BertTokenizer
from torch.nn.utils.rnn import pad_sequence
from transformers import FlaubertModel, FlaubertTokenizer

config = get_config()

def check_sentences_for_one_word_occurance_and_max_length(sentence, word, tokenizer):
    input_ids = tokenizer.encode(sentence, add_special_tokens=False, return_tensors='pt')
    length = input_ids.shape[1]
    if length > 500:  # max length is 512
        return False
    word_ids = tokenizer.encode(word, add_special_tokens=False, return_tensors='pt')
    boundary_indices = get_boundary_indices(word_ids, input_ids)
    good_sentence = False
    word_ids = word_ids[0]
    test_ids = input_ids[0, boundary_indices]
    if torch.equal(word_ids, test_ids):
        good_sentence = True
    return good_sentence

def get_context_input_strings(context_example_dict, word, tokenizer, normalizer):
    """Check that word is in context_example_dict to avoid breaking code"""
    if word in context_example_dict:
        context_examples = context_example_dict[word]
        random.shuffle(context_examples)
        word = normalizer.normalize_str(word)
    else:
        return [word, word, word]  # don't want dimension to be 1
    """Return a list with 10 context words"""
    if len(context_examples) > 0:
        """For simplicity, we want to assert that there is only ONE occurence of the word in the sentence
           before adding it"""
        examples_to_return = []
        total_examples = 0
        for x in context_examples:
            x = normalizer.normalize_str(x)
            good_sentence = check_sentences_for_one_word_occurance_and_max_length(x, word, tokenizer)
            if good_sentence:
                examples_to_return.append(x)
                total_examples += 1
            if total_examples >= 10:
                break
        if len(examples_to_return) < 3:
            examples_to_return = [word, word, word]  # don't want dimension to be 1
        return examples_to_return
    else:
        return [word, word, word]  # don't want dimension to be 1

def get_boundary_indices(word_ids, input_ids):
    input_ids = np.squeeze(input_ids.detach().cpu().numpy())
    word_ids = np.squeeze(word_ids.detach().cpu().numpy())
    indices = np.where(np.in1d(input_ids, word_ids))[0]
    return indices

def get_embeddings(lang, context_example_dict, device):
    """Load the proper BERT model based on lang"""
    pretrained_model_dict = {'en': 'bert-base-uncased',
                             'zh': 'bert-base-chinese',
                             'es': 'dccuchile/bert-base-spanish-wwm-uncased',
                             'ar': 'asafaya/bert-base-arabic',
                             'fi': 'TurkuNLP/bert-base-finnish-uncased-v1',
                             'ru': 'DeepPavlov/rubert-base-cased',
                             'pl': 'dkleczek/bert-base-polish-uncased-v1',  # uncased performs better than cased
                             'he': 'onlplab/alephbert-base',
                             'fr': 'flaubert/flaubert_base_cased'}
    normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
    try:
        if lang != 'fr':
            model = BertModel.from_pretrained(pretrained_model_dict[lang])
            model = model.to(device)
            tokenizer = BertTokenizer.from_pretrained(pretrained_model_dict[lang])
        else:
            model = FlaubertModel.from_pretrained(pretrained_model_dict[lang])
            model = model.to(device)
            tokenizer = FlaubertTokenizer.from_pretrained(pretrained_model_dict[lang])

    except:
        raise ValueError(f'language {lang} not implemented')

    """We extract embeddings by averaging up to the 8th layer (pretty much the best on average from original paper)"""
    model.eval()
    embeddings = {2: {}, 4: {}, 6: {}, 8: {}, 10: {}, 12: {}}
    device = model.device
    wordset = [key for key, value in context_example_dict.items()]
    for debug_index, word in tqdm(enumerate(wordset)):
        # word = wordset[1039]
        # word = "frustración"
        normalized_word = normalizer.normalize_str(word)
        if debug_index < 100000000000000000000:
            # if word == "jugador de bolos":
            #     stop = None
            """Get the list of input strings for each word"""
            inputs = get_context_input_strings(context_example_dict, word=word, tokenizer=tokenizer, normalizer=normalizer)
            input_embeds = {2: [], 4: [], 6: [], 8: [], 10: [], 12: []}  # grab all layer sums in one pass through BERT
            """The goal is to batch all the context examples together to do one pass through the BERT model"""
            data = []  # batches data
            # word_ids = tokenizer.encode(word, add_special_tokens=False, return_tensors='pt')
            word_ids = tokenizer.encode(normalized_word, add_special_tokens=False, return_tensors='pt')
            for input in inputs:
                string_input = input
                input_ids = tokenizer.encode(string_input, add_special_tokens=False, return_tensors='pt')
                if input_ids.shape[1] > 1:  # if only one element it squeezes the dimension out and causes problems!
                    input_ids = torch.squeeze(input_ids)
                else:
                    input_ids = torch.squeeze(input_ids, dim=1)
                attention_mask = torch.ones_like(input_ids)
                data.append([input_ids, attention_mask, word, input])
            """Now batch it"""
            token_ids = [x[0] for x in data]
            attention_masks = [x[1] for x in data]
            words = [x[2] for x in data]
            inputs = [x[3] for x in data]
            token_ids = pad_sequence(token_ids, padding_value=tokenizer.pad_token_id, batch_first=True).to(device)
            attention_masks = pad_sequence(attention_masks, padding_value=0, batch_first=True).to(device)
            """Pass through the model"""
            outputs = model(token_ids, attention_mask=attention_masks, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            hidden_states = torch.stack(hidden_states)
            if hidden_states.shape[1] > 1:
                outputs_BATCH = torch.squeeze(hidden_states, dim=1)  # only squeeze if more than one context example!
            else:
                outputs_BATCH = hidden_states
            for batch_output_index, input in enumerate(inputs):
                """The new context dictionary stores underscore-delimited strings, not lists of strings
                                   so underscore-delimited string to list to fit previous code (minimal downstream changes)"""
                outputs = outputs_BATCH[:, batch_output_index, :, :]
                input = inputs[batch_output_index]
                # outputs = torch.squeeze(torch.stack(hidden_states), dim=1)
                temp_outputs = {}
                temp_id = token_ids[batch_output_index].detach().cpu().numpy()
                if len(temp_id.shape) > 0 and temp_id.size > 1:  # more annoying shape problems
                    input_ids = np.squeeze(temp_id)
                else:
                    input_ids = temp_id
                for layer_number in [2, 4, 6, 8, 10, 12]:
                    temp_outputs[layer_number] = outputs[0:layer_number]
                    if len(temp_outputs[layer_number].size()) != 3:
                        print('outputs', temp_outputs[layer_number].size(), inputs)
                    assert len(temp_outputs[layer_number].size()) == 3
                    layer_sum = torch.sum(temp_outputs[layer_number], dim=0)

                    """Only want to sum over time indices of tokens from original word"""
                    """Find the index for the word, then tokenize each word separately to find the right
                    indices to sum over"""
                    sum_indices = get_boundary_indices(word_ids, token_ids[batch_output_index])
                    string_seq = tokenizer.convert_ids_to_tokens(input_ids)
                    # dummy_test = string_seq[list(sum_indices)]
                    dummy_test = [string_seq[x] for x in list(sum_indices)]
                    layer_sum = layer_sum[sum_indices]
                    time_sum = torch.mean(layer_sum, dim=0)
                    time_sum = time_sum.detach().cpu().numpy()

                    # if np.any(np.isnan(time_sum)):
                    #     print('nan word', word, torch.any(torch.isnan(outputs)))
                    #     raise ValueError('numpy array nan value!')

                    input_embeds[layer_number].append(time_sum)
            for layer_number in [2, 4, 6, 8, 10, 12]:
                input_embeds[layer_number] = np.asarray(input_embeds[layer_number])

                is_nan = False
                if np.any(np.isnan(input_embeds[layer_number])):
                    print('nan word', word, torch.any(torch.isnan(outputs)))
                    is_nan = True
                    # raise ValueError('numpy array nan value!')
                if not is_nan:
                    if len(input_embeds[layer_number].shape) != 2:
                        print(input_embeds[layer_number].shape)
                        input_embeds[layer_number] = input_embeds[layer_number].reshape(1, -1)
                    mean_embed = np.sum(input_embeds[layer_number], axis=0)
                    mean_embed = mean_embed / np.linalg.norm(mean_embed)
                    embeddings[layer_number][word] = mean_embed
    return embeddings


def get_context_examples(lang, context_example_path, search_words):
    """Get the context examples"""
    dataset = os.path.join(config.directories.wikipedia, lang + 'wiki.txt')
    file1 = open(dataset, 'r')
    Lines = file1.readlines()

    sentence_split_token = {'en': '.', 'es': '.', 'ar': '.', 'fi': '.', 'fr': '.', 'he': '.',
                            'pl': '.', 'ru': '.', 'zh': '。'}

    by_words = False
    if lang in ['ar', 'en', 'es', 'fi', 'fr', 'he', 'pl', 'ru']:
        by_words = True
    elif lang in ['zh']:
        by_words = False

    normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])  # maybe remove the accent stripping
    # normalizer = normalizers.Sequence([NFD(), Lowercase()])
    original_dataset_length = len(Lines)
    nltk_tokenizer = RegexpTokenizer(r'\w+')
    context_example_dict = {}
    normalized_word_dict = {}
    for word in search_words:
        normalized_word_dict[word] = normalizer.normalize_str(word)
    # search_words_normalized = [normalizer.normalize_str(x) for x in search_words]
    """Initialize blank lists for each word in search_words"""
    for word in search_words:
        context_example_dict[word] = []
    for i, line in tqdm(enumerate(Lines)):
        # sentences = line.split(".")  # wikipedia forms paragraphs
        sentences = line.split(sentence_split_token[lang])
        # sentences_normalized = [normalizer.normalize_str(x) for x in sentences]
        sentences = [normalizer.normalize_str(x) for x in sentences]
        # for sent_index, sentence in enumerate(sentences_normalized):
        for sent_index, sentence in enumerate(sentences):
            """Don't need to put into words!!! Just check for the s_word in the raw sentence string"""
            words = None
            if by_words:
                words = nltk_tokenizer.tokenize(sentence)
            for s_word in search_words:
                normalized_word = normalized_word_dict[s_word]
                # s_word_normalized = normalizer.normalize_str(s_word)  # added this because some source spanish words had accents and weren't being found
                save_sentence = False
                if by_words:
                    if normalized_word in words:
                        save_sentence = True
                elif not by_words:
                    if normalized_word in sentence:
                        save_sentence = True
                if save_sentence:
                    """Just keep the whole sentence for now"""
                    if len(context_example_dict[s_word]) < 300:
                        # context_example_dict[s_word].append(sentences[sent_index])  # keep the unnormalized sentences!
                        context_example_dict[s_word].append(sentence)

        if i % 100000 == 0:
            """Check if we have 100 examples for each word"""
            finished = True
            for word, word_sentences in context_example_dict.items():
                if len(word_sentences) < 50:
                    finished = False
            if finished:
                break
            # print(f'{i}th / {original_dataset_length}')
    dump(context_example_dict, context_example_path)

def print_examples(lang, search_words):
    """Get the context examples"""
    config = get_config()
    dataset = os.path.join(config.directories.wikipedia, lang + 'wiki.txt')
    file1 = open(dataset, 'r')
    Lines = file1.readlines()

    sentence_split_token = {'en': '.', 'es': '.', 'ar': '.', 'fi': '.', 'fr': '.', 'he': '.',
                            'pl': '.', 'ru': '.', 'zh': '。'}

    by_words = False
    if lang in ['ar', 'en', 'es', 'fi', 'fr', 'he', 'pl', 'ru']:
        by_words = True
    elif lang in ['zh']:
        by_words = False

    normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])  # maybe remove the accent stripping
    # normalizer = normalizers.Sequence([NFD(), Lowercase()])
    original_dataset_length = len(Lines)
    nltk_tokenizer = RegexpTokenizer(r'\w+')
    context_example_dict = {}
    normalized_word_dict = {}
    for word in search_words:
        normalized_word_dict[word] = normalizer.normalize_str(word)
    # search_words_normalized = [normalizer.normalize_str(x) for x in search_words]
    """Initialize blank lists for each word in search_words"""
    for word in search_words:
        context_example_dict[word] = []
    for i, line in enumerate(Lines):
        # sentences = line.split(".")  # wikipedia forms paragraphs
        sentences = line.split(sentence_split_token[lang])
        # sentences_normalized = [normalizer.normalize_str(x) for x in sentences]
        sentences = [normalizer.normalize_str(x) for x in sentences]
        # for sent_index, sentence in enumerate(sentences_normalized):
        for sent_index, sentence in enumerate(sentences):
            """Don't need to put into words!!! Just check for the s_word in the raw sentence string"""
            words = None
            if by_words:
                words = nltk_tokenizer.tokenize(sentence)
            for s_word in search_words:
                normalized_word = normalized_word_dict[s_word]
                # s_word_normalized = normalizer.normalize_str(s_word)  # added this because some source spanish words had accents and weren't being found
                save_sentence = False
                if by_words:
                    if normalized_word in words:
                        save_sentence = True
                elif not by_words:
                    if normalized_word in sentence:
                        save_sentence = True
                if save_sentence:
                    print(sentence)


def main(args):
    # args = edict(vars(args))
    languages = args.languages
    languages = languages.split("_")
    if args.print_examples:
        """"""
        for lang in languages:
            print_examples(lang, ["sillage"])
    else:
        config = get_config()
        if not os.path.isdir(config.directories.context_examples):
            os.mkdir(config.directories.context_examples)
        """First, get the set of words to grab vectors for based on the options"""
        for lang in languages:
            unique_words = None
            if args.eval_word_type == 'LSIM':
                word_pairs, unique_words = get_multisimlex(lang)
            """Next, get the context examples"""
            context_examples = None
            context_path = os.path.join(config.directories.context_examples, lang.upper() + '.pkl')
            if not os.path.exists(context_path):
                get_context_examples(lang, context_example_path=context_path, search_words=unique_words)
            else:
                context_examples = load(context_path)
            if args.get_embeddings:
                """Get the embeddings for each word from the pretrained BERT model"""
                device = torch.device('cuda' if args.use_gpu else 'cpu')
                embeddings = get_embeddings(lang, context_examples, device)
                dump_root = os.path.join(config.directories.word_vectors, "BERT", lang)
                os.makedirs(dump_root, exist_ok=True)
                for word, _ in context_examples.items():
                    try:
                        embedding = embeddings[args.layer][word]
                        dump_path = os.path.join(dump_root, word + '.pkl')
                        dump(embedding, dump_path)
                    except:
                        print(word + " failed for some reason...")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Arguments to get word embedding vectors')
    parser.add_argument('--eval_word_type', type=str, default='LSIM')
    parser.add_argument('--languages', type=str, default='fr')  # es_ar_en_fi_fr_he_pl_ru_zh, running: en, es, zh, ar, fi, fr, ru, pl, he
    parser.add_argument('--get_embeddings', type=utils.str2bool, default=True)  # if False, only get context examples
    parser.add_argument('--use_gpu', type=utils.str2bool, default=True)
    parser.add_argument('--layer', type=int, default=8)  # take mean of embeddings up to 8th layer
    parser.add_argument('--print_examples', type=utils.str2bool, default=True)
    args = parser.parse_args()
    main(args)