import torch
import math
import numpy
from torch.autograd import Variable
from torch import nn
from torch.nn import init
from torch import optim
from best.entity_mention_encoder import EntityMentionEncoder
from best.relation_encoder import RelationEncoder
from best.event_encoder import EventEncoder
from best.entity_encoder import EntityEncoder
from best.token_encoder import TokenEncoder
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
from best.best_evaluator import score_pst_tuples
from best.loaders import belief_lbl_enc, belief_encoder, belief_encoder_inverter
from best.metadata import get_relation_metadata, get_event_metadata
from best.pretrained import PretrainedEmbeddings
from best.predict import Predict

class BeliefPredictor(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size,
                 tagset_size, num_layers=1, class_weights=None,
                 bidirectional=False, dropout=0.0, pretrained=False,
                 attention="multilinear",relation_metadata=(None,None,None),
                 event_metadata=(None,None,None), encode_relation="concat",
                 encode_event="concat", attention_hyperparam=5, parameterization="rank"):

        super(BeliefPredictor, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        if bidirectional:
            if hidden_dim % 2 == 1:
                raise ValueError
            else:
                hidden_dim = int(hidden_dim/2)
        self.vocab_size = vocab_size
        self.tagset_size = tagset_size
        self.attention = attention
        self.pair_dim_ = self.get_pair_dim()
        self.pretrained = pretrained
        self.attention_hyperparam = attention_hyperparam
        self.parameterization = parameterization
        self.dropout = dropout
        self.predict = Predict("belief", 4, self.hidden_dim, self.pair_dim_, self.resolve_author, self.dropout, self.attention, self.parameterization, self.attention_hyperparam)
        self.entity_mention_encoder = EntityMentionEncoder()
        self.relation_metadata = relation_metadata
        self.encode_relation = encode_relation
        self.relation_encoder = RelationEncoder(self.hidden_dim, self.resolve_author, relation_metadata)
        self.event_metadata = event_metadata
        self.encode_event = encode_event
        self.event_encoder = EventEncoder(self.hidden_dim, self.resolve_author, event_metadata)
        self.entity_encoder = EntityEncoder()
        self.enc_same_author = nn.Parameter(torch.Tensor(self.hidden_dim))
        self.enc_diff_author = nn.Parameter(torch.Tensor(self.hidden_dim))
        init.uniform_(self.enc_same_author, -0.01, 0.01)
        init.uniform_(self.enc_diff_author, -0.01, 0.01)
        if not self.pretrained:
            self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers,
                          bidirectional=bidirectional, dropout=self.dropout)
        if self.pretrained == "doc_projection":
            # Pre refers to before RNN
            self.pre_projection = nn.Linear(self.embedding_dim, self.embedding_dim)
            # Post refers to after RNN
            self.post_projection = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.enc_missing_src = nn.Parameter(torch.Tensor(self.hidden_dim))
        init.uniform_(self.enc_missing_src, -0.01, 0.01)
        self.loss = nn.CrossEntropyLoss(weight=class_weights,
                                        size_average=False)

    def get_pair_dim(self):
        # in the future, we could combine pairs differently
        # and have subclasses or something
        if self.attention:
            return 3 * self.hidden_dim
        else:
            return 2 * self.hidden_dim

    def resolve_author(self, encoded_value, post_author):
        if hasattr(encoded_value, 'lower'):
            if encoded_value == post_author:
                encoded_value = self.enc_same_author
            else:
                encoded_value = self.enc_diff_author
        return encoded_value

    def forward(self, doc, doc_ix):
        # get embeddings for all words in doc
        if self.pretrained:
            embeds = doc_ix
            if self.pretrained == "doc_projection":
                embeds = self.pre_projection(embeds)
            embeds = embeds.unsqueeze(1)
            encoded_doc, _ = self.rnn(embeds)
            encoded_doc = encoded_doc.squeeze(1)
            if self.pretrained == "doc_projection":
                encoded_doc = self.post_projection(encoded_doc)
        else:
            embeds = self.embed(doc_ix)
            embeds = embeds.unsqueeze(1)
            encoded_doc, _ = self.rnn(embeds)
            encoded_doc = encoded_doc.squeeze(1)

        # encode all entity mentions
        enc_mentions = {
            mtn_id: self.entity_mention_encoder.encode_entity_mention(mtn_id, doc, encoded_doc)
            for mtn_id in doc.evaluator_ere.entity_mentions}

        enc_entities = {
            ent_id: self.entity_encoder.encode_entity(ent_id, doc, enc_mentions)
            for ent_id in doc.evaluator_ere.entities}

        enc_entities[None] = self.enc_missing_src

        # encode relation mentions
        for mtn_id in doc.evaluator_ere.relation_mentions:
            enc_mentions[mtn_id] = self.relation_encoder(mtn_id,
                                                         doc,
                                                         encoded_doc,
                                                         enc_mentions,
                                                         enc_mention=self.encode_relation)

        # encode event mentions
        for mtn_id in doc.evaluator_ere.event_mentions:
            enc_mentions[mtn_id] = self.event_encoder(mention_id=mtn_id,
                                                      enc_mentions=enc_mentions,
                                                      doc=doc,
                                                      encoded_doc=encoded_doc,
                                                      enc_mention=self.encode_event)
        return self.predict(doc, encoded_doc, enc_entities, enc_mentions)

    def compute_loss(self, pair_groups, y_groups, average=True,
                     return_counts=True):
        loss = sum(self.loss(pairs, y)
                   for pairs, y in zip(pair_groups, y_groups)
                   if not pairs.data.numel() == 0)

        n_pairs = sum(len(y) for y in y_groups)

        if average:
            loss /= n_pairs

        if return_counts:
            correct = sum(
                (pairs.data.numpy().argmax(axis=1) == y.data.numpy()).sum()
                for pairs, y in zip(pair_groups, y_groups)
                if not pairs.data.numel() == 0)
            return loss, correct, n_pairs
        else:
            return loss


