import json

import os
import sys

import nltk
import torch
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Union, Dict

import logging
import openai
import pandas as pd
import re

from dotenv import load_dotenv
from nltk import sent_tokenize, word_tokenize, WordNetLemmatizer
from tenacity import retry, wait_fixed, stop_after_attempt, wait_exponential

from colbert.gpt import GPT, clean_response


def re_extract(prefix: str, target: str, truncate: bool = False):
    reg = rf"[\w\W]*{prefix}[\w\W]*: ([\w\W]+)"

    next_line_removed_target = (
        target.split("\n")[0].strip() if truncate else target.strip()
    )
    res = re.findall(reg, next_line_removed_target)
    return_target = res[0] if len(res) >= 1 else next_line_removed_target
    return_target = return_target.replace("\t", " ")

    if (return_target.startswith('"') and return_target.endswith('"')) or (
        return_target.startswith("'") and return_target.endswith("'")
    ):
        return_target = return_target[1:-1]

    return return_target


class PromptGenerator:
    # this example is for general retrieval
    react_examples = {
        # these examples are from msmarco
        1: "Example 1:\n"
        "Observe[Query]: walgreens store sales average\n"
        "Observe[Doc 1]: The average Walgreens salary ranges from approximately $15,000 per year for Customer Service Associate / Cashier to $179,900 per year for District Manager. Average Walgreens hourly pay ranges from approximately $7.35 per hour for Laboratory Technician to $68.90 per hour for Pharmacy Manager. Salary information comes from 7,810 data points collected directly from employees, users, and jobs on Indeed.\n"
        "Think: The query asked about walgreens store sales on average. \n"
        "Expand[Query]: [salary, Walgreens]\n",
        2: "Example 2:\n"
        "Observe[Query]: dna in bacteria\n"
        "Observe[Doc 1]: 'Bacterial DNA in Human Genomes. A new study finds strong evidence that bacteria can transfer genes into human genomes, especially in cancer cells. By Ed Yong | June 20, 2013. Pseudomonas, one of the bacteria groups that have transferred genes to human'\n"
        "Think: the query is looking for dna and bacteria. The given document talks about bacterial DNA in human genome.\n"
        "Expand[Query]: [DNA, Genomes, cells]\n",
    }
    dataset_specific_examples = {
        "webis-touche2020": react_examples,
        "nfcorpus": react_examples,
        "scidocs": react_examples,
        "fiqa": react_examples,
        "trec-covid-v2": react_examples,
        "scifact": react_examples,
    }

    def __init__(self, dataset_name: str):
        self.conversations = []
        self.response_format = None
        self.instruction = None
        self.dataset_name = dataset_name

        self.append_history(role="system", content=self._get_system_prompt())

    @staticmethod
    def __system_status():
        return f"You are an intelligent assistant who can help user what to query with more clear intent."

    def __response_format(self):
        return f'Response Format: "{self.response_format}"'

    @staticmethod
    def __restriction():
        return (
            "Restriction1: The response must follow the given response format.\n"
            "Restriction2: Violating restriction 1 is strongly forbidden.\n"
            "Restriction3: Any other explanation or note cannot be added to the response."
        )

    def _get_system_prompt(self):
        return f"{self.__system_status()}"

    def __instruction(self):
        return f"Instruction: {self.instruction}"

    @staticmethod
    def __input_data(data: Dict[str, str]):
        return "\n".join(
            map(lambda row: f"{row[0].capitalize()}: {row[1]}", data.items())
        )

    def _get_user_prompt(self, data: Dict[str, str]):
        return f"{self.__instruction()}\n" f"###\n" f"{self.__input_data(data)}"

    def get_prompt(self, job_type: str, data: Dict[str, str]):
        assert job_type in self.instructions

        self.instruction = self.instructions[job_type]
        self.response_format = self.response_formats[job_type]
        return [
            {"role": "system", "content": self._get_system_prompt()},
            {"role": "user", "content": self._get_user_prompt(data)},
        ]

    def append_history(self, role, content):
        assert role in ["user", "assistant", "system"]
        self.conversations.append({"role": role, "content": content})
        print(".", end=" ")

    def get_history(self):
        return self.conversations

    def clear_history(self):
        while len(self.conversations) > 1:
            self.conversations.pop()


