import conllu
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn import preprocessing
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import torch
from collections import defaultdict, Counter

import h5py
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch.optim as optim

import sys
sys.path.insert(0, "../")
import os
# from get_data import (
#     ParallelSentenceDataFamilies,
#     ParallelSentenceDataSyntax,
#     get_lang_texts_ud,
#     run_bert,
# )

BERT_DIM = 768

def get_tokens_and_labels(filename, limit=-1, case_set=None, role_set=["A","O"],
                          balanced=False):
    """
    Parameters:
    filename: the location of the treebank (conll file)
    limit: how many relevant examples should this corpus contain? Relevant means
           nouns of a role in ROLE_SET and CASE_SET (if not None), and balanced if BALANCED
    case_set: What cases to count as cases
    role_set: Which ASO roles to count.
    balanced: Should we balance this corpus to have the same number of examples
              from every role

    This function parses the conll file to get:
    - labels: A dict, whose keys are types of labels (eg, "animacy"), and each 
        value is a list of length num_sentences
    - num_relevant_examples: how many nouns are both in the given case_set and 
        role_set
    - relevant_examples_index: a list of (sent_idx, word_idx) tuples, telling
        us where in the data the relevant examples (the nouns which are in 
        case_set and role_set) are.
    - cases_per_role: for every role, the distribution of cases
    """
    with open(filename) as f:
        conll_data = f.read()
    sentences = conllu.parse(conll_data)
    labels = defaultdict(list)
    cases_per_role = defaultdict(Counter)
    relevant_examples_index = []
    if balanced:
        assert role_set is not None, "Must provide which roles to balance if we're balancing!"
    # Closed set of possibilities if balanced, open otherwise
    if balanced:
        role_example_counts = dict([(role, 0) for role in role_set])
    else:
        role_example_counts = Counter()
    num_nouns = 0
    num_relevant_examples = 0
    for sent_i, tokenlist in enumerate(sentences):
        sentence_info = defaultdict(list)
        if "sent_id" in tokenlist.metadata.keys():
            sentence_info["sent_id"] = [tokenlist.metadata["sent_id"]]*len(tokenlist)
        noun_count = 0
        for token in tokenlist:
            token_info = get_token_info(token, tokenlist)
            token_case = None
            token_animacy = ""
            if token_info["role"] is not None:
                if token['feats'] and 'Case' in token['feats']:
                    token_case = token['feats']['Case']
                if token['feats'] and 'Animacy' in token['feats']:
                    token_animacy = token['feats']['Animacy']
            token_info["case"] = token_case
            token_info["animacy"] = token_animacy
            sentence_info["token"].append(token['form'])
            for label_type in token_info.keys():
                sentence_info[label_type].append(token_info[label_type])
            sentence_info["preceding_nouns"].append(noun_count)
            if token["upostag"] == "NOUN" or token["upostag"] == "PROPN" or token["upostag"]=="PRON":
                noun_count += 1
        for label_type in sentence_info.keys():
            labels[label_type].append(sentence_info[label_type])
        assert len(sentence_info["case"]) == len(sentence_info["role"]), \
               "Length of case and role should be the same for every sentence (though both lists can include Nones)"
        for i in range(len(sentence_info["role"])):
            role_ok = role_set is None or sentence_info["role"][i] in role_set
            role_ok = role_ok and sentence_info["role"][i] is not None
            case_ok = case_set is None or sentence_info["case"][i] in case_set
            if role_ok and case_ok:
                relevant_examples_index.append((sent_i, i))
                role_example_counts[sentence_info["role"][i]] += 1
            cases_per_role[sentence_info["role"][i]][sentence_info["case"][i]] += 1
        if limit > 0:
            if balanced:
                num_relevant_examples = min(role_example_counts.values())*len(role_example_counts)
            else:
                num_relevant_examples = sum(role_example_counts.values())
            if num_relevant_examples >= limit:
                break
    print("Counts of each role", role_example_counts)
    print("Case counts per role", cases_per_role)
    print("returning from get_tokens, the keys are", list(labels.keys()))
    return dict(labels), num_relevant_examples, relevant_examples_index, cases_per_role