def pack_data(docs, indices):
    data = []
    for doc, doc_ix in zip(docs, indices):
        try:
            hasattr(doc_ix, "data")
        except RuntimeError:
            doc_ix = Variable(doc_ix, requires_grad=False)
        ys = (doc.belief_labels(doc.pairs_relation),
              doc.belief_labels(doc.pairs_event))
        ys = [Variable(torch.from_numpy(y), requires_grad=False)
              for y in ys]
        data.append((doc, doc_ix, ys))
    return data


def _score_doc(doc, y_preds, null_only=False, include_relations=True, include_events=True,
               valid_and_print_output=False):
    pair_groups = (doc.pairs_relation, doc.pairs_event)
    beliefs = []
    for pairs, y_pred in zip(pair_groups, y_preds):
        for (src, trg), y in zip(pairs, belief_encoder_inverter(y_pred)):
            if null_only and src is not None:
                continue
            if not include_relations and trg[0] == 'r':
                continue
            if not include_events and trg[0] == 'e':
                continue
            beliefs.append((src, trg, y))
    ann_pred = doc.build_pseudo_annotations(beliefs=beliefs)
    pst_pred = doc.state_tuples(ann_pred, sentiment=False)
    pst_true = doc.state_tuples(sentiment=False)
    pst_true_duplicate = []
    for pst in pst_true:
        if null_only and pst.source_entity is not None:
            continue
        if not include_relations and pst.target_object[0] == 'r':
            continue
        if not include_events and pst.target_object[0] == 'h':
            continue
        pst_true_duplicate.append(pst)
    pst_true = pst_true_duplicate
    predicted = len(pst_pred)
    required = len(pst_true)
    # pst_true will be altered in score_pst_tuples
    tps, fn, fp = score_pst_tuples(pst_true, pst_pred)
    return tps, fn, fp, predicted, required


