# Ref: https://github.com/kojima-takeshi188/zero_shot_cot

import re
import os
import json
import random
import torch
import numpy as np
import pandas as pd
import transformers
from tqdm import tqdm, trange
import argparse
import math
from collections import defaultdict, Counter
import glob
import sys

import ssl
import urllib.request
import zipfile
import tiktoken
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from generation import LLM

transformers.logging.set_verbosity(40)

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

N_SHOT = 8
COT_FLAG = True
DEBUG = True
ANSWER_TRIGGER = "The answer is"


def num_tokens_from_message(message, model="davinci"):
    encoding = tiktoken.encoding_for_model(model)
    num_tokens = len(encoding.encode(message))
    return num_tokens


def truncate_message(prompt1, prompt2, model="davinci"):
    if num_tokens_from_message(prompt1 + prompt2, model) > 2033:
        truncation_length = 2033 - num_tokens_from_message(prompt2)
        while num_tokens_from_message(prompt1) > truncation_length:
            prompt1 = " ".join(prompt1.split(' ')[:-1])
    prompt = prompt1 + prompt2
    return prompt


demo_keys = []

data_candidate_keys = {
    'summarization': ['hallucinated_summary', 'right_summary'],
    'qa': ['hallucinated_answer', 'right_answer'],
    'dialogue': ['hallucinated_response', 'right_response'],
    'xsum': ['hallucinated_summary', 'right_summary'],
    'cnndm': ['hallucinated_summary', 'right_summary'],
}

data_context_keys = {
    'summarization': 'document',
    'qa': 'question',
    'dialogue': 'dialogue_history',
    'xsum': 'document',
    'cnndm': 'article',
}

data_context_names = {
    'summarization': 'Document',
    'qa': 'Question',
    'dialogue': 'Dialogue History',
    'litm': 'Document',
    'xsum': 'Article',
    'cnndm': 'Article',
}

data_response_names = {
    'summarization': 'Summary',
    'qa': 'Answer',
    'dialogue': 'Response',
    'litm': 'Answer',
    'xsum': 'Summary',
    'cnndm': 'Summary',
}

def load_nq_open(file_path, parallel=False, total_shard=8, shard_id=0, debug=False, data_type='nq_open', subsample=None):
    '''Format of NQ Open'''
    '''{"question": "who got the first nobel prize in physics", "answers": ["Wilhelm Conrad R\u00f6ntgen"], "ctxs": [{"id": "628725", "title": "Nobel Prize in Phys
ics", "text": "receive a diploma, a medal and a document confirming the prize amount. Nobel Prize in Physics The Nobel Prize in Physics () is a yearly award
 given by the Royal Swedish Academy of Sciences for those who have made the most outstanding contributions for mankind in the field of physics. It is one of
 the five Nobel Prizes established by the will of Alfred Nobel in 1895 and awarded since 1901; the others being the Nobel Prize in Chemistry, Nobel Prize in
 Literature, Nobel Peace Prize, and Nobel Prize in Physiology or Medicine. The first Nobel Prize in Physics was", "score": "1.6234328", "hasanswer": false,
"original_retrieval_index": 0, "isgold": false},'''
    list_data_dict = []
    is_train = 'nq_train' in file_path
    with open(file_path, 'r', encoding="utf-8") as f:
        data = []
        for line in f:
            data.append(json.loads(line))
        if debug:
            data = data[:100]
        if subsample is not None:
            # select data if idx%subsample == 0
            data = [data[i] for i in range(len(data)) if i % subsample == 0]
        if parallel:
            chunk_size = len(data) // total_shard
            data = data[shard_id * chunk_size: (shard_id + 1) * chunk_size]

        for idx in range(len(data)):
            data_index = idx
            question = data[idx]['question']
            # capitalize the first letter of the question, add the question mark if not present at the end
            question = question[0].upper() + question[1:]
            if question[-1] != '?':
                question += '?'
            answers = data[idx]['answers']
            if is_train:
                pos_ctxs = data[idx]['positive_ctxs']
                neg_ctxs = data[idx]['negative_ctxs']
            else:
                ctxs = data[idx]['ctxs']
                pos_ctxs = [ctx for ctx in ctxs if ctx['hasanswer']]
                neg_ctxs = [ctx for ctx in ctxs if not ctx['hasanswer']]
            assert len(pos_ctxs) > 0, "No positive context found."
            assert len(neg_ctxs) >= 2, "At least two negative contexts are required."
            context = f"#Document#: " + neg_ctxs[0]['text'] + '\n' + pos_ctxs[0]['text'] + '\n' + neg_ctxs[1]['text']
            context += f"\n#Question#: {question}"
            response = f"\n#Answer#:"
            new_item = dict(
                context=context,
                response=response,
                net_response=None,
                answer=answers[0],
                data_index=data_index
            )
            list_data_dict.append(new_item)
    return list_data_dict


