from .retriever import ReferenceRetiever
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import re, os
import os
import json
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class WebGLM:
    def __init__(self, webglm_ckpt_path, retriever_ckpt_path, device=None, filter_max_batch_size=400, searcher_name="serpapi"):
        self.device = device
        self.ref_retriever = ReferenceRetiever(retriever_ckpt_path, device, filter_max_batch_size, searcher_name)
        self.tokenizer = AutoTokenizer.from_pretrained(webglm_ckpt_path, trust_remote_code=True)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(webglm_ckpt_path, trust_remote_code=True)
        self.model = self.model.half()
        if device:
            self.model.to(device)
        self.model.eval()
    
    def query(self, question):
        """
        查询函数，接收一个问题字符串，返回一个包含答案、引用列表和提示的字典，如果引用列表为空，则返回空字典。
        
        Args:
            question (str): 问题字符串，不能为None或者空字符串。
        
        Returns:
            dict{str: list, str: str, str: str}, {
                    "answer": str, (str) 答案，格式为字符串，可以是空字符串。
                    "references": list[dict], (list of dict) 引用列表，每个元素都是一个包含文本（'text'）键值对的字典，不能为空列表。
                    "prompt": str, (str) 提示，格式为字符串，不能为None或者空字符串。
                }
        
        Raises:
            None
        
        """
        refs = self.ref_retriever.query(question)
        if not refs:
            return { "references": [], "answer": "", "prompt": ""}
        prompt = ""
        question_info = "Question: " + question + "\\Answer: [gMASK]"
        cmp_prompt = ""
        for ix, ref in enumerate(refs):
            temp_passage = ""
            txt = ref["text"]
            temp_passage = "Reference [" + str(ix+1) + "]: " + txt + "\\"
            cmp_prompt =  prompt + temp_passage + question_info 
            if len(self.tokenizer(cmp_prompt)["input_ids"]) >= 1024:   # 限制reference长度为1024否则报错
                break
            prompt += temp_passage

        prompt += question_info 

        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = self.tokenizer.build_inputs_for_generation(inputs, max_gen_length=4096)
        if self.device:
            inputs = inputs.to(self.device)
        outputs = self.model.generate(**inputs, max_length=6000, \
        eos_token_id = self.tokenizer.eop_token_id, \
        pad_token_id=self.tokenizer.eop_token_id)
        f = re.findall(r"<\|startofpiece\|>(.+)<\|endofpiece\|>", self.tokenizer.decode(outputs[0].tolist()))
        assert len(f) > 0
        return { "answer": f[0].strip(), "references": refs, "prompt": prompt}


    def query_references(self, question):
        """
        查询问题对应的所有引用，并生成提示信息
        Args:
            question (str): 问题字符串，不含标点符号
        Returns:
            dict: 包含如下键值对：
                    {
                        "question": str, （str）问题字符串，不含标点符号
                        "prompt": str, （str）提示信息，包含所有引用和问题信息
                        "output": list, （list of str）空列表，因为模型预测结果在后续处理中获取
                        "model_predict": str, （str）空字符串，因为模型预测结果在后续处理中获取
                    }
        Raises:
            None
        """
        try:
            refs = self.ref_retriever.query(question)
        except:
            refs = []
        if not refs:
            return {"question": question, "prompt": "", "output": [], "model_predict": ""}
        prompt = ""
        question_info = "Question: " + question + "\\Answer:"
        cmp_prompt = ""
        for ix, ref in enumerate(refs):
            temp_passage = ""
            txt = ref["text"]
            temp_passage = "Reference [{" + str(ix+1) + "}]:" + "txt" + "\\"
            cmp_prompt =  prompt + temp_passage + question_info 
            if len(self.tokenizer(cmp_prompt)["input_ids"]) >= 1024:   # 限制reference长度为1024否则报错
                break
            prompt += temp_passage
        prompt += question_info 
        return {"question": question, "prompt": prompt, "output": [], "model_predict": ""}

    def query_generate(self, question):
        """
        生成一个问题，包含问题和mask部分。
        
        Args:
            question (str): 问题字符串，不能为空。
        
        Returns:
            str: 返回一个包含问题和mask部分的字符串，如果没有找到任何结果则返回空字符串。
        
            格式为：question + "[gMASK]"，例如："你好世界是什么？ [gMASK]"。
        
        Raises:
            None
        
        Others:
            None
        
        TODO:
            - 当前实现只支持英文问题，需要添加中文问题的支持。
            - 当前实现存在一些bug，需要进行调试和修复。
        """
        try:
            prompt = question + " [gMASK]"
            if not prompt:
                return ""
            inputs = self.tokenizer(prompt, return_tensors="pt")
            inputs = self.tokenizer.build_inputs_for_generation(inputs, max_gen_length=4096)
            if self.device:
                inputs = inputs.to(self.device)
            outputs = self.model.generate(**inputs, max_length=6000, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id)
            f = re.findall(r"<\|startofpiece\|>(.+)<\|endofpiece\|>", self.tokenizer.decode(outputs[0].tolist()))
            if len(f) < 0:
                return ""
            return f[0].strip()
        except:
            return ""
    
    def stream_query(self, question):
        """
        Stream the query to the model and retrieve the answer.
        
        Args:
            question (str): The question to be answered.
        
        Yields:
            dict: A dictionary containing two keys: 'references' (a list of references retrieved by the retriever,
                each reference is a dictionary with keys 'title' and 'text'), and 'answer' (the answer to the question).
                If no reference is found, only an empty 'answer' field will be returned.
        """
        refs = self.ref_retriever.query(question)

        if not refs:
            yield { "references": [], "answer": "", "prompt": ""}
            return
        yield { "references": refs }
        prompt = ''
        for ix, ref in enumerate(refs):
            txt = ref["text"]
            prompt += "Reference [{" + str(ix+1) + "}]:" + "txt" + "\\"
        prompt += "Question: " + question + "\\Answer: [gMASK]"

        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = self.tokenizer.build_inputs_for_generation(inputs, max_gen_length=4096)
        if self.device:
            inputs = inputs.to(self.device)
        outputs = self.model.generate(**inputs, max_length=6000, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id)
        f = re.findall(r"<\|startofpiece\|>(.+)<\|endofpiece\|>", self.tokenizer.decode(outputs[0].tolist()))
        assert len(f) > 0
        yield { "answer": f[0].strip() }


def load_model(args):
    webglm_ckpt_path = args.webglm_ckpt_path or os.getenv("WEBGLM_CKPT") or 'THUDM/WebGLM'
    retiever_ckpt_path = args.retriever_ckpt_path or os.getenv("WEBGLM_RETRIEVER_CKPT")
    if not retiever_ckpt_path:
        print('Retriever checkpoint not specified, please specify it with --retriever_ckpt_path or $WEBGLM_RETRIEVER_CKPT')
        exit(1)
    if args.serpapi_key:
        os.environ["SERPAPI_KEY"] = args.serpapi_key
    
    print('WebGLM Initializing...')
    
    webglm = WebGLM(webglm_ckpt_path, retiever_ckpt_path, args.device, args.filter_max_batch_size, args.searcher)
    
    print('WebGLM Loaded')
    
    return webglm