class LoggerBase:
    def __init__(
        self,
        name: str,
        filepath: str,
        console_print: bool = True,
        continuing: bool = False,
    ):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(level=logging.DEBUG)

        log_stream_formatter = logging.Formatter(
            fmt=f"%(message)s",
        )

        datafile_handler = logging.FileHandler(
            filename=filepath, mode="a+" if continuing else "w+"
        )
        datafile_handler.setFormatter(log_stream_formatter)
        datafile_handler.setLevel(level=logging.INFO)
        self.logger.addHandler(datafile_handler)

        if console_print:
            console_handler = logging.StreamHandler(stream=sys.stdout)
            console_handler.setFormatter(log_stream_formatter)
            console_handler.setLevel(level=logging.INFO)
            self.logger.addHandler(console_handler)


class LoggerQuery(LoggerBase):
    def __init__(self, filepath: str, continuing: bool):
        super().__init__(name="NewQuery", filepath=filepath, continuing=continuing)

    def log(self, qid, query):
        self.logger.info(f"{qid}\t{query}")


class LoggerIntentQuery(LoggerBase):
    def __init__(self, filepath: str, continuing: bool):
        super().__init__(
            name="Intent-NewQuery",
            filepath=filepath,
            console_print=False,
            continuing=continuing,
        )

    def log(self, prev_query: str, docs: List[str], intent: str, new_query: str):
        self.logger.info(f"{prev_query}\t{docs}\t{intent}\t{new_query}")


class LoggerEverything(LoggerBase):
    def __init__(self, filepath: str, continuing: bool, name: str = "Everything"):
        super().__init__(
            name=name, filepath=filepath, console_print=False, continuing=continuing
        )

    def log(self, **kwargs):
        self.logger.info(json.dumps(kwargs))


wnl = WordNetLemmatizer()


def is_plural(word):
    lemma = wnl.lemmatize(word, "n")
    plural = True if word is not lemma else False
    return plural, lemma


