import torch.nn as nn
import torch
from utils.global_attention import GlobalAttention
from models.model import list_to_tensor
import torch.nn.functional as F

def neighbor_list_to_tensor(neighbor, node_max_neighbor, gpu):
    bow_feature = len(neighbor[0])
    neighbor_vec = torch.zeros(node_max_neighbor, bow_feature,dtype=torch.float)
    neighbor_vec[:len(neighbor), :] = torch.FloatTensor(neighbor)
    if gpu:
        neighbor_vec = neighbor_vec.cuda()
    return neighbor_vec


class CitationModel(nn.Module):
    def __init__(self, model_opt, gpu):
        self.in_feat = int(model_opt["bow_size"])
        self.out_feat = int(model_opt["node_dim"])
        self.dropout = float(model_opt['citation_dropout'])
        super(CitationModel, self).__init__()
        self.attention = GlobalAttention(
            self.out_feat, coverage=False,
            attn_type="general", attn_func="softmax"
        )
        self.linear_in = nn.Linear(self.in_feat, self.out_feat, bias=True)
        self.linear_out = nn.Linear(self.out_feat*2, 1, bias=True)
        self.gpu = gpu
    def forward(self, batch, node_max_neighbor, is_train = False):
        tgt_abstract = []
        src_abstract = []
        tgt_neighbor = []
        src_neighbor = []
        tgt_nei_num = []
        src_nei_num = []
        tgt_label = []
        for batch_data in batch:
            tgt_abstract.append(batch_data['tgt_abstract_bow'])
            src_abstract.append(batch_data['src_abstract_bow'])
            neighbor_tensor = neighbor_list_to_tensor(batch_data['tgt_neighbor'], node_max_neighbor, self.gpu)
            neighbor_tensor = F.dropout(neighbor_tensor, self.dropout, training=self.training)
            neighbor_tensor = self.linear_in(neighbor_tensor).unsqueeze(0)
            tgt_neighbor.append(neighbor_tensor)
            tgt_nei_num.append(len(batch_data['tgt_neighbor']))

            neighbor_tensor = neighbor_list_to_tensor(batch_data['src_neighbor'], node_max_neighbor, self.gpu)
            neighbor_tensor = F.dropout(neighbor_tensor, self.dropout, training=self.training)
            neighbor_tensor = self.linear_in(neighbor_tensor).unsqueeze(0)
            src_neighbor.append(neighbor_tensor)
            src_nei_num.append(len(batch_data['src_neighbor']))
            tgt_label.append(batch_data['relation'])
        del batch, neighbor_tensor
        tgt_label = torch.FloatTensor(tgt_label)
        src_abstract, _ = list_to_tensor(src_abstract, self.in_feat, self.gpu)  # shape: [batch, bow_dim]
        src_abstract = F.dropout(src_abstract.float(), self.dropout, training=self.training)
        src_abstract = self.linear_in(src_abstract)
        tgt_abstract, _ = list_to_tensor(tgt_abstract, self.in_feat, self.gpu)  # shape: [batch, bow_dim]
        tgt_abstract = F.dropout(tgt_abstract.float(), self.dropout, training=self.training)
        tgt_abstract = self.linear_in(tgt_abstract)
        tgt_neighbor = torch.cat(tgt_neighbor, dim=0) # shape: [batch, max_neighbor, node_dim]
        src_neighbor = torch.cat(src_neighbor, dim=0) # shape: [batch, max_neighbor, node_dim]
        tgt_nei_num = torch.LongTensor(tgt_nei_num)
        src_nei_num = torch.LongTensor(src_nei_num)
        if self.gpu:
            tgt_nei_num = tgt_nei_num.cuda()
            src_nei_num = src_nei_num.cuda()
            tgt_label = tgt_label.cuda()
        _, src_abs = self.attention(src_abstract, src_neighbor, memory_lengths=src_nei_num)
        _, tgt_abs = self.attention(tgt_abstract, tgt_neighbor, memory_lengths=tgt_nei_num)

        output = torch.cat((src_abs.squeeze(1), tgt_abs.squeeze(1)), dim=1)
        if is_train == True:
            score = self.linear_out(output).squeeze(1)

            return output, score, tgt_label
        else:
            return output