def get_token_info(token, tokenlist):
    token_info = {}
    token_info["role"] = None
    token_info["verb_word"] = ""
    token_info["subject_word"] = ""
    token_info["object_word"] = ""
    if not (token["upostag"] == "NOUN" or token["upostag"] == "PROPN"):
        return token_info

    head_id = token['head']
    head_list = tokenlist.filter(id=head_id)
    head_pos = None
    if len(head_list) > 0:
        head_token = head_list[0]
        if head_token["upostag"] == "VERB":
            head_pos = "verb"
            token_info["verb_word"] = head_token["lemma"]
        elif head_token["upostag"] == "AUX":
            head_pos = "aux"
            token_info["verb_word"] = head_token["lemma"]
        else:
            return token_info

    if "nsubj" in token['deprel']:
        token_info["subject_word"] = token['form']
        has_object = False
        has_expletive_sibling = False
        # 'deps' field is often empty in treebanks, have to look through
        # the whole sentence to find if there is any object of the head
        # verb of this subject (this would determine if it's an A or an S)
        for obj_token in tokenlist:
            if obj_token['head'] == head_id:
                if "obj" in obj_token['deprel']:
                    has_object = True
                    token_info["object_word"] = obj_token["form"]
                if obj_token['deprel'] == "expl":
                    has_expletive_sibling = True
        if has_expletive_sibling:
            token_info["role"] = "S-expletive"
        elif has_object:
            token_info["role"] = "A"
        else:
            token_info["role"] = "S"
        if "pass" in token['deprel']:
            token_info["role"] += "-passive"
    elif "obj" in token['deprel']:
        token_info["role"] = "O"
        token_info["object_word"] = token['form']
        for subj_token in tokenlist:
            if subj_token['head'] == head_id:
                if "subj" in subj_token['deprel']:
                    token_info["subject_word"] = subj_token['form']
    if head_pos == "aux" and token_info["role"] is not None:
        token_info["role"] += "-aux"
    return token_info

def get_bert_tokens(orig_tokens, tokenizer):
    """
    Given a list of sentences, return a list of those sentences in BERT tokens,
    and a list mapping between the indices of each sentence, where
    bert_tokens_map[i][j] tells us where in the list bert_tokens[i] to find the
    start of the word in sentence_list[i][j]
    The input orig_tokens should be a list of lists, where each element is a word.
    """
    bert_tokens = []
    orig_to_bert_map = []
    bert_to_orig_map = []
    for i, sentence in enumerate(orig_tokens):
        sentence_bert_tokens = []
        sentence_map_otb = []
        sentence_map_bto = []
        sentence_bert_tokens.append("[CLS]")
        for orig_idx, orig_token in enumerate(sentence):
            sentence_map_otb.append(len(sentence_bert_tokens))
            tokenized = tokenizer.tokenize(orig_token)
            for bert_token in tokenized:
                sentence_map_bto.append(orig_idx)
            sentence_bert_tokens.extend(tokenizer.tokenize(orig_token))
        sentence_map_otb.append(len(sentence_bert_tokens))
        sentence_bert_tokens = sentence_bert_tokens[:511]
        sentence_bert_tokens.append("[SEP]")
        bert_tokens.append(sentence_bert_tokens)
        orig_to_bert_map.append(sentence_map_otb)
        bert_to_orig_map.append(sentence_map_bto)
    bert_ids = [tokenizer.convert_tokens_to_ids(b) for b in bert_tokens]
    return bert_tokens, bert_ids, orig_to_bert_map, bert_to_orig_map

