import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from locked_dropout import LockedDropout
from embed_regularize import embedded_dropout
from cell import Cell
from distribution import Distribution


class OrderedMemoryRecurrent(nn.Module):
    def __init__(self, semantic_size, syntax_size, nslot, dropout=0.2, sample_structure=False, distribution='softmax'):
        super(OrderedMemoryRecurrent, self).__init__()

        slot_size = semantic_size + syntax_size

        self.semantic_act = nn.LayerNorm(semantic_size, elementwise_affine=False)
        self.syntax_act = nn.LayerNorm(syntax_size, elementwise_affine=False)

        self.distribution = Distribution(slot_size * 2, slot_size, nslot, dropout,
                                         sample=sample_structure, process=distribution)
        self.decoder_distribution = Distribution(slot_size, slot_size, nslot, dropout,
                                                 sample=sample_structure, process=distribution)

        self.cell = Cell(semantic_size, syntax_size, dropout)
        self.decoder_cell = Cell(semantic_size, syntax_size, dropout)
        self.init_slot = nn.Parameter(torch.zeros(slot_size))
        self.init_slot.data.uniform_(-0.1, 0.1)

        self.nslot = nslot
        self.slot_size = slot_size
        self.semantic_size = semantic_size
        self.syntax_size = syntax_size

    def reset_semantic_parameter(self, semantic_size):
        self.semantic_size = semantic_size
        self.slot_size = semantic_size + self.syntax_size

        self.cell.reset_semantic_parameter(semantic_size)
        self.decoder_cell.reset_semantic_parameter(semantic_size)
        self.semantic_act = nn.LayerNorm(semantic_size, elementwise_affine=False)

    def semantic_parameters(self):
        return self.cell.semantic_parameters() + self.decoder_cell.semantic_parameters()

    def syntax_parameters(self):
        return self.cell.syntax_parameters() + self.decoder_cell.syntax_parameters() \
               + list(self.distribution.parameters()) + list(self.decoder_distribution.parameters()) + [self.init_slot]

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        M = weight.new(bsz, self.nslot, self.slot_size).zero_()
        c_M = self.activation(self.init_slot)[None, None, :].expand(bsz, self.nslot, -1)
        p = weight.new(bsz, self.nslot).zero_()
        prev_structure = weight.new(bsz, 1).zero_().long() + self.nslot - 1
        return (M, c_M, p), prev_structure

    def split_semantic_syntax(self, x):
        return x.split([self.semantic_size, self.syntax_size], dim=-1)

    def detach_syntax(self, x):
        semantic, syntax = self.split_semantic_syntax(x)
        return torch.cat([semantic, syntax.detach()], dim=-1)

    def detach_semantic(self, x):
        semantic, syntax = self.split_semantic_syntax(x)
        return torch.cat([semantic.detach(), syntax], dim=-1)

    def activation(self, x):
        if self.syntax_size > 0 and self.semantic_size > 0:
            semantic, syntax = self.split_semantic_syntax(x)
            return torch.cat([self.semantic_act(semantic), self.syntax_act(syntax)], dim=-1)
        elif self.semantic_size > 0:
            return self.semantic_act(x)
        elif self.syntax_size > 0:
            return self.syntax_act(x)

    def omr_step(self, in_val, hidden, ctrl_idx):
        prev_M, prev_c_M, prev_p = hidden
        batch_size, nslot, slot_size = prev_M.size()
        _batch_size, slot_size = in_val.size()

        assert self.slot_size == slot_size
        assert self.nslot == nslot
        assert batch_size == _batch_size

        # 1 look ahead
        dist_input = torch.cat([self.detach_semantic(in_val)[:, None, :].expand(-1, self.nslot, -1),
                                self.detach_semantic(prev_c_M)], dim=-1)
        p, cp, rcp, p_predicted = self.distribution(dist_input, prev_p, ctrl_idx)
        M = prev_M * (1 - rcp)[:, :, None] + prev_c_M * rcp[:, :, None]

        M_list = []

        # Init with first value
        h = in_val
        for i in range(nslot):
            if cp[:, i].max() > 0:
                h = self.cell(h, M[:, i, :])
            h = in_val * (1 - cp)[:, i, None] + h * cp[:, i, None]
            M_list.append(h)
        c_M = torch.stack(M_list, dim=1)

        hidden = (M, c_M, p)
        return hidden, p_predicted

    def predict_future(self, M_array, c_M_array, structure_p, trg_idx):
        batch_size, length, _, _ = M_array.size()
        M = M_array.view(-1, self.nslot, self.slot_size)
        c_M = c_M_array.view(-1, self.nslot, self.slot_size)
        p = structure_p.view(-1, self.nslot)

        output_p, cq, rcq, q_array = self.decoder_distribution(self.detach_semantic(c_M), p,
                                                               trg_idx.reshape(-1) if trg_idx is not None else None)

        encoded = M * (1 - rcq)[:, :, None] + c_M * rcq[:, :, None]
        # encoded = self.detach_syntax(encoded)
        encoded = encoded.flip([1])
        h = encoded[:, 0, :]
        h_list = [h]
        for i in range(1, self.nslot):
            h = self.decoder_cell(encoded[:, i, :], h)
            h_list.append(h)
        rnn_encoded = torch.stack(h_list, dim=1)
        rnn_encoded = rnn_encoded.flip([1])

        # h = torch.zeros_like(M[:, 0, :])
        # h_list = [h]
        # for i in range(self.nslot - 1, 0, -1):
        #     h = self.decoder_cell(M[:, i, :], h)
        #     h_list.append(h)
        # rnn_encoded = torch.stack(h_list, dim=1)
        # rnn_encoded = rnn_encoded.flip([1])
        # rnn_encoded = self.decoder_cell(c_M, rnn_encoded)

        output = torch.bmm(output_p.view(batch_size * length, 1, self.nslot), rnn_encoded)

        output = output.view(batch_size, length, -1)
        output_p = output_p.view(batch_size, length, self.nslot)
        q_array = q_array.view(batch_size, length, self.nslot)

        return output, output_p, q_array

    def forward(self, X, hidden, ctrl_idx, trg_idx, sos_mark):
        batch_size, length, _ = X.size()
        self.cell.generate_weight()
        self.decoder_cell.generate_weight()

        M_list = []
        c_M_list = []
        structure_p_list = []
        p_list = []

        X = self.activation(X)
        for t in range(X.size(1)):
            if sos_mark is not None:
                M, c_M, structure_p = hidden
                M = M * sos_mark[:, t, None, None]
                c_M = c_M * sos_mark[:, t, None, None] + self.activation(self.init_slot)[None, None, :] * (1 - sos_mark[:, t, None, None])
                structure_p = structure_p * sos_mark[:, t, None]
                hidden = (M, c_M, structure_p)

            hidden, p_predicted = \
                self.omr_step(X[:, t], hidden,
                              ctrl_idx[:, t] if ctrl_idx is not None else None)

            M, c_M, structure_p = hidden
            M_list.append(M)
            c_M_list.append(c_M)
            structure_p_list.append(structure_p)
            p_list.append(p_predicted)

        M_array = torch.stack(M_list, dim=1)
        c_M_array = torch.stack(c_M_list, dim=1)
        structure_p = torch.stack(structure_p_list, dim=1)
        p_array = torch.stack(p_list, dim=1)

        output, output_p, q_array = self.predict_future(M_array, c_M_array, structure_p, trg_idx)

        probs = (structure_p, output_p, p_array, q_array)
        return output, probs, hidden


