# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class Biaffine(nn.Module):
    def __init__(self, inputs1_size, inputs2_size, output_size,
                 input_bias1=True, input_bias2=False, output_bias=False,
                 dropout=0.0):
        super().__init__()

        self.output_size = output_size

        self.input_bias1 = input_bias1
        self.input_bias2 = input_bias2
        self.output_bias = output_bias

        self.weights = nn.Parameter(torch.zeros(inputs1_size + input_bias1,
                                                output_size * (inputs2_size + input_bias2)))

        if output_bias:
            self.bias = nn.Parameter(torch.zeros(1, output_size, 1))

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

        self.reset_parameters()

    def extra_repr(self):
        return f'output_size={self.output_size}'

    def reset_parameters(self):
        init.orthogonal_(self.weights.data)

    def forward(self, inputs1, inputs2):
        batch_size, length1, feature_size1 = inputs1.shape
        _, length2, feature_size2 = inputs2.shape

        if self.dropout is not None:
            ones = inputs1.new_ones((batch_size, 1, feature_size1))
            noise1 = self.dropout(ones)

            ones = inputs2.new_ones((batch_size, 1, feature_size2))
            noise2 = self.dropout(ones)

            inputs1 = inputs1 * noise1
            inputs2 = inputs2 * noise2

        bias1 = self.input_bias1
        bias2 = self.input_bias2

        if bias1:
            inputs1 = F.pad(inputs1, [0, 1, 0, 0, 0, 0], value=1).contiguous()
            feature_size1 += 1

        if bias2:
            inputs2 = F.pad(inputs2, [0, 1, 0, 0, 0, 0], value=1).contiguous()
            feature_size2 += 1

        # Do the multiplications
        # b: batch_size; n1: length1; o: output_size; d1: feature_size1; d2: feature_size2
        # (b n1 x 1 x d1) x (d1 x o d2) -> (b n1 x 1 x o d2)
        lin = torch.matmul(inputs1.view(-1, 1, feature_size1), self.weights)
        lin = lin.view(batch_size, -1, feature_size2)

        # (b x n1 o x d2) x (b x n2 x d2)T -> (b x n1 o x n2)
        bilin = torch.matmul(lin, inputs2.transpose(1, 2))
        bilin = bilin.view(batch_size, length1, -1, length2)

        if self.output_bias:
            bilin += self.bias

        return bilin


class DiagonalBiaffine(nn.Module):
    def __init__(self, input_size, output_size, dropout=0):
        super().__init__()

        self.output_size = output_size
        self.weights = nn.Parameter(torch.zeros(output_size, input_size))
        self.weights1 = nn.Parameter(torch.zeros(input_size, output_size))
        self.weights2 = nn.Parameter(torch.zeros(input_size, output_size))
        self.bias = nn.Parameter(torch.zeros(1, output_size, 1))

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def extra_repr(self):
        return f'output_size={self.output_size}'

    def forward(self, inputs1, inputs2):
        assert inputs1.size(-1) == inputs2.size(-1)
        batch_size, length1, input_size = inputs1.shape
        batch_size, length2, _ = inputs2.shape

        if self.dropout is not None:
            ones = inputs1.new_ones((batch_size, 1, input_size))
            noise1 = self.dropout(ones)
            noise2 = self.dropout(ones)

            inputs1 = inputs1 * noise1
            inputs2 = inputs2 * noise2

        # (b n1 x i) x (d x o) -> (b n1 x i)
        lin_inputs1 = torch.matmul(inputs1.view(-1, input_size), self.weights1)
        # (b, n1, o, 1)
        lin_inputs1 = lin_inputs1.view(batch_size, length1, -1, 1)

        # (b n2 x i) x (d x o) -> (b n2 x i)
        lin_inputs2 = torch.matmul(inputs2.view(-1, input_size), self.weights2)
        # (b, 1, o, n2)
        lin_inputs2 = lin_inputs2.view(batch_size, length2, -1).transpose(1, 2).unsqueeze(1)

        # (b, n1, 1, i) * (o, i) -> (b, n1, o, i)
        bilin = (inputs1.unsqueeze(2) * self.weights).view(batch_size, -1, input_size)
        # (b, n1 x o, i) x (b, i, n2) -> (b, n1, o, n2)
        bilin = torch.bmm(bilin, inputs2.transpose(1, 2))
        bilin = bilin.view(batch_size, length1, -1, length2)

        bilin += lin_inputs1
        bilin += lin_inputs2
        bilin += self.bias

        return bilin

    def reset_parameters(self):
        init.xavier_normal_(self.weights)
        init.xavier_normal_(self.weights1)
        init.xavier_normal_(self.weights2)
