from copy import copy, deepcopy

import torch
from torch import nn as nn

class Reshape(nn.Module):
    def __init__(self, reshape_type="flat_dim0"):
        super(Reshape, self).__init__()
        self.reshape_type = reshape_type

    def forward(self, x):
        if self.reshape_type == "flat_dim0":
            B = x.size()[0]
            return x.view(B, -1)
        else:
            raise NotImplementedError("Un-supported reshape_type: {}".format(self.reshape_type))


class GatingLayer(nn.Module):
    def __init__(self, seperate_trans=0):
        super().__init__()

        # ---------------------------- Build gating layer ----------------------------

        norm_input_layer = SwitchNorm1d(num_features=1024)
        input_feat_size = 1024

        reshape_layer = Reshape(reshape_type="flat_dim0")
        output_layer = nn.Linear(in_features=input_feat_size, out_features=24)
        norm_outputs = torch.nn.BatchNorm1d(num_features=24)
        self.norm_input_layer = norm_input_layer
        self.gating = nn.Sequential(
            reshape_layer,
            output_layer,
            norm_outputs
        )

        self.seperate_trans = seperate_trans
        self.norm_input_layer
        self.gating

        self.norm_input_each_forward = True

    def forward(self, x):
        # assert len(x.size()) in [3, 4], f"Un-expected input shape for gating layer, got {len(x.size())}"
        if self.norm_input_each_forward:
            x = self.norm_input_layer(x)
        gating_score = self._forward(x, self.gating)
        trans_weight = gating_score
        return gating_score, trans_weight

    def _forward(self, x, layer):
        # predict suitable sub-blocks of base-model according to given example
        res = layer(x)
        if len(res.size()) == 3:  # [B, seq_len, block_len] for texts
            return torch.mean(res, dim=1)
        elif len(res.size()) == 2:  # [B, block_len] for images
            return nn.functional.sigmoid(res)
            # return res
        else:
            raise RuntimeError(f"Un-expected mask weights shape for gating layer, got {len(res.size())}")

class SwitchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
        super(SwitchNorm1d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.weight = nn.Parameter(torch.ones(1, num_features))
        self.bias = nn.Parameter(torch.zeros(1, num_features))
        self.mean_weight = nn.Parameter(torch.ones(2))
        self.var_weight = nn.Parameter(torch.ones(2))
        self.register_buffer('running_mean', torch.zeros(1, num_features))
        self.register_buffer('running_var', torch.zeros(1, num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.zero_()
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 2:
            raise ValueError('expected 2D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):
        self._check_input_dim(x)
        mean_ln = x.mean(1, keepdim=True)
        var_ln = x.var(1, keepdim=True)

        if self.training:
            mean_bn = x.mean(0, keepdim=True)
            var_bn = x.var(0, keepdim=True)
            if self.using_moving_average:
                self.running_mean.mul_(self.momentum)
                self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                self.running_var.mul_(self.momentum)
                self.running_var.add_((1 - self.momentum) * var_bn.data)
            else:
                self.running_mean.add_(mean_bn.data)
                self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        mean = mean_weight[0] * mean_ln + mean_weight[1] * mean_bn
        var = var_weight[0] * var_ln + var_weight[1] * var_bn

        x = (x - mean) / (var + self.eps).sqrt()
        return x * self.weight + self.bias
