import torch
import math
import numpy
import random
from torch.autograd import Variable
from torch import nn
from torch.nn import init
from torch import optim
from best.token_encoder import TokenEncoder
from best.entity_mention_encoder import EntityMentionEncoder
from best.relation_encoder import RelationEncoder
from best.event_encoder import EventEncoder
from best.entity_encoder import EntityEncoder
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
from best.best_evaluator import score_pst_tuples
from best.loaders import sentiment_lbl_enc, sentiment_encoder, sentiment_encoder_inverter
from best.metadata import get_relation_metadata, get_event_metadata
from best.pretrained import PretrainedEmbeddings
from best.predict import Predict


class SentimentPredictor(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size,
                 tagset_size, num_layers=1, class_weights=None,
                 bidirectional=False, dropout=0.0, pretrained=False,
                 attention="multilinear",relation_metadata=(None,None,None),
                 event_metadata=(None,None,None), encode_relation="concat",
                 encode_event="concat", attention_hyperparam=5, parameterization="rank",use_entities=True):
        super(SentimentPredictor, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        if bidirectional:
            if hidden_dim % 2 == 1:
                raise ValueError
            else:
                hidden_dim = int(hidden_dim / 2)
        self.vocab_size = vocab_size
        self.tagset_size = tagset_size
        self.attention = attention
        self.pair_dim_ = self.get_pair_dim()
        self.pretrained = pretrained
        self.attention_hyperparam = attention_hyperparam
        self.parameterization = parameterization
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(p=self.dropout, inplace=False)
        if not self.pretrained:
            self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.predict = Predict("sentiment", 3, self.hidden_dim, self.pair_dim_, self.resolve_author, self.dropout, self.attention, self.parameterization, self.attention_hyperparam, use_entities)
        self.entity_mention_encoder = EntityMentionEncoder()
        self.relation_metadata = relation_metadata
        self.encode_relation = encode_relation
        self.relation_encoder = RelationEncoder(self.hidden_dim, self.resolve_author, relation_metadata)
        self.event_metadata = event_metadata
        self.encode_event = encode_event
        self.event_encoder = EventEncoder(self.hidden_dim, self.resolve_author, event_metadata)
        self.entity_encoder = EntityEncoder()
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers,
                          bidirectional=bidirectional, dropout=self.dropout)
        if self.pretrained == "doc_projection":
            self.pre_projection = nn.Linear(self.embedding_dim, self.embedding_dim)
            self.post_projection = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.enc_same_author = nn.Parameter(torch.Tensor(self.hidden_dim))
        self.enc_diff_author = nn.Parameter(torch.Tensor(self.hidden_dim))
        init.uniform_(self.enc_same_author, -0.01, 0.01)
        init.uniform_(self.enc_diff_author, -0.01, 0.01)
        self.enc_missing_src = nn.Parameter(torch.Tensor(self.hidden_dim))
        init.uniform_(self.enc_missing_src, -0.01, 0.01)
        self.loss_entity = nn.CrossEntropyLoss(weight=class_weights["entity"],
                                        reduction='mean')
        self.loss_relation = nn.CrossEntropyLoss(weight=class_weights["relation"],
                                        reduction='mean')
        self.loss_event = nn.CrossEntropyLoss(weight=class_weights["event"],
                                        reduction='mean')
        self.loss_all = nn.CrossEntropyLoss(weight=class_weights["all"],
                                        reduction='mean')
        del class_weights

    def printnorm(self, input):
        print('Norm: {}'.format(torch.norm(input)))

    def get_pair_dim(self):
        if self.attention:
            return 3 * self.hidden_dim
        else:
            return 2 * self.hidden_dim

    def resolve_author(self, encoded_value, post_author):
        if hasattr(encoded_value, 'lower'):
            if encoded_value == post_author:
                encoded_value = self.enc_same_author
            else:
                encoded_value = self.enc_diff_author
        return encoded_value

    def forward(self, doc, doc_ix):
        # get embeddings for all words in doc
        if self.pretrained:
            embeds = doc_ix
            if self.pretrained == "doc_projection":
                embeds = self.pre_projection(embeds)
            embeds = embeds.unsqueeze(1)
            encoded_doc, _ = self.rnn(embeds)
            encoded_doc = encoded_doc.squeeze(1)
            if self.pretrained == "doc_projection":
                encoded_doc = self.post_projection(encoded_doc)
        else:
            embeds = self.embed(doc_ix)
            embeds = embeds.unsqueeze(1)
            encoded_doc, _ = self.rnn(embeds)
            encoded_doc = encoded_doc.squeeze(1)
        # encode "all" entity mentions
        enc_mentions = {
            mtn_id: self.entity_mention_encoder.encode_entity_mention(mtn_id, doc, encoded_doc)
            for mtn_id in doc.evaluator_ere.entity_mentions}
        enc_entities = {
            ent_id: self.entity_encoder.encode_entity(ent_id, doc, enc_mentions)
            for ent_id in doc.evaluator_ere.entities}

        enc_entities[None] = self.enc_missing_src

        # encode relation mentions
        for mtn_id in doc.evaluator_ere.relation_mentions:
            enc_mentions[mtn_id] = self.relation_encoder(mtn_id,
                                                         doc,
                                                         encoded_doc,
                                                         enc_mentions,
                                                         enc_mention=self.encode_relation)

        # encode event mentions
        for mtn_id in doc.evaluator_ere.event_mentions:
            enc_mentions[mtn_id] = self.event_encoder(mention_id=mtn_id,
                                                      enc_mentions=enc_mentions,
                                                      doc=doc,
                                                      encoded_doc=encoded_doc,
                                                      enc_mention=self.encode_event)
        return self.predict(doc, encoded_doc, enc_entities, enc_mentions)

    def compute_loss(self, pairs, ys, target_type, return_counts=False):
        if target_type == "entity":
            loss = self.loss_entity(pairs, ys)
        elif target_type == "relation":
            loss = self.loss_relation(pairs, ys)
        elif target_type == "event":
            loss = self.loss_event(pairs, ys)
        else:
            loss = self.loss_all(pairs, ys)

        if return_counts:
            n_pairs = len(ys)
            eps = 0.00001
            real_targs = 0
            correct, real_correct = 0,0
            q = 0
            for pair,y in zip(pairs, ys):
                q += 1
                if y.item() != 2:
                    real_targs += 1
                if pair.data.numpy().argmax() == y.data.numpy():
                    correct += 1
                    if y.item() != 2:
                        real_correct += 1
            return (loss, round(correct / (n_pairs + eps), 2), round(real_correct / (real_targs + eps), 2))
        return loss
        # if return_counts:
        #     correct = sum(
        #         (pairs.data.numpy().argmax(axis=1) == y.data.numpy()).sum()
        #         for pairs, y in zip(pair_groups, y_groups)
        #         if not pairs.data.numel() == 0)
        #     return loss, correct, n_pairs
        # else:
        #     return loss

    def compute_total_loss(self, pair_groups, y_groups, use_entities=True, return_counts=False):
        if return_counts:
            if use_entities:
                ents, rels, evs = pair_groups 
                y_ents, y_rels, y_evs = y_groups 
                ent_loss, ent_acc, ent_realacc = self.compute_loss(ents, y_ents, 'entity', return_counts=return_counts) 
                rel_loss, rel_acc, rel_realacc = self.compute_loss(rels, y_rels, 'relation', return_counts=return_counts) 
                ev_loss, ev_acc, ev_realacc = self.compute_loss(evs, y_evs, 'event', return_counts=return_counts)
                loss = ent_loss + rel_loss + ev_loss
                return {'Loss': loss, 'Ent_Acc': ent_acc, 'Ent_RAcc': ent_realacc, 'Rel_Acc': rel_acc, 'Rel_RAcc': rel_realacc, 'Ev_Acc': ev_acc, 'Ev_RAcc': ev_realacc}
            else:
                rels, evs = pair_groups
                _, y_rels, y_evs = y_groups 
                rel_loss, rel_acc, rel_realacc = self.compute_loss(rels, y_rels, 'relation', return_counts=return_counts) 
                ev_loss, ev_acc, ev_realacc = self.compute_loss(evs, y_evs, 'event', return_counts=return_counts)
                loss =  rel_loss + ev_loss
                return {'Loss': loss, 'Rel_Acc': rel_acc, 'Rel_RAcc': rel_realacc, 'Ev_Acc': ev_acc, 'Ev_RAcc': ev_realacc}
        else:
            if use_entities:
                ents, rels, evs = pair_groups 
                y_ents, y_rels, y_evs = y_groups 
                return self.compute_loss(ents, y_ents, 'entity') + self.compute_loss(rels, y_rels, 'relation') + self.compute_loss(evs, y_evs, 'event')
            else:
                rels, evs = pair_groups
                _, y_rels, y_evs = y_groups 
                return self.compute_loss(rels, y_rels, 'relation') + self.compute_loss(evs, y_evs, 'event')
        
