import os, time
from tenacity import retry, wait_random_exponential, stop_after_attempt
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastchat.model import load_model, get_conversation_template, add_model_args
import logging
import argparse
import json
import concurrent.futures
from dotenv import load_dotenv
import openai
import traceback
import replicate
from openai import AzureOpenAI
from torch.nn.functional import cross_entropy
load_dotenv()
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
import torch
import yaml

# Load configuration
with open('generation_config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

English_Q_selection = "As a helpful assistant, you need to categorize an English question, considering that the background of this question is not common in an English environment. Therefore, you need to choose the most suitable language for this question. You need to analyze the required language context for the question first, and then tell me at the end which language you think is most suitable to answer the question. The question is as follows: "
Chinese_Q_selection = "作为乐于助人的助理，您需要将一个中文问题进行分类，考虑到该问题背景在中文环境中并不常见，因此您需要返回最适合该问题的语言。你需要首先对问题所需要的语言环境进行分析，然后在最后告诉我你返回的最适合回答该问题的语言。问题如下："


def read_json(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:  # Add 'encoding' parameter
        data = json.load(f)
    return data




from http import HTTPStatus
import dashscope


def qwen_res(string):
    dashscope.api_key = config['api']['qwen_api']
    messages = [{'role': 'system', 'content': 'You are a helpful assistant.'},
                {'role': 'user', 'content': string}]
    try:
        response = dashscope.Generation.call(
            dashscope.Generation.Models.qwen_turbo,
            messages=messages,
            result_format='message',  # set the result to be "message" format.
        )
        if response.status_code == HTTPStatus.OK:
            print(response)
            return response['output']['choices'][0].message.content
        else:
            print(response)
            return None
    except:
        return None

@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def get_res(string,):
    if config['api']['type'] == 'azure':
        client = AzureOpenAI(
            api_key=config['api']['key'],
            api_version=config['api']['version'],
            azure_endpoint=config['api']['endpoint']
        )
        chat_completion = client.chat.completions.create(
            model=config['api']['model'],
            messages=[{"role": "user", "content": string}]
        )
    else:
        client = OpenAI()
        completion = client.chat.completions.create(
            model=config['api']['model'],
            messages=[{"role": "user", "content": string}]
        )
        print(completion.choices[0].message.content)
    return chat_completion.choices[0].message.content


def replicate_api(string, model, temperature=1):
    model_mapping = {
        'yi-34b': "01-ai/yi-34b-chat:914692bbe8a8e2b91a4e44203e70d170c9c5ccc1359b283c84b0ec8d47819a46",
        'qwen-14b': "nomagick/qwen-14b-chat:f9e1ed25e2073f72ff9a3f46545d909b1078e674da543e791dec79218072ae70",
        'yi-6b': "01-ai/yi-6b:d302e64fad6b4d85d47b3d1ed569b06107504f5717ee1ec12136987bec1e94f1",
    }
    os.environ["REPLICATE_API_TOKEN"] = config['api']['replicate_api']
    res = replicate.run(
        model_mapping[model],
        input={"prompt": string, "temperature": 0.5}
    )

    res = "".join(res)
    print(res)
    return res



def prompt2conversation(prompt, model_path):
    msg = prompt
    conv = get_conversation_template(model_path)
    conv.set_system_message('')
    conv.append_message(conv.roles[0], msg)
    conv.append_message(conv.roles[1], None)
    conversation = conv.get_prompt()
    return conversation


def generation_text(model, tokenizer, prompt, device, max_length, temperature, model_path,
                                      return_details=True):
    model.to(device)
    inputs = tokenizer([prompt], return_tensors="pt")
    input_length = inputs['input_ids'].size(1)
    system_prompt = ""
    prompt = system_prompt + prompt
    prompt = prompt2conversation(prompt, model_path=model_path)
    inputs = tokenizer([prompt])
    inputs = {k: torch.tensor(v).to(device) for k, v in inputs.items()}
    if 'token_type_ids' in inputs:
        del inputs['token_type_ids']
    print(type(temperature))
    try:
        output_ids = model.generate(
            **inputs,
            do_sample=True if temperature > 1e-5 else False,
            temperature=temperature,
            max_length=max_length
        )

        if model.config.is_encoder_decoder:
            output_ids = output_ids[0]
        else:
            output_ids = output_ids[0][len(inputs["input_ids"][0]):]
        generated_text = tokenizer.decode(
            output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
        )
    except Exception as e:
        print(e)
        generated_text = None

    print(generated_text)
    return generated_text


def process_item(item, model):
    ch_answer = get_res(item['ch_question'], model=model)
    en_answer = get_res(item['en_question'], model=model)
    ch_selection_answer = get_res(Chinese_Q_selection + item['ch_question'], model=model)
    en_selection_answer = get_res(English_Q_selection + item['en_question'], model=model)
    item['ch_res'] = ch_answer
    item['en_res'] = en_answer
    item['ch_selection'] = ch_selection_answer
    item['en_selection'] = en_selection_answer
    return item


import functools


def run_single_test(args, model=None, tokenizer=None):
    model_mapping = {"baichuan-inc/Baichuan-13B-Chat": "baichuan-13b",
                     "baichuan-inc/Baichuan2-13B-chat": "baichuan2-13b",
                     "THUDM/chatglm2-6b": "chatglm2",
                     "lmsys/vicuna-13b-v1.3": "vicuna-13b",
                     "lmsys/vicuna-7b-v1.3": "vicuna-7b",
                     "lmsys/vicuna-33b-v1.3": "vicuna-33b",
                     "meta-llama/Llama-2-7b-chat-hf": "llama2-7b",
                     "meta-llama/Llama-2-13b-chat-hf": "llama2-13b",
                     'TheBloke/koala-13B-HF': "koala-13b",
                     "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5": "oasst-12b",
                     "WizardLM/WizardLM-13B-V1.2": "wizardlm-13b",
                     'ernie': "ernie",
                     "chatgpt": 'chatgpt',
                     'gpt-4': 'gpt-4'}
    device = "cuda"
    if args.model_path == 'THUDM/chatglm3-6b':
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
        save_data = []
        with open(args.filename, 'r') as f:
            data = json.load(f)
            data = [el for el in data if 'en_question' in list(el.keys())]
            print(len(data))
            for el in data:
                ch_answer = generation_text(model, tokenizer, el['ch_question'], device,
                                                              max_length=8000, temperature=0.5,
                                                              model_path=args.model_path, return_details=False)

                en_answer = generation_text(model, tokenizer, el['en_question'], device,
                                                              max_length=8000, temperature=0.5,
                                                              model_path=args.model_path, return_details=False)

                ch_selection_answer = generation_text(model, tokenizer,
                                                                        Chinese_Q_selection + el['ch_question'], device,
                                                                        max_length=8000, temperature=0,
                                                                        model_path=args.model_path,
                                                                        return_details=False)

                en_selection_answer = generation_text(model, tokenizer,
                                                                        English_Q_selection + el['en_question'], device,
                                                                        max_length=8000, temperature=0,
                                                                        model_path=args.model_path,
                                                                        return_details=False)


                el['ch_res'] = ch_answer
                el['en_res'] = en_answer
                el['ch_selection'] = ch_selection_answer
                el['en_selection'] = en_selection_answer
                save_data.append(el)
                with open(os.path.join('results', 'chatglm3',
                                       args.filename.split('/')[-1].replace('.json', '_chatglm3_time' + '.json')),
                          'w') as f2:
                    json.dump(save_data, f2, indent=4)
    elif args.model_path in ['gpt-4', 'chatgpt']:
        save_data = []
        with open(args.filename, 'r') as f:
            data = json.load(f)
            data = [el for el in data if 'en_question' in list(el.keys())]
            print(len(data))
            with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
                partial_process_item = functools.partial(process_item, model=args.model_path)
                save_data = list(executor.map(partial_process_item, data))

            with open(os.path.join('results', args.model_path,
                                   args.filename.split('/')[-1].replace('.json', '_' + args.model_path + '.json')),
                      'w') as f2:
                json.dump(save_data, f2, indent=4)
    elif args.model_path in ['yi-34b', 'qwen-14b', 'yi-6b']:
        save_data = []
        with open(args.filename, 'r') as f:
            data = json.load(f)
            data = [el for el in data if 'en_question' in list(el.keys())]
            for el in data:
                ch_answer = replicate_api(el['ch_question'], model=args.model_path)
                en_answer = replicate_api(el['en_question'], model=args.model_path)
                ch_selection_answer = replicate_api(Chinese_Q_selection + el['ch_question'], model=args.model_path)
                en_selection_answer = replicate_api(English_Q_selection + el['en_question'], model=args.model_path)
                el['ch_res'] = ch_answer
                el['en_res'] = en_answer
                el['ch_selection'] = ch_selection_answer
                el['en_selection'] = en_selection_answer
                print(el)
                save_data.append(el)
                with open(os.path.join('results', args.model_path,
                                       args.filename.split('/')[-1].replace('.json', '_' + args.model_path + '.json')),
                          'w') as f2:
                    json.dump(save_data, f2, indent=4)
    elif args.model_path == 'qwen-turbo':
        save_data = []
        with open(args.filename, 'r') as f:
            data = json.load(f)
            data = [el for el in data if 'en_question' in list(el.keys())]
            for el in data:
                ch_answer = qwen_res(el['ch_question'])
                en_answer = qwen_res(el['en_question'])
                ch_selection_answer = qwen_res(Chinese_Q_selection + el['ch_question'])
                en_selection_answer = qwen_res(English_Q_selection + el['en_question'])
                el['ch_res'] = ch_answer
                el['en_res'] = en_answer
                el['ch_selection'] = ch_selection_answer
                el['en_selection'] = en_selection_answer
                print(el)
                save_data.append(el)
                with open(os.path.join('results', args.model_path,
                                       args.filename.split('/')[-1].replace('.json', '_' + args.model_path + '.json')),
                          'w') as f2:
                    json.dump(save_data, f2, indent=4)
    elif args.model_path == "/data/huggingface/hub/Llama3-Chinese":
        save_data = []
        with open(args.filename, 'r') as f:
            data = json.load(f)
            data = [el for el in data if 'en_question' in list(el.keys())]
            print(len(data))
            save_path = os.path.join('results', 'llama3',
                                     args.filename.split('/')[-1].replace('.json', '_llama3' + '.json'))
            if os.path.exists(save_path):
                with open(save_path, 'r') as f3:
                    previous_data = json.load(f3)
                    save_data.extend(previous_data[:-1])
                    new_data = data[len(previous_data) - 1:]
            else:
                previous_data = 0
                new_data = data
            print()
            for el in new_data:
                ch_answer = generation_text(model, tokenizer, el['ch_question'], device,
                                                              max_length=8000, temperature=0.5,
                                                              model_path=args.model_path, return_details=False)

                en_answer = generation_text(model, tokenizer, el['en_question'], device,
                                                              max_length=8000, temperature=0.5,
                                                              model_path=args.model_path, return_details=False)

                ch_selection_answer = generation_text(model, tokenizer,
                                                                        Chinese_Q_selection + el['ch_question'], device,
                                                                        max_length=8000, temperature=0,
                                                                        model_path=args.model_path,
                                                                        return_details=False)

                en_selection_answer = generation_text(model, tokenizer,
                                                                        English_Q_selection + el['en_question'], device,
                                                                        max_length=8000, temperature=0,
                                                                        model_path=args.model_path,
                                                                        return_details=False)
                el['ch_res'] = ch_answer
                el['en_res'] = en_answer
                el['ch_selection'] = ch_selection_answer
                el['en_selection'] = en_selection_answer

                save_data.append(el)

                with open(save_path, 'w') as f2:
                    json.dump(save_data, f2, indent=4)
    else:
        raise ValueError('no model!')

@torch.inference_mode()
def main(args, max_retries=20, retry_interval=3):
    model, tokenizer = load_model(
        args.model_path,
        device=args.device,
        num_gpus=args.num_gpus,
        max_gpu_memory=args.max_gpu_memory,
        load_8bit=args.load_8bit,
        cpu_offloading=args.cpu_offloading,
        revision=args.revision,
        debug=args.debug,
    )

    for attempt in range(max_retries):
        try:
            state = run_single_test(args, model, tokenizer)
            message = f"Test function successful on attempt {attempt + 1}"
            logging.info(message)
            print(message)
            return state
        except Exception as e:
            traceback.print_exc()
            message = f"Test function failed on attempt {attempt + 1}:{e}"
            logging.error(message)
            print(message)
            print("Retrying in {} seconds...".format(retry_interval))
            time.sleep(retry_interval)

    return None


# Generate a unique timestamp for the log file
timestamp = time.strftime("%Y%m%d%H%M%S")
log_filename = f"test_log_{timestamp}.txt"

logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_model_args(parser)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--num_gpus", type=int, default=2)
    parser.add_argument("--max_length", type=int, default=8000)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--model_path", type=str, default='')
    parser.add_argument("--filename", type=str, default='')
    parser.add_argument("--test_type", type=str, default='plugin')
    args = parser.parse_args()
    state = main(args, )
    print(state)


