import os, sys
from pathlib import Path
# file location, 
PROJECT_ROOT = Path(os.curdir).resolve().parent.parent
# assert PROJECT_ROOT contains utils and data directories
assert (PROJECT_ROOT / 'src').exists()
print(PROJECT_ROOT)
if PROJECT_ROOT not in sys.path:
    sys.path.append(str(PROJECT_ROOT / 'src'))

import vllm
from transformers import AutoTokenizer
from tqdm import tqdm
import torch
import random
import json
from pyserini.search.lucene import LuceneSearcher
import numpy as np
from string import Template
from collections import defaultdict
import pandas as pd


ROOT_DIR = os.environ.get('ROOT_DIR', 'default')

encoders = ['ance', 'contriever', 'dpr', 'gtr', 'simcse', 'tasb']

def append_to_result(result, file):
    if os.path.exists(file):
        append_write = 'a'
    else:
        append_write = 'w'

    if isinstance(result, dict):
        result = json.dumps(result)

    with open(file, append_write) as w:
        w.write(result + '\n')


def readjsonl(_path):
    with open(_path, 'r') as f:
        lines = [json.loads(line) for line in f.readlines()]
        
    return lines

def percent_round(_value):
    return round(_value * 100, 1)


def truncate_passage(passage, max_words):
    return ' '.join(passage.split()[:max_words])


def deduplicate_dicts(dict_list, key):
    seen = set()
    unique_dicts = []
    for d in dict_list:
        value = d.get(key)
        if value not in seen:
            seen.add(value)
            unique_dicts.append(d)
    return unique_dicts

def get_records(file_path, searcher=None):
    """
        Get the records from the file
        searcher is a pyserini searcher
        There will be two types of inputs:
        1. .aug: qid, pid, content
        2. .txt: the pyserini retrieval result

        Return: List of Dict
        [{
            'qid': str,
            'pid': str,
            'content': str
        }]
    """

    if file_path.endswith('.aug') or file_path.endswith('tsv') or 'rrf' in file_path:
        #read pd dataframe
        records = pd.read_csv(file_path, sep='\t').to_dict('records')
        if searcher == None:
            return records

        if 'content' not in records[0]:
            for record in records:
                doc = searcher.doc(record['pid'])
                content = json.loads(doc.raw())['contents'].replace('\n', '\\n').replace('\t', ' ')
                record['content'] = content

    elif file_path.endswith('.txt'):
        records = []
        list_records = pd.read_csv(file_path, sep=' ', header=None).to_records()
        for record in tqdm(list_records):
            if searcher is not None:
                doc = searcher.doc(record[3])
                content = json.loads(doc.raw())['contents'].replace('\n', '\\n').replace('\t', ' ')
            else:
                content = None
            record = {
                'qid': record[1],
                'pid': record[3],
                'score': record[5],
                'content': content,
            }
            records.append(record)
    
    return records

def list2dict(lines):
    """
    convert list of lines to 
    {
        id : {
        pid: [str, ...]
        }
    }
    """
    qid_prop_dict = dict()
    for line in lines:
        qid = line['qid']
        
        if qid not in qid_prop_dict:
            qid_prop_dict[qid] = []
            
        qid_prop_dict[qid].append((line['content'], line['score']))
        
    
    for _k, _v in qid_prop_dict.items():
        tmp_v = sorted(_v, key=lambda x: x[1], reverse=True)
        qid_prop_dict[_k] = [' '.join(_prop.split("\\n")[1:]) for _prop, _ in tmp_v]
        
    return qid_prop_dict


def get_gpu_memory_utilization():
    """
    Get the current GPU memory utilization.
    """
    gpu_max_memory = torch.cuda.mem_get_info(0)[1] / 1024**3
    if gpu_max_memory > 20.0:
        utilization = 0.9
    else:
        utilization = min(0.95, (gpu_max_memory - 1.2) / gpu_max_memory)
    print(f"[INFO] GPU memory utilization: {utilization:.2f}")
    return utilization