def pack_data(docs, indices, parameterization):
    data = []
    for doc, doc_ix in zip(docs, indices):
        try:
            hasattr(doc_ix, "data")
        except RuntimeError:
            doc_ix = Variable(doc_ix, requires_grad=False)
        ys = (doc.sentiment_labels(doc.pairs_entity),
              doc.sentiment_labels(doc.pairs_relation),
              doc.sentiment_labels(doc.pairs_event))
        if parameterization != "classify":
            extra_ys = [[], [], []]
            sources = len(doc.evaluator_ere.entities) + 1
            for c, y in enumerate(ys):
                y = list(y)
                for i in range(0, len(y), sources):
                    y_slice = y[i: i + sources]
                    insert = 0
                    if sum(y_slice) != (2 * len(y_slice)):
                        insert = 2 
                    extra_ys[c] = extra_ys[c] + [insert] + y_slice
            ys = [Variable(torch.Tensor(y), requires_grad=False).long()
                  for y in extra_ys]
        else:
            ys = [Variable(torch.Tensor(y), requires_grad=False).long()
                  for y in ys]
        data.append((doc, doc_ix, ys))
    return data


def _score_doc(doc, y_preds, null_only=False, include_entities=True, include_relations=True, include_events=True, use_entities=True):
    if use_entities:
        pair_groups = (doc.pairs_entity, doc.pairs_relation, doc.pairs_event)
    else:
        pair_groups = (doc.pairs_relation, doc.pairs_event)
    sentiments = []
    for pairs, y_pred in zip(pair_groups, y_preds):
        for (src, trg), y in zip(pairs, sentiment_encoder_inverter(y_pred)):
            if null_only and src is not None:
                continue
            if not include_entities and trg[0] == 'm':
                continue
            if not include_relations and trg[0] == 'r':
                continue
            if not include_events and trg[0] == 'e':
                continue
            sentiments.append((src, trg, y))
    ann_pred = doc.build_pseudo_annotations(sentiments=sentiments)
    pst_pred = doc.state_tuples(ann_pred, belief=False)
    pst_true = doc.state_tuples(belief=False)
    pst_true_duplicate = []
    for pst in pst_true:
        if null_only and pst.source_entity is not None:
            continue
        if not include_entities and pst.target_object[0] == 'e':
            continue
        if not include_relations and pst.target_object[0] == 'r':
            continue
        if not include_events and pst.target_object[0] == 'h':
            continue
        pst_true_duplicate.append(pst)
    pst_true = pst_true_duplicate
    predicted = len(pst_pred)
    required = len(pst_true)
    # pst_true will be altered in score_pst_tuples
    tps, fn, fp = score_pst_tuples(pst_true, pst_pred)
    return tps, fn, fp, predicted, required