def validate(model, data):
    model.eval()
    total_tps = total_fn = total_fp = total_predicted = total_required = 0
    validation_configs = {}
    # Null_only, include_relations, include_events
    # Corresponds to all validation data
    validation_configs[(False, True, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with null source
    validation_configs[(True, True, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with target relations
    validation_configs[(False, True, False)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    # Corresponds to validation data with target events
    validation_configs[(False, False, True)] = (total_tps, total_fn, total_fp, total_predicted, total_required)
    total_loss = 0
    for doc, doc_ix, y in data:
        pair_groups = model(doc, doc_ix)
        loss_ = model.compute_loss(pair_groups, y, return_counts=False)
        total_loss = total_loss + loss_
        y_pred = [pair_scores.data.numpy().argmax(axis=1)
                  for pair_scores in pair_groups]
        for key in validation_configs:
            (null_only, include_relations, include_events) = key

            tps, fn, fp, predicted, required = _score_doc(doc, y_pred, null_only=null_only,
                                                          include_relations=include_relations,
                                                          include_events=include_events,
                                                          valid_and_print_output=True)
            total_tps, total_fn, total_fp, total_predicted, total_required = validation_configs[key]
            total_tps += tps
            total_fn += fn
            total_fp += fp
            total_predicted += predicted
            total_required += required
            validation_configs[key] = (total_tps,
                                       total_fn,
                                       total_fp,
                                       total_predicted,
                                       total_required)
    for key in validation_configs:
        (total_tps, total_fn, total_fp, total_predicted, total_required) = validation_configs[key]
        micro_p = total_tps / total_predicted if total_predicted > 0 else 1
        micro_r = total_tps / total_required if total_required > 0 else 1
        micro_f = (2 * micro_p * micro_r / (micro_p + micro_r) if micro_p + micro_r > 0 else 0)
        validation_configs[key] = (micro_p, micro_r, micro_f)
    return validation_configs, total_loss


def _compute_class_weights(data):
    y = torch.cat([y for _, _, ys in data for y in ys if not y.data.numel() == 0])
    cw = compute_class_weight(class_weight='balanced',
                              classes=range(len(belief_lbl_enc.classes_)),
                              y=y.data.numpy())
    return torch.from_numpy(cw.astype('float32'))


def belief_model_handler(train_docs, valid_docs, path_to_embeddings, embedding_dim=50,
                         n_epochs=10, hidden_dim=40, initial_learning_rate=0.12,
                         momentum=0.0, weight_decay=0.0001, num_layers=1, bidirectional=False,
                         dropout=0.0, batch=12, pretrained=False, optimize='SGD',
                         attention="multilinear", encode_relation="concat", encode_event="concat",
                         attention_hyperparam=5, parameterization="rank", batch_mentions=True):

    if pretrained:
        embeds = PretrainedEmbeddings(path_to_embeddings, embedding_dim, pretrained)
        train_ix = embeds(train_docs)
        valid_ix = embeds(valid_docs)
        vocab_size = 0
    else:
        vocab = TokenEncoder().fit(train_docs)
        train_ix = vocab.transform(train_docs)
        valid_ix = vocab.transform(valid_docs)
        vocab_size = vocab.size()
    train_data = pack_data(train_docs, train_ix)
    valid_data = pack_data(valid_docs, valid_ix)
    class_weights = _compute_class_weights(train_data)
    model = BeliefPredictor(embedding_dim=embedding_dim,
                            hidden_dim=hidden_dim,
                            vocab_size=vocab_size,
                            tagset_size=len(belief_lbl_enc.classes_),
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            dropout=dropout,
                            class_weights=class_weights,
                            pretrained=pretrained,
                            attention=attention,
                            relation_metadata=get_relation_metadata(train_docs),
                            event_metadata=get_event_metadata(train_docs),
                            encode_relation=encode_relation,
                            encode_event=encode_event,
                            attention_hyperparam=attention_hyperparam,
                            parameterization=parameterization)

    if optimize == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=initial_learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif optimize == 'Adam':
        optimizer = optim.Adam(model.parameters())
    elif optimize == 'RMSProp':
        optimizer = optim.RMSProp(model.parameters())
    else:
        raise NotImplementedError

    n_train = len(train_data)
    validation_results = {"loss":[]}
    for epoch in range(n_epochs):
        permutation = torch.randperm(n_train)
        correct = 0
        total = 0
        if batch_mentions:
            minibatch = 1
        else:
            minibatch = batch
        number_of_minibatches = math.ceil(n_train/minibatch)
        model.train()
        for group in (range(0, number_of_minibatches)):
            model.zero_grad()
            total_loss = 0
            start_idx = group * minibatch
            end_idx = min((group+1) * minibatch, n_train)
            for i in tqdm(range(start_idx, end_idx)):
                j = permutation[i]
                doc, doc_ix, y_groups = train_data[j]
                pair_groups = model(doc, doc_ix)
                # TODO: Slice this correctly to handle batching on mentions
                loss_, correct_, total_ = model.compute_loss(pair_groups, y_groups, return_counts=True)
                total_loss += loss_
                correct += correct_
                total += total_
            if not batch_mentions:
                total_loss.backward()
                optimizer.step()
        print("Epoch {:03d} train acc={:.3f}".format(epoch, correct / total))
        if epoch%3 == 0:
            validation_print = {}
            validation_print[(False, True, True)] = "All"
            # Corresponds to validation data with null source
            validation_print[(True,  True, True)] = "Null source only"
            # Corresponds to validation data with target relations
            validation_print[(False, True, False)] = "Target relations"
            # Corresponds to validation data with target events
            validation_print[(False, False, True)] = "Target events"
            validation_configs, validation_loss = validate(model, valid_data)
            for validation_config in validation_configs:
                print(validation_print[validation_config])
                (p, r, f) = validation_configs[validation_config]
                print("          micro P={:.3f} R={:.3f} F1={:.3f}".format(p, r, f))
                if validation_print[validation_config] in validation_results:
                    validation_results[validation_print[validation_config]].append((p,r,f))
                else:
                    validation_results[validation_print[validation_config]] = [(p,r,f)]
            print(validation_loss)
            validation_results["loss"].append(validation_loss)
    return validation_results
