import os
DEVICE = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = DEVICE
import json
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, GenerationConfig
from transformers import TopPLogitsWarper, LogitsProcessorList
from tqdm import tqdm
from data_utils import Prompt, MDDataset
import argparse
import regex
import string
import collections
from rouge_score import rouge_scorer, scoring
from xopen import xopen
import logging

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)  # 将日志输出到控制台
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger = logging.getLogger(__name__)

MODEL_PATH = {
    0: "meta-llama/Llama-2-7b-hf",
    1: "meta-llama/Llama-2-7b-chat-hf",
    2: "../RL/model/llama2_13b",
    3: "../models/llama2_7b",
    4: "../models/llama2_13b",
    5: "../models/vicuna_7b_v1.5",
    6: "../models/vicuna_13b_v1.5",
    7: "../RL/model/llama2_7b",
}
CLOSED_BOOK_PROMPT = '''Question: {question}\nAnswer:'''
MD_RANDOM_PROMPT = prompt = '''Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). The search results are ordered randomly.

{search_results}

Question: {question}
Answer:'''
MD_PROMPT = '''Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). 

{search_results}

Question: {question}
Answer:'''

QA_MODEL = "gaotianyu1350/roberta-large-squad"


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def get_argument():
    parser = argparse.ArgumentParser()
    # environmental
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--model", type=int, default=0, help="See MODEL_PATH for available models")
    # file
    parser.add_argument("--input_file", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--type", type=str, required=True)
    # prompt setting
    parser.add_argument("--shot", type=int, default=0)
    parser.add_argument("--doc_num", type=int, default=0)
    parser.add_argument("--prompt", type=str, help="closed, md or from file", required=True)
    parser.add_argument("--use_random", action="store_true", help="The documents are ordered randomly.")
    parser.add_argument("--special_location", type=int, default=-1)
    # generation config
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--do_sample", action="store_true")
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--max_new_tokens", type=int, default=300)
    parser.add_argument("--max_length", type=int, default=4000)
    # evaluation setting
    parser.add_argument("--do_inference", action="store_true")
    parser.add_argument("--do_evaluate", action="store_true")
    parser.add_argument("--eval_during_infer", action="store_true")
    parser.add_argument("--metrics", type=str, default=None, help="accuray, em, f1, recall")
    parser.add_argument("--citation", action="store_true")
    # structure
    parser.add_argument("--mode", type=str, default="standard", help="standard, nbce, pcw, human")
    parser.add_argument("--doc_cluster", type=int, default=-1)
    parser.add_argument("--closed_p", type=str, default=None, help="for nbce,pcw")
    # debug
    parser.add_argument("--debug", action="store_true")

    args = parser.parse_args()
    args.model = MODEL_PATH[args.model]
    if args.doc_num == 0:
        args.closed_book = True
    else:
        args.closed_book = False
    args.device_num = DEVICE
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
    if args.eval_during_infer and args.citation:
        raise ValueError("Citation is post-hoc, which can't be evaluated during inference ! ")
    if args.special_location > 0:
        assert args.type == "nq_doc_num"
    logger.info(args)
    return args


def normalize_answer(s: str) -> str:
    """Normalization from the SQuAD evaluation script.

    See https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
    """

    def remove_articles(text):
        return regex.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s):
    if not s: return []
    return normalize_answer(s).split()


def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_presence(short_answers, context):
    """Verify if any of the answers is present in the given context.
    Args:
        short_answers: list of short answers to look for in the context
        context: a paragraph to search for short answers
    Returns:
        true if any of the short answers is present in the context
    """

    n_short_answers = [normalize_answer(sa) for sa in short_answers]
    n_context = normalize_answer(context)

    for ans in n_short_answers:
        if ans in n_context:
            return True

    return False


