import torch
import torch.nn as nn
import torch.nn.functional as F

class MetricNN_V2(nn.Module):
    """docstring for MetricNN_V2."""

    def __init__(self, input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size):
        super(MetricNN_V2, self).__init__()
        self.linear_caption = nn.Linear(input_cptn_size, hidden_sizes[0],
                                        bias=True)
        self.caption_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.caption_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.caption_mlp.append(nn.ReLU(inplace=True))
        self.caption_mlp = nn.Sequential(*self.caption_mlp)

        self.linear_label = nn.Linear(input_label_size, hidden_sizes[0],
                                        bias=True)
        self.label_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.label_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.label_mlp.append(nn.ReLU(inplace=True))
        self.label_mlp = nn.Sequential(*self.label_mlp)

        # self.linear_visual = nn.Linear(input_visual_size, hidden_sizes[0],
        #                                 bias=True)
        # self.visual_mlp = []
        # for idx in range(len(hidden_sizes)-1):
        #     self.visual_mlp.append(nn.Linear(hidden_sizes[idx],
        #                                     hidden_sizes[idx+1],
        #                                     bias=True))
        #     self.visual_mlp.append(nn.ReLU(inplace=True))
        # self.visual_mlp = nn.Sequential(*self.visual_mlp)

        # self.linear_summarize = nn.Linear(5*hidden_sizes[-1], hidden_sizes[0],
        #                                     bias=True)
        self.linear_summarize = nn.Linear(4*hidden_sizes[-1], hidden_sizes[0],
                                            bias=True)

        self.score_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.score_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.score_mlp.append(nn.ReLU(inplace=True))
        self.score_mlp = nn.Sequential(*self.score_mlp)

        self.linear_score = nn.Linear(hidden_sizes[-1], output_size,
                                        bias=True)

    def forward(self, c1, l1, c2, l2, v1):
        # Caption linear embedding
        c1_out = self.linear_caption(c1)
        c1_out = self.caption_mlp(c1_out)
        c2_out = self.linear_caption(c2)
        c2_out = self.caption_mlp(c2_out)

        # Label linear embedding
        l1_out = self.linear_label(l1)
        l1_out = self.label_mlp(l1_out)
        l2_out = self.linear_label(l2)
        l2_out = self.label_mlp(l2_out)

        # # Visual linear embedding
        # v1_out = self.linear_visual(v1)
        # v1_out = self.visual_mlp(v1_out)

        # Score calculation
        # summary = self.linear_summarize(torch.cat((c1_out, c2_out, l1_out, l2_out, v1_out),axis=1))
        summary = self.linear_summarize(torch.cat((c1_out, c2_out, l1_out, l2_out),axis=1))
        mlp_out = self.score_mlp(summary)
        score = self.linear_score(mlp_out)

        return score


class MetricNN(nn.Module):
    """docstring for MetricNN."""

    def __init__(self, input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size):
        super(MetricNN, self).__init__()
        self.linear_caption = nn.Linear(input_cptn_size, hidden_sizes[0],
                                        bias=True)
        self.linear_label = nn.Linear(input_label_size, hidden_sizes[0],
                                        bias=True)
        self.linear_visual = nn.Linear(input_visual_size, hidden_sizes[0],
                                        bias=True)

        self.linear_summarize = nn.Linear(5*hidden_sizes[0], hidden_sizes[0],
                                            bias=True)
        self.score_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.score_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.score_mlp.append(nn.ReLU(inplace=True))

        self.score_mlp = nn.Sequential(*self.score_mlp)

        self.linear_score = nn.Linear(hidden_sizes[-1], output_size,
                                        bias=True)

    def forward(self, c1, l1, c2, l2, v1):
        # Caption linear embedding
        c1_out = self.linear_caption(c1)
        c2_out = self.linear_caption(c2)

        # Label linear embedding
        l1_out = self.linear_label(l1)
        l2_out = self.linear_label(l2)

        # Visual linear embedding
        v1_out = self.linear_visual(v1)

        # Score calculation
        summary = self.linear_summarize(torch.cat((c1_out, c2_out, l1_out, l2_out, v1_out),axis=1))
        mlp_out = self.score_mlp(summary)
        score = self.linear_score(mlp_out)

        return score

'''
###############################################################################

                            ABLATION EXPERIMENTS

###############################################################################
'''
class MetricNN_No_Visual(nn.Module):
    """docstring for MetricNN."""

    def __init__(self, input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size):
        super(MetricNN_No_Visual, self).__init__()
        self.linear_caption = nn.Linear(input_cptn_size, hidden_sizes[0],
                                        bias=True)
        self.linear_label = nn.Linear(input_label_size, hidden_sizes[0],
                                        bias=True)
        # self.linear_visual = nn.Linear(input_visual_size, hidden_sizes[0],
        #                                 bias=True)

        self.linear_summarize = nn.Linear(4*hidden_sizes[0], hidden_sizes[0],
                                            bias=True)
        self.score_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.score_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.score_mlp.append(nn.ReLU(inplace=True))

        self.score_mlp = nn.Sequential(*self.score_mlp)

        self.linear_score = nn.Linear(hidden_sizes[-1], output_size,
                                        bias=True)

    def forward(self, c1, l1, c2, l2, v1):
        # Caption linear embedding
        c1_out = self.linear_caption(c1)
        c2_out = self.linear_caption(c2)

        # Label linear embedding
        l1_out = self.linear_label(l1)
        l2_out = self.linear_label(l2)

        # Visual linear embedding
        # v1_out = self.linear_visual(v1)

        # Score calculation
        summary = self.linear_summarize(torch.cat((c1_out, c2_out, l1_out, l2_out),axis=1))
        mlp_out = self.score_mlp(summary)
        score = self.linear_score(mlp_out)

        return score