def get_bert_outputs(hdf5_path, bert_ids, bert_model):
    """
    Given a list of lists of bert IDs, runs them through BERT.
    Cache the results to hdf5_path, and load them from there if available.
    """
    outputs = []
    print(f"Bert vectors file is {hdf5_path}")
    if os.path.exists(hdf5_path):
        try:
            with h5py.File(hdf5_path, 'r') as datafile:
                if len(datafile.keys()) == len(bert_ids):
                    max_key = max([int(key) for key in datafile.keys()])
                    for i in tqdm(range(max_key + 1), desc='[Loading from disk]'):
                        hidden_layers = datafile[str(i)[:]]
                        output = dict([(i, np.array(hidden_layers[i])) for i in range(len(hidden_layers))])
                        outputs.append(output)
                    print(f"Loaded {i} sentences from disk.")
                else:
                    print("Found", len(datafile.keys()), "keys, which doesn't match", len(bert_ids), "data points")
                    print(f"Try removing the file {hdf5_path} and running again")
                    sys.exit(0)
        except OSError:
            print("Encountered hdf5 reading error.  Wiping file...")
            os.remove(hdf5_path)
    else:
        datafile = h5py.File(hdf5_path, 'w')
        with torch.no_grad():
            print(f"Running {len(bert_ids)} sentences through BERT. This takes a while")
            for idx, sentence in enumerate(tqdm(bert_ids)):
                bert_output = bert_model(torch.tensor(sentence).unsqueeze(0))
                hidden_layers = bert_output["hidden_states"]
                output = dict([(i, np.array(hidden_layers[i])) for i in range(len(hidden_layers))])
                outputs.append(output)

                layer_count = len(hidden_layers)
                _, sentence_length, dim = hidden_layers[0].shape
                dset = datafile.create_dataset(str(idx), (layer_count, sentence_length, dim))
                dset[:, :, :] = np.vstack([np.array(x) for x in hidden_layers])
        datafile.close()

    # Word embeddings don't have matrix multiplications, so we might as well do 
    # them on the fly every time since it's more convenient for other reasons.
    with torch.no_grad():
        for idx, sentence in enumerate(tqdm(bert_ids, desc="Getting word and position embeddings")):
            sentence = torch.tensor(sentence)
            word_emb = bert_model.embeddings.word_embeddings(sentence)
            positions = torch.tensor(range(len(sentence))).unsqueeze(0)
            position_emb = bert_model.embeddings(positions)
            outputs[idx]["word_embeddings"] = word_emb
            outputs[idx]["position_embeddings"] = position_emb
    return outputs

class _classifier(nn.Module):
    def __init__(self, nlabel, bert_dim):
        super(_classifier, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(bert_dim, 64),
            nn.ReLU(),
            nn.Linear(64, nlabel),
            nn.Dropout(.1)
        )
    def forward(self, input):
        return self.main(input)

def train_classifier(dataset, logistic):                                        
    if logistic:                                                                
        return train_classifier_logistic(dataset)                               
    else:                                                                       
        return train_classifier_mlp(dataset)                                    

