import torch.nn as nn
import torch
import torch.nn.functional as F

def list_to_tensor(batch, max_len, gpu, is_word=True):
    seq_lengths = torch.LongTensor(list(map(len, batch)))
    if is_word:
        seq_tensor = torch.ones((len(batch), max_len)).long()
    else:
        seq_tensor = torch.zeros((len(batch), max_len)).long()
    for idx, (seq, seqlen) in enumerate(zip(batch, seq_lengths)):
        try:
            seq_tensor[torch.tensor(idx), :seqlen] = torch.LongTensor(seq)
        except:
            print(seqlen)
    if gpu:
        seq_tensor = seq_tensor.cuda()
    return seq_tensor, seq_lengths

def rouge_list_to_tensor(batch, max_len, gpu, eps=1e-20):
    seq_lengths = torch.LongTensor(list(map(len, batch)))
    seq_tensor = torch.zeros((len(batch), max_len)).float()
    for idx, (seq, seqlen) in enumerate(zip(batch, seq_lengths)):
        seq_tensor[torch.tensor(idx), :seqlen] = torch.FloatTensor(seq)
    seq_tensor = seq_tensor + eps
    if gpu:
        seq_tensor = seq_tensor.cuda()
    return seq_tensor

class SalienceModel(nn.Module):
    def __init__(self, encoder, salience, gpu):
        super(SalienceModel, self).__init__()
        self.encoder = encoder
        self.salience = salience
        self.gpu = gpu
    def forward(self, batch, src_max_len, abs_max_sent_num):
        tgt_abstract = []
        src_abstract = []
        tgt_abs_end_idx = []
        src_abs_end_idx = []
        context_end_idx = []
        context = []
        tgt_salience = []
        for batch_data in batch:
            tgt_abstract.append(batch_data['tgt_abstract'])
            src_abstract.append(batch_data['src_abstract'])
            tgt_abs_end_idx.append(batch_data['tgt_abs_end_idx'])
            src_abs_end_idx.append(batch_data['src_abs_end_idx'])
            context_end_idx.append(batch_data['context_end_idx'])
            context.append(batch_data['context'])
            tgt_salience.append(batch_data['ROUGE'])
        del batch
        src_abstract, src_abstract_lengths = list_to_tensor(src_abstract, src_max_len, self.gpu)  # shape: [batch, src_seq_len]
        assert src_abstract.size()[1] == src_max_len
        tgt_abstract, tgt_abstract_lengths = list_to_tensor(tgt_abstract, src_max_len, self.gpu)  # shape: [batch, src_seq_len]
        assert tgt_abstract.size()[1] == src_max_len
        src_abs_end_idx, src_abs_sent_num = list_to_tensor(src_abs_end_idx, abs_max_sent_num,
                                                           self.gpu, is_word=False)  # shape: [batch, abs_max_sent_num]
        tgt_abs_end_idx, tgt_abs_sent_num = list_to_tensor(tgt_abs_end_idx, abs_max_sent_num,
                                                           self.gpu, is_word=False)  # shape: [batch, abs_max_sent_num]
        context_end_idx, _ = list_to_tensor(context_end_idx, abs_max_sent_num, self.gpu,
                                            is_word=False)  # shape: [batch, abs_max_sent_num]
        tgt_salience = rouge_list_to_tensor(tgt_salience, abs_max_sent_num, self.gpu)  # shape: [batch, abs_max_sent_num]
        context, context_lengths = list_to_tensor(context, src_max_len, self.gpu)  # shape: [batch, src_seq_len]
        assert context.size()[1] == src_max_len


        # encoder_final: [2 , batch, enc_rnn_size], memory_bank: [batch, src_seq_leng, 2*enc_rnn_size]
        _, _, _, \
        _, _, tgt_abs_sent_emb, \
        _, _, src_abs_sent_emb, \
        tgt_abs_max_len = self.encoder(context, tgt_abstract, src_abstract, context_lengths, tgt_abstract_lengths,
                                       src_abstract_lengths, context_end_idx, tgt_abs_end_idx, src_abs_end_idx)
        # src_abs_sent_emb = src_abs_sent_emb.mul(src_abs_end_idx.ne(0).unsqueeze(2))[1, :, :]
        src_abs_doc_emb = torch.mean(src_abs_sent_emb, 1)  # [batch, 2*enc_size]
        # tgt_abs_doc_emb = torch.mean(tgt_abs_sent_emb, 1)  # [batch, 2*enc_size]

        del tgt_abstract, src_abstract, src_abstract_lengths, src_abs_end_idx, context, context_lengths, \
            tgt_abstract_lengths
        # predict the salience of each sentence in the target abstract
        # salience_dist, salience_scores = self.salience(tgt_abs_sent_emb, tgt_abstract_lengths)
        if self.gpu:
            tgt_abs_sent_num = tgt_abs_sent_num.cuda()
        salience_scores, _ = self.salience(src_abs_doc_emb, tgt_abs_sent_emb, tgt_abs_sent_num,
                                           tgt_abs_end_idx, tgt_abs_max_len)

        return salience_scores, tgt_salience

