import argparse
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
from tqdm.auto import tqdm
from cpi_run import run_with_cpi, run_get_noise_and_te, get_key_pos_effects
from os.path import join
import numpy as np

torch.set_grad_enabled(False)
module_kinds = ['res', 'attn', 'mlp']
torch.manual_seed(100)


def main(args):
    # load LM and tokenizer
    device = torch.device('cuda')
    model_path = join(args.model_dir, args.model_name)
    model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
    model.eval()
    num_layers = model.config.n_layer
    model_hidden_dim = model.config.n_embd
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    # load hallucination dataframe for gpt2-xl
    hall_df = pd.read_csv(join(args.data_dir, f'pararel_hallucinations_{args.model_name}.csv'))
    hall_row_idx_chunk = np.array_split(np.arange(hall_df.shape[0]), args.n_data_chunk)[args.data_chunk_idx]
    hall_df = hall_df.iloc[hall_row_idx_chunk].reset_index(drop=True)

    results = {
        'TE': [], 'IE_res': [],
        'IE_attn': [], 'IE_mlp': [],
        'y_0': [], 
        'intervene_success_rate': []
    }

    for i, row in tqdm(hall_df.iterrows(), total=hall_df.shape[0]):
        seq_len = len(tokenizer(row['prompt'])['input_ids'])
        true_obj_id, hall_obj_id = row['true object first token id'], row['predicted object first token id']
        subj_start, subj_end = row['cue entity start idx'], row['cue entity end idx']
        batch_size = seq_len * num_layers + 2
        batch_inputs = tokenizer(
            [row['prompt']] * batch_size, return_tensors='pt'
        )['input_ids'].to(device)

        noise_i, TE, y_0, intervene_success_rate = run_get_noise_and_te(
            model, tokenizer, row, model_hidden_dim, batch_size=args.n_sample_noise
        )

        results['y_0'].append(y_0)  # (1,)
        results['TE'].append(TE.mean())  # (1,)
        results['intervene_success_rate'].append(intervene_success_rate)

        patch_layer_idx = torch.arange(num_layers).repeat_interleave(seq_len)
        patch_seq_idx = torch.arange(seq_len).repeat(num_layers)

        for module_kind in module_kinds:
            IEs = []
            for j in range(args.n_intervene):
                IEs_ij = run_with_cpi(
                    model, batch_inputs, noise_i[j],
                    subj_start, subj_end + 1,
                    true_obj_id, hall_obj_id,
                    patch_layer_idx, patch_seq_idx,
                    num_layers,
                    module_kind=module_kind
                )
                torch.cuda.empty_cache()
                key_pos_IEs = get_key_pos_effects(IEs_ij.view(num_layers, seq_len), row)
                IEs.append(key_pos_IEs)  # (num_layer, 6)

            results[f'IE_{module_kind}'].append(torch.stack(IEs).mean(0))

    for k, v in results.items():
        torch.save(
            torch.stack(v), join(args.results_dir, f'{k}_{args.data_chunk_idx}.pt')
        )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--results_dir', default='/home/scratch/hallucination_mech_evol/hall_mech/cpi/results/', type=str)
    parser.add_argument('--data_dir', default='/home/scratch/hallucination_mech_evol/data/', type=str)
    parser.add_argument('--model_dir', default='/home/scratch/cma_hallucination/models/', type=str)
    parser.add_argument('--model_name', default='gpt2-xl', type=str)
    parser.add_argument('--n_intervene', default=10, type=int)
    parser.add_argument('--data_chunk_idx', default=0, type=int)
    parser.add_argument('--n_data_chunk', default=10, type=int)
    parser.add_argument('--n_sample_noise', default=100, type=int)
    args = parser.parse_args()
    main(args)