class OrderedMemory(nn.Module):
    def __init__(self, semantic_size, syntax_size, nslot, ntokens,
                 dropoute=0.1, dropout=0.2, dropouto=0.5,
                 sos=None, sample_structure=False, distribution='softmax'):
        super(OrderedMemory, self).__init__()

        self.lockdrop = LockedDropout()
        self.drop = nn.Dropout(dropout)

        self.semantic_encoder = nn.Embedding(ntokens, semantic_size)
        self.syntax_encoder = nn.Embedding(ntokens, syntax_size)
        self.OM_forward = OrderedMemoryRecurrent(semantic_size, syntax_size, nslot, dropout,
                                                 sample_structure=sample_structure, distribution=distribution)
        self.decoder_bias = nn.Parameter(torch.Tensor(ntokens))

        self.init_weights()

        self.nslot = nslot
        self.dropoute = dropoute
        self.dropouto = dropouto
        self.sos = sos

    def reset_semantic_parameter(self, semantic_size):
        self.OM_forward.reset_semantic_parameter(semantic_size)
        self.semantic_encoder = nn.Embedding(self.semantic_encoder.num_embeddings, semantic_size)
        initrange = 0.1
        self.semantic_encoder.weight.data.uniform_(-initrange, initrange)

    def semantic_parameters(self):
        return self.OM_forward.semantic_parameters() + list(self.semantic_encoder.parameters()) + [self.decoder_bias]

    def syntax_parameters(self):
        return self.OM_forward.syntax_parameters() + list(self.syntax_encoder.parameters())

    def init_weights(self):
        initrange = 0.1
        self.semantic_encoder.weight.data.uniform_(-initrange, initrange)
        self.syntax_encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder_bias.data.fill_(0)

    def init_hidden(self, bsz):
        return self.OM_forward.init_hidden(bsz)

    def forward(self, X, hidden, ctrl_idx=None, trg_idx=None):
        batch_size, length = X.size()

        if self.sos is not None:
            sos_mark = (X != self.sos).float()
        else:
            sos_mark = None

        emb_weight = torch.cat([self.semantic_encoder.weight, self.syntax_encoder.weight], dim=1)
        emb = embedded_dropout(
            self.semantic_encoder, emb_weight, X,
            dropout=self.dropoute if self.training else 0
        )

        raw_output, probs, hidden = self.OM_forward(emb, hidden, ctrl_idx, trg_idx, sos_mark)

        raw_output = self.lockdrop(raw_output, self.dropouto)

        output = F.linear(raw_output, emb_weight, self.decoder_bias)

        return output.view(batch_size * length, -1), probs, hidden


