import argparse
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
from tqdm.auto import tqdm
from causal_patching import run_get_causal_effects
from os.path import join
import numpy as np

torch.set_grad_enabled(False)
torch.manual_seed(100)


def main(args):
    # load LM and tokenizer
    device = torch.device('cuda')
    model_path = join(args.model_dir, args.model_name)
    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    model.eval()
    num_layers = model.config.n_layer
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    # load hallucination dataframe for gpt2-xl
    df = pd.read_csv(join(args.data_dir, f'{args.dataset}_{args.model_name}.csv'))
    if args.dataset != 'truthful_qa':
        row_idx_chunk = np.array_split(np.arange(df.shape[0]), args.n_data_chunk)[args.data_chunk_idx]
        df = df.iloc[row_idx_chunk].reset_index(drop=True)

    results = {
        'log_p_target_token': [],
        'IE': [], 'DE': []
    }

    for i, row in tqdm(df.iterrows(), total=df.shape[0]):
        batch_size = 2 * num_layers + 2
        prompt = f"Question: {row['question']}; Answer: "
        target_token_id = tokenizer(row['top answer'])['input_ids'][0]
        batch_inputs = tokenizer(
            [prompt] * batch_size, return_tensors='pt'
        )['input_ids'].to(device)
        # print(f"batch_inputs: {batch_inputs.shape}")

        IEs, DEs = [], []
        for j in range(args.n_intervene):
            log_p_target, IEs_ij = run_get_causal_effects(
                model, batch_inputs, target_token_id,
                effect_kind='indirect'
            )
            _, DEs_ij = run_get_causal_effects(
                model, batch_inputs, target_token_id,
                effect_kind='direct'
            )
            torch.cuda.empty_cache()

            if j == 0:
                results['log_p_target_token'].append(log_p_target)
            IEs.append(IEs_ij)  # (2, num_layer)
            DEs.append(DEs_ij)
        IEs = torch.stack(IEs).mean(0)
        DEs = torch.stack(DEs).mean(0)
        # print(f"IE: {IEs}")
        # print(f"DE: {DEs}\n\n")
        results[f'IE'].append(IEs)
        results[f'DE'].append(DEs)

    for k, v in results.items():
        torch.save(
            torch.stack(v), join(args.results_dir, f'{k}_{args.dataset}_{args.data_chunk_idx}.pt')
        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', default='/home/scratch/hallucination_mech_evol/hall_detect/features/', type=str)
    parser.add_argument('--data_dir', default='/home/scratch/hallucination_mech_evol/data/', type=str)
    parser.add_argument('--dataset', default='trivial_qa', type=str)
    parser.add_argument('--model_dir', default='/home/scratch/cma_hallucination/models/', type=str)
    parser.add_argument('--model_name', default='gpt2-xl', type=str)
    parser.add_argument('--data_chunk_idx', default=0, type=int)
    parser.add_argument('--n_data_chunk', default=10, type=int)
    parser.add_argument('--n_intervene', default=10, type=int)
    args = parser.parse_args()
    main(args)

