import os
import torch

from model import model_utils
from model import encoders
from model import decoders

import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from pytorch_pretrained_bert import BertModel

from model.von_mises_fisher import VonMisesFisher
from model.decorators import auto_init_args, auto_init_pytorch
from torch.autograd import Variable

MAX_LEN = 384


class MultiDisModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, embed_init, experiment):
        super(MultiDisModel, self).__init__()
        self.expe = experiment
        self.eps = self.expe.config.eps
        self.margin = self.expe.config.m
        self.use_cuda = self.expe.config.use_cuda

        self.yencode = getattr(encoders, self.expe.config.yencoder_type)(
            embed_dim=embed_dim,
            embed_init=embed_init,
            hidden_size=self.expe.config.ensize,
            vocab_size=vocab_size,
            log=experiment.log,
            embed_type=self.expe.config.embed_type)

        self.zencode = getattr(encoders, self.expe.config.zencoder_type)(
            embed_dim=embed_dim,
            embed_init=embed_init,
            hidden_size=self.expe.config.ensize,
            vocab_size=vocab_size,
            log=experiment.log,
            embed_type=self.expe.config.embed_type)

        if "lstm" in self.expe.config.yencoder_type.lower():
            y_out_size = 2 * self.expe.config.ensize
        elif self.expe.config.yencoder_type.lower() == "word_avg":
            y_out_size = embed_dim
        else:
            y_out_size = embed_dim

        if "lstm" in self.expe.config.zencoder_type.lower():
            z_out_size = 2 * self.expe.config.ensize
        elif self.expe.config.zencoder_type.lower() == "word_avg":
            z_out_size = embed_dim
        else:
            z_out_size = embed_dim

        self.mean1 = model_utils.get_mlp(
            input_size=y_out_size,
            hidden_size=self.expe.config.mhsize,
            output_size=self.expe.config.ysize,
            n_layer=self.expe.config.ymlplayer,
            dropout=self.expe.config.dp)

        self.logvar1 = model_utils.get_mlp(
            input_size=y_out_size,
            hidden_size=self.expe.config.mhsize,
            output_size=1,
            n_layer=self.expe.config.ymlplayer,
            dropout=self.expe.config.dp)

        self.mean2 = model_utils.get_mlp(
            input_size=z_out_size,
            hidden_size=self.expe.config.mhsize,
            output_size=self.expe.config.zsize,
            n_layer=self.expe.config.zmlplayer,
            dropout=self.expe.config.dp)

        self.logvar2 = model_utils.get_mlp(
            input_size=z_out_size,
            hidden_size=self.expe.config.mhsize,
            output_size=self.expe.config.zsize,
            n_layer=self.expe.config.zmlplayer,
            dropout=self.expe.config.dp)

        if self.expe.config.zencoder_type.lower() == "word_avg":
            assert self.expe.config.decoder_type.lower() == "bag_of_words"

        self.decode = getattr(decoders, self.expe.config.decoder_type)(
            ysize=self.expe.config.ysize,
            zsize=self.expe.config.zsize,
            mlp_hidden_size=self.expe.config.mhsize,
            mlp_layer=self.expe.config.mlplayer,
            hidden_size=self.expe.config.desize,
            dropout=self.expe.config.dp,
            vocab_size=vocab_size)

        self.pos_decode = model_utils.get_mlp(
            input_size=self.expe.config.zsize + embed_dim,
            hidden_size=self.expe.config.mhsize,
            n_layer=self.expe.config.mlplayer,
            output_size=MAX_LEN,
            dropout=self.expe.config.dp)

        self.classifier = nn.Linear(self.expe.config.zsize, 2)

    def pos_loss(self, mask, vecs):
        batch_size, seq_len = mask.size()
        # batch size x seq len x MAX LEN
        logits = self.pos_decode(vecs)
        if MAX_LEN - seq_len:
            padded = torch.zeros(batch_size, MAX_LEN - seq_len).cuda()
            new_mask = 1 - torch.cat([mask, self.to_var(padded)], -1)
        else:
            new_mask = 1 - mask
        new_mask = new_mask.unsqueeze(1).expand_as(logits)
        logits.data.masked_fill_(new_mask.data.byte(), -float('inf'))
        loss = F.softmax(logits, -1)[:, np.arange(int(seq_len)),
               np.arange(int(seq_len))]
        loss = -(loss + self.eps).log() * mask

        loss = loss.sum(-1) / mask.sum(1)
        return loss.mean()

    def binary_cross_entropyloss(self, prob, target, weight=None):
        loss = -weight * (target * torch.log(prob) + (1 - target) * (torch.log(1 - prob)))
        loss = torch.sum(loss) / torch.numel(target)
        return loss

    def language_loss1(self, vecs, label):
        vecs = self.classifier(vecs)
        y = F.sigmoid(vecs)
        if label:
            output = torch.ones(y.shape[0], dtype=torch.long)
        else:
            output = torch.zeros(y.shape[0], dtype=torch.long)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(y, output)
        return loss

    def language_loss(self, vecs, label):
        vecs = self.classifier(vecs)
        y = F.sigmoid(vecs)
        """if label:
            output = torch.ones(y.shape[0], dtype=torch.long)
        else:
            output = torch.zeros(y.shape[0], dtype=torch.long)"""
        output1 = torch.ones((y.shape[0], 1))
        output2 = torch.zeros((y.shape[0], 1))
        vectors_concat = []
        if label:
            vectors_concat.append(output1)
            vectors_concat.append(output2)
            output = torch.cat(vectors_concat, 1)
        else:
            vectors_concat.append(output2)
            vectors_concat.append(output1)
            output = torch.cat(vectors_concat, 1)
        weight = torch.ones(y.shape[1])
        # criterion = nn.CrossEntropyLoss()
        loss = self.binary_cross_entropyloss(y, output, weight)
        return loss

    def sample_gaussian(self, mean, logvar):
        sample = mean + torch.exp(0.5 * logvar) * \
                 Variable(logvar.data.new(logvar.size()).normal_())
        return sample

    def to_var(self, inputs):
        if self.use_cuda:
            if isinstance(inputs, Variable):
                inputs = inputs
                inputs.volatile = self.volatile
                return inputs
            else:
                if not torch.is_tensor(inputs):
                    inputs = torch.from_numpy(inputs)
                return Variable(inputs, volatile=self.volatile)
        else:
            if isinstance(inputs, Variable):
                inputs = inputs.cpu()
                inputs.volatile = self.volatile
                return inputs
            else:
                if not torch.is_tensor(inputs):
                    inputs = torch.from_numpy(inputs)
                return Variable(inputs, volatile=self.volatile)

    def to_vars(self, *inputs):
        return [self.to_var(inputs_) if inputs_ is not None and
                                        inputs_.size else None for inputs_ in inputs]

    def optimize(self, loss):
        self.opt.zero_grad()
        loss.backward()
        if self.expe.config.gclip is not None:
            torch.nn.utils.clip_grad_norm(
                self.parameters(), self.expe.config.gclip)
        self.opt.step()

    def init_optimizer(self, opt_type, learning_rate, weight_decay):
        if opt_type.lower() == "adam":
            optimizer = torch.optim.Adam
        elif opt_type.lower() == "rmsprop":
            optimizer = torch.optim.RMSprop
        elif opt_type.lower() == "sgd":
            optimizer = torch.optim.SGD
        else:
            raise NotImplementedError("invalid optimizer: {}".format(opt_type))

        opt = optimizer(
            params=filter(
                lambda p: p.requires_grad, self.parameters()
            ),
            weight_decay=weight_decay,
            lr=learning_rate)

        return opt

    def save(self, dev_avg, dev_perf, test_avg,
             test_perf, epoch, iteration=None, name="best"):
        save_path = os.path.join(self.expe.experiment_dir, name + ".ckpt")
        checkpoint = {
            "dev_perf": dev_perf,
            "test_perf": test_perf,
            "dev_avg": dev_avg,
            "test_avg": test_avg,
            "epoch": epoch,
            "iteration": iteration,
            "state_dict": self.state_dict(),
            "opt_state_dict": self.opt.state_dict(),
            "config": self.expe.config
        }
        torch.save(checkpoint, save_path)
        self.expe.log.info("model saved to {}".format(save_path))

    def load(self, checkpointed_state_dict=None, name="best"):
        if checkpointed_state_dict is None:
            save_path = os.path.join(self.expe.experiment_dir, name + ".ckpt")
            checkpoint = torch.load(save_path, map_location=lambda storage, loc: storage)
            self.load_state_dict(checkpoint['state_dict'])
            if checkpoint.get("opt_state_dict"):
                self.opt.load_state_dict(checkpoint.get("opt_state_dict"))

                if self.use_cuda:
                    for state in self.opt.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v
            self.expe.log.info("model loaded from {}".format(save_path))
            return checkpoint.get('epoch', 0), \
                   checkpoint.get('iteration', 0), \
                   checkpoint.get('dev_avg', 0), \
                   checkpoint.get('test_avg', 0), \
                   checkpoint.get('kl_temp', 0)
        else:
            self.load_state_dict(checkpointed_state_dict)
            self.expe.log.info("model loaded!")

    @property
    def volatile(self):
        return not self.training