# class OrderedMemoryRecurrent(nn.Module):
#     def __init__(self, semantic_size, syntax_size, nslot, dropout=0.2, sample_structure=False):
#         super(OrderedMemoryRecurrent, self).__init__()
#
#         slot_size = semantic_size + syntax_size
#
#         self.semantic_act = nn.LayerNorm(semantic_size, elementwise_affine=False)
#         self.syntax_act = nn.LayerNorm(syntax_size, elementwise_affine=False)
#
#         self.distribution = Distribution(slot_size * 2, slot_size, nslot, dropout, sample=sample_structure)
#
#         self.cell = Cell(semantic_size, syntax_size, dropout)
#         self.init_slot = nn.Parameter(torch.zeros(slot_size))
#         self.init_slot.data.uniform_(-0.1, 0.1)
#
#         self.nslot = nslot
#         self.slot_size = slot_size
#         self.semantic_size = semantic_size
#         self.syntax_size = syntax_size
#
#     def reset_semantic_parameter(self, semantic_size):
#         self.semantic_size = semantic_size
#         self.slot_size = semantic_size + self.syntax_size
#
#         self.cell.reset_semantic_parameter(semantic_size)
#         self.decoder_cell.reset_semantic_parameter(semantic_size)
#         self.semantic_act = nn.LayerNorm(semantic_size, elementwise_affine=False)
#
#     def semantic_parameters(self):
#         return self.cell.semantic_parameters()
#
#     def syntax_parameters(self):
#         return self.cell.syntax_parameters() \
#                + list(self.distribution.parameters()) + [self.init_slot]
#
#     def init_hidden(self, bsz):
#         weight = next(self.parameters()).data
#         M = weight.new(bsz, self.nslot, self.slot_size).zero_()
#         c_M = self.activation(self.init_slot)[None, None, :].expand(bsz, self.nslot, -1)
#         p = weight.new(bsz, self.nslot).zero_()
#         prev_structure = weight.new(bsz, 1).zero_().long() + self.nslot - 1
#         return (M, c_M, p), prev_structure
#
#     def split_semantic_syntax(self, x):
#         return x.split([self.semantic_size, self.syntax_size], dim=-1)
#
#     def detach_syntax(self, x):
#         semantic, syntax = self.split_semantic_syntax(x)
#         return torch.cat([semantic, syntax.detach()], dim=-1)
#
#     def detach_semantic(self, x):
#         semantic, syntax = self.split_semantic_syntax(x)
#         return torch.cat([semantic.detach(), syntax], dim=-1)
#
#     def activation(self, x):
#         if self.syntax_size > 0 and self.semantic_size > 0:
#             semantic, syntax = self.split_semantic_syntax(x)
#             return torch.cat([self.semantic_act(semantic), self.syntax_act(syntax)], dim=-1)
#         elif self.semantic_size > 0:
#             return self.semantic_act(x)
#         elif self.syntax_size > 0:
#             return self.syntax_act(x)
#
#     def omr_step(self, in_val, hidden, ctrl_idx):
#         prev_M, prev_c_M, prev_p = hidden
#         batch_size, nslot, slot_size = prev_M.size()
#         _batch_size, slot_size = in_val.size()
#
#         assert self.slot_size == slot_size
#         assert self.nslot == nslot
#         assert batch_size == _batch_size
#
#         # 1 look ahead
#         dist_input = torch.cat([self.detach_semantic(in_val)[:, None, :].expand(-1, self.nslot, -1),
#                                 self.detach_semantic(prev_c_M)], dim=-1)
#         p, cp, rcp, p_predicted = self.distribution(dist_input, prev_p, ctrl_idx)
#         M = prev_M * (1 - rcp)[:, :, None] + prev_c_M * rcp[:, :, None]
#
#         M_list = []
#
#         # Init with first value
#         h = in_val
#         for i in range(nslot):
#             if cp[:, i].max() > 0:
#                 h = self.cell(h, M[:, i, :])
#             h = in_val * (1 - cp)[:, i, None] + h * cp[:, i, None]
#             M_list.append(h)
#         c_M = torch.stack(M_list, dim=1)
#
#         hidden = (M, c_M, p)
#         return hidden, p_predicted
#
#     def forward(self, X, hidden, ctrl_idx, trg_idx, sos_mark):
#         batch_size, length, _ = X.size()
#         self.cell.generate_weight()
#
#         M_list = []
#         c_M_list = []
#         structure_p_list = []
#         p_list = []
#
#         X = self.activation(X)
#         for t in range(X.size(1)):
#             if sos_mark is not None:
#                 M, c_M, structure_p = hidden
#                 M = M * sos_mark[:, t, None, None]
#                 c_M = c_M * sos_mark[:, t, None, None] + self.activation(self.init_slot)[None, None, :] * (1 - sos_mark[:, t, None, None])
#                 structure_p = structure_p * sos_mark[:, t, None]
#                 hidden = (M, c_M, structure_p)
#
#             hidden, p_predicted = \
#                 self.omr_step(X[:, t], hidden,
#                               ctrl_idx[:, t] if ctrl_idx is not None else None)
#
#             M, c_M, structure_p = hidden
#             M_list.append(M)
#             c_M_list.append(c_M)
#             structure_p_list.append(structure_p)
#             p_list.append(p_predicted)
#
#         M_array = torch.stack(M_list, dim=1)
#         c_M_array = torch.stack(c_M_list, dim=1)
#         structure_p = torch.stack(structure_p_list, dim=1)
#         p_array = torch.stack(p_list, dim=1)
#
#         output = c_M_array[:, :, -1]
#         output_p = torch.zeros_like(structure_p)
#         output_p[:, :, -1] = 1
#         q_array = output_p
#
#         probs = (structure_p, output_p, p_array, q_array)
#         return output, probs, hidden
#
#
# class OrderedMemory(nn.Module):
#     def __init__(self, semantic_size, syntax_size, nslot, ntokens,
#                  dropoute=0.1, dropout=0.2, dropouto=0.5,
#                  sos=None, sample_structure=False):
#         super(OrderedMemory, self).__init__()
#
#         self.lockdrop = LockedDropout()
#         self.drop = nn.Dropout(dropout)
#
#         self.semantic_encoder = nn.Embedding(ntokens, semantic_size)
#         self.syntax_encoder = nn.Embedding(ntokens, syntax_size)
#         self.OM_forward = OrderedMemoryRecurrent(semantic_size, syntax_size, nslot,
#                                                  dropout=dropout, sample_structure=sample_structure)
#         self.decoder_bias = nn.Parameter(torch.Tensor(ntokens))
#
#         self.init_weights()
#
#         self.nslot = nslot
#         self.dropoute = dropoute
#         self.dropouto = dropouto
#         self.sos = sos
#
#     def reset_semantic_parameter(self, semantic_size):
#         self.OM_forward.reset_semantic_parameter(semantic_size)
#         self.semantic_encoder = nn.Embedding(self.semantic_encoder.num_embeddings, semantic_size)
#         initrange = 0.1
#         self.semantic_encoder.weight.data.uniform_(-initrange, initrange)
#
#     def semantic_parameters(self):
#         return self.OM_forward.semantic_parameters() + list(self.semantic_encoder.parameters()) + [self.decoder_bias]
#
#     def syntax_parameters(self):
#         return self.OM_forward.syntax_parameters() + list(self.syntax_encoder.parameters())
#
#     def init_weights(self):
#         initrange = 0.1
#         self.semantic_encoder.weight.data.uniform_(-initrange, initrange)
#         self.syntax_encoder.weight.data.uniform_(-initrange, initrange)
#         self.decoder_bias.data.fill_(0)
#
#     def init_hidden(self, bsz):
#         return self.OM_forward.init_hidden(bsz)
#
#     def forward(self, X, hidden, ctrl_idx=None, trg_idx=None):
#         batch_size, length = X.size()
#
#         if self.sos is not None:
#             sos_mark = (X != self.sos).float()
#         else:
#             sos_mark = None
#
#         emb_weight = torch.cat([self.semantic_encoder.weight, self.syntax_encoder.weight], dim=1)
#         emb = embedded_dropout(
#             self.semantic_encoder, emb_weight, X,
#             dropout=self.dropoute if self.training else 0
#         )
#
#         raw_output, probs, hidden = self.OM_forward(emb, hidden, ctrl_idx, trg_idx, sos_mark)
#
#         raw_output = self.lockdrop(raw_output, self.dropouto)
#
#         output = F.linear(raw_output, emb_weight, self.decoder_bias)
#
#         return output.view(batch_size * length, -1), probs, hidden