class LLM:
    def __init__(self, model_name_or_path):
        # model_name_or_path = "../../../hf_models/Llama-2-7b-hf/"
        # model_name_or_path = "TheBloke/Llama-2-7B-AWQ"
        # model_name_or_path = "facebook/opt-125m"
        # quantization = "AWQ"
        tokenizer_name_or_path = None
        use_slow_tokenizer = False
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path if tokenizer_name_or_path else model_name_or_path)
        self.model_vllm = vllm.LLM(
            model=model_name_or_path,
            tokenizer=tokenizer_name_or_path if tokenizer_name_or_path else model_name_or_path,
            tokenizer_mode="slow" if use_slow_tokenizer else "auto",
            tensor_parallel_size=torch.cuda.device_count(),
            gpu_memory_utilization=get_gpu_memory_utilization(),
            dtype="half",
            # torch_dtype=torch.bfloat16,
            # tensor_parallel_size=2,
            # quantization=quantization,
            # max_num_seqs=8,
            # max_num_batched_tokens=4096,
            max_model_len=4096,
        )

    def generate(self, input_text_list, **sampleing_kwargs):
        sampling_params = vllm.SamplingParams(
            # temperature=0,
            # max_tokens=512,
            # stop=["\n"],
            **sampleing_kwargs
        )
        # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
        # Print the outputs.
        output_text_list = []
        outputs = self.model_vllm.generate(input_text_list, sampling_params)
        for input_text, output in zip(input_text_list, outputs):
            prompt = output.prompt
            assert prompt == input_text
            generated_text = output.outputs[0].text
            output_text_list.append(generated_text)
            # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
        return output_text_list


def _truncate_passage(tokenizer, passage, max_tokens) -> str:
    # passage_truncated = " ".join(passage.split()[:max_tokens])
    tokens = tokenizer.tokenize(passage, max_length=max_tokens, truncation=True)
    assert len(tokens) <= max_tokens
    passage_truncated = tokenizer.convert_tokens_to_string(tokens)
    # print(len(tokenizer.tokenize(passage_truncated)), max_words)
    # assert len(tokenizer.tokenize(passage_truncated)) == len(tokens)
    return passage_truncated


def _extract_first_paired_parentheses(s):
    """
    Extract the first paired parentheses from a string, considering nested parentheses.
    """
    stack = []
    for i, char in enumerate(s):
        if char == '(':
            stack.append(i)
        elif char == ')' and stack:
            start = stack.pop()  # Get the opening parenthesis index
            if len(stack) == 0:  # If the stack is empty, we have found the first paired parentheses
                return s[start + 1:i]  # Return the substring including the parentheses
    return None


def load_prompt_dict():
    global prompt_dict
    prompt_dict = {"template": {}, "demo": defaultdict(list)}
    prompt_dict["template"]["unifiedqa"] = Template("${query} \\n ${options} \\n ${context}")
    prompt_dict["template"]["llama"] = {}
    # prompt_dict["template"]["llama"]['sciq'] = Template("Given the knowledge source: ${context} \\n Question: ${query} \\n Options: ${options} \\n Answer:")
    prompt_dict['template']['llama']['sciq'] = Template("Given the knowledge source: ${context} \\n Question: ${query} \\n Reply with one phrase. \\n Answer:")
    prompt_dict["template"]["llama"]['scifact'] = Template("Context: ${context} \\n Claim: ${query} \\n For the claim, the context is supportive, contradictory or not related? \\n Options: (A) Supportive (B) Contradictory (C) Not related \\n Answer:")



def get_prompt(args, model_name, question, options, context):
    load_prompt_dict()
    if model_name.startswith("llama"):
        if args.task == 'sciq':
            # input_text = prompt_dict["template"]["llama"]['sciq'].substitute(query=question,
            #                                                                 options=options,
            #                                                                 context=context)

            input_text = prompt_dict["template"]["llama"]['sciq'].substitute(query=question,
                                                                            context=context)
        elif args.task == 'scifact':
            input_text = prompt_dict["template"]["llama"]['scifact'].substitute(query=question,
                                                                                context=context)
        return input_text
    else:
        raise ValueError(f"model_name {model_name} not supported")