def evaluate_example(answers, prediction, metrics, qa_pairs=None):
    if metrics == "accuracy":
        normalized_prediction = normalize_answer(prediction)

        for ground_truth in answers:
            normalized_ground_truth = normalize_answer(ground_truth)
            if normalized_ground_truth.lower() in normalized_prediction.lower():
                return 1.0
        return 0.0
    elif metrics == "extractive":
        # em and f1
        best_em, best_f1 = 0.0, 0.0
        for target in answers:
            em = compute_exact(target, prediction)
            f1 = compute_f1(target, prediction)
            if em > best_em:
                best_em = em
            if f1 > best_f1:
                best_f1 = f1
        return (best_em, best_f1)

    elif metrics == "asqa":
        # asqa: citations, qa, mauve, str_em, str_hit
        # QAMPARI: citations
        # eli5: claims_nli, mauve
        loc_acc = []
        for qa_pair in qa_pairs:
            loc_acc.append(exact_presence(qa_pair['short_answers'], prediction))
        str_em = np.mean(loc_acc)
        str_hit = int(np.mean(loc_acc) == 1)

        return (str_em, str_hit)


def standard_generate(example, model, args, P, tokenizer):
    demos = ""
    if args.shot > 0:
        # train_examples = json.load(open("prompts/" + args.prompt))['demos']
        train_examples = MDDataset("prompts/" + args.prompt, args.type + '_demo')
        demos = P.apply(train_examples[:args.shot], args.use_random, with_answer=True) + P.sep
    eval_prompt = P.apply([example], use_random=args.use_random, with_answer=False, special=args.special_location)
    input_prompt = demos + eval_prompt

    encoding = tokenizer(input_prompt, return_tensors="pt")
    input_ids = encoding['input_ids'].to(args.device)
    prompt_len = len(input_ids[0])

    generation_config = GenerationConfig(
        temperature=args.temperature,
        do_sample=args.do_sample,
        top_p=args.top_p,
        max_new_tokens=min(args.max_new_tokens, args.max_length - prompt_len),
        return_dict_in_generate=False,
    )

    outputs = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
    )
    decoded_output = tokenizer.decode(outputs[0][len(input_ids[0]):])
    model_answer = decoded_output.split('\n')[0].strip()
    result = {}
    result['prompt'] = input_prompt
    result['question'] = example.question
    result['answers'] = example.answers
    result['token_len_prompt'] = prompt_len,
    result['prediction'] = model_answer
    return result


