"""The code is adapted from https://github.com/princeton-nlp/ALCE/blob/main/eval.py"""
import copy
import json
import logging
import random
import re
from argparse import ArgumentParser

import os
import numpy as np
import pandas as pd
import torch
from nltk import sent_tokenize
from tqdm import tqdm
from transformers import (
    AutoTokenizer, AutoModelForCausalLM
)
from concurrent.futures import ThreadPoolExecutor, as_completed

random.seed(0)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S')
logger = logging.getLogger(__name__)


global mistral_7b_instruct, mistral_7b_tokenizer
mistral_7b_instruct, mistral_7b_tokenizer = None, None
global total
total = 0


def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")


def truncate_paragraph(paragraph, max_words):
    # Tokenize paragraph into sentences
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', paragraph)

    # Tokenize each sentence into words and form trunks
    trunks = []
    current_trunk = []
    current_word_count = 0

    for sentence in sentences:
        sentence_words = sentence.split()  # Tokenize sentence into words
        sentence_word_count = len(sentence_words)

        if current_word_count + sentence_word_count <= max_words:
            current_trunk.append(sentence)
            current_word_count += sentence_word_count
        else:
            trunks.append(' '.join(current_trunk))
            current_trunk = [sentence]
            current_word_count = sentence_word_count

    if current_trunk:
        trunks.append(' '.join(current_trunk))

    return trunks


def _run_nli_autoais(passage, claim, partial):
    """
    Run inference for assessing AIS between a premise and hypothesis.
    Adapted from https://github.com/google-research-datasets/Attributed-QA/blob/main/evaluation.py
    """
    global mistral_7b_instruct, mistral_7b_tokenizer, total

    passage_trunks = truncate_paragraph(passage, 500)
    inference = 0
    for trunk in passage_trunks:
        if partial:
            s = f"Can the source at least partially support the claim? Start your answer with 'Yes' or 'No'.\nSource: {trunk}\nClaim: {claim}"
        else:
            s = f"Is the claim faithful to the source? A claim is faithful to the source if the core part in the claim can be supported by the source.\nStart your answer with 'Yes' or 'No'.\nSource: {trunk}\nClaim: {claim}"
        messages = [{'role': 'user', 'content': s}]
        encodeds = mistral_7b_tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encodeds.to('cuda')
        generated_ids = mistral_7b_instruct.generate(model_inputs, max_new_tokens=200, do_sample=False, pad_token_id=mistral_7b_tokenizer.eos_token_id)
        decoded = mistral_7b_tokenizer.batch_decode(generated_ids, temperature=0)[0]
        res = decoded[decoded.find('[/INST]') + len('[/INST]'):].strip()

        if res.startswith('Yes'):
            inference = 1
            break
        inference = 1

    return inference