def train_classifier_mlp(train_dataset, epochs=20):
    classifier = _classifier(train_dataset.get_num_labels(), train_dataset.get_bert_dim())
    optimizer = optim.Adam(classifier.parameters())
    criterion = nn.CrossEntropyLoss()

    dataloader = train_dataset.get_dataloader()

    for epoch in range(epochs):
        losses = []
        for emb_batch, role_label_batch, _ in dataloader:
            output = classifier(emb_batch)
            loss = criterion(output, role_label_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.data.mean().item())
        print('[%d/%d] Train loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
    return classifier

def train_classifier_logistic(train_dataset):                                   
    X, y = [], []                                                               
    dataloader = train_dataset.get_dataloader(batch_size=1)                     
    for emb_batch, role_label_batch, _ in dataloader:                           
        X.append(emb_batch[0])                                                  
        y.append(role_label_batch[0])                                           
    X = np.stack(X, axis=0)                                                     
    y = np.stack(y, axis=0)                                                     
    scaler = preprocessing.StandardScaler().fit(X)                              
    X_scaled = scaler.transform(X)                                              
    classifier = LogisticRegression(random_state=0, max_iter=10000).fit(X_scaled, y)
    return classifier                                                           

# Evaluates `classifier`, returning a dict of {role : acc}.
def eval_classifier(classifier, dataset):
    dataloader = dataset.get_dataloader(shuffle=False)
    role_correct = defaultdict(int)
    role_total = defaultdict(int)
    with torch.no_grad():
        for emb_batch, role_label_batch, _ in dataloader:
            output = classifier(emb_batch)
            _, role_predictions = output.max(1)
            #role_label_batch = np.array(role_label_batch)
            for role in set([pred.item() for pred in role_predictions]):
              role_name = dataset.get_label_set()[role]
              role_correct[role_name] += \
                torch.sum(torch.eq(role_predictions[torch.where(role_label_batch == role)],
                                   role_label_batch[torch.where(role_label_batch == role)])).data.item()
              role_total[role_name] += torch.sum(role_label_batch == role).item()
    role_accuracy = {i: role_correct[i] / role_total[i] for i in role_correct}
    return dict(role_accuracy)

# Evaluates a classifier out-of-domain, returning the distribution
# Run dataset through the classifier, and record the results. The results
# are returned in a dictionary, where for every sentence role, we get a dictionary
# of how many words were marked each case. For example:
# {A: {Nom: 25, Acc: 47}, S: {Nom: 26, Acc: 26}, O: {Nom: 40, Acc: 26}}
def eval_classifier_ood(classifier, classifier_labelset, dataset):
    labelset = dataset.get_label_set()
    A_index = dataset.labeldict["A"]
    dataloader = dataset.get_dataloader(shuffle=False, batch_size=1)
    out = defaultdict(lambda: dict([(label, 0) for label in classifier_labelset]))
    rows = defaultdict(list)
    with torch.no_grad():
        for emb_batch, role_label_batch, aux_labels in dataloader:
            output = classifier(emb_batch)
            probs = torch.softmax(output, 1)
            A_prob = probs[:,A_index][0].item()
            _, role_predictions = output.max(1)
            new_row = {}
            rows["probability_A"].append(A_prob)
            rows["predicted_role"].append(labelset[int(role_predictions[0])])
            for label_type in aux_labels.keys():
                label = aux_labels[label_type][0]
                if type(label) == torch.Tensor:
                    label = label.item()
                rows[label_type].append(label)
    df = pd.DataFrame(rows)
    return df

# Evaluates a classifier out-of-domain.
# Takes a list of embeddings rather than a CaseLayerDataset, and returns a similar
# dictionary to eval_classifier_ood except it assumes everything is an "S".
def eval_classifier_ood_list(classifier, emb_list, labelset):
    out = defaultdict(lambda: dict([(label, 0) for label in labelset]))
    with torch.no_grad():
        for embedding in emb_list:
            output = classifier(embedding)
            _, case_pred = output.max(0)
            out["S"][labelset[int(case_pred)]] += 1
    out = {x : dict(out[x]) for x in out}
    return out

def run_classifier(sentence_list, bert_model, bert_tokenizer, classifier,
                   labelset, layer_num=-1):
    """
    Run the classifier on a sentence list. The sentence list does not need to be
    conll, but it does need to be tokenised in the form:
    [["The", "words", "in", "sentence", "one"], ["And", "those", "in", "sentence", "two"]]
    Use the .split(" ") method on a string to achieve that easily.
    """
    bert_tokens, bert_ids, otb_map, bto_map = \
        get_bert_tokens(sentence_list, bert_tokenizer)
    bert_outputs = get_bert_outputs(None, bert_ids, bert_model)
    for i_s, layers in enumerate(bert_outputs):
        sentence = layers[layer_num].squeeze(0)
        for i_w, word in enumerate(sentence):
            if i_w in otb_map[i_s]:
                orig_index = otb_map[i_s].index(i_w)
                output = classifier(torch.tensor(word).unsqueeze(0))
                top_cases = [labelset[int(j)] for j in torch.topk(output, 3)[1][0]]
                print(sentence_list[i_s][orig_index], top_cases)

