import pykp.utils.io as io
from inference.beam import Beam
from inference.beam import GNMTGlobalScorer

import torch

EPS = 1e-8

class SetGenerator(object):
    def __init__(self, model, opt):
        self.model = model
        self.beam_size = opt.beam_size
        self.n_best = opt.n_best
        self.cuda = opt.gpuid > -1
        self.global_scorer = GNMTGlobalScorer(None, None, None, None) # 不用任何惩罚

    @classmethod
    def from_opt(cls, model, opt):
        return cls(model, opt)

    def inference(self, src, src_lens, src_oov, src_mask, oov_lists, word2idx):
        """
        :param src: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], with oov words replaced by unk idx
        :param src_lens: a list containing the length of src sequences for each batch, with len=batch
        :param src_oov: a LongTensor containing the word indices of source sentences, [batch, src_seq_len], contains the index of oov words (used by copy)
        :param src_mask: a FloatTensor, [batch, src_seq_len]
        :param oov_lists: list of oov words (idx2word) for each batch, len=batch
        :param word2idx: a dictionary
        """
        self.model.eval()
        batch_size = src.size(0)
        max_kp_num = self.model.decoder.max_kp_num
        max_kp_len = self.model.decoder.max_kp_len
        vocab_size = self.model.decoder.vocab_size

        # Encoding
        memory_bank = self.model.encoder(src, src_lens, src_mask)
        state = self.model.decoder.init_state(memory_bank, src_mask)
        control_embed = self.model.decoder.forward_seg(state)

        max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch
        attn_dict_list = []
        decoder_score_list = []
        output_tokens = src.new_zeros(batch_size, max_kp_num, max_kp_len + 1)
        output_tokens[:, :, 0] = word2idx[io.BOS_WORD]
        for t in range(1, max_kp_len+1):
            decoder_inputs = output_tokens[:, :, :t]
            decoder_inputs = decoder_inputs.masked_fill(decoder_inputs.gt(vocab_size - 1), word2idx[io.UNK_WORD])

            decoder_dist, attn_dist = self.model.decoder(decoder_inputs, state, src_oov, max_num_oov, control_embed)
            attn_dict_list.append(attn_dist.reshape(batch_size, max_kp_num, 1, -1))
            decoder_score_list.append(decoder_dist.max(-1)[0].reshape(batch_size, max_kp_num, 1))

            _, tokens = decoder_dist.max(-1)
            output_tokens[:, :, t] = tokens

        output_tokens = output_tokens[:, :, 1:].reshape(batch_size, max_kp_num*max_kp_len)[:, None]  # [batch_size, 1, max_kp_num, max_kp_len]
        attn_dicts = torch.cat(attn_dict_list, -2).reshape(batch_size, max_kp_num*max_kp_len, -1)[:, None]  # [batch_size, 1, max_kp_num, max_kp_len, max_src_len]
        decoder_scores = torch.cat(decoder_score_list, -1).reshape(batch_size, max_kp_num * max_kp_len)[:, None]

        # Extract sentences
        result_dict = {"predictions": [], "attention": [], "decoder_scores": []}
        for b in range(batch_size):
            result_dict["predictions"].append(output_tokens[b])
            result_dict["attention"].append(attn_dicts[b])
            result_dict["decoder_scores"].append(decoder_scores[b])
        return result_dict

    def inference_with_beam_search(self, src, src_lens, src_oov, src_mask, oov_lists, word2idx):
        self.model.eval()
        batch_size = src.size(0)
        max_kp_num = self.model.decoder.max_kp_num
        max_kp_len = self.model.decoder.max_kp_len
        vocab_size = self.model.decoder.vocab_size
        beam_size = self.beam_size

        # Encoding
        max_pos = self.model.encoder.pos_embed.num_embeddings - 1
        src = src[:, :max_pos]
        src_mask = src_mask[:, :max_pos]
        src_oov = src_oov[:, :max_pos]
        memory_bank = self.model.encoder(src, src_lens, src_mask)
        src_lens = [min(l, max_pos) for l in src_lens]
        oov_lists = [oov[:max_pos] for oov in oov_lists]
        # expand
        memory_bank = memory_bank.repeat(beam_size, 1, 1) # [batch * beam_size, max_src_len, memory_bank_size]
        src_mask = src_mask.repeat(beam_size, 1)
        src_oov = src_oov.repeat(beam_size, 1)
        
        state = self.model.decoder.init_state(memory_bank, src_mask)
        control_embed = self.model.decoder.forward_seg(state)

        beam_list = [Beam(beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer,
                          pad=word2idx[io.PAD_WORD], eos=word2idx[io.EOS_WORD], bos=word2idx[io.BOS_WORD],
                          block_ngram_repeat=3, exclusion_tokens=set())
                     for _ in range(batch_size*max_kp_num)]
        
        max_num_oov = max([len(oov) for oov in oov_lists])
        # output_tokens = src.new_zeros(batch_size*beam_size, max_kp_num, max_kp_len + 1)
        # output_tokens[:, :, 0] = word2idx[io.BOS_WORD]

        def var(a):
            return torch.tensor(a, requires_grad=False)
        
        for t in range(1, max_kp_len + 1):
            decoder_input = var(torch.stack([b.get_current_tokens() for b in beam_list])
                                .t().contiguous()).unsqueeze(-1)
            decoder_input = decoder_input.masked_fill(decoder_input.gt(vocab_size - 1), word2idx[io.UNK_WORD])
            
            decoder_input = decoder_input.reshape(beam_size, batch_size, max_kp_num, -1)
            decoder_input = decoder_input.reshape(beam_size*batch_size, max_kp_num, -1)
            if t > 1:
                decoder_inputs = torch.cat([decoder_inputs, decoder_input], -1)
            else:
                decoder_inputs = decoder_input

            decoder_dist, attn_dist = self.model.decoder(decoder_inputs, state, src_oov, max_num_oov, control_embed)
            decoder_dist = decoder_dist.squeeze(1)
            
            attn_dist = attn_dist.squeeze(1)
            
            log_decoder_dist = torch.log(decoder_dist + EPS)

            # (beam_size*batch_size, max_kp_num, -1)-->(beam_size, batch_size*max_kp_num, -1)
            log_decoder_dist = log_decoder_dist.view(beam_size, batch_size*max_kp_num, -1)
            attn_dist = attn_dist.view(beam_size, batch_size*max_kp_num, -1)

            for batch_idx, beam in enumerate(beam_list):
                beam.advance(log_decoder_dist[:, batch_idx], attn_dist[:, batch_idx, :src_lens[batch_idx//max_kp_num]])

        ret = {"predictions": [], "decoder_scores": [], "attention": []}
        for batch_idx in range(batch_size):
            hyps, attn, scores_list = [], [], []
            for b in beam_list[batch_idx*max_kp_num:(batch_idx+1)*max_kp_num]:
                # n_best = self.n_best
                n_best = self.n_best if self.n_best else self.beam_size
                scores, ks = b.sort_finished(minimum=n_best)
                for i, (times, k) in enumerate(ks[:n_best]):
                    hyp, att = b.get_hyp(times, k)
                    hyps.append(hyp)
                    attn.append(att)
                    scores_list.append(scores[i])
            ret["predictions"].append(hyps)
            ret["decoder_scores"].append(scores_list)
            ret["attention"].append(attn)
        
        return ret
        