def validate(model, data, parameterization, use_entities=True, valid=True):

    model.eval()
    total_tps = total_fn = total_fp = total_predicted = total_required = 0
    validation_configs = {}
    # Null_only, include_entities, include_relations, include_events
    # Corresponds to all validation data
    validation_configs[(False, True, True, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with null source
    validation_configs[(True, True, True, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with target entities
    validation_configs[(False, True, False, False)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with target relations
    validation_configs[(False, False, True, False)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with target events
    validation_configs[(False, False, False, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    results = {'DF': validation_configs.copy(), 'NW': validation_configs.copy(), 'All': validation_configs.copy()}
    total_loss = None
    for doc, doc_ix, y in data:
        pair_groups = model(doc, doc_ix)
        loss_stats = model.compute_total_loss(pair_groups, y, use_entities=use_entities, return_counts=True)
        if random.randint(0, 5) > 4 and valid:
            print("Validation Loss Stats: {}".format(loss_stats))
        loss_ = loss_stats['Loss']
        if not torch.isnan(loss_):
            if total_loss is not None:
                total_loss = total_loss + loss_
            else:
                total_loss = loss_            
        del loss_
        aggressive = True
        if parameterization == 'rank':
            sources = len(doc.evaluator_ere.entities) + 1
            y_pred = []
            for pair_scores in pair_groups:
                pair_preds = []
                labels = 3
                pair_scores = pair_scores.data.numpy()
                for i in range(0, len(pair_scores), sources + 1):
                    preds = []
                    if aggressive:
                        dummy = pair_scores[i]
                        scores = pair_scores[i + 1:i + sources + 1, :labels - 1]
                    else:
                        dummy = pair_scores[i]
                        scores = pair_scores[i + 1:i + sources + 1]
                    make_pred = numpy.argmax(scores[0])
                    preds = numpy.array([labels - 1] * sources)
                    if make_pred == labels - 1:
                        pred_special = numpy.argmax(scores)
                        pos = (pred_special // labels) 
                        val = pred_special % labels 
                        preds[pos] = val
                    pair_preds = numpy.concatenate((pair_preds, preds))
                y_pred.append(pair_preds)
        else:
            y_pred = [pair_scores.data.numpy().argmax(axis=1) for pair_scores in pair_groups]
        for key in validation_configs:
            (null_only, include_entities, include_relations, include_events) = key
            doc_results = _score_doc(doc, y_pred, null_only=null_only,
                                                          include_entities=include_entities,
                                                          include_relations=include_relations,
                                                          include_events=include_events,
                                                          use_entities=use_entities)
            results['All'][key] = tuple(p + q for p, q in zip(results['All'][key], doc_results))
            if doc.doc_id[4:6] == 'DF':
                results['DF'][key] = tuple(p + q for p, q in zip(results['DF'][key], doc_results))
            else:
                results['NW'][key] = tuple(p + q for p, q in zip(results['NW'][key], doc_results))
    for key in validation_configs:
        for result in results:
            (total_tps, total_fn, total_fp, total_predicted, total_required) = results[result][key]
            micro_p = total_tps / total_predicted if total_predicted > 0 else 1
            micro_r = total_tps / total_required if total_required > 0 else 1
            micro_f = 2 * micro_p * micro_r / (micro_p + micro_r) if micro_p + micro_r > 0 else 0
            results[result][key] = (micro_p, micro_r, micro_f)

    return results, total_loss


def _compute_class_weights(data):
    ys = [ys for _, _, ys in data]
    ents, rels, evs = [], [], []
    for ent, rel, ev in ys:
        ents.append(ent)
        rels.append(rel)
        evs.append(ev)
    ys = torch.cat(ents + rels + evs)
    ents = torch.cat(ents)
    rels = torch.cat(rels)
    evs = torch.cat(evs)
    get_weights = (lambda y: compute_class_weight(class_weight='balanced',
                              classes=range(len(sentiment_lbl_enc.classes_)),
                              y=y.data.numpy()))
    class_weights = {name: get_weights(vals) for name, vals in [("entity", ents), ("relation", rels), ("event", evs), ("all", ys)]}
    for key in class_weights:
        class_weight = class_weights[key]
        class_weights[key] = class_weight / sum(class_weight)
    class_weights = {key: torch.from_numpy(class_weights[key].astype('float32')) for key in class_weights}
    return class_weights


def model_handler(train_docs, valid_docs, path_to_embeddings, glove_file, embedding_dim=100,
                  n_epochs=5, hidden_dim=5, initial_learning_rate=0.1,
                  momentum=0.0, weight_decay=0.1, num_layers=1, bidirectional=False,
                  dropout=0.0, batch=20, pretrained=False, optimize='SGD',
                  attention="multilinear", encode_relation="concat", encode_event="concat",
                  attention_hyperparam=5, parameterization="rank", batch_mentions=True):
    if pretrained:
        print("Fetching Pretrained Embeddings")
        embeds = PretrainedEmbeddings(path_to_embeddings, glove_file, embedding_dim, pretrained)
        train_ix = embeds(train_docs)
        valid_ix = embeds(valid_docs)
        vocab_size = 0
        print("Completed Fetching Pretrained Embeddings")
    else:
        vocab = TokenEncoder().fit(train_docs)
        train_ix = vocab.transform(train_docs)
        valid_ix = vocab.transform(valid_docs)
        vocab_size = vocab.size()
    train_data = pack_data(train_docs, train_ix, parameterization)
    valid_data = pack_data(valid_docs, valid_ix, parameterization)
    class_weights = _compute_class_weights(train_data)
    
    use_entities = False
    
    print("Class Weights: {}".format(class_weights))

    model = SentimentPredictor(embedding_dim=embedding_dim,
                               hidden_dim=hidden_dim,
                               vocab_size=vocab_size,
                               tagset_size=len(sentiment_lbl_enc.classes_),
                               num_layers=num_layers,
                               bidirectional=bidirectional,
                               dropout=dropout,
                               class_weights=class_weights,
                               pretrained=pretrained,
                               attention=attention,
                               relation_metadata=get_relation_metadata(train_docs),
                               event_metadata=get_event_metadata(train_docs),
                               encode_relation=encode_relation,
                               encode_event=encode_event,
                               attention_hyperparam=attention_hyperparam,
                               parameterization=parameterization,
                               use_entities=use_entities)

    if optimize == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=initial_learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif optimize == 'Adam':
        optimizer = optim.Adam(model.parameters())
    elif optimize == 'RMSProp':
        optimizer = optim.RMSProp(model.parameters(),
                                  lr=initial_learning_rate,
                                  weight_decay=weight_decay,
                                  momentum=momentum)
    else:
        raise NotImplementedError

    n_train = len(train_data)
    validation_results = {"loss": []}
    print("Begin Training")
    for epoch in range(n_epochs):
        print("Epoch {:03d}".format(epoch))
        permutation = torch.randperm(n_train)
        correct = 0
        total = 0
        minibatch = batch
        number_of_minibatches = int(math.ceil(n_train / minibatch))
        model.train()
        for group in range(number_of_minibatches):
            model.zero_grad()
            total_loss = None
            start_idx = group * minibatch
            end_idx = min((group + 1) * minibatch, n_train)
            for i in (range(start_idx, end_idx)):
                j = permutation[i]
                doc, doc_ix, y_groups = train_data[j]
                pair_groups = model(doc, doc_ix)
                loss_ = model.compute_total_loss(pair_groups, y_groups, use_entities=use_entities)
                if not total_loss:
                    total_loss = loss_
                else:
                    total_loss += loss_
            if random.randint(0, 5) == 2:
                print('Training Loss for Group {}: {}'.format(group, total_loss))
            total_loss.backward()
            optimizer.step()
            del pair_groups
            del total_loss
            del loss_ 
        if epoch % 5 == 0:
            print("Validation for epoch {:03d}".format(epoch))
            validation_print = {}
            validation_print[(False, True, True, True)] = "All"
            # Corresponds to validation data with null source
            validation_print[(True, True, True, True)] = "Null source only"
            # Corresponds to validation data with target entities
            validation_print[(False, True, False, False)] = "Target entities"
            # Corresponds to validation data with target relations
            validation_print[(False, False, True, False)] = "Target relations"
            # Corresponds to validation data with target events
            validation_print[(False, False, False, True)] = "Target events"
            training_results, training_loss = validate(model, train_data, parameterization, use_entities=use_entities, valid=False)
            validation_results, validation_loss = validate(model, valid_data, parameterization, use_entities=use_entities, valid=True)
            for config in validation_print:
                print(validation_print[config])
                for result in training_results:
                    (tp, tr, tf) = training_results[result][config]
                    (vp, vr, vf) = validation_results[result][config]
                    print("          Doc type: {} Train micro P={:.3f} R={:.3f} F1={:.3f}".format(result, tp, tr, tf))
                    print("          Doc type: {} Valid micro P={:.3f} R={:.3f} F1={:.3f}".format(result, vp, vr, vf))
            print("Training Loss: {}  Validation Loss: {}".format(training_loss, validation_loss))
            del validation_print 
            del validation_loss
    print("Completed Training")
    exit()
    return validation_results