def prepare_prompts(args, query_list, max_words=256):
    """
    query_list is a list of dict
    keys:
    - qid
    - question:
    - results: list of passages
    - answers
    """
    prompt_list = []
    for query in query_list:
        question = query["question"]
        # results here is the retreived results
        results = query["results"]
        options = query['options']
        options_list = query['options']
        if 'scifact' in query['qid']:
            options = "(A) yes (B) no"
            options_list = ['supportive', 'contradictory']
            tmp = []

            for _k in query['answers']:
                if _k == 'yes':
                    tmp.append('supportive')
                elif _k == 'no':
                    tmp.append('contradictory')
                else:
                    # untouched
                    tmp.append(_k)
            query['answers'] = tmp
            
        else:
            options = f"(A) {options[0]} (B) {options[1]} (C) {options[2]} (D) {options[3]}"
        passage_list = results
        
        if not args.top_passages:
            passage_all = " ".join(passage_list)
            passage_truncated = truncate_passage(passage_all, max_words=max_words)
        else:
            # get the top passages
            passage_truncated = ' '.join(results[:max_words])
        
        input_text = get_prompt(args, args.model_name, question=question, options=options, context=passage_truncated)
        prompt_list.append({"qid": query["qid"],
                            "options": options,
                            "options_list": options_list,
                            "question": question,
                            "answers": query["answers"],
                            "input_text": input_text})
        
    print(f'[INFO] {len(prompt_list)} prompts generated')
    return prompt_list

def inference(args, prompt_list, model):
    # Step 4: call llama
    batch_size = 256
    output_list = []
    for i in tqdm(range(0, len(prompt_list), batch_size)):
        inst_batch = prompt_list[i:i+batch_size]
        input_text_batch = [inst["input_text"] for inst in inst_batch]
        output_text_batch = model.generate(
            input_text_batch, temperature=args.temperature, top_p=args.top_p, 
            max_tokens=args.max_new_tokens, stop=["\n"]
        )
        output_list.extend(output_text_batch)

    return_output_list = []

    for _prompt, _output in zip(prompt_list, output_list):
        _output = _output.strip().split('\\n')[0].strip()
        
        # multiple choice
        if _output.startswith('(A)') or _output == 'A':
            _output = _prompt['options_list'][0]
        elif _output.startswith('(B)') or _output == 'B':
            _output = _prompt['options_list'][1]
        elif (_output.startswith('(C)') or _output == 'C') and len(_prompt['options_list']) > 2:
            _output = _prompt['options_list'][2]
        elif (_output.startswith('(D)') or _output == 'D') and len(_prompt['options_list']) > 3:
            _output = _prompt['options_list'][3]
        else:
            if args.task != 'sciq':
                _output = 'not related'

        return_output_list.append(_output)

    print(return_output_list)
    
    return return_output_list

def compute_score(prompt_list, output_text_list, max_words=256):
    from collections import defaultdict
    from metrics import normalize_answer, f1_score, exact_match_score, metric_max_over_ground_truths
    result_list = []
    qid_to_result = defaultdict(dict)
    for query, output_text in zip(prompt_list, output_text_list):
        qid = query["qid"]
        answers = query["answers"]
        norm_answers = [normalize_answer(ans) for ans in answers]
        em_value = metric_max_over_ground_truths(exact_match_score, output_text, norm_answers)
        f1_value = metric_max_over_ground_truths(f1_score, output_text, norm_answers)
        qid_to_result[qid][f'em_{max_words}'] = float(em_value)
        qid_to_result[qid][f'f1_{max_words}'] = float(f1_value)
    
    for query in prompt_list:
        qid = query["qid"]
        result = qid_to_result[qid]
        result_list.append({
            "qid": qid,
            **result
        })
    
    return result_list

def save_results(args, eval_list, output_path):
    write_jsonl(eval_list, output_path)


