import torch
from captum.attr import LayerIntegratedGradients
from torch.nn.functional import log_softmax


def compute_ig_features(model, prompt, tokenizer, internal_features, target_token_id):
    attn_modules = [
        model.transformer.h[i].attn.c_attn
        for i in range(model.config.num_hidden_layers)
    ]
    mlp_modules = [
        model.transformer.h[i].mlp
        for i in range(model.config.num_hidden_layers)
    ]
    res_stream_modules = [
        model.transformer.h[i].ln_1
        for i in range(model.config.num_hidden_layers)
    ]

    def gpt_forward_func(input_emb):
        logits = model(inputs_embeds=input_emb).logits[:, -1]  # (B, vocab_size)
        log_prob = log_softmax(logits, -1)[:, target_token_id]  # (B, )
        return log_prob

    lig_attn = LayerIntegratedGradients(gpt_forward_func, attn_modules)
    lig_mlp = LayerIntegratedGradients(gpt_forward_func, mlp_modules)
    lig_res = LayerIntegratedGradients(gpt_forward_func, res_stream_modules)

    input_ids = tokenizer(
        [prompt], return_tensors='pt'
    )['input_ids'].to(model.device)
    input_embs = gpt_input_embedding_func(model, input_ids, model.device)

    igs_attn = lig_attn.attribute(inputs=input_embs, internal_batch_size=32)
    igs_mlp = lig_mlp.attribute(inputs=input_embs, internal_batch_size=32)
    igs_res = lig_res.attribute(inputs=input_embs, internal_batch_size=32,
                                attribute_to_layer_input=True)  # list of n_layer tensors of shape (1, seq_len, h_dim)

    igs_attn = torch.stack(igs_attn, 1).squeeze().detach().cpu().sum(-2).mean(0)  # (h_dim)
    igs_mlp = torch.stack(igs_mlp, 1).squeeze().detach().cpu().sum(-2).mean(0)  # (h_dim)
    igs_res = torch.stack(igs_res, 1).squeeze().detach().cpu().sum(-2).mean(0)  # (h_dim)

    internal_features['IG_res'].append(igs_res.to(torch.float32))
    internal_features['IG_attn'].append(igs_attn.to(torch.float32))
    internal_features['IG_mlp'].append(igs_mlp.to(torch.float32))

    return internal_features


def gpt_input_embedding_func(model, input_ids, device):
    input_shape = input_ids.size()
    position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
    position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

    inputs_embeds = model.transformer.wte(input_ids)
    position_embeds = model.transformer.wpe(position_ids)

    return inputs_embeds + position_embeds