def load_jsonl(file_path, parallel=False, total_shard=8, shard_id=0, debug=False, data_type='summarization', subsample=None):
    list_data_dict = []
    with open(file_path, 'r', encoding="utf-8") as f:
        data = []
        data_indices = []
        data_index = 0
        for line in f:
            data.append(json.loads(line))
            data_indices.append(data_index)
            data_index += 1
        if debug:
            data = data[:100]
            data_indices = data_indices[:100]
        if subsample is not None:
            # select data if idx%subsample == 0
            data = [data[i] for i in range(len(data)) if i % subsample == 0]
            data_indices = [data_indices[i] for i in range(len(data_indices)) if i % subsample == 0]
        if parallel:
            chunk_size = len(data) // total_shard
            data = data[shard_id * chunk_size: (shard_id + 1) * chunk_size]
            data_indices = data_indices[shard_id * chunk_size: (shard_id + 1) * chunk_size]

        for idx in range(len(data)):
            data_index = data_indices[idx]
            if data_type == 'qa':
                context = f"#Knowledge#: {data[idx]['knowledge']}\n#{data_context_names[data_type]}#: " + data[idx][data_context_keys[data_type]]
            else:
                context = f"#{data_context_names[data_type]}#: " + data[idx][data_context_keys[data_type]]
            new_item = dict(
                context=context,
                data_index=data_index
            )
            list_data_dict.append(new_item)

    return list_data_dict

def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        json_record = json.dumps(data, ensure_ascii=False)
        f.write(json_record + '\n')


def create_demo_text(pondering=None, data_type='summarization'):
    if data_type == 'summarization':
        return "Generate a summary based on the information in the document.\n\n"
    elif data_type == 'litm':
        return "Answer the question based on the information in the document. Explain your reasoning in the document step-by-step before providing the final answer.\n\n"
    elif data_type == 'xsum':
        return "Generate a summary comprising of 1 sentence for the given article.\n\n"
    elif data_type == 'cnndm':
        return "Generate a summary comprising of 3 sentences for the given article.\n\n"
    else:
        raise ValueError("Please specify the data type.")


# instruction = 'You should try your best to determine if the summary contains non-factual or hallucinated information according to the above hallucination types. The answer you give MUST be \"Yes\" or \"No\"".'


def build_prompt(context, response, pondering=None, data_type='summarization'):
    demo = create_demo_text(pondering, data_type)
    prompt = demo + context
    if data_type == 'summarization':
        input_text_prompt = truncate_message(prompt, response)
    else:
        input_text_prompt = prompt + response
    return input_text_prompt


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def plot_histogram_distributions(scores, labels, filename='histogram.png'):
    """
    Plot the histogram distributions of prediction scores for two groups of examples, 
    separated by their labels (0 or 1).
    
    :param scores: A list of predicted scores from the model.
    :param labels: A list of ground truth labels.
    """
    # Validate input lengths
    if len(scores) != len(labels):
        raise ValueError("The length of scores and labels must be the same.")

    # Create a DataFrame for easier plotting
    data = {'Score': scores, 'Label': labels}
    df = pd.DataFrame(data)

    # Plotting
    plt.figure(figsize=(10, 6))
    sns.histplot(data=df, x='Score', hue='Label', element='step', stat='count', common_norm=False, palette='viridis')
    plt.title('Histogram of Prediction Scores by Label')
    plt.xlabel('Predicted Score')
    plt.ylabel('Count')
    # plt.show()
    plt.savefig(filename)

def find_best_threshold(fpr, tpr, thresholds):
    """
    Find the best threshold from the ROC curve by choosing the point 
    closest to the top-left corner (0,1).

    :param fpr: Array of False Positive Rates
    :param tpr: Array of True Positive Rates
    :param thresholds: Array of thresholds corresponding to each (FPR, TPR) point
    :return: The best threshold value
    """
    # Calculate the Euclidean distance for each point on the ROC curve from the top-left corner
    distances = np.sqrt((1 - tpr) ** 2 + fpr ** 2)
    
    # Find the index of the smallest distance
    best_idx = np.argmin(distances)

    # Return the threshold at this index
    return thresholds[best_idx]

