import torch
from torch.nn.functional import log_softmax

torch.manual_seed(100)


def untuple(x):
    return x[0] if isinstance(x, tuple) else x


def run_get_causal_effects(model, batch_inputs,
                           target_token_id,
                           effect_kind='indirect'):
    hooks = []

    def wte_noise_hook_ie(module, inputs, outputs):
        outputs_0 = untuple(outputs)
        noise = torch.randn(outputs_0[0].shape)
        outputs_0[1:] += noise.to(outputs_0.device)
        return outputs

    def wte_noise_hook_de(module, inputs, outputs):
        outputs_0 = untuple(outputs)
        noise = torch.randn(outputs_0[0].shape)
        outputs_0[1] += noise.to(outputs_0.device)
        return outputs

    if effect_kind == 'indirect':
        emb_hook = model.transformer.wte.register_forward_hook(wte_noise_hook_ie)
    else:
        emb_hook = model.transformer.wte.register_forward_hook(wte_noise_hook_de)
    hooks.append(emb_hook)

    # Define the model-patching hook for computing the causal effects
    num_layers = model.config.n_layer
    patch_layer_idx = torch.arange(num_layers).repeat(2)
    patch_module_idx = torch.arange(2).repeat_interleave(num_layers)
    for i in range(len(patch_layer_idx)):
        def make_patching_hook(patched_batch_id, effect_kind):
            def patching_hook(module, inputs, outputs):
                outputs_0 = untuple(outputs)  # (B, seq_len, hidden_dim)
                if effect_kind == 'indirect':
                    outputs_0[patched_batch_id] = outputs_0[0]
                else:
                    outputs_0[patched_batch_id] = outputs_0[1]

                return outputs

            return patching_hook

        if patch_module_idx[i] == 0:  # 'attn'
            patch_layer_start = max(0, patch_layer_idx[i] - 5)
            patch_layer_end = min(patch_layer_idx[i] + 5, num_layers)
            for j in range(patch_layer_start, patch_layer_end):
                hook_ij = model.transformer.h[j].attn.register_forward_hook(
                    make_patching_hook(i + 2, effect_kind))
                hooks.append(hook_ij)

        elif patch_module_idx[i] == 1:  # 'mlp'
            patch_layer_start = max(0, patch_layer_idx[i] - 5)
            patch_layer_end = min(patch_layer_idx[i] + 5, num_layers)
            for j in range(patch_layer_start, patch_layer_end):
                hook_ij = model.transformer.h[j].mlp.register_forward_hook(
                    make_patching_hook(i + 2, effect_kind))
                hooks.append(hook_ij)

    # With the patching rules defined, run the patched model in inference.
    with torch.no_grad():
        batch_log_probs = log_softmax(model(batch_inputs).logits[:, -1], -1).cpu()  # (B, vocab_size)
    for hook in hooks:
        hook.remove()
    torch.cuda.empty_cache()

    target_log_probs = batch_log_probs[:, target_token_id]
    if effect_kind == 'indirect':
        effects = (target_log_probs[2:] - target_log_probs[1]).view(2, num_layers)
    else:
        effects = (target_log_probs[2:] - target_log_probs[0]).view(2, num_layers)
    # print(target_log_probs[:50])
    # print(target_log_probs[50:98])
    # print(target_log_probs[98:])
    # print()

    return target_log_probs[0], effects