def prepare_data(dataset, encoder):
    chunk_searcher = LuceneSearcher(args.doc_bm25_dir)
    searcher = LuceneSearcher(args.prop_bm25_dir)

    qa_file = os.path.join(ROOT_DIR, f'retrieval/task_topics/{dataset}/qa.jsonl')
    qa_lines = readjsonl(qa_file)
    deduplicates = deduplicate_dicts(qa_lines, 'qid')

    chunk_max = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/run.whole.chunk.{encoder}.hits-500.txt'), searcher=chunk_searcher)
    chunk_max_prop = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/rrf/qid_cid_max_conv.tsv'), searcher=chunk_searcher)
    chunk_conv = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/rrf/conv.tsv'), searcher=chunk_searcher)
    rrf_chunk_max_conv = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/rrf/qid_cid_max_conv.tsv'), searcher=chunk_searcher)
    qid_max = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/rrf_prop/qid_max.tsv'), searcher=searcher)
    rrf_qid_max = get_records(os.path.join(ROOT_DIR, f'retrieval/runs_{dataset}/{encoder}/rrf_prop/rrf_qid_max.tsv'), searcher=searcher)

    retrieval_dict = {
        # 'chunk': chunk_max,
        # 'rrf_chunk': rrf_chunk_max_conv,
        'qid_max': qid_max,
        'rrf_qid_max': rrf_qid_max,
    }

    return retrieval_dict, deduplicates

def score_main(args, retrieval_dict, deduplicates, model, dataset, encoder, max_words):
    random.seed(args.seed)

    print("Test LLM ...")
    print(model.generate(["Hi! How are you?"], temperature=0, max_tokens=32, stop=[]))
    print("Done.")

    for _name, _retrieval in retrieval_dict.items():

        qid_content_dict = list2dict(_retrieval)

        for query in deduplicates:
            query['results'] = qid_content_dict[query['qid']]

        device = "cuda:0"

        prompt_list = prepare_prompts(args, deduplicates, max_words=max_words)
    
        output_text_list = inference(args, prompt_list, model)

        result_list = compute_score(prompt_list, output_text_list, max_words=max_words)
        
        result_dict = {
            'dataset': dataset,
            'encoder': encoder,
            'retrieval': _name,
            'em': percent_round(np.mean([_x[f'em_{max_words}'] for _x in result_list])),
            'f1': percent_round(np.mean([_x[f'f1_{max_words}'] for _x in result_list])),
        }

        if args.top_passages:
            result_dict['top_passages'] = max_words
        else:
            result_dict['max_words'] = max_words

        append_to_result(result_dict, LOG_PATH)


def main(args):
    if args.model_name == 'llama3':
        hf_model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
    elif args.model_name == 'llama':
        hf_model_name = 'meta-llama/Llama-2-7b-chat-hf'
    elif args.model_name == 'llama_instruct':
        hf_model_name = 'togethercomputer/Llama-2-7B-32K-Instruct'

    model = LLM(hf_model_name)

    
    
    if not args.top_passages:
        for dataset in ['sciq']:
            for encoder in encoders:
                retrieval_dict, data_points = prepare_data(args, dataset, encoder)
                for max_words in [10, 20, 50, 100, 200, 500]:
                    if max_words < 200:
                        args.batch_size = 8
                    else:
                        args.batch_size = 2
                    args.task = dataset
                    score_main(args, retrieval_dict, data_points, model, dataset, encoder, max_words)

    else:
        for dataset in ['sciq']:
            for encoder in encoders:
                retrieval_dict, data_points = prepare_data(args, dataset, encoder)
                for top_passage in [1, 2, 3, 4, 5]:
                    if top_passage < 2:
                        args.batch_size = 8
                    else:
                        args.batch_size = 4
                    args.task = dataset
                    score_main(args, retrieval_dict, data_points, model, dataset, encoder, top_passage)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", type=int, default=0)

    # for model
    parser.add_argument('--model_name', type=str, default='llama3')
    parser.add_argument('--top_passages', action='store_true', default=False)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--max_new_tokens", type=int, default=32)
    parser.add_argument("--prop_bm25_dir", type=str, default='prop_bm25_dir')
    parser.add_argument("--doc_bm25_dir", type=str, default='doc_bm25_dir')

    args = parser.parse_args()
    global LOG_PATH
    LOG_PATH = os.path.join(ROOT_DIR, f'context_{args.model_name}.log')

    print(args)

    main(args)