def nbce_generate(example, model, args, P, CP, tokenizer):
    # tokenizer.pad_token = tokenizer.unk_token
    full = True
    processors = LogitsProcessorList()
    processors.append(TopPLogitsWarper(min(0.95, args.top_p)))
    # construct the batch
    doc_cluster = args.doc_cluster
    closed_prompt = CP.apply([example], use_random=args.use_random, with_answer=False)
    if full:
        full_prompt = P.apply([example], use_random=args.use_random, with_answer=False, doc_cluster=args.doc_num)
    else:
        full_prompt = []
    if doc_cluster == args.doc_num and full:
        doc_prompt = []
    else:
        doc_prompt = P.apply([example], use_random=args.use_random, with_answer=False, doc_cluster=doc_cluster)
    batch = [closed_prompt] + full_prompt + doc_prompt
    # print(len(batch))
    # for i in range(len(batch)):
    #    print(batch[i])

    inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(args.device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    prompt_len = len(input_ids[0])
    logger.info(f"The batch shape is {input_ids.shape}")

    past_key_values = None
    n = input_ids.shape[0]
    model_output = []
    max_tokens = min(args.max_new_tokens, args.max_length - prompt_len)
    tokens = None
    '''
    if args.debug:
        print("input_ids:", input_ids)
    '''

    for i in range(max_tokens):
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True,
                        use_cache=True,
                        past_key_values=past_key_values)
        past_key_values = outputs.past_key_values
        # ===== 核心部分，对logits进行调整 =====
        beta, eta = 0.25, 0.1
        '''
        if args.debug:
            print("output:", outputs.logits)
            print(outputs.logits.shape)
        '''
        logits = outputs.logits[:, -1]
        logits = logits - logits.logsumexp(dim=-1, keepdims=True)
        logits = processors(input_ids, logits)
        entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
        '''
        if args.debug:
            print('ori_logits:',logits)
            print(logits.shape)
            print('entropy:', entropy)
        '''
        if i > 0:
            entropy[k] -= eta
        # k = entropy[1:].argmin() + 1 if n > 1 else 0
        k = entropy[2:].argmin() + 2 if n > 1 else 0
        min_k = entropy[2:].argmax() + 2
        random_k = random.randint(2, n - 1)
        logits_max = logits[k]
        logits_uncond = logits[0]
        logits_full = logits[1]
        logits_min = logits[min_k]
        logits_random = logits[random_k]
        logits_rebest = logits[2]
        logits_remin = logits[n - 1]
        # full
        logits = logits_full
        # original nbce
        # logits_merged = (1 + beta) * logits_max - beta * logits_uncond
        # logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
        # CAD
        # beta = 0.5
        # logits_merged = (1 + beta) * logits_full - beta * logits_uncond
        # logits = torch.where(logits_uncond > -100, logits_merged, logits_full)

        # ===== 核心代码结束 =====
        # 构建分布，采样
        # tau = 1是标准的随机采样，tau->0则是贪心搜索，就是temperature
        # 简单起见，这里没有实现topk、topp截断
        tau = args.temperature
        probas = torch.nn.functional.softmax(logits[None] / tau, dim=-1)
        full_probas = torch.nn.functional.softmax(logits_full[None] / tau, dim=-1)
        probas = torch.where(torch.isnan(probas), full_probas, probas)
        try:
            next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
        except RuntimeError as e:
            print(e)
            print("Error!")
            print(k)
            print(entropy)
            print(logits)
            print(probas)
            print(tokenizer.decode(tokens[0]) if tokens is not None else 'None')
            break
        if next_tokens[0] == tokenizer.eos_token_id:
            break

        # print(next_tokens)
        if tokens is None:
            tokens = next_tokens[:, None]
        else:
            tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
        # print(tokens)
        ret = tokenizer.batch_decode(next_tokens)
        if ret[0] == '\n':
            break
        model_output.append(ret[0])
        # print(ret[0], flush=True, end='')

        # prepare for next iteration
        input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
        attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=args.device)], dim=-1)

    result = {}
    result['prompt'] = batch
    result['question'] = example.question
    result['answers'] = example.answers
    result['token_len_prompt'] = prompt_len,
    result['prediction'] = tokenizer.decode(tokens[0]) if tokens is not None else ''
    # print(' '.join(model_output))
    # print(tokens)
    # print(tokenizer.decode(tokens[0]))

    return result


