import torch
from torch.nn.functional import log_softmax
from torch.distributions.categorical import Categorical


def untuple(x):
    return x[0] if isinstance(x, tuple) else x


def approx_causal_effects(module_outputs, module_grads):
    module_output_diffs = module_outputs[1:] - module_outputs[0]
    # module_outputs: (B, num_layer, seq_len, h_dim)
    # module_grads: (B, num_layer, seq_len, h_dim)
    agg_module_outputs = module_outputs[0].sum(-2).cpu().mean(0)  # (h_dim)
    # agg_module_grads = module_grads[0].sum(-2).cpu().mean(0)  # (h_dim)
    agg_module_grads = module_grads[1].sum(-2).cpu().mean(0)  # (h_dim)
    agg_module_outputs_x_grads = (module_outputs[0] * module_grads[1]).sum(-2).cpu().mean(0)  # (h_dim)
    approx_ie = (module_grads[1:] * module_output_diffs).sum(-2).mean(0).cpu().mean(0)  # (h_dim)
    approx_de = (-module_grads[0].unsqueeze(0) * module_output_diffs).sum(-2).mean(0).cpu().mean(0)  # (h_dim)
    last_hidden_state_0 = module_outputs[0][-1][-1].cpu()  # (h_dim)
    last_hidden_state_1 = module_outputs[1][-1][-1].cpu()  # (h_dim)

    return approx_ie, approx_de, agg_module_outputs, agg_module_grads, agg_module_outputs_x_grads, \
           last_hidden_state_0, last_hidden_state_1


def make_output_hook(layer_id, output_dict):
    def output_hook(module, inputs, outputs):
        outputs_0 = untuple(outputs)
        output_dict[layer_id] = outputs_0.detach()  # (B, seq_len, h_dim)

    return output_hook


def compute_act_grad_features(model, batch_inputs, internal_features, target_token_id):
    num_layers = model.config.n_layer

    forward_output_dicts = {
        'attn': {i: [] for i in range(num_layers)},
        # 'c_attn': {i: [] for i in range(num_layers)},
        'mlp': {i: [] for i in range(num_layers)},
        'res': {i: [] for i in range(num_layers)}
    }
    backward_output_dicts = {
        'attn': {i: [] for i in range(num_layers)},
        # 'c_attn': {i: [] for i in range(num_layers)},
        'mlp': {i: [] for i in range(num_layers)},
        'res': {i: [] for i in range(num_layers)}
    }

    backward_hooks = []
    for i in range(num_layers):
        backward_hook = model.transformer.h[i].register_full_backward_hook(
            make_output_hook(i, backward_output_dicts['res'])
        )
        backward_hooks.append(backward_hook)

        backward_hook = model.transformer.h[i].attn.register_full_backward_hook(
            make_output_hook(i, backward_output_dicts['attn'])
        )
        backward_hooks.append(backward_hook)

        # backward_hook = model.transformer.h[i].attn.c_attn.register_full_backward_hook(
        #     make_output_hook(i, backward_output_dicts['c_attn'])
        # )
        # backward_hooks.append(backward_hook)

        backward_hook = model.transformer.h[i].mlp.register_full_backward_hook(
            make_output_hook(i, backward_output_dicts['mlp'])
        )
        backward_hooks.append(backward_hook)

    forward_hooks = []
    for i in range(num_layers):
        forward_hook = model.transformer.h[i].register_forward_hook(
            make_output_hook(i, forward_output_dicts['res'])
        )
        forward_hooks.append(forward_hook)

        forward_hook = model.transformer.h[i].attn.register_forward_hook(
            make_output_hook(i, forward_output_dicts['attn'])
        )
        forward_hooks.append(forward_hook)

        # forward_hook = model.transformer.h[i].attn.c_attn.register_forward_hook(
        #     make_output_hook(i, forward_output_dicts['c_attn'])
        # )
        # forward_hooks.append(forward_hook)

        forward_hook = model.transformer.h[i].mlp.register_forward_hook(
            make_output_hook(i, forward_output_dicts['mlp'])
        )
        forward_hooks.append(forward_hook)

    def emb_intervene_hook(model, inputs, outputs):
        outputs_0 = untuple(outputs)  # (1+n_intervene, seq_len, h_dim)
        noise = torch.randn(outputs_0[1:].shape)
        outputs_0[1:] += noise.to(outputs_0.device)
        return outputs

    emb_hook = model.transformer.wte.register_forward_hook(emb_intervene_hook)
    forward_hooks.append(emb_hook)

    batch_logits = model(**batch_inputs.to(model.device)).logits[:, -1]  # (B, vocab_size)
    batch_log_probs = log_softmax(batch_logits, -1)
    with torch.no_grad():
        uncertainty = Categorical(logits=batch_logits).entropy().cpu()
        internal_features['uncertainty'].append(uncertainty)
    ans_log_probs = batch_log_probs[:, target_token_id]  # (B)
    ans_log_probs.backward(torch.ones_like(ans_log_probs))
    torch.cuda.empty_cache()

    for hook in forward_hooks:
        hook.remove()
    for hook in backward_hooks:
        hook.remove()

    for k in forward_output_dicts:
        forward_dict = forward_output_dicts[k]  # {i: tensor of shape (B, seq_len, h_dim)}
        backward_dict = backward_output_dicts[k]  # {i: tensor of shape (B, seq_len, h_dim)}
        module_acts = torch.stack([forward_dict[i] for i in range(num_layers)], 1)  # (B, num_layer, seq_len, h_dim)
        module_grads = torch.stack([backward_dict[i] for i in range(num_layers)], 1)  # (B, num_layer, seq_len, h_dim)

        internal_patterns = approx_causal_effects(
            module_acts, module_grads)  # (h_dim)

        internal_features[f'IE_{k}'].append(internal_patterns[0])
        internal_features[f'DE_{k}'].append(internal_patterns[1])
        internal_features[f'act_{k}'].append(internal_patterns[2])
        internal_features[f'grad_{k}'].append(internal_patterns[3])
        internal_features[f'act_x_grad_{k}'].append(internal_patterns[4])
        internal_features[f'last_hidden_state_0_{k}'].append(internal_patterns[5])
        internal_features[f'last_hidden_state_1_{k}'].append(internal_patterns[6])

    return internal_features