class MetricNN_No_Label(nn.Module):
    """docstring for MetricNN."""

    def __init__(self, input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size):
        super(MetricNN_No_Label, self).__init__()
        self.linear_caption = nn.Linear(input_cptn_size, hidden_sizes[0],
                                        bias=True)
        # self.linear_label = nn.Linear(input_label_size, hidden_sizes[0],
        #                                 bias=True)
        self.linear_visual = nn.Linear(input_visual_size, hidden_sizes[0],
                                        bias=True)

        self.linear_summarize = nn.Linear(3*hidden_sizes[0], hidden_sizes[0],
                                            bias=True)
        self.score_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.score_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.score_mlp.append(nn.ReLU(inplace=True))

        self.score_mlp = nn.Sequential(*self.score_mlp)

        self.linear_score = nn.Linear(hidden_sizes[-1], output_size,
                                        bias=True)

    def forward(self, c1, l1, c2, l2, v1):
        # Caption linear embedding
        c1_out = self.linear_caption(c1)
        c2_out = self.linear_caption(c2)

        # Label linear embedding
        # l1_out = self.linear_label(l1)
        # l2_out = self.linear_label(l2)

        # Visual linear embedding
        v1_out = self.linear_visual(v1)

        # Score calculation
        summary = self.linear_summarize(torch.cat((c1_out, c2_out, v1_out),axis=1))
        mlp_out = self.score_mlp(summary)
        score = self.linear_score(mlp_out)

        return score

class MetricNN_No_Label_No_Visual(nn.Module):
    """docstring for MetricNN."""

    def __init__(self, input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size):
        super(MetricNN_No_Label_No_Visual, self).__init__()
        self.linear_caption = nn.Linear(input_cptn_size, hidden_sizes[0],
                                        bias=True)
        # self.linear_label = nn.Linear(input_label_size, hidden_sizes[0],
        #                                 bias=True)
        # self.linear_visual = nn.Linear(input_visual_size, hidden_sizes[0],
        #                                 bias=True)

        self.linear_summarize = nn.Linear(2*hidden_sizes[0], hidden_sizes[0],
                                            bias=True)
        self.score_mlp = []
        for idx in range(len(hidden_sizes)-1):
            self.score_mlp.append(nn.Linear(hidden_sizes[idx],
                                            hidden_sizes[idx+1],
                                            bias=True))
            self.score_mlp.append(nn.ReLU(inplace=True))

        self.score_mlp = nn.Sequential(*self.score_mlp)

        self.linear_score = nn.Linear(hidden_sizes[-1], output_size,
                                        bias=True)

    def forward(self, c1, l1, c2, l2, v1):
        # Caption linear embedding
        c1_out = self.linear_caption(c1)
        c2_out = self.linear_caption(c2)

        # # Label linear embedding
        # l1_out = self.linear_label(l1)
        # l2_out = self.linear_label(l2)
        #
        # # Visual linear embedding
        # v1_out = self.linear_visual(v1)

        # Score calculation
        summary = self.linear_summarize(torch.cat((c1_out, c2_out),axis=1))
        mlp_out = self.score_mlp(summary)
        score = self.linear_score(mlp_out)

        return score


def create_net(input_cptn_size, input_label_size, input_visual_size,
                hidden_sizes, output_size, lr):
    # 1. Create the model
    net = MetricNN(input_cptn_size, input_label_size, input_visual_size,
                    hidden_sizes, output_size)
    # net = MetricNN_V2(input_cptn_size, input_label_size, input_visual_size,
    #                 hidden_sizes, output_size)
    # net = MetricNN_No_Visual(input_cptn_size, input_label_size, input_visual_size,
    #                 hidden_sizes, output_size)
    # net = MetricNN_No_Label(input_cptn_size, input_label_size, input_visual_size,
    #                 hidden_sizes, output_size)
    # net = MetricNN_No_Label_No_Visual(input_cptn_size, input_label_size, input_visual_size,
    #                 hidden_sizes, output_size)

    # 2. Weight Initialization
    net.apply(xavier_init_weights)

    # 3. MSE Loss as criterion
    criterion = nn.MSELoss()

    # 4. Optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    # 5. LR Scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.01)

    return net, criterion, optimizer, lr_scheduler


def xavier_init_weights(layer):
    if type(layer) == nn.Linear:
        nn.init.xavier_uniform_(layer.weight)
