from transformers import GPTNeoXForCausalLM, AutoTokenizer
import os
import torch
from os.path import join
from torch.nn.functional import log_softmax
torch.manual_seed(100)

def load_model(step=0, model_size="1b", deduped=True):
    """
    Args:
        model_size: one of (70m, 160m, 410m, 1b, 1.4b, 2.8b, 6.9b, 12b)
    """
    
    cache_dir = join(os.environ['TRANSFORMERS_CACHE'], f"pythia-{model_size}-deduped/step{step}")

    model = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped" if deduped else f"EleutherAI/pythia-{model_size}",
        revision=f"step{step}",
        cache_dir=cache_dir,
        torch_dtype=torch.float16,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped" if deduped else f"EleutherAI/pythia-{model_size}",
        revision=f"step{step}",
        cache_dir=cache_dir,
    )

    return model, tokenizer


def untuple(x):
    return x[0] if isinstance(x, tuple) else x


def split_heads(tensor, num_heads, attn_head_size):
    """
    Splits hidden_size dim into attn_head_size and num_heads
    """
    new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
    tensor = tensor.view(new_shape)
    return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)


def get_key_pos_effects(effects, row):
    # effects: (n_layer, seq_len)
    effects_subj_first = effects[:, row['cue entity start idx']]
    if row['cue entity end idx'] > row['cue entity start idx'] + 1:
        effects_subj_mid = effects[:, 1 + row['cue entity start idx']:row['cue entity end idx']].mean(-1)
    else:
        effects_subj_mid = effects_subj_first
    effects_subj_last = effects[:, row['cue entity end idx']]

    effects_first_after = effects[:, 1 + row['cue entity end idx']]
    effects_further = effects[:, 1 + row['cue entity end idx']:-1].mean(-1)
    effects_last = effects[:, -1]

    igs_key_pos = torch.stack([
        effects_subj_first, effects_subj_mid, effects_subj_last, effects_first_after, effects_further, effects_last
    ], -1)

    return igs_key_pos  # (n_layer, 6)


def run_get_noise_and_te(model, tokenizer, row,
                         model_hidden_dim, n_intervene=10,
                         batch_size=100, intervene_ent='subj'):
    
    batch_inputs = tokenizer(
        [row['prompt']] * (1 + batch_size), return_tensors='pt'
    )['input_ids'].to(model.device)
    true_obj_id, hall_obj_id = row['true object first token id'], row['predicted object first token id']
    subj_start, subj_end = row['cue entity start idx'], row['cue entity end idx']

    def make_wte_noise_hook(noise, intervene_start, intervene_end):

        def wte_noise_hook(module, inputs, outputs):
            outputs_0 = untuple(outputs)
            outputs_0[1:, intervene_start:intervene_end] += noise.to(outputs_0.device)
            return outputs

        return wte_noise_hook

    if intervene_ent == 'subj':
        intervene_start, intervene_end = subj_start, subj_end + 1
    else:
        intervene_start, intervene_end = subj_end + 1, batch_inputs.shape[1] - 1

    noise = torch.randn(batch_size, 1, model_hidden_dim)
    emb_hook = model.gpt_neox.embed_in.register_forward_hook(
        make_wte_noise_hook(noise, intervene_start, intervene_end)
    )
    with torch.no_grad():
        logits = model(batch_inputs.to(model.device)).logits[:, -1]  # (B, vocab_size)
        obj_logit_diffs = (logits[:, hall_obj_id] - logits[:, true_obj_id]).cpu()

    valid_batch_idx = (obj_logit_diffs[1:] < obj_logit_diffs[0]).nonzero(as_tuple=True)[0]
    intervene_success_idx = (obj_logit_diffs[1:] < 0).nonzero(as_tuple=True)[0]
    intervene_success_rate = float(len(intervene_success_idx)) / batch_size

    if 1 < len(valid_batch_idx) < n_intervene:
        valid_batch_idx = torch.cat([
            valid_batch_idx, valid_batch_idx[-1].repeat(n_intervene - len(valid_batch_idx))
        ])
    elif len(valid_batch_idx) > n_intervene:
        valid_batch_idx = valid_batch_idx[:n_intervene]
    else:
        valid_batch_idx = torch.arange(n_intervene)

    if len(valid_batch_idx) != n_intervene:
        print(len(valid_batch_idx))

    emb_noises = noise[valid_batch_idx]
    TEs = obj_logit_diffs[1:][valid_batch_idx] - obj_logit_diffs[0]

    emb_hook.remove()
    torch.cuda.empty_cache()

    return emb_noises, TEs, obj_logit_diffs[0], torch.tensor(intervene_success_rate)


def run_with_cpi(model, batch_inputs, batch_noise,
                 intervene_start, intervene_end,
                 true_obj_id, hall_obj_id,
                 patch_layer_idx, patch_seq_idx,
                 num_layers,
                 module_kind='res'):
    
    hooks = []

    def make_wte_noise_hook(noise, intervene_start, intervene_end):

        def wte_noise_hook(module, inputs, outputs):
            outputs_0 = untuple(outputs)
            outputs_0[1:, intervene_start:intervene_end] += noise.to(outputs_0.device).unsqueeze(0)
            return outputs

        return wte_noise_hook

    emb_hook = model.gpt_neox.embed_in.register_forward_hook(
        make_wte_noise_hook(batch_noise, intervene_start, intervene_end)
    )
    hooks.append(emb_hook)

    # Define the model-patching hook for computing the indirect effects
    for i in range(len(patch_layer_idx)):

        def make_patching_hook(patched_batch_id, patched_seq_id):
            def patching_hook(module, inputs, outputs):
                outputs_0 = untuple(outputs)  # (B, seq_len, hidden_dim)
                outputs_0[patched_batch_id, patched_seq_id] = outputs_0[0, patched_seq_id]

                return outputs

            return patching_hook

        if module_kind == 'res':
            hook_i = model.gpt_neox.layers[patch_layer_idx[i]].register_forward_hook(
                make_patching_hook(i + 2, patch_seq_idx[i]))
            hooks.append(hook_i)
        elif module_kind == 'attn':
            patch_layer_start = max(0, patch_layer_idx[i] - 5)
            patch_layer_end = min(patch_layer_idx[i] + 5, num_layers)
            for j in range(patch_layer_start, patch_layer_end):
                hook_ij = model.gpt_neox.layers[j].attention.register_forward_hook(
                    make_patching_hook(i + 2, patch_seq_idx[i]))
                hooks.append(hook_ij)
        elif module_kind == 'mlp':
            patch_layer_start = max(0, patch_layer_idx[i] - 5)
            patch_layer_end = min(patch_layer_idx[i] + 5, num_layers)
            for j in range(patch_layer_start, patch_layer_end):
                hook_ij = model.gpt_neox.layers[j].mlp.register_forward_hook(
                    make_patching_hook(i + 2, patch_seq_idx[i]))
                hooks.append(hook_ij)
        else:
            raise ValueError('Invalid patching module kind')

    # With the patching rules defined, run the patched model in inference.
    with torch.no_grad():
        batch_log_probs = log_softmax(model(batch_inputs).logits[:, -1], -1).cpu()  # (B, vocab_size)
    for hook in hooks:
        hook.remove()
    torch.cuda.empty_cache()

    log_prob_diffs = batch_log_probs[:, hall_obj_id] - batch_log_probs[:, true_obj_id]

    return log_prob_diffs[2:] - log_prob_diffs[1]
