import math

import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F


class Linear(nn.Module):
    def __init__(self, input_semantic_size, input_syntax_size,
                 output_semantic_size, output_syntax_size):
        super(Linear, self).__init__()

        self.input_semantic_size = input_semantic_size
        self.input_syntax_size = input_syntax_size
        self.output_semantic_size = output_semantic_size
        self.output_syntax_size = output_syntax_size
        self.input_size = input_semantic_size + input_syntax_size
        self.output_size = output_semantic_size + output_syntax_size

        self.sem_weight = nn.Parameter(torch.Tensor(output_semantic_size, self.input_size))
        self.syn_weight = nn.Parameter(torch.Tensor(output_syntax_size, input_syntax_size))
        self.sem_bias = nn.Parameter(torch.Tensor(output_semantic_size))
        self.syn_bias = nn.Parameter(torch.Tensor(output_syntax_size))

        self.weight = None
        self.bias = None

        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.sem_weight, a=math.sqrt(5))
        fan_in = self.input_size
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.sem_bias, -bound, bound)

        if self.output_syntax_size > 0 and self.input_syntax_size > 0:
            init.kaiming_uniform_(self.syn_weight, a=math.sqrt(5))
            fan_in = self.input_syntax_size
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.syn_bias, -bound, bound)

    def reset_semantic_parameter(self, input_semantic_size, output_semantic_size):
        self.input_semantic_size = input_semantic_size
        self.output_semantic_size = output_semantic_size
        self.input_size = self.input_semantic_size + self.input_syntax_size
        self.output_size = self.output_semantic_size + self.output_syntax_size

        self.sem_weight = nn.Parameter(torch.Tensor(output_semantic_size, self.input_size))
        self.sem_bias = nn.Parameter(torch.Tensor(output_semantic_size))

        init.kaiming_uniform_(self.sem_weight, a=math.sqrt(5))
        fan_in = self.input_size
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.sem_bias, -bound, bound)

    def generate_weight(self):
        self.weight = torch.cat([self.sem_weight,
                                 F.pad(self.syn_weight, (self.input_semantic_size, 0), value=0.)], dim=0)
        self.bias = torch.cat([self.sem_bias, self.syn_bias], dim=0)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        # return LinearFunction.apply(input, self.weight, self.bias, self.bpmask)
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_size={}, output_size={}, bias={}'.format(
            self.input_size, self.output_size, self.bias is not None
        )

    def semantic_parameters(self):
        return [self.sem_weight, self.sem_bias]

    def syntax_parameters(self):
        return [self.syn_weight, self.syn_bias]