import numpy as np
import os
from vendi_score import vendi
from openai_access import call_chatgpt

from depth import addTopicPrompt,addDk, addRequirement, addGoal, addApplication,addPs,addPc,addReasoning,addEmotion,all_Prompt
from breadth import BrandnewPrompt,FormattingIntputPrompt,FormattingOutputPrompt,RefinePrompt
import time
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import json

from scipy.special import softmax
def start_timer():
    return time.time()
def stop_timer(start_time):
    return time.time() - start_time
devices = [torch.device(f'cuda:{i}') for i in [4, 5, 6]]
device, device1, device2 = devices
tokenizer = AutoTokenizer.from_pretrained("")
tokenizer1 = AutoTokenizer.from_pretrained("deita-quality-scorer")
tokenizer2 = AutoTokenizer.from_pretrained("deita-complexity-scorer")
model = AutoModelForCausalLM.from_pretrained("/InsTagger",device_map=device,torch_dtype=torch.float16, revision = 'v1.0.0')
model.eval()
model1 = AutoModelForCausalLM.from_pretrained("deita-quality-scorer",device_map=device1,torch_dtype=torch.float16)
model1.eval()
model2 = AutoModelForCausalLM.from_pretrained("/deita-complexity-scorer",device_map=device2,torch_dtype=torch.float16)
model2.eval()
def infer_quality(input_text, resp_text):
    quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ")
    user_input = quality_template.format(instruction=input_text, output=resp_text)
    input_ids = tokenizer1.encode(user_input, return_tensors="pt").to(device1)
    outputs = model1.generate(input_ids, max_new_tokens=512, num_return_sequences=1, return_dict_in_generate=True, output_scores=True)
    logprobs_list = outputs.scores[0][0].cpu()
    score_logits = []
    id2score = {
        29896: "1",
        29906: "2",
        29941: "3",
        29946: "4",
        29945: "5",
        29953: "6"
    }
    score_template = np.array([1,2,3,4,5,6])
    for k in id2score:
        score_logits.append(logprobs_list[k])
    score_logits = np.array(score_logits)
    score_npy = softmax(score_logits, axis=0)
    score_npy = score_npy * score_template

    score_npy = np.sum(score_npy, axis=0)
    return score_npy


def infer_complexity(input_text):
    complexity_template = ("You are a helpful assistant. Please identify the complexity score of the following user query. \n##Query: {instruction}  \n##Complexity: ")
    user_input = complexity_template.format(instruction=input_text)
    input_ids = tokenizer2.encode(user_input, return_tensors="pt").to(device2)
    outputs = model2.generate(input_ids, max_new_tokens=512, num_return_sequences=1, return_dict_in_generate=True, output_scores=True)
    logprobs_list = outputs.scores[0][0].cpu()
    score_logits = []
    id2score = {
        29896: "1",
        29906: "2",
        29941: "3",
        29946: "4",
        29945: "5",
        29953: "6"
    }
    score_template = np.array([1,2,3,4,5,6])
    for k in id2score:
        score_logits.append(logprobs_list[k])
    score_logits = np.array(score_logits)
    score_npy = softmax(score_logits, axis=0)
    score_npy = score_npy * score_template

    score_npy = np.sum(score_npy, axis=0)
    return score_npy

def get_instag_value(instruction):
    gen_config_dic={"do_sample":True,"max_new_tokens":512,"num_return_sequences":1}
    prompt=f"Please identify tags of user intentions in the following user query and provide an explanation for each tag. Please response in the JSON format {{\"tag\": str, \"explanation\": str}}.\n User query: "+instruction
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(device)
    generate_ids = model.generate(**inputs,**gen_config_dic)
    result=tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0][len(prompt):].strip()
    value=0
    try:
        value=len(json.loads(result))
    except:
        value=result.count("explanation:")
    return value



class Node:
    def __init__(self, instruction=None,parent=None, prompt=None):
        self.instruction = instruction
        self.parent = parent
        self.rewrite_prompt=prompt 
        self.children = [] 
        self.visits = 0 
        self.value = 0  
        self.depth = 0 if parent is None else parent.depth + 1
        self.is_terminal = False
        self.values=[]
    def uct(self, c=1):
        if self.visits == 0:
            return float('inf')
        return self.value / self.visits + c*np.sqrt(2 * np.log(self.parent.visits) / self.visits)
 
    def __str__(self):
        return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, instruction={self.instruction})"
    
    def to_dict(self):
        return {
            'instruction': self.instruction,
            'visits': self.visits,
            'value': self.value,
            'depth': self.depth,
            'parent': self.parent.to_dict() if self.parent else None,
            'children': [child.to_dict() for child in self.children],
            'is_terminal': self.is_terminal
        }

