import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from unilm.modeling import BertEncoder, BertLayer


class BertOutput(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size)

    def forward(self, hidden_states, input_tensor):
        return (self.dense(hidden_states) + input_tensor, )


def _replace_output_layer(layer, input_size, output_size):
    layer_norm = layer.layer_norm_approximation
    new_layer = BertOutput(input_size, output_size)
    new_layer.dense = layer.dense
    return new_layer, layer_norm


def _squash_dense_layer(dense_layer, layer_norm_list):
    device = dense_layer.weight.device
    input_size = dense_layer.weight.size()[1]

    wb = torch.eye(input_size, device=device)
    b = torch.zeros_like(wb)

    for layer_norm in layer_norm_list:
        wb = wb * layer_norm.weight + layer_norm.bias
        b = b * layer_norm.weight + layer_norm.bias
    wb = dense_layer(wb)
    b = dense_layer(b)

    w = wb - b

    r1 = torch.rand(10, input_size, device=device)
    squashed = F.linear(r1, weight=w.transpose(-1, -2), bias=b[0])

    for layer_norm in layer_norm_list:
        r1 = r1 * layer_norm.weight + layer_norm.bias
    original = dense_layer(r1)

    assert (squashed - original).abs().max() < 1e-4

    dense_layer.weight.data.copy_(w.transpose(-1, -2))
    dense_layer.bias.data.copy_(b[0])


def squash_model_encoder(model, config, squash_attention_scale=False):
    model_encoder = model.bert.encoder
    assert isinstance(model_encoder, BertEncoder)

    layer_norm_list = []
    for layer in model_encoder.layer:
        assert isinstance(layer, BertLayer)
        # if len(layer_norm_list) > 0:
        #     _squash_dense_layer(layer.attention.self.query, layer_norm_list)
        #     _squash_dense_layer(layer.attention.self.key, layer_norm_list)
        #     _squash_dense_layer(layer.attention.self.value, layer_norm_list)
        #
        # new_output, layer_norm = _replace_output_layer(layer.attention.output, config.hidden_size, config.hidden_size)
        # layer.attention.output = new_output
        # layer_norm_list.append(layer_norm)
        #
        # _squash_dense_layer(layer.intermediate.dense, layer_norm_list)
        #
        # new_output, layer_norm = _replace_output_layer(layer.output, config.intermediate_size, config.hidden_size)
        # layer_norm_list.append(layer_norm)
        # layer.output = new_output

        if squash_attention_scale:
            attention_scale = layer.attention.self.attention_head_size
            layer.attention.self.query.weight.data.div_(math.sqrt(attention_scale))
            layer.attention.self.query.bias.data.div_(math.sqrt(attention_scale))
            print("squash_attention_scale")

    # _squash_dense_layer(model.classifier, layer_norm_list)

    # squash into task layer
