import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaForMLM
from datasets import load_dataset
from tqdm import tqdm
import torch
from peft import PeftModel
import json 
import nltk
from nltk.translate.bleu_score import sentence_bleu

def evaluate_completion(
    completion: str,
    target: str,
    case_sensitive: bool = False,
) -> bool:
    """Evaluate completion using exact-match vs the target.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    
    # target = target.strip()
    target = target.split(',')
    target = [sub_target.strip() for sub_target in target]
    test_str = completion.strip()
    test_str = test_str.lower() if not case_sensitive else test_str
    target_str = [sub_target.lower() for sub_target in target] if not case_sensitive else target
    # target_str = target.lower() if not case_sensitive else target
    return any([sub_target_str in test_str for sub_target_str in target_str])
    # return target_str in test_str

def compute_bleu(run_results):
    all = 0
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        bleu_score = sentence_bleu([target], completion['gen_answer'])
        all += bleu_score
    return all / len(completions)

def compute_metric(run_results):
    """Compute accuracy of completions using exact-match.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    n_correct = 0
    is_correct_list = []
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        correct = evaluate_completion(completion['gen_answer'], target)
        is_correct_list.append(correct)
        if correct:
            n_correct += 1

    accuracy = n_correct / len(completions)
    return accuracy

def batch_split(prompts, batch_num):   # 把测试的数据集切分成若干个mini-batch
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def prepare_input(tokenizer, prompts):  # 把准备好的batch用tokenizer转化为llama可接受的数据
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')  # 转化为tensor，放到对应的GPU上

    return input_tokens

    # [1] 代表llama2-nosft的测评
    # [1840] 代表百分之40数据量的情况下，能够实现的iteration

def main():
    for ckpt in [1840]:
        print(f'----------------------{ckpt}---------------------')
        # lora_dir需要校正
        lora_dir = None
        train_ratio = 'all'   # 这里可以是具体的比例或者是'all'
        lora_dir = "/home/hadoop-aipnlp/nazarite/reverse_curse/mitigating-reversal-curse-main/final_output/ood_parent_to_child_final/ood_{}/low_bico/checkpoint-{}".format(train_ratio, ckpt)
        model_name_or_path = "/home/hadoop-aipnlp/nazarite/llama2/llama-2-7b"
        data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/"

        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, padding_side="left",
        )
        tokenizer.pad_token_id = 0  
        tokenizer.bos_token_id = 1 
        tokenizer.eos_token_id = 2

        model = LlamaForCausalLM.from_pretrained(model_name_or_path, use_safetensors=False, device_map={"":0}, torch_dtype=torch.float16)
        
        if lora_dir is not None:
            model = PeftModel.from_pretrained(model, lora_dir)
            model.merge_adapter()
        
        model.eval()

        # ood的各项数据测试
        # ['iid_ar_positive_positive_positive_test_high_dataset', 'iid_ar_positive_positive_negative_test_high_dataset', 'iid_ar_positive_negative_positive_test_high_dataset', 'iid_ar_positive_negative_negative_test_high_dataset',
        # 'iid_ar_negative_positive_positive_test_high_dataset', 'iid_ar_negative_positive_negative_test_high_dataset', 'iid_ar_negative_negative_positive_test_high_dataset', 'iid_ar_negative_negative_negative_test_high_dataset',
        # 'iid_ar_positive_positive_positive_test_low_dataset', 'iid_ar_positive_positive_negative_test_low_dataset', 'iid_ar_positive_negative_positive_test_low_dataset', 'iid_ar_positive_negative_negative_test_low_dataset',
        # 'iid_ar_negative_positive_positive_test_low_dataset', 'iid_ar_negative_positive_negative_test_low_dataset', 'iid_ar_negative_negative_positive_test_low_dataset', 'iid_ar_negative_negative_negative_test_low_dataset', 
        # 'ood_positive_positive_positive_test_dataset', 'ood_positive_positive_negative_test_dataset', 'ood_positive_negative_positive_test_dataset', 'ood_positive_negative_negative_test_dataset',
        # 'ood_negative_positive_positive_test_dataset', 'ood_negative_positive_negative_test_dataset', 'ood_negative_negative_positive_test_dataset', 'ood_negative_negative_negative_test_dataset',
        # 'ood_ar_positive_positive_positive_test_dataset', 'ood_ar_positive_positive_negative_test_dataset', 'ood_ar_positive_negative_positive_test_dataset', 'ood_ar_positive_negative_negative_test_dataset',
        # 'ood_ar_negative_positive_positive_test_dataset', 'ood_ar_negative_positive_negative_test_dataset', 'ood_ar_negative_negative_positive_test_dataset', 'ood_ar_negative_negative_negative_test_dataset']

        # 针对性的只测试ood数据的正向问答

        for task in ['ood_positive_positive_positive_test_dataset', 'ood_positive_positive_negative_test_dataset', 'ood_positive_negative_positive_test_dataset', 'ood_positive_negative_negative_test_dataset']:

        # 标准sft的反向测试:
            print(f'----------------------{task}---------------------')
            task_type = 'ood'  # here is 'standard'
            base_name = "/home/hadoop-aipnlp/nazarite/reverse_curse/mitigating-reversal-curse-main/final_output/ood_parent_to_child_final/ood_{}/low_bico".format(train_ratio)
            output_name = f"{ckpt}-{task}.json"
            if 'reverse' not in task:
                if task_type == 'permuate':
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"permuate_{task}_dataset_dynamic_category_1.json"})
                elif task_type == 'forge':
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"forge_{task}_dataset_dynamic_category_1.json"})
                elif task_type == 'ood':
                    dataset = load_dataset("json", data_files={"eval": data_path + f"ood_{train_ratio}/{task}.json"})
                else:
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"{task}_dataset_dynamic_category_1.json"})
            else:
                if task_type == 'permuate':
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"permuate_{task}_dataset_static_category_1.json"})
                elif task_type == 'forge':
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"forge_{task}_dataset_static_category_1.json"})
                else:
                    dataset = load_dataset("json", data_files={"eval":  data_path + f"{task}_dataset_static_category_1.json"})

            val_data = dataset['eval']
            prompts = [d['prompt'] for d in val_data]
            completions = [d['completion'] for d in val_data]

            results = []
            batch_size = 32
            for batch_input in tqdm(batch_split(prompts, batch_size)):
                encode_inputs = prepare_input(tokenizer, batch_input)   # 转化为输入，放到对应的gpu上
                with torch.no_grad():
                    output = model.generate(**encode_inputs, max_new_tokens=40, pad_token_id=tokenizer.pad_token_id, do_sample=False, num_beams=1,
                                        eos_token_id= tokenizer.eos_token_id)
                # 输出respnose，然后得到对应的结果
                response = tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
                # 得到对应的输出
                for idx, p in enumerate(batch_input):
                    results.append({'question': p, 'gen_answer': response[idx][len(p):]})
            run_results = {'completions':results, 'targets':completions}   
            # 将生成的结果和completion的结果进行比对
            with open(os.path.join(base_name, output_name), 'w', encoding='utf-8') as f:   # 将生成的结果进行比对
                json.dump(run_results, f, ensure_ascii=False, indent=2)

            # 全部都是预测人名，直接用EM即可
            print('Using EM')
            metrics = compute_metric(run_results)
            print(metrics)

if __name__ == '__main__':
    # 要切记改lora_dir的位置，改成模型所在目录
    # 改output_dir的输出目录
    # 改task和dataset的名称,'forge'这里的也要调
    # 改GPU的编号，还有ckpt的位置
    main()
    