def compute_autoais(data,
                    decontext=False,
                    concat=False,
                    qampari=False,
                    at_most_citations=None):
    """
    Compute AutoAIS score.

    Args:
        data: requires field `output` and `docs`
              - docs should be a list of items with fields `title` and `text` (or `phrase` and `sent` for QA-extracted docs)
        citation: check citations and use the corresponding references.
        decontext: decontextualize the output
    """

    global mistral_7b_instruct, mistral_7b_tokenizer

    def _format_document(doc):
        """Format document for AutoAIS."""

        if "sent" in doc:
            # QA-extracted docs
            return "Title: %s\n%s" % (doc['title'], doc['sent'])
        else:
            return "Title: %s\n%s" % (doc['title'], doc['text'])

    ais_scores = []
    ais_scores_prec = []

    sent_total = 0
    sent_mcite = 0
    sent_mcite_support = 0
    sent_mcite_overcite = 0

    eval_log = []

    for item in data:
        # Get sentences by using NLTK
        if qampari:
            sents = [item['question'] + " " + x.strip() for x in
                     item['output'].rstrip().rstrip(".").rstrip(",").split(",")]
        else:
            sents = sent_tokenize(item['output'])
        if len(sents) == 0:
            continue

        target_sents = [remove_citations(sent).strip() for sent in sents]

        entail = 0
        entail_prec = 0
        total_citations = 0
        # for sent_id, sent in  tqdm(enumerate(sents), total=len(sents), desc="Processing sentences"):
        for sent_id, sent in enumerate(sents):
            target_sent = target_sents[sent_id]  # Citation removed and (if opted for) decontextualized
            joint_entail = -1  # Undecided

            # Find references
            ref = [int(r[1:]) - 1 for r in re.findall(r"\[\d+", sent)]  # In text citation id starts from 1
            logger.info(f"For `{sent}`, find citations {ref}")
            if len(ref) == 0:
                # No citations
                joint_entail = 0
            elif any([ref_id >= len(item['docs']) for ref_id in ref]):
                # Citations out of range
                joint_entail = 0
            else:
                if at_most_citations is not None:
                    ref = ref[:at_most_citations]
                total_citations += len(ref)
                joint_passage = '\n'.join([_format_document(item['docs'][psgs_id]) for psgs_id in ref])

            # If not directly rejected by citation format error, calculate the recall score
            if joint_entail == -1:
                joint_entail = _run_nli_autoais(joint_passage, target_sent, partial=False)

            entail += joint_entail
            if joint_entail == 0:
                logger.info(f'[Unsupported sentence] {sent}')
            if len(ref) > 1:
                sent_mcite += 1

            unnecessary_citations = []

            # calculate the precision score if applicable
            if joint_entail and len(ref) > 1:
                sent_mcite_support += 1
                # Precision check: did the model cite any unnecessary documents?
                for psgs_id in ref:
                    # condition A
                    passage = _format_document(item['docs'][psgs_id])
                    nli_result = _run_nli_autoais(passage, target_sent, partial=True)

                    # condition B
                    if not nli_result:
                        subset_exclude = copy.deepcopy(ref)
                        subset_exclude.remove(psgs_id)
                        passage = '\n'.join([_format_document(item['docs'][pid]) for pid in subset_exclude])
                        nli_result = _run_nli_autoais(passage, target_sent, partial=False)
                        if nli_result:  # psgs_id is not necessary
                            flag = 0
                            sent_mcite_overcite += 1
                            logger.info(f'[Unnecessary citation] sent: {sent} citation: [{psgs_id}]')
                            unnecessary_citations.append(psgs_id)
                        else:
                            entail_prec += 1
                    else:
                        entail_prec += 1
            else:
                entail_prec += joint_entail

            eval_log.append({
                "sent": sent,
                "target_sent": target_sent,
                "ref": ref,
                "joint_entail": joint_entail,
                "unnecessary_citations": unnecessary_citations,
            })

        sent_total += len(sents)
        ais_scores.append(entail / len(sents))
        ais_scores_prec.append(entail_prec / total_citations if total_citations > 0 else 0)  # len(sents))


    citation_rec = 100 * np.mean(ais_scores)
    citation_prec = 100 * np.mean(ais_scores_prec)

    return {
        "evaluation_logs": eval_log,
        "citation_rec": citation_rec,
        "citation_prec": citation_prec,
    }


def load_str(path):
    with open(path, 'r') as f:
        return '\n'.join(f.readlines())


def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)


def extract_url(text):
    # This pattern matches the format [number]: text (url)
    pattern = r'\[\d+\]: .* \((https?://[^\)]+)\)'

    # Find the first match of the pattern in the text
    match = re.search(pattern, text)

    # Return the URL if a match is found
    if match:
        return match.group(1)
    else:
        return None


def expand_citaions(output):
    """
    Expand citations by following rule:
        1. convert "<sentence 1>. <sentence 2> [2][3]" into "<sentence 1> [2][3]"
        2. "<sentence 1>[1]. <last paragraph senetence>" will be changed to "<sentence 1>[1]. <last paragraph senetence>[1]. "
    """

    def find_citations(sentence):
        return re.findall(r'\[(\d+)\]', sentence)

    modified_pargraphs = []
    for paragraph_idx, paragraph in enumerate(output.split("\n")):
        if len(paragraph) == 0:
            continue
        sentence_endings = r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s"
        sentences = re.split(sentence_endings, paragraph)
        sentences = [sentence.strip() for sentence in sentences if sentence]
        modified_sentences = []
        for sentence_idx, sentence in enumerate(sentences):
            if len(sentence) == 0:
                continue
            citations = find_citations(sentence)
            added_citations = ''
            if len(citations) == 0:
                if sentence_idx == len(sentences) - 1 and sentence_idx - 1 >= 0:
                    for citation in find_citations(sentences[sentence_idx - 1]):
                        added_citations += f"[{citation}]"
                elif sentence_idx + 1 < len(sentences):
                    for citation in find_citations(sentences[sentence_idx + 1]):
                        added_citations += f"[{citation}]"
            modified_sentences.append(sentence[:-1] + added_citations + sentence[-1])
        modified_pargraph = " ".join(modified_sentences)
        modified_pargraphs.append(modified_pargraph)
    modified_output = '\n'.join(modified_pargraphs).strip()
    return modified_output