def topk_generate(example, model, args, P, CP, tokenizer):
    # tokenizer.pad_token = tokenizer.unk_token
    full = True
    processors = LogitsProcessorList()
    processors.append(TopPLogitsWarper(min(0.95, args.top_p)))
    # construct the batch
    doc_cluster = args.doc_cluster
    closed_prompt = CP.apply([example], use_random=args.use_random, with_answer=False)
    if full:
        full_prompt = P.apply([example], use_random=args.use_random, with_answer=False, doc_cluster=args.doc_num)
    else:
        full_prompt = []
    if doc_cluster == args.doc_num and full:
        doc_prompt = []
    else:
        doc_prompt = P.apply([example], use_random=args.use_random, with_answer=False, doc_cluster=doc_cluster)
    batch = [closed_prompt] + full_prompt + doc_prompt
    # print(len(batch))
    # for i in range(len(batch)):
    #    print(batch[i])

    inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(args.device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    prompt_len = len(input_ids[0])
    logger.info(f"The batch shape is {input_ids.shape}")

    past_key_values = None
    n = input_ids.shape[0]
    model_output = []
    max_tokens = min(args.max_new_tokens, args.max_length - prompt_len)
    tokens = None

    for j in range(max_tokens):
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        return_dict=True,
                        use_cache=True,
                        past_key_values=past_key_values)
        past_key_values = outputs.past_key_values
        # ===== 核心部分，对logits进行调整 =====
        beta, eta = 0.25, 0.1
        '''
        if args.debug:
            print("output:", outputs.logits)
            print(outputs.logits.shape)
        '''
        logits = outputs.logits[:, -1]
        logits = logits - logits.logsumexp(dim=-1, keepdims=True)
        logits = processors(input_ids, logits)
        entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
        # print(entropy)
        tau = args.temperature
        # topk = logits.shape[1]
        topk = 10

        # CASE3: segmented

        logits_uncond = logits[0]
        logits_full = logits[1]
        logits_rebest = logits[2]
        logits_remin = logits[n - 1]
        logits_max = logits[2]  # rebest
        logits_min = logits[n - 1]  # remin

        temp_logit = logits[0]
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, indices = torch.topk(probas, topk, largest=True)
        V = temp_logit[indices]
        cad_entropy = -(V.exp() * V.clip(-100, 0)).sum(dim=-1).item()
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        cad_sub = values[0][0] - values[0][1]

        temp_logit = logits[1]
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, indices = torch.topk(probas, topk, largest=True)
        V = temp_logit[indices]
        full_entropy = -(V.exp() * V.clip(-100, 0)).sum(dim=-1).item()
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        full_sub = values[0][0] - values[0][1]

        maxent = -0.01
        minent = 50.0
        for i in range(2, n):
            temp_logit = logits[i]
            probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
            values, indices = torch.topk(probas, topk, largest=True)
            V = temp_logit[indices]
            entropy = -(V.exp() * V.clip(-100, 0)).sum(dim=-1).item()
            if entropy > maxent:
                maxent = entropy
                logits_min = temp_logit
            if entropy < minent:
                minent = entropy
                logits_max = temp_logit

        # frozen
        '''
        #beta = 0.1
        gamma = 1.0
        #print(cad_entropy, full_entropy, maxent, minent)
        #logits_max = logits_rebest
        #logits_min = logits_remin
        #random_1 = random.randint(2, n - 1)
        #random_2 = random.randint(2, n - 1)
        #logits_max = logits[random_1]
        #logits_min = logits[random_2]

        if cad_entropy * 10 < full_entropy:
            logits_merged = logits_full  + gamma * logits_max - gamma * logits_min
            logits = torch.where(logits_min > -100, logits_merged, logits_uncond)
        else:
            # CAD + documents
            logits_merged = (1 + beta) * logits_full - beta * logits_uncond + gamma * logits_max - gamma * logits_min
            logits = torch.where(logits_uncond > -100, logits_merged, logits_full)
            logits = torch.where(logits_min > -100, logits, logits_full)
        '''
        # dynamic
        logits_max = logits_rebest
        logits_min = logits_remin
        temp_logit = logits_max
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        max_sub = values[0][0] - values[0][1]

        temp_logit = logits_min
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        min_sub = values[0][0] - values[0][1]

        temp_logit = logits_max - logits_min
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        rmax_sub = values[0][0] - values[0][1]
        temp_logit = logits_full - logits_uncond
        probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
        values, _ = torch.topk(probas, 2, largest=True, sorted=True)
        rfull_sub = values[0][0] - values[0][1]

        gamma = max(max_sub - min_sub, 0)  # max
        beta = max(full_sub - cad_sub, 0)  # full
        # gamma = (max_sub + min_sub) / 2
        # beta = (full_sub + cad_sub) / 2
        # print(full_sub, cad_sub, max_sub, min_sub)
        if cad_entropy * 10 < full_entropy:
            logits_merged = logits_uncond + gamma * logits_max - gamma * logits_min
            logits = torch.where(logits_min > -100, logits_merged, logits_uncond)
        else:
            # CAD + documents
            logits_merged = (1 + beta) * logits_full - beta * logits_uncond + gamma * logits_max - gamma * logits_min
            logits = torch.where(logits_uncond > -100, logits_merged, logits_full)
            logits = torch.where(logits_min > -100, logits, logits_full)

        probas = torch.nn.functional.softmax(logits[None] / tau, dim=-1)
        full_probas = torch.nn.functional.softmax(logits_full[None] / tau, dim=-1)
        probas = torch.where(torch.isnan(probas), full_probas, probas)
        next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)

        # end
        if next_tokens is None:
            print('logits: ', logits)
            for i in range(n):
                temp_logit = logits[i]
                probas = torch.nn.functional.softmax(temp_logit[None] / tau, dim=-1)
                values, indices = torch.topk(probas, topk, largest=True, sorted=True)
                V = temp_logit[indices]
                entropy = -(V.exp() * V.clip(-100, 0)).sum(dim=-1)
                print(indices, V, entropy)
            break

        if next_tokens[0] == tokenizer.eos_token_id:
            break

        # print(next_tokens)
        if tokens is None:
            tokens = next_tokens[:, None]
        else:
            tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
        # print(tokens)
        ret = tokenizer.batch_decode(next_tokens)
        if ret[0] == '\n':
            break
        model_output.append(ret[0])

        # print(ret[0], flush=True, end='')

        # prepare for next iteration
        input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
        attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=args.device)], dim=-1)

    result = {}
    result['prompt'] = batch
    result['question'] = example.question
    result['answers'] = example.answers
    result['token_len_prompt'] = prompt_len,
    result['prediction'] = tokenizer.decode(tokens[0]) if tokens is not None else ''
    # print(' '.join(model_output))
    # print(tokens)
    # print(tokenizer.decode(tokens[0]))

    return result