class bilstm(nn.Module):
    def __init__(self, embed_dim, hidden_size, input_vecs, mask):
        super(bilstm, self).__init__()
        self.lstm = nn.LSTM(
            embed_dim, hidden_size, bidirectional=True, batch_first=True)
        self.input_vecs = input_vecs
        self.mask = mask

    def forward(self):
        outputs, _ = model_utils.get_rnn_vecs(
            self.input_vecs, self.mask, self.lstm, bidir=True)
        outputs = outputs * self.mask.unsqueeze(-1)
        sent_vec = outputs.sum(1) / self.mask.sum(1, keepdim=True)
        return sent_vec


class vgvae(MultiDisModel):
    @auto_init_pytorch
    @auto_init_args
    def __init__(self, vocab_size, embed_dim, embed_init, experiment, tags_vocab_size):
        super(vgvae, self).__init__(vocab_size, embed_dim, embed_init, experiment)
        if self.experiment.config.embed_type.lower() == "lm":
            self.bert = BertModel.from_pretrained(self.experiment.config.ml_model_path)
            self.bert.eval()
        self.embed_dim = embed_dim
        self.fc = nn.Linear(968, tags_vocab_size)
        print(tags_vocab_size)

    def sent2param(self, sent, mask):
        with torch.no_grad():
            encoded_layers, _ = self.bert(sent.long())
            input_vecs = encoded_layers[-1]
        yvecs = input_vecs
        zvecs = input_vecs

        mean = self.mean1(yvecs)
        mean = mean / mean.norm(dim=-1, keepdim=True)
        sum_vecs = (mean * mask.unsqueeze(-1)).sum(1)
        mean = sum_vecs / mask.sum(1, keepdim=True)

        logvar = self.logvar1(yvecs)
        var = F.softplus(logvar) + 50
        sum_vecs = (var * mask.unsqueeze(-1)).sum(1)
        var = sum_vecs / mask.sum(1, keepdim=True)

        mean2 = self.mean2(zvecs)
        sum_vecs = (mean2 * mask.unsqueeze(-1)).sum(1)
        mean2 = sum_vecs / mask.sum(1, keepdim=True)

        logvar2 = self.logvar2(zvecs)
        sum_vecs = (logvar2 * mask.unsqueeze(-1)).sum(1)
        logvar2 = sum_vecs / mask.sum(1, keepdim=True)

        return zvecs, mean, var, mean2, logvar2

    def forward(self, sent1, mask1, sent2, mask2, tgt1,
                tgt_mask1, tgt2, tgt_mask2,
                neg_sent1, neg_mask1, ntgt1, ntgt_mask1,
                neg_sent2, neg_mask2, ntgt2, ntgt_mask2, vtemp,
                gtemp, y, use_margin, true_it):
        global ploss1, ploss2, ploss3, ploss4
        self.train()

        sent1, mask1, sent2, mask2, tgt1, \
        tgt_mask1, tgt2, tgt_mask2, neg_sent1, \
        neg_mask1, ntgt1, ntgt_mask1, neg_sent2, \
        neg_mask2, ntgt2, ntgt_mask2 = \
            self.to_vars(sent1, mask1, sent2, mask2, tgt1,
                         tgt_mask1, tgt2, tgt_mask2,
                         neg_sent1, neg_mask1, ntgt1, ntgt_mask1,
                         neg_sent2, neg_mask2, ntgt2, ntgt_mask2)

        if self.expe.config.use_cuda and use_margin:
            sent1, mask1, sent2, mask2, tgt1, tgt_mask1, tgt2, tgt_mask2, \
            neg_sent1, neg_mask1, ntgt1, ntgt_mask1, neg_sent2, neg_mask2, \
            ntgt2, ntgt_mask2 = [var.cuda() for var in [sent1, mask1, sent2, mask2, tgt1,
                                                        tgt_mask1, tgt2, tgt_mask2,
                                                        neg_sent1, neg_mask1, ntgt1, ntgt_mask1,
                                                        neg_sent2, neg_mask2, ntgt2, ntgt_mask2]]
        elif self.expe.config.use_cuda:
            sent1, mask1, sent2, mask2, tgt1, tgt_mask1, tgt2, tgt_mask2 = [var.cuda() for var in
                                                                            [sent1, mask1, sent2, mask2, tgt1,
                                                                             tgt_mask1, tgt2, tgt_mask2]]

        s1_vecs, sent1_mean, sent1_var, sent1_mean2, sent1_logvar2 = \
            self.sent2param(sent1, mask1)
        if 'nan' in str(sent1_mean.data):
            print('sent1_mean nan wrong')
        s2_vecs, sent2_mean, sent2_var, sent2_mean2, sent2_logvar2 = \
            self.sent2param(sent2, mask2)

        sent1_dist = VonMisesFisher(sent1_mean, sent1_var)
        sent2_dist = VonMisesFisher(sent2_mean, sent2_var)

        sent1_syntax = self.sample_gaussian(sent1_mean2, sent1_logvar2)
        sent2_syntax = self.sample_gaussian(sent2_mean2, sent2_logvar2)

        sent1_semantic = sent1_dist.rsample()
        sent2_semantic = sent2_dist.rsample()

        logloss1 = self.decode(
            sent1_semantic, sent1_syntax, tgt1, tgt_mask1)
        logloss2 = self.decode(
            sent2_semantic, sent2_syntax, tgt2, tgt_mask2)

        logloss3 = self.decode(
            sent2_semantic, sent1_syntax, tgt1, tgt_mask1)
        logloss4 = self.decode(
            sent1_semantic, sent2_syntax, tgt2, tgt_mask2)

        if self.expe.config.pratio:
            s1_vecs = torch.cat(
                [s1_vecs, sent1_syntax.unsqueeze(1).expand(-1, s1_vecs.size(1), -1)], -1)
            s2_vecs = torch.cat(
                [s2_vecs, sent2_syntax.unsqueeze(1).expand(-1, s2_vecs.size(1), -1)], -1)
            ploss1 = self.pos_loss(mask1, s1_vecs)
            ploss2 = self.pos_loss(mask2, s2_vecs)

        sent1_kl = model_utils.gauss_kl_div(
            sent1_mean2, sent1_logvar2,
            eps=self.eps).mean()
        sent2_kl = model_utils.gauss_kl_div(
            sent2_mean2, sent2_logvar2,
            eps=self.eps).mean()

        dist = torch.zeros_like(sent1_kl)
        if use_margin and true_it > 400:
            n1_vecs, nsent1_mean, nsent1_var, nsent1_mean2, nsent1_logvar2 = \
                self.sent2param(neg_sent1, neg_mask1)
            n2_vecs, nsent2_mean, nsent2_var, nsent2_mean2, nsent2_logvar2 = \
                self.sent2param(neg_sent2, neg_mask2)

            sent_cos_pos = F.cosine_similarity(sent1_mean, sent2_mean)

            sent1_cos_neg = F.cosine_similarity(sent1_mean, nsent1_mean)
            sent2_cos_neg = F.cosine_similarity(sent2_mean, nsent2_mean)

            dist_ = F.relu(self.margin - sent_cos_pos + sent1_cos_neg) + \
                    F.relu(self.margin - sent_cos_pos + sent2_cos_neg)

            dist += dist_.mean()

        vkl = sent1_dist.kl_div().mean() + sent2_dist.kl_div().mean()
        # vkl = sent1_kl1 + sent2_kl1
        gkl = sent1_kl + sent2_kl

        rec_logloss = logloss1 + logloss2

        para_logloss = logloss3 + logloss4

        if self.expe.config.pratio:
            ploss = ploss1 + ploss2
        else:
            ploss = torch.zeros_like(gkl)

        loss = self.expe.config.lratio * rec_logloss + \
               self.expe.config.plratio * para_logloss + \
               vtemp * vkl + gtemp * gkl + \
               self.expe.config.pratio * ploss
        if self.expe.config.posratio:
            if not self.expe.config.pratio:
                s1_vecs = torch.cat(
                    [s1_vecs, sent1_syntax.unsqueeze(1).expand(-1, s2_vecs.size(1), -1)], -1)
            # print(s1_vecs.size())
            logits = self.fc(s1_vecs)
            logits = logits.view(-1, logits.shape[-1])  # (N*T, VOCAB)
            y_1 = torch.from_numpy(y[0])
            # print(y.size())
            y_1 = y_1.view(-1)  # (N*T,)
            criterion = nn.CrossEntropyLoss(ignore_index=0)
            dist += criterion(logits, y_1.cuda())
            if not self.expe.config.pratio:
                s2_vecs = torch.cat(
                    [s2_vecs, sent2_syntax.unsqueeze(1).expand(-1, s2_vecs.size(1), -1)], -1)
            # print(s2_vecs.size())
            logits = self.fc(s2_vecs)
            logits = logits.view(-1, logits.shape[-1])  # (N*T, VOCAB)
            y_2 = torch.from_numpy(y[1])
            # print(y.size())
            y_2 = y_2.view(-1)  # (N*T,)
            criterion = nn.CrossEntropyLoss(ignore_index=0)
            dist += self.expe.config.posratio * criterion(logits, y_2.cuda())  # this posratio is for low-source languages
            loss += dist
        else:
            loss += dist

        return loss, vkl, gkl, rec_logloss, para_logloss, ploss, dist

    def score(self, sent1, mask1, sent2, mask2):
        self.eval()
        sent1, mask1, sent2, mask2 = self.to_vars(sent1, mask1, sent2, mask2)
        if self.expe.config.use_cuda:
            sent1 = sent1.cuda()
            mask1 = mask1.cuda()
            sent2 = sent2.cuda()
            mask2 = mask2.cuda()
        sent1_vecs, _ = self.bert(sent1.long())
        sent2_vecs, _ = self.bert(sent2.long())
        yembed1, yvecs1 = self.yencode(sent1_vecs[-1], mask1)
        yembed2, yvecs2 = self.yencode(sent2_vecs[-1], mask2)
        sent1_vec = self.mean1(yvecs1)
        sent2_vec = self.mean1(yvecs2)

        return model_utils.pariwise_cosine_similarity(
            sent1_vec, sent2_vec).data.cpu().numpy()

    def pred(self, sent1, mask1, sent2, mask2):
        self.eval()
        sent1, mask1, sent2, mask2 = self.to_vars(sent1, mask1, sent2, mask2)
        if self.expe.config.use_cuda:
            sent1 = sent1.cuda()
            mask1 = mask1.cuda()
            sent2 = sent2.cuda()
            mask2 = mask2.cuda()

        sent1_vecs, _ = self.bert(sent1.long())
        sent2_vecs, _ = self.bert(sent2.long())
        sent1_mean = self.mean1(sent1_vecs[-1])
        sent1_mean = sent1_mean / sent1_mean.norm(dim=-1, keepdim=True)
        sum_vecs = (sent1_mean * mask1.unsqueeze(-1)).sum(1)
        sent1_mean = sum_vecs / mask1.sum(1, keepdim=True)

        sent2_mean = self.mean1(sent2_vecs[-1])
        sent2_mean = sent2_mean / sent2_mean.norm(dim=-1, keepdim=True)
        sum_vecs = (sent2_mean * mask2.unsqueeze(-1)).sum(1)
        sent2_mean = sum_vecs / mask2.sum(1, keepdim=True)

        sent_cos_pos = F.cosine_similarity(sent1_mean, sent2_mean)
        return sent_cos_pos.data.cpu().numpy()

    def predz(self, sent1, mask1, sent2, mask2):
        self.eval()
        sent1, mask1, sent2, mask2 = self.to_vars(sent1, mask1, sent2, mask2)
        if self.expe.config.use_cuda:
            sent1 = sent1.cuda()
            mask1 = mask1.cuda()
            sent2 = sent2.cuda()
            mask2 = mask2.cuda()
        sent1_vecs, _ = self.bert(sent1.long())
        sent2_vecs, _ = self.bert(sent2.long())
        sent1_mean = self.mean2(sent1_vecs[-1])
        sent1_mean = sent1_mean / sent1_mean.norm(dim=-1, keepdim=True)
        sum_vecs = (sent1_mean * mask1.unsqueeze(-1)).sum(1)
        sent1_mean2 = sum_vecs / mask1.sum(1, keepdim=True)

        sent2_mean = self.mean2(sent2_vecs[-1])
        sent2_mean = sent2_mean / sent2_mean.norm(dim=-1, keepdim=True)
        sum_vecs = (sent2_mean * mask2.unsqueeze(-1)).sum(1)
        sent2_mean2 = sum_vecs / mask2.sum(1, keepdim=True)

        sent_cos_pos = F.cosine_similarity(sent1_mean2, sent2_mean2)
        return sent_cos_pos.data.cpu().numpy()

    def pred_bert(self, sent1, mask1, sent2, mask2):
        self.eval()
        sent1, mask1, sent2, mask2 = self.to_vars(sent1, mask1, sent2, mask2)
        if self.expe.config.use_cuda:
            sent1 = sent1.cuda()
            sent2 = sent2.cuda()
        encoded_layers1, vecs1 = self.bert(sent1.long())
        encoded_layers2, vecs2 = self.bert(sent2.long())
        sent_cos_pos = F.cosine_similarity(encoded_layers1[-1], encoded_layers2[-1])
        return sent_cos_pos.data.cpu().numpy()