def format_data(root_dir, file_name_suffix, do_citation_expansion=False):
    final_page = load_str(f'{root_dir}{file_name_suffix}.txt')
    search_results = load_json(f'{root_dir}_search_results.json')
    if 'url_to_info' in search_results:
        url_to_info = search_results['url_to_info']
        assert list(search_results['url_to_unified_index'].keys()) == list(url_to_info.keys())
    else:
        url_to_info = {d['url']: {'title': d['title'], 'snippets': d['snippets']} for d in search_results}

    output = []
    for line in final_page.split('\n'):
        if len(line) == 0 or line[0] == '#':
            continue
        output.append(line)

    output = '\n'.join(output).strip()

    if do_citation_expansion:
        output = expand_citaions(output)

    docs = []

    for url in url_to_info:
        docs.append({
            'title': url_to_info[url]['title'],
            'text': '\n'.join(set(url_to_info[url]['snippets'])),
        })

    return output, docs


def process_method(article_name, method_name, method_data):
    result = compute_autoais(data=[method_data])
    return method_name, {
        "recall": result['citation_rec'],
        "precision": result['citation_prec'],
        "article": article_name,
        "eval_log": result['evaluation_logs']
    }

def main(args):
    global mistral_7b_instruct, mistral_7b_tokenizer
    mistral_7b_instruct = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
    mistral_7b_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
    mistral_7b_instruct = mistral_7b_instruct.to('cuda')
    
    with open(args.file_to_grade) as f:
        data_to_grade = json.load(f)

    # futures = []
    # with ThreadPoolExecutor(max_workers=10) as executor:
    #     for article_name, article_data in data_to_grade.items():
    #         for method_name, method_data in article_data.items():
    #             futures.append(executor.submit(process_method, article_name, method_name, method_data))
        
    #     for future in tqdm(as_completed(futures), total=len(futures), desc="Processing methods"):
    #         method_name, result = future.result()
    #         if method_name not in grading_result:
    #             grading_result[method_name] = []
    #         grading_result[method_name].append(result)

    #         # Write to file after each future completes
    #         with open(args.output_path, "w") as f:
    #             json.dump(grading_result, f, indent=2)
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    for article_name, article_data in tqdm(list(data_to_grade.items()), "total process: "):
        grading_result = {}
        for method_name, method_data in article_data.items():
            method_name, result = process_method(article_name, method_name, method_data)
            grading_result[method_name] = result

        # Write to a separate file for each article_name
        article_output_path = os.path.join(args.output_path, f"{article_name}.json")
        with open(article_output_path, "w") as f:
            json.dump(grading_result, f, indent=2)

    # Aggregate and print statistics
    all_grading_results = {}

    # Read all JSON files back
    for filename in os.listdir(args.output_path):
        if filename.endswith(".json"):
            file_path = os.path.join(args.output_path, filename)
            with open(file_path, "r") as f:
                grading_result = json.load(f)
                for method_name, evaluation in grading_result.items():
                    del evaluation["eval_log"]
                    if method_name not in all_grading_results:
                        all_grading_results[method_name] = []
                    all_grading_results[method_name].append(evaluation)

    # Compute and print aggregate statistics
    for method_name, evaluations in all_grading_results.items():
        total_recall = sum(evaluation["recall"] for evaluation in evaluations)
        total_precision = sum(evaluation["precision"] for evaluation in evaluations)
        count = len(evaluations)
        
        avg_recall = total_recall / count
        avg_precision = total_precision / count
        
        print(f"Method: {method_name}")
        print(f"Average Recall: {avg_recall}")
        print(f"Average Precision: {avg_precision}")
        print()

    print(f"total: {total}")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("-f", '--file-to-grade', default="./grading_items.json", help="grading file")
    parser.add_argument("-o", '--output-path', default="./grading_output", help="grading result dir")
    args = parser.parse_args()
    logger.setLevel(logging.ERROR)
    main(args)