import torch
from torch import nn
from torch.nn import Module, init, Parameter
import torch.nn.functional as F
import math

from linear import Linear


class Cell(Module):
    def __init__(self, semantic_size, syntax_size, dropout):
        super(Cell, self).__init__()

        hidden_size = semantic_size + syntax_size
        self.semantic_size = semantic_size
        self.syntax_size = syntax_size
        self.hidden_size = hidden_size
        self.cell_hidden_size = 4 * hidden_size

        self.input_t = nn.Sequential(
            Linear(semantic_size * 2, syntax_size * 2, semantic_size * 4, syntax_size * 4),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            Linear(semantic_size * 4, syntax_size * 4, semantic_size * 4, syntax_size * 4),
        )

        self.semantic_gating = nn.Sequential(
            nn.LayerNorm(semantic_size * 3),
            nn.Sigmoid()
        )

        self.syntax_gating = nn.Sequential(
            nn.LayerNorm(syntax_size * 3),
            nn.Sigmoid()
        )

        self.semantic_act = nn.LayerNorm(semantic_size, elementwise_affine=False)
        self.syntax_act = nn.LayerNorm(syntax_size, elementwise_affine=False)

        self.drop = nn.Dropout(dropout)

    def generate_weight(self):
        self.input_t[0].generate_weight()
        self.input_t[3].generate_weight()

    def semantic_parameters(self):
        return self.input_t[0].semantic_parameters() + self.input_t[3].semantic_parameters() \
               + list(self.semantic_gating.parameters())

    def syntax_parameters(self):
        return self.input_t[0].syntax_parameters() + self.input_t[3].syntax_parameters() \
               + list(self.syntax_gating.parameters())

    def reset_semantic_parameter(self, semantic_size):
        self.semantic_size = semantic_size
        self.hidden_size = semantic_size + self.syntax_size
        self.cell_hidden_size = 4 * self.hidden_size

        self.input_t[0].reset_semantic_parameter(semantic_size * 2, semantic_size * 4)
        self.input_t[3].reset_semantic_parameter(semantic_size * 4, semantic_size * 4)

    def gated_sum(self, gates, cell, vi, hi):
        vg, hg, cg = gates.chunk(3, dim=-1)
        output = vg * vi + hg * hi + cg * cell
        return output

    def forward(self, vi, hi, drop_vi=True):
        vi_drop, hi_drop = self.drop(vi), self.drop(hi)
        vi_drop_sem, vi_drop_syn = vi_drop.split((self.semantic_size, self.syntax_size), dim=-1)
        hi_drop_sem, hi_drop_syn = hi_drop.split((self.semantic_size, self.syntax_size), dim=-1)
        input = torch.cat([vi_drop_sem, hi_drop_sem, vi_drop_syn, hi_drop_syn], dim=-1)

        g_sem, cell_sem, g_syn, cell_syn = self.input_t(input).split(
            (self.semantic_size * 3, self.semantic_size, self.syntax_size * 3, self.syntax_size),
            dim=-1
        )
        g_sem = self.semantic_gating(g_sem)
        g_syn = self.syntax_gating(g_syn)

        if drop_vi:
            vi_sem, vi_syn = vi_drop.split((self.semantic_size, self.syntax_size), dim=-1)
            hi_sem, hi_syn = hi.split((self.semantic_size, self.syntax_size), dim=-1)
        else:
            vi_sem, vi_syn = vi.split((self.semantic_size, self.syntax_size), dim=-1)
            hi_sem, hi_syn = hi.split((self.semantic_size, self.syntax_size), dim=-1)


        if self.syntax_size > 0 and self.semantic_size > 0:
            output_sem = self.semantic_act(self.gated_sum(g_sem, cell_sem, vi_sem, hi_sem))
            output_syn = self.syntax_act(self.gated_sum(g_syn, cell_syn, vi_syn, hi_syn))
            return torch.cat([output_sem, output_syn], dim=-1)
        elif self.semantic_size > 0:
            output_sem = self.semantic_act(self.gated_sum(g_sem, cell_sem, vi_sem, hi_sem))
            return output_sem
        elif self.syntax_size > 0:
            output_syn = self.syntax_act(self.gated_sum(g_syn, cell_syn, vi_syn, hi_syn))
            return output_syn