import torch, operator
from models.model import list_to_tensor
from queue import PriorityQueue
from models.copy_generator import collapse_copy_scores

class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length

    def eval(self, alpha=1.0):
        reward = 0
        # Add here a function for shaping a reward

        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward

class BeamSearch(object):
    def __init__(self, model, gpu, max_length, tgt_vocab, SOS_token=3, EOS_token=2, PAD_token=1):
        self.model = model
        self.gpu = gpu
        self.max_length = max_length
        self.tgt_vocab = tgt_vocab
        self.EOS_token = EOS_token
        self.SOS_token = SOS_token
        self.PAD_token = PAD_token

    def decode(self, batch, src_max_len, abs_max_sent_num, abstract_src_map, context_src_map, node_max_neighbor):
        tgt_abstract = []
        src_abstract = []
        tgt_abs_end_idx = []
        src_abs_end_idx = []
        context_end_idx = []
        context = []
        batch_size = len(batch)
        abs_src_vocab = []
        context_src_vocab = []
        tgt = []
        for batch_data in batch:
            tgt_abstract.append(batch_data['tgt_abstract'])
            src_abstract.append(batch_data['src_abstract'])
            tgt_abs_end_idx.append(batch_data['tgt_abs_end_idx'])
            src_abs_end_idx.append(batch_data['src_abs_end_idx'])
            context_end_idx.append(batch_data['context_end_idx'])
            context.append(batch_data['context'])
            abs_src_vocab.append(batch_data['abs_src_vocab'])
            context_src_vocab.append(batch_data['context_src_vocab'])
            tgt.append(batch_data['citation'])
        citation_output = self.model.citation(batch, node_max_neighbor, is_train=False)
        del batch
        src_abstract, src_abstract_lengths = list_to_tensor(src_abstract, src_max_len, self.gpu)
        tgt_abstract, tgt_abstract_lengths = list_to_tensor(tgt_abstract, src_max_len, self.gpu)
        src_abs_end_idx, src_abs_sent_num = list_to_tensor(src_abs_end_idx, abs_max_sent_num, self.gpu,
                                                           is_word=False)
        tgt_abs_end_idx, tgt_abs_sent_num = list_to_tensor(tgt_abs_end_idx, abs_max_sent_num, self.gpu,
                                                           is_word=False)
        context_end_idx, _ = list_to_tensor(context_end_idx, abs_max_sent_num, self.gpu,
                                            is_word=False)
        context, context_lengths = list_to_tensor(context, src_max_len, self.gpu)

        # context_encoder_final, context_memory_bank, tgt_abs_encoder_final, \
        # tgt_abs_memory_bank, tgt_abs_sent_emb, src_abs_encoder_final, \
        # src_abs_memory_bank, src_abs_sent_emb, tgt_abs_max_len = self.model.encoder(context, tgt_abstract, src_abstract,
        #                                                                       context_lengths, tgt_abstract_lengths,
        #                                                                       src_abstract_lengths, tgt_abs_end_idx,
        #                                                                       src_abs_end_idx)

        context_encoder_final, context_memory_bank, context_sent_emb, \
        tgt_abs_encoder_final, tgt_abs_memory_bank, tgt_abs_sent_emb, \
        src_abs_encoder_final, src_abs_memory_bank, src_abs_sent_emb, \
        tgt_abs_max_len = self.model.encoder(context, tgt_abstract, src_abstract, context_lengths, tgt_abstract_lengths,
                                             src_abstract_lengths, context_end_idx, tgt_abs_end_idx, src_abs_end_idx)

        src_abs_doc_emb = torch.mean(src_abs_sent_emb, 1)
        del src_abstract, src_abstract_lengths, src_abs_end_idx
        if self.gpu:
            tgt_abs_sent_num = tgt_abs_sent_num.cuda()
            context_lengths = context_lengths.cuda()
            tgt_abstract_lengths = tgt_abstract_lengths.cuda()
        salience_scores, salience_dist = self.model.salience(src_abs_doc_emb, tgt_abs_sent_emb, tgt_abs_sent_num,
                                        tgt_abs_end_idx, tgt_abs_max_len)

        del tgt_abs_sent_emb, src_abs_sent_emb, src_abs_doc_emb

        beam_width = 4
        topk = 1  # how many sentence do you want to generate
        decoded_batch = []
        context_encoder_final = list(context_encoder_final)
        # self.model.decoder.init_state(context, context_memory_bank, context_encoder_final)
        for idx in range(batch_size):
            dec_in = torch.tensor([self.SOS_token])
            dec_in = dec_in.unsqueeze(1)
            if self.gpu:
                dec_in = dec_in.cuda()
            cur_encoder_final = []
            for i in range(len(context_encoder_final)):
                cur_encoder_final.append(context_encoder_final[i][:,idx,:].unsqueeze(1))
            cur_encoder_final = tuple(cur_encoder_final)
            self.model.decoder.init_state(context[idx].unsqueeze(0), context_memory_bank[idx].unsqueeze(0),
                                          cur_encoder_final)
            # Number of sentence to generate
            decoder_hidden = self.model.decoder.state["hidden"]
            endnodes = []
            number_required = min((topk + 1), topk - len(endnodes))
            # starting node -  hidden vector, previous node, word id, logp, length
            node = BeamSearchNode(decoder_hidden, None, dec_in, 0, 1)
            nodes = PriorityQueue()
            # start the queue
            nodes.put((-node.eval(), node))
            qsize = 1

            # start beam search
            while True:
                # give up when decoding takes too long
                if qsize > 2000: break

                # fetch the best node
                score, n = nodes.get()
                decoder_input = n.wordid.unsqueeze(0)
                # decoder_hidden = n.h

                if n.wordid.item() == self.EOS_token and n.prevNode != None:
                    endnodes.append((score, n))
                    # if we reached maximum # of sentences required
                    if len(endnodes) >= number_required:
                        break
                    else:
                        continue

                # decode for one step using decoder
                # decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
                dec_out, attns = self.model.decoder(decoder_input, context_memory_bank[idx].unsqueeze(0),
                                                    tgt_abs_memory_bank[idx].unsqueeze(0),
                                                    context_lengths[idx].unsqueeze(0),
                                                    tgt_abstract_lengths[idx].unsqueeze(0),
                                                    salience_dist[idx].unsqueeze(0), tgt_abs_end_idx[idx].unsqueeze(0),
                                                    citation_output[idx].unsqueeze(0))
                decoder_hidden = self.model.decoder.state["hidden"]
                # dec_out shape: [seq_len, batch, dec_hidden]
                out_prob, abs_copy_prob, context_copy_prob = self.model.generator(
                    self._bottle(dec_out), self._bottle(attns.get("copy")), self._bottle(attns.get("tgt_std")),
                    self._bottle(attns.get("context_std")), abstract_src_map[:,idx,:].unsqueeze(1),
                    context_src_map[:,idx,:].unsqueeze(1)
                )
                out_scores = collapse_copy_scores(torch.cat([out_prob, abs_copy_prob],1), self.tgt_vocab,
                                                  abs_src_vocab[idx], len(self.tgt_vocab["itos"]), 0)
                assert out_scores.size(1) == out_prob.size(1) + abs_copy_prob.size(1)
                out_scores = collapse_copy_scores(torch.cat([out_scores, context_copy_prob], 1), self.tgt_vocab,
                                                  context_src_vocab[idx],
                                                  out_prob.size(1) + abs_copy_prob.size(1), 0)
                out_scores = out_scores.squeeze(0)[:len(self.tgt_vocab["stoi"])]
                # PUT HERE REAL BEAM SEARCH OF TOP
                log_prob, indexes = torch.topk(out_scores.log(), beam_width)
                nextnodes = []

                for new_k in range(beam_width):
                    decoded_t = indexes[new_k].view(1, -1)
                    log_p = log_prob[new_k].item()

                    node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                    score = -node.eval()
                    nextnodes.append((score, node))

                # put them into queue
                for i in range(len(nextnodes)):
                    score, nn = nextnodes[i]
                    try:
                        nodes.put((score, nn))
                    except:
                        print(score)
                    # increase qsize
                qsize += len(nextnodes) - 1

            # choose nbest paths, back trace them
            if len(endnodes) == 0:
                endnodes = [nodes.get() for _ in range(topk)]

            utterances = []
            for score, n in sorted(endnodes, key=operator.itemgetter(0)):
                utterance = []
                utterance.append(n.wordid.item())
                # back trace
                while n.prevNode != None:
                    n = n.prevNode
                    utterance.append(n.wordid.item())

                utterance = utterance[::-1]
                utterances.append(utterance)

            decoded_batch.append(utterances)


        return decoded_batch, tgt
    def _bottle(self, _v):
        return _v.view(-1, _v.size(2))