if __name__ == "__main__":
    parser = ArgumentParser(
        "Obtain tacit intent of the query and refine with the generated intent."
    )

    parser.add_argument("--dataset_name", type=str)

    parser.add_argument("--new_queries", type=str)
    parser.add_argument("--new_queries_intents", type=str)
    parser.add_argument("--everything", type=str)
    parser.add_argument("--everything_step", type=str)

    parser.add_argument("--expansion", type=str)
    parser.add_argument("--prf_terms", type=str)

    args = parser.parse_args()

    # ============= CHECK CONTINUE ============= #
    done_qids = []
    if os.path.exists(args.new_queries):
        new_queries_df = pd.read_csv(
            args.new_queries, sep="\t", names=["q_id", "q_text"]
        )
        new_queries_df = new_queries_df.astype(str)
        done_qids = new_queries_df["q_id"].values.tolist()
    print("# ============= CONTINUING ============= #")
    print(f"# ==    from {len(done_qids) + 1}    == #")
    print("# ====================================== #")

    exp = torch.load(args.expansion)
    metadata = exp["metadata"]
    docs = dict(
        map(
            lambda info: (
                info[0],
                list(
                    map(
                        lambda d: f'{d["title"]} {d["text"]}',
                        info[1]["query_prf"]["PRF"],
                    )
                ),
            ),
            metadata.items(),
        )
    )

    queries = dict(
        map(
            lambda info: (
                info[0],
                info[1]["query"],
            ),
            metadata.items(),
        )
    )

    qids = list(filter(lambda qid: qid not in done_qids, metadata.keys()))

    # prf_terms_meta = torch.load(args.prf_terms)['metadata']
    with open(args.prf_terms) as f:
        prf_terms_meta = json.load(f)
    # ===================================== #

    # # ============= AGENT PREP ============ #
    prompt_generator = PromptGenerator(
        dataset_name=args.dataset_name,
    )
    gpt = GPT()
    # ===================================== #

    # ============ INTERACTION ============ #
    for p in [
        args.new_queries,
        args.new_queries_intents,
        args.everything,
        args.everything_step,
    ]:
        Path(p).parent.mkdir(parents=True, exist_ok=True)
    query_logger = LoggerQuery(
        args.new_queries, continuing=True if len(done_qids) > 0 else False
    )
    intent_query_logger = LoggerIntentQuery(
        args.new_queries_intents, continuing=True if len(done_qids) > 0 else False
    )
    everything_logger = LoggerEverything(
        args.everything, continuing=True if len(done_qids) > 0 else False
    )
    stepbystep_logger = LoggerEverything(
        args.everything_step,
        name="Everything-step",
        continuing=True if len(done_qids) > 0 else False,
    )

    k = 1
    for idx, qid in enumerate(qids):
        # query + LLM + action / environment stage
        query = queries[qid]
        topk_docs = docs[qid]

        prf_terms = prf_terms_meta[qid]["words"]
        prf_terms_by_doc = prf_terms_meta[qid]["d_words"]

        prompt_generator.append_history(
            role="user",
            content="Solve a query expansion task with interleaving Observation, Thought, Action steps. \n"
            "Observe[Query, Doc]: observe the query and document. \n"
            "Think: reason about the current situation. \n"
            "Expand[query]: extracts terms from given document to expand query. \n"
            "Here are some examples. \n"
            "Example1:\n"
            f"{prompt_generator.dataset_specific_examples[prompt_generator.dataset_name][1]}\n"
            "Example2:\n"
            f"{prompt_generator.dataset_specific_examples[prompt_generator.dataset_name][2]}\n"
            f"###\n"
            f"Do you understand the task?",
        )
        response = gpt.query(prompt=prompt_generator.get_history())
        prompt_generator.append_history(role="assistant", content=response)

        prompt_generator.append_history(
            role="user",
            content=f"Now here is the actual task.\n"
            f"###\n"
            f"Observe[Query]: {query}",
        )
        for d_idx, doc_text in enumerate(topk_docs[:k]):
            # reduce the document length
            sentences: List[str] = sent_tokenize(doc_text)
            word_cnt = 0
            sliced_doc_sents = []
            for sentence in sentences:
                words = word_tokenize(sentence)
                if word_cnt + len(words) > 400:
                    continue

                word_cnt += len(words)
                sliced_doc_sents.append(sentence)

            sliced_doc_text = " ".join(sliced_doc_sents)
            prompt_generator.append_history(
                role="user", content=f"Observe[Doc {d_idx + 1}]: {sliced_doc_text}\n"
            )

        # few-shot intent
        prompt_generator.append_history(role="user", content=f"Think: {{thought}}\n")

        response = gpt.query(prompt=prompt_generator.get_history())
        prompt_generator.append_history(role="assistant", content=response)
        final_intent = prompt_generator.get_history()[-1]["content"]

        # few-shot expand
        n_target_words = 5
        target_output = (
            f"[{', '.join([f'word{i + 1}' for i in range(n_target_words)])}]"
        )
        prompt_generator.append_history(
            role="user",
            content=f"Extract terms from the given document to expand the query.\n"
            "###\n"
            "Only response the expanding terms, do not say any other explain.\n"
            f"Total number of words to be extracted is {n_target_words}. Do not provide more or less.\n"
            f"Follow the format given below.\n"
            f"You must give me {n_target_words} in an array surrounded by square brackets. Other formats, such as list with line change and numbered lines, are not allowed\n"
            f"Expand[Query]: {target_output}\n",
        )

        response = gpt.query(prompt=prompt_generator.get_history())
        prompt_generator.append_history(role="assistant", content=response)

        unclean_query = prompt_generator.conversations[-1]["content"]
        if (
            "Expand[Query]" not in unclean_query
            or "[" not in unclean_query
            or "]" not in unclean_query
        ) and ("apologize" in unclean_query or "sorry" in unclean_query):
            clean_query = "[]"
        else:
            clean_query = re_extract("Expand\[Query\]", unclean_query)
            clean_query = f'[{", ".join(clean_response(clean_query))}]'
            refined_query = clean_query

        # logging
        print(f"{args.dataset_name}\t{idx + 1}/{len(qids)} :: \t", end="")
        query_logger.log(qid, clean_query)
        everything_logger.log(history=prompt_generator.get_history())
        stepbystep_logger.log(original=query, intent=final_intent, prf=clean_query)
        intent_query_logger.log(
            prev_query=query, docs=[""], intent=final_intent, new_query=clean_query
        )

        prompt_generator.clear_history()
    # ===================================== #