def main():
    args = get_argument()
    set_seed(args)

    if args.do_inference:
        logger.info(f"The model path is {args.model}")
        tokenizer = AutoTokenizer.from_pretrained(args.model, use_auth_token=True, trust_remote_code=True)
        tokenizer.padding_side = "left"
        torch_type = torch.float16
        tokenizer.pad_token = tokenizer.unk_token
        if "3" in args.model:
            torch.type = torch.bfloat16
            tokenizer.pad_token = tokenizer.eos_token

        logger.info(f"Begin to process the data from input_file {args.input_file}, type is {args.type}")
        if args.type in ["nq_doc_num", "nq", "tqa", "alce_asqa", "alce_qampari", "alce_eli5"]:
            eval_dataset = MDDataset(args.input_file, args.type)
        else:
            raise NotImplementedError

        logger.info(f"The chosen prompt is {args.prompt} with doc_num is {args.doc_num}.")
        if args.use_random:
            logger.info("The document is ordered randomly.")
        else:
            logger.info("The document is not ordered randomly.")
        if args.prompt == "closed":
            P = Prompt("str", CLOSED_BOOK_PROMPT, closed_book=True, ndoc=0)
        elif args.prompt == "md":
            if args.doc_num == 0:
                raise ValueError("The doc num is set to zero in multi-document experiments!")
            if args.use_random:
                P = Prompt("str", MD_RANDOM_PROMPT, closed_book=False, ndoc=args.doc_num)
            else:
                P = Prompt("str", MD_PROMPT, closed_book=False, ndoc=args.doc_num)
            if args.closed_p is not None:
                CP = Prompt("str", CLOSED_BOOK_PROMPT, closed_book=True, ndoc=0)
        else:
            P = Prompt("json", args.prompt, closed_book=args.closed_book, ndoc=args.doc_num)
            if args.closed_p is not None:
                CP = Prompt("json", args.closed_p, closed_book=True, ndoc=0)
            else:
                CP = None

        if args.mode == "standard" or args.mode == "nbce" or args.mode == "probsub" or args.mode == "topk" or args.mode == "maxprob":
            model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_type, device_map='auto',
                                                         use_auth_token=True, trust_remote_code=True)
            # model.config.pad_token_id = model.config.unk_token_id
            '''
            if args.mode == "nbce":
                print(model.config)
                print(model.model.padding_idx)
                print(model.model.embed_tokens)
                tokenizer.add_special_tokens({"pad_token": "<pad>"})
                model.config.pad_token_id = model.config.vocab_size
                print(model.config)
                print(model.model.padding_idx)
                print(model.model.embed_tokens)
                return
            '''

            model.eval()
        elif args.mode == "pcw":
            if args.doc_num % args.doc_cluster == 0:
                n_windows = args.doc_num / args.doc_cluster + 1
            else:
                n_windows = args.doc_num // args.doc_cluster + 1 + 1
            model = load_pcw_wrapper(args.model, n_windows=n_windows)
            logger.info(f"Using PCW with {n_windows} windows.")

        datas = {}
        if args.eval_during_infer:
            logger.info(f"Evaluate during inference with metric {args.metrics}.")
            assert args.metrics in ["accuracy", "extractive", "asqa"]
            scores = {
                "accuracy": [],
                "em": [],
                "f1": [],
                "str_em": [],
                "str_em_hit": []
            }
        with torch.inference_mode():
            for example in tqdm(eval_dataset, total=len(eval_dataset)):
                qid = example.qas_id
                datas[qid] = {}
                if args.mode == "standard":
                    datas[qid].update(standard_generate(example, model, args, P, tokenizer))
                elif args.mode == "nbce":
                    datas[qid].update(nbce_generate(example, model, args, P, CP, tokenizer))
                elif args.mode == "probsub":
                    datas[qid].update(prob_sub_generate(example, model, args, P, CP, tokenizer))
                elif args.mode == "maxprob":
                    datas[qid].update(max_prob(example, model, args, P, CP, tokenizer))
                elif args.mode == "topk":
                    datas[qid].update(topk_generate(example, model, args, P, CP, tokenizer))
                elif args.mode == "pcw":
                    datas[qid].update(pcw_generate(example, model, args, P, CP))
                if example.qa_pairs is not None:
                    datas[qid]['qa_pairs'] = example.qa_pairs
                if example.annotations is not None:
                    datas[qid]['annotations'] = example.annotations
                if example.claims is not None:
                    datas[qid]['claims'] = example.claims
                if args.eval_during_infer:
                    score = evaluate_example(example.answers, datas[qid]['prediction'], args.metrics, example.qa_pairs)
                    if args.metrics == "accuracy":
                        scores['accuracy'].append(score)
                    elif args.metrics == "extractive":
                        scores["em"].append(score[0])
                        scores["f1"].append(score[1])
                    elif args.metrics == "asqa":
                        scores['str_em'].append(score[0])
                        scores["str_em_hit"].append(score[1])

                if len(datas) < 10:
                    if isinstance(datas[qid]['prompt'], list):
                        for p in datas[qid]['prompt']:
                            print(p)
                    else:
                        print(datas[qid]['prompt'])
                    if "token_len_prompt" in datas[qid]:
                        print(datas[qid]['token_len_prompt'])
                    print(example.answers)
                    print(datas[qid]['prediction'])
                else:
                    if args.debug and len(datas) == 50:
                        break

        if args.eval_during_infer:
            for key in scores.keys():
                if len(scores[key]) > 0:
                    scores[key] = (sum(scores[key]) / len(scores[key]))
                    print(key, scores[key])
        if args.type == "nq_doc_num":
            if args.special_location > 0:
                output_file = args.model.split('/')[-1] + '_' + eval_dataset.type + '_' + str(
                    args.doc_num) + 'doc_' + str(args.special_location) + '_loc_' + str(args.shot) + 'shot.jsonl.gz'
            else:
                output_file = args.model.split('/')[-1] + '_' + eval_dataset.type + '_' + str(args.doc_num) + 'doc_' \
                              + str(args.shot) + 'shot.jsonl.gz'
            logger.info(f"The result is saved to {args.output_dir}/{output_file}")
            with xopen(os.path.join(args.output_dir, output_file), "w") as f:
                for qid in datas.keys():
                    f.write(json.dumps(datas[qid]) + '\n')
        else:
            # output_file = args.model.split('/')[-1] + '_' + eval_dataset.type + '_' + str(args.doc_num) + 'doc_' + str(args.shot) + 'shot.json'
            if args.mode == "nbce" or args.mode == "pcw" or args.mode == "probsub" or args.mode == "topk" or args.mode == "maxprob":
                output_file = args.model.split('/')[-1] + '_' + eval_dataset.type + '_' + str(
                    args.doc_num) + 'doc_' + str(args.shot) + 'shot_' + args.mode + str(args.doc_cluster) + '.json'
            else:
                output_file = args.model.split('/')[-1] + '_' + eval_dataset.type + '_' + str(
                    args.doc_num) + 'doc_' + str(args.shot) + 'shot.json'
            logger.info(f"The result is saved to {args.output_dir}/{output_file}")
            with open(os.path.join(args.output_dir, output_file), 'w', encoding='utf-8') as f:
                json.dump(datas, f, indent=4)
                f.close()

    if args.do_evaluate:
        evaluate(args)


if __name__ == '__main__':
    main()