class NMTModel(nn.Module):
    """
    Core trainable object in OpenNMT. Implements a trainable interface
    for a simple, generic encoder + decoder models.
    Args:
      encoder (onmt.encoders.EncoderBase): an encoder object
      decoder (onmt.decoders.DecoderBase): a decoder object
    """

    def __init__(self, encoder, decoder, salience, citation, gpu, is_citation_func=True):
        super(NMTModel, self).__init__()
        self.encoder = encoder
        self.citation = citation
        self.decoder = decoder
        self.salience = salience
        self.gpu = gpu
        self.is_citation_func = is_citation_func
        if is_citation_func:
            self.linear1 = nn.Linear(decoder.hidden_size + encoder.text_encoder.hidden_size*2, 512)
            self.linear2 = nn.Linear(512, 4)

    def forward(self, batch, src_max_len, tgt_max_len, abs_max_sent_num, node_max_neighbor, bptt=False):
        """Forward propagate a `src` and `tgt` pair for training.
        Possible initialized with a beginning decoder state.
        Args:
            src (Tensor): A source sequence passed to encoder.
                typically for inputs this will be a padded `LongTensor`
                of size ``(len, batch, features)``. However, may be an
                image or other generic input depending on encoder.
            tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``.
            lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
            bptt (Boolean): A flag indicating if truncated bptt is set.
                If reset then init_state
        Returns:
            (FloatTensor, dict[str, FloatTensor]):
            * decoder output ``(tgt_len, batch, hidden)``
            * dictionary attention dists of ``(tgt_len, batch, src_len)``
        """
        tgt_abstract = []
        src_abstract = []
        tgt_abs_end_idx = []
        src_abs_end_idx = []
        context_end_idx = []
        tgt = []
        context = []
        tgt_salience = []
        citation_function = []
        for batch_data in batch:
            tgt_abstract.append(batch_data['tgt_abstract'])
            src_abstract.append(batch_data['src_abstract'])
            tgt_abs_end_idx.append(batch_data['tgt_abs_end_idx'])
            src_abs_end_idx.append(batch_data['src_abs_end_idx'])
            context_end_idx.append(batch_data['context_end_idx'])
            context.append(batch_data['context'])
            tgt.append(batch_data['citation'])
            tgt_salience.append(batch_data['ROUGE'])
            citation_function.append(batch_data['citation_function'])
        citation_output = self.citation(batch, node_max_neighbor, is_train=False)
        del batch
        src_abstract, src_abstract_lengths = list_to_tensor(src_abstract, src_max_len, self.gpu) # shape: [batch, src_seq_len]
        assert src_abstract.size()[1] == src_max_len
        tgt_abstract, tgt_abstract_lengths = list_to_tensor(tgt_abstract, src_max_len, self.gpu) # shape: [batch, src_seq_len]
        assert tgt_abstract.size()[1] == src_max_len
        src_abs_end_idx, _ = list_to_tensor(src_abs_end_idx, abs_max_sent_num, self.gpu,
                                                           is_word=False) # shape: [batch, abs_max_sent_num]
        tgt_abs_end_idx, tgt_abs_sent_num = list_to_tensor(tgt_abs_end_idx, abs_max_sent_num, self.gpu,
                                                           is_word=False) # shape: [batch, abs_max_sent_num]
        context_end_idx, _ = list_to_tensor(context_end_idx, abs_max_sent_num, self.gpu,
                                                           is_word=False)  # shape: [batch, abs_max_sent_num]
        context, context_lengths = list_to_tensor(context, src_max_len, self.gpu) # shape: [batch, src_seq_len]
        assert context.size()[1] == src_max_len
        tgt_salience = rouge_list_to_tensor(tgt_salience, abs_max_sent_num,
                                            self.gpu)  # shape: [batch, abs_max_sent_num]
        tgt, _ = list_to_tensor(tgt, tgt_max_len, self.gpu) # shape: [batch, tgt_seq_len]
        tgt = tgt.transpose(1, 0).unsqueeze(2) # shape: [tgt_seq_len, batch, 1]
        dec_in = tgt[:-1] # exclude last target token from inputs, shape: [tgt_seq_len-1, batch, 1]
        # del tgt

        # encoder_final: [2 , batch, enc_rnn_size], memory_bank: [batch, src_seq_leng, 2*enc_rnn_size]
        context_encoder_final, context_memory_bank, context_sent_emb, \
        tgt_abs_encoder_final, tgt_abs_memory_bank, tgt_abs_sent_emb, \
        src_abs_encoder_final, src_abs_memory_bank, src_abs_sent_emb, \
        tgt_abs_max_len = self.encoder(context, tgt_abstract, src_abstract, context_lengths, tgt_abstract_lengths,
                                       src_abstract_lengths, context_end_idx, tgt_abs_end_idx, src_abs_end_idx)

        # src_abs_sent_emb = src_abs_sent_emb.mul(src_abs_end_idx.ne(0).unsqueeze(2))[1, :, :]
        src_abs_doc_emb = torch.mean(src_abs_sent_emb, 1) # [batch, 2*enc_size]
        # tgt_abs_doc_emb = torch.mean(tgt_abs_sent_emb, 1)  # [batch, 2*enc_size]

        del tgt_abstract, src_abstract, src_abstract_lengths, src_abs_end_idx
        # predict the salience of each sentence in the target abstract
        # salience_dist, salience_scores = self.salience(tgt_abs_sent_emb, tgt_abstract_lengths)
        if self.gpu:
            tgt_abs_sent_num = tgt_abs_sent_num.cuda()
            context_lengths = context_lengths.cuda()
            tgt_abstract_lengths = tgt_abstract_lengths.cuda()
        salience_scores, salience_dist = self.salience(src_abs_doc_emb, tgt_abs_sent_emb, tgt_abs_sent_num,
                                        tgt_abs_end_idx, tgt_abs_max_len)
        del tgt_abs_sent_emb, src_abs_sent_emb, src_abs_doc_emb

        # tgt_abs_sent = tgt_abs_memory_bank[:,]
        if bptt is False:
            # now use context encoder's final state to initialize the decoder
            self.decoder.init_state(context, context_memory_bank, context_encoder_final)
        del context

        dec_out, attns = self.decoder(dec_in, context_memory_bank, tgt_abs_memory_bank,
                                      context_lengths, tgt_abstract_lengths, salience_dist,
                                      tgt_abs_end_idx, citation_output)
        citation_func_pred = None
        if self.is_citation_func:
            if self.gpu:
                citation_function = torch.LongTensor(citation_function).cuda()
            else:
                citation_function = torch.LongTensor(citation_function)
            context_doc_emb = torch.mean(context_sent_emb, 1)
            citation_func_pred = torch.cat([dec_out[-1], context_doc_emb], -1)
            citation_func_pred = F.dropout(citation_func_pred, 0.3)
            citation_func_pred = self.linear1(citation_func_pred)
            citation_func_pred = torch.tanh(citation_func_pred)
            citation_func_pred = self.linear2(citation_func_pred)

        return dec_out, attns, tgt, salience_scores, tgt_salience, citation_func_pred, citation_function