import argparse
from os.path import join
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import pandas as pd
from tqdm.auto import tqdm
from act_grad_features import compute_act_grad_features
from ig_features import compute_ig_features


def main(args):
    # load LM and tokenizer
    device = torch.device('cuda')
    model_path = join(args.model_dir, args.model_name)
    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    model.eval()
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    # load hallucination dataframe for gpt2-xl
    df = pd.read_csv(join(args.data_dir, f'{args.dataset}_{args.model_name}.csv'))

    internal_features = {
        'uncertainty': [],
        'IE_res': [], 'IE_attn': [], 'IE_mlp': [],
        'DE_res': [], 'DE_attn': [], 'DE_mlp': [],
        'act_res': [], 'act_attn': [], 'act_mlp': [],
        'grad_res': [], 'grad_attn': [], 'grad_mlp': [],
        'act_x_grad_res': [], 'act_x_grad_attn': [], 'act_x_grad_mlp': [],
        # 'IG_res': [], 'IG_attn': [], 'IG_mlp': [],
        'last_hidden_state_0_res': [], 'last_hidden_state_0_attn': [], 'last_hidden_state_0_mlp': [],
        'last_hidden_state_1_res': [], 'last_hidden_state_1_attn': [], 'last_hidden_state_1_mlp': [],
    }
    for i, row in tqdm(df.iterrows(), total=df.shape[0]):
        batch_size = 1 + args.n_intervene
        if args.dataset == 'pararel_questions':
            prompt = row['prompt']
            target_token_id = row['predicted object first token id']
        else:
            prompt = f"Question: {row['question']}; Answer: "
            target_token_id = tokenizer(row['top answer'])['input_ids'][0]
        batch_inputs = tokenizer(
            [prompt] * batch_size,
            truncation=True,
            max_length=32,
            return_tensors='pt'
        )

        internal_features = compute_act_grad_features(model, batch_inputs, internal_features, target_token_id)
        # internal_features = compute_ig_features(model, prompt, tokenizer, internal_features, target_token_id)
        torch.cuda.empty_cache()

    for k, v in internal_features.items():
        torch.save(
            torch.stack(v), join(args.results_dir, f'{k}-{args.dataset}-{args.model_name}.pt')
        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', default='/home/scratch/hallucination_mech_evol/hall_detect/features/',
                        type=str)
    parser.add_argument('--data_dir', default='/home/scratch/hallucination_mech_evol/data/', type=str)
    parser.add_argument('--dataset', default='natural_qa', type=str)
    parser.add_argument('--model_dir', default='/home/scratch/cma_hallucination/models/', type=str)
    parser.add_argument('--model_name', default='gpt2-xl', type=str)
    parser.add_argument('--n_intervene', default=10, type=int)
    args = parser.parse_args()
    main(args)



