import torch
from torch import nn
from models.common import GradReverse


class Fudan(nn.Module):
    def __init__(self, cfg, embeddings):
        super(Fudan, self).__init__()
        self.cfg = cfg
        self.config = cfg.model
        self.embeddings = nn.Embedding(*embeddings.shape)
        self.embeddings.weight = nn.Parameter(embeddings, requires_grad=False)
        self.n = len(cfg.data['tasks'])
        self.dropout = nn.Dropout(0.5) # self.config.dropout
        convs = []
        fcs = []
        for i in range(self.n + 1):
            convs.append(nn.Sequential(
                nn.Conv1d(in_channels=self.config.embed_size, out_channels=self.config.num_channels, kernel_size=self.config.kernel_size[0]),
                nn.ReLU(),
                nn.MaxPool1d(self.config.max_sen_len - self.config.kernel_size[0] + 1)
            ))
            if i == self.n:
                fcs.append(nn.Linear(self.config.num_channels, len(cfg.data['tasks'])))
            else:
                fcs.append(nn.Linear(self.config.num_channels * 2, self.config.output_size))
        self.convs = nn.ModuleList(convs)
        self.fcs = nn.ModuleList(fcs)

    def forward(self, x, y, t_label):
        criterion = nn.CrossEntropyLoss()
        # x.shape = (max_sen_len, batch_size)
        embedded = self.embeddings(x).permute(1, 2, 0)
        embedded = self.dropout(embedded)
        # embedded_sent.shape = (batch_size=64,embed_size=300,max_sen_len=20)
        share_out = self.convs[self.n](embedded).squeeze(2)
        # dis loss
        t_ = self.fcs[self.n](self.dropout((GradReverse.apply(share_out))))
        d_loss = criterion(t_, t_label)
        c_losses = []
        r_losses = []
        for i in range(self.n):
            is_t = t_label == i
            embedded_t = embedded[is_t]
            # share_out_task and task_out
            task_out = self.convs[i](embedded_t).squeeze(2)
            share_out_task = share_out[is_t]
            feature = torch.cat([share_out_task, task_out], 1)
            # classification loss
            y_ = self.fcs[i](self.dropout(feature))
            c_losses.append(criterion(y_, y[is_t]))
            # orthogonality loss
            normed_task = task_out - task_out.mean(0)
            normed_task = normed_task / normed_task.norm(p=2, dim=1).unsqueeze(1)
            normed_share = share_out_task - share_out_task.mean(0)
            normed_share = normed_share / normed_share.norm(p=2, dim=1).unsqueeze(1)
            corr = normed_task.t().mm(normed_share)
            r_losses.append(max((corr*corr).mean() * 0.01, 0))
        return c_losses, d_loss, r_losses

    def val(self, x, t_label):
        embedded = self.embeddings(x).permute(1, 2, 0)
        share_out = self.convs[self.n](embedded).squeeze(2)
        task_out = self.convs[t_label](embedded).squeeze(2)
        feature = torch.cat([share_out, task_out], 1)
        y_ = self.fcs[t_label](self.dropout(feature))
        return y_