def calculate_metrics(scores, labels, threshold):
    """
    Calculate precision, recall, F1 score, and average accuracy based on a given threshold.
    
    :param scores: A list of predicted scores from the model.
    :param labels: A list of ground truth labels.
    :param threshold: The threshold to convert scores to binary classifications.
    :return: A tuple containing precision, recall, F1 score, and accuracy.
    """
    # Convert scores to binary classifications
    predictions = [1 if score >= threshold else 0 for score in scores]

    # Calculate metrics
    precision = precision_score(labels, predictions)
    recall = recall_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    accuracy = accuracy_score(labels, predictions)
    
    subset_acc_where_label_0 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 0], [predictions[i] for i in range(len(labels)) if labels[i] == 0])
    subset_acc_where_label_1 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 1], [predictions[i] for i in range(len(labels)) if labels[i] == 1])
    # harmonic mean of subset accuracy 2*corrects[0]*corrects[1]/(correct*len(list_data_dict[0]))
    harmonic_mean_accuracy = 2*subset_acc_where_label_0*subset_acc_where_label_1/(subset_acc_where_label_0+subset_acc_where_label_1)

    return precision, recall, f1, accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--device", type=str,
                        choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--data-path", type=str, default="./gsm8k")
    parser.add_argument("--output-path", type=str, default="./gsm8k_result")
    # parallel mode (split the dataset into multiple parts, inference by separate processes)
    parser.add_argument("--early-exit-layers", type=str, default="-1")
    parser.add_argument("--divergence-type", type=str, default="js")
    parser.add_argument("--parallel", action="store_true")
    parser.add_argument("--total-shard", type=int, default=8)
    parser.add_argument("--shard-id", type=int, default=0)
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--extrapolate_coeff", type=float, default=10000.0)
    parser.add_argument("--relative_top", type=float, default=0.1)
    # parser.add_argument("--relative_top_value", type=float, default=-1000.0)
    parser.add_argument("--relative_top_with_norm", action="store_true")
    parser.add_argument("--contrast_disagree_only", action="store_true")
    parser.add_argument("--pre_softmax", action="store_true")
    parser.add_argument("--do_sample", action="store_true")
    parser.add_argument("--do_shuffle", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--subsample", type=int, default=None)
    parser.add_argument("--penalty_alpha", type=float, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--retry", type=int, default=1)
    parser.add_argument("--tuned-lens-path", type=str, default=None)
    parser.add_argument("--auth-token", type=str, default=None)
    parser.add_argument("--premature_temp", type=float, default=1.0)
    parser.add_argument("--apply_early_norm", action="store_true")
    parser.add_argument("--attn-intervention", action="store_true")
    parser.add_argument("--attn-intervention-low-prob", action="store_true")
    parser.add_argument("--attn-int-factor", type=float, default=0.0001)
    parser.add_argument("--low-prob-percentile", type=float, default=0.1)
    parser.add_argument("--keys-path", type=str, default=None)
    parser.add_argument("--pause-num", type=int, default=3)
    parser.add_argument("--alpha", type=float, default=10)
    # parser.add_argument("--subsets", type=str, default="hallucinated_summary,right_summary")
    parser.add_argument("--pondering", type=str, default=None)
    parser.add_argument("--free-form", action="store_true")
    parser.add_argument("--attn-score", action="store_true")
    parser.add_argument("--shift-by-1", action="store_true")
    parser.add_argument("--important-token-type", type=str, default=None)
    parser.add_argument("--data-type", type=str, default=None)
    # observe_top_n_heads
    parser.add_argument("--observe-top-n-heads", type=int, default=None)
    # skip save attn
    parser.add_argument("--skip-save-attn", action="store_true")
    

    args = parser.parse_args()
    model_name = args.model_name
    num_gpus = args.num_gpus
    device = args.device

    # load your finetuned model (saved as xxx.ckpt)
    #    in yaml file federate.save_to
    forced_truncate = ('gpt2' in args.model_name)
    if args.data_type is None:
        if 'cnndm' in args.data_path:
            args.data_type = 'summarization'
        elif 'nq-open' in args.data_path:
            args.data_type = 'litm'
        elif 'xsum' in args.data_path:
            args.data_type = 'xsum'
        else:
            raise ValueError("Please specify the data type.")
    # Get test file
    fp = args.data_path
    if not os.path.exists(fp):
        raise ValueError(f"Test file {fp} does not exist.")

    if "nq-open" in fp:
        list_data_dict = load_nq_open(fp, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, subsample=args.subsample)
    else:
        list_data_dict = load_jsonl(fp, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, data_type=args.data_type, subsample=args.subsample)
    
    if args.pondering is not None:
        list_data_dict_keys = load_jsonl(
            fp, pondering=args.pondering, keys_path=args.keys_path, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, data_type=args.data_type, subsample=args.subsample)
    llm = LLM(
        model_name, device, num_gpus, args.tuned_lens_path, args.auth_token)
    stop_word_list = ["Q:", "\end{code}", "#Document#:", "#Pondering#:", "#Question#:", "#Dialogue History#:"]
    llm.set_stop_words(stop_word_list)
    early_exit_layers = [int(x) for x in args.early_exit_layers.split(',')]
    if early_exit_layers == [-1]:
        print("MODE: naive decoding from the last layer", flush=True)
        mode = "vanilla"
        final_layer = None
        base_layer = None
        dynamic_exit_layers = None
    elif len(early_exit_layers) == 2:
        print(
            f"MODE: early exit contrastive with final layer: {early_exit_layers[1]} and base layer: {early_exit_layers[0]}")
        mode = "early_exit_contrastive"
        final_layer = early_exit_layers[1]
        base_layer = early_exit_layers[0]
        dynamic_exit_layers = None
    else:
        print(
            f"MODE: dynamic early exit contrastive with final layer: {early_exit_layers[-1]} and base layers: {early_exit_layers[:-1]}")
        mode = "dynamic_early_exit_contrastive"
        final_layer = early_exit_layers[-1]
        base_layer = None
        dynamic_exit_layers = early_exit_layers[:-1]
        critical_layer_dist = {l: 0 for l in dynamic_exit_layers}
    answers = []
    to_save_attention_list = []
    output_path = args.output_path
    corrects, incorrects = [], []
    choices_token_ids = None
    output_path = args.output_path+"_"+str(args.shard_id)+".jsonl"
    fw = open(output_path, 'w')
    extra_prompt_length = len(llm.tokenizer(f"\n#{data_response_names[args.data_type]}#:")['input_ids']) - 1
    for idx in tqdm(range(len(list_data_dict))):
        # data_index = j * len(list_data_dict[j]) + idx
        # Print all the information
        sample = list_data_dict[idx]
        if args.pondering is None:
            input_text_keys = None
        else:
            sample_keys = list_data_dict_keys[idx]
            input_text_keys = build_prompt(
                sample_keys['context'], sample_keys['response'], pondering=args.pondering, data_type=args.data_type)

        input_text = build_prompt(sample['context'], f"\n#{data_response_names[args.data_type]}#:", data_type=args.data_type)
        keywords = sample['response'] if mode == "attn_intervention" else None
        attn_intervention_factor = args.attn_int_factor
        generate_kwargs = dict(max_new_tokens=args.max_new_tokens, penalty_alpha=args.penalty_alpha, do_sample=args.do_sample, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty, extrapolate_coeff=args.extrapolate_coeff, pre_softmax=args.pre_softmax, mode=mode, final_layer=final_layer, base_layer=base_layer,
                            base_layers=dynamic_exit_layers, divergence_type=args.divergence_type, 
                            relative_top=args.relative_top, relative_top_with_norm=args.relative_top_with_norm, 
                            contrast_disagree_only=args.contrast_disagree_only, 
                            premature_temp=args.premature_temp, apply_early_norm=args.apply_early_norm, 
                            return_attentions=True,)
        model_completion, attentions, gen_seq = llm.generate(
            input_text, input_text_keys=keywords, **generate_kwargs)
        if not args.skip_save_attn:
            context_length = attentions[0][0].shape[-1] - extra_prompt_length
            new_token_length = len(attentions)
            num_layers = len(attentions[0])
            num_heads = attentions[0][0].shape[1]
            attn_scores = torch.zeros((num_layers, num_heads, new_token_length))
            attn_scores_on_sink = torch.zeros((num_layers, num_heads, new_token_length))
            attn_scores_no_sink = torch.zeros((num_layers, num_heads, new_token_length))
            for i in range(len(attentions)): # iterating over the new tokens length
                for l in range(num_layers):
                    attn_on_context = attentions[i][l][0, :, -1, :context_length].mean(-1)
                    attn_on_new_tokens = attentions[i][l][0, :, -1, context_length:].mean(-1)
                    attn_scores[l, :, i] = attn_on_context / (attn_on_context + attn_on_new_tokens)
        cropped_model_completion = model_completion
        for stop_word in stop_word_list:
            length_to_remove = len(stop_word)
            if model_completion[-length_to_remove:] == stop_word:
                cropped_model_completion = model_completion[:-length_to_remove]
        cropped_gen_seq = llm.tokenizer(model_completion)['input_ids'][1:]
        return_dict = {
            sample['data_index']: cropped_model_completion.strip()
        }
        fw.write(json.dumps(return_dict, ensure_ascii=False) + '\n')
        # flush
        fw.flush()
        if not args.skip_save_attn:
            to_save = {
                'data_index': sample['data_index'],
                'model_completion': cropped_model_completion,
                'model_completion_ids': gen_seq,
                'full_input_text': input_text,
                'attn_scores': attn_scores,
            }
            to_save_attention_list.append(to_save)

    fw.close()
    if not args.skip_save_attn:
        # save attention scores
        attn_scores_path = args.output_path + "_attn_scores_" + str(args.shard_id) + ".pt"
        torch.save(to_save_attention_list, attn_scores_path)