class Tree:
    def __init__(self,root,max_value):
        self.root=root
        self.max_value=max_value
    def get_value(self,instruct,new_node):
        value_instag=get_instag_value(instruct)   
        value_quality_score=infer_quality(instruct,"")
        complexity_score=infer_complexity(instruct)
        new_node.values=[(value_quality_score),(value_instag),(complexity_score)] 
        return (value_quality_score)+(value_instag)+(complexity_score)
   
    def prompt_wrap(self,node):
        prompts=[]
        for prompt_wrap in (addTopicPrompt,addDk, addRequirement, addGoal, addApplication,addPs,addPc,addReasoning,addEmotion,BrandnewPrompt,FormattingIntputPrompt,FormattingOutputPrompt,RefinePrompt):
            prompts.append(prompt_wrap(node.instruction))
        return  prompts

    def generate_new_states(self,node,n,max_depth):
        sampled_actions = self.get_samples(node,n)
        nodes=[]
        for prompt,instruct in sampled_actions:   
            new_node = Node(instruction=instruct, parent=node)
            new_node.rewrite_prompt=prompt
            new_node.value =self.get_value(instruct,new_node)
            if new_node.value >= self.max_value:
                new_node.is_terminal = True
            else:
                new_node.is_terminal = False    
                new_node.depth = node.depth + 1
            if new_node.depth >= max_depth:
                new_node.is_terminal = True
            nodes.append(new_node)
        return nodes

    def get_samples(self,node,n):
        result=[]
        prompts = self.prompt_wrap(node)
        prompts_ = [random.choice(prompts) for _ in range(n)]
        for prompt in prompts_: 
            try:
                response=call_chatgpt(prompt, n=1,temperature=0.7,max_tokens=len(prompt)+50)[0]
            except:
                response="Failure occur!"
            if('rewritten' in response.lower()):
                response=response.split(':')[-1]
            result.append((prompt,response))
        return result
    def collect_all_nodes(self,node):
            """Recursively collect all nodes starting from the given node."""
            nodes = [node]
            for child in node.children:
                nodes.extend(self.collect_all_nodes(child))
            return nodes
    
    def backpropagate(self,node, value):
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + value) / node.visits
            node = node.parent

    def select_node(self,node):
        while node and node.children:
            terminal_children = [child for child in node.children if child.is_terminal]
            if len(terminal_children) == len(node.children):
                if node.parent:  
                    node.parent.children.remove(node)
                    node = node.parent  
                    continue  
            node = max((child for child in node.children if not child.is_terminal), key=lambda child: child.uct(), default=None)            
        return node 

    def expand_node(self,node,n,max_depth):
        if node.depth >= max_depth:
            node.is_terminal = True
            return
        new_nodes = self.generate_new_states(node, n,max_depth)
        node.children.extend(new_nodes)

    def rollout(self,node, max_depth,n=1):
        depth = node.depth
        while not node.is_terminal and depth < max_depth:
            new_states = []
            values = []
            while len(new_states) == 0:
                new_states = self.generate_new_states(node,n,max_depth)
                node.children.extend(new_states)
            for state in new_states:
                if state.is_terminal:
                    return state
            values =[child.value for child in new_states ]
            max_value_index = values.index(max(values))
            node = new_states[max_value_index]
            depth += 1
        node.is_terminal=True
        return  node
    def mcts_search(self,n,n_expands,max_depth):
        all_nodes_list=[]
        terminal_nodes=[]
        for i in range(n):
            node = self.select_node(self.root)
            if node.is_terminal and node.value >= self.max_value:
                return node
            self.expand_node(node,n_expands,max_depth)
            terminal_node = self.rollout(max(node.children, key=lambda child: child.value), max_depth=max_depth,n=2)
            terminal_nodes.append(terminal_node)
            self.backpropagate(terminal_node, terminal_node.value)           
            terminal_nodes_with_reward_1 = [node for node in self.collect_all_nodes(self.root) if node.is_terminal and terminal_node.value  >= self.max_value]
            if terminal_nodes_with_reward_1:
                best_node = max(terminal_nodes_with_reward_1, key=lambda x: x.value)
                all_nodes_list = self.collect_all_nodes(self.root)
                return all_nodes_list,best_node
        all_nodes_list = self.collect_all_nodes(self.root)
        best_child = max(terminal_nodes, key=lambda x: x.value)
        return all_nodes_list,best_child