import os
import argparse
import pandas
import transformers
import torch
import json
from pathlib import Path
import ast
from prompts import tofueval_prompts, ultrachat_prompts, _example_output_document_text
import openai

client = openai.Client()

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."

def make_llama3_prompt(system, user_message):
    prompt = f"""<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
{system}<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{user_message}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
    return prompt

def get_gpt4_response(user_prompt: str, model="gpt-4-0125-preview"):
    response = client.chat.completions.create(
                model="gpt-4-0125-preview",messages=[ {"role": "user", "content": user_prompt}])
    content = response.choices[0].message.content
    return content


def make_llama2_prompt(
        system: str,
        user_message: str,
) -> str:
    return """<s>[INST] <<SYS>> {system} <</SYS>>
{instruction} [/INST]""".format(system=system, instruction=user_message)


prompt_set = {
    "tofueval":tofueval_prompts,
    "ultrachat":ultrachat_prompts
}

def get_feedback(feedback_method, values, method, generator, model):
    feedback_str = None
    if "json" in method:
        values['example_output'] = _example_output_document_text
    prompt = feedback_method.format(**values)
    if model == 'llama2':
        feedback_prompt = make_llama2_prompt(system=DEFAULT_SYSTEM_PROMPT, user_message=prompt)
    elif model=='llama3':
        feedback_prompt = make_llama3_prompt(system=DEFAULT_SYSTEM_PROMPT, user_message=prompt)
    elif model=="gpt4":
        feedback_prompt = prompt
    if model=="gpt4":
        feedback_response = generator(user_prompt=feedback_prompt)
    else:
        feedback_response = generator(feedback_prompt)[0]['generated_text'].replace(feedback_prompt, "")
    if "json" in method:
        feedback_str = ""
        try:
            feedback_response = json.loads(feedback_response)
            for f, feed in enumerate(feedback_response):
                feedback_str += f"{f + 1}. For sentence/span in the summary: '{feed['inconsistency']}', feedback is: '{feed['feedback']}'\n"
        except:
            try:
                feedback_response = ast.literal_eval(feedback_response)
                for f, feed in enumerate(feedback_response):
                    feedback_str += f"{f + 1}. For sentence/span in the summary: '{feed['inconsistency']}', feedback is: '{feed['feedback']}'\n"
            except:
                print("error")
                feedback_str = feedback_response
    else:
        feedback_str = feedback_response
    return feedback_str, feedback_prompt

def get_refinement_response(refinement_method, values, generator, model):
    prompt = refinement_method.format(**values)
    if model == 'llama2':
        refinement_prompt = make_llama2_prompt(system=DEFAULT_SYSTEM_PROMPT, user_message=prompt)
    elif model=='llama3':
        refinement_prompt = make_llama3_prompt(system=DEFAULT_SYSTEM_PROMPT, user_message=prompt)
    elif model=="gpt4":
        refinement_prompt = prompt
    if model=="gpt4":
        refinement_response = generator(user_prompt=refinement_prompt)
    else:
        refinement_response = generator(refinement_prompt)[0]['generated_text'].replace(refinement_prompt, "")
    print(refinement_response)
    return refinement_response, refinement_response

def main():
    try:
        assert 'CUDA_VISIBLE_DEVICES' in os.environ
    except:
        raise EnvironmentError("set CUDA_VISIBLE_DEVICES in environment")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        help="tofueval or ultrachat",
    )
    parser.add_argument(
        "--hf_token",
        type=str,
        help="huggingface token, required to access llama2-7b-chat",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="model name -- llama2 or llama3 or gpt4"
    )
    parser.add_argument("--input", type=Path, help="Path to input jsonl file")
    parser.add_argument("--output_dir", type=Path, help="path to write output!")
    parser.add_argument("--feedback_type", type=str, default="feedback type whether it is string or json")
    parser.add_argument("--feedback_method",help="one of - improve_simple, simple, categories, json, or all - cannot be true if improve simple is true")

    args = parser.parse_args()
    print(args)
    input = args.input
    output_dir = args.output_dir
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    print(output_dir)
    feedback_type = args.feedback_type
    model_name = args.model_name
    hf_token = args.hf_token
    if hf_token:
        os.environ['HF_TOKEN'] = hf_token
    dataset = args.dataset
    print("loading model!")
    if model_name == "llama2":
        print(model_name)
        model_id = "meta-llama/Llama-2-7b-chat-hf"
        generator = transformers.pipeline("text-generation", model=model_id, device_map="auto", max_new_tokens=2048)
    elif model_name=="llama3":
        print(model_name)
        model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
        generator = transformers.pipeline("text-generation", model=model_id,
                                         model_kwargs={"torch_dtype": torch.bfloat16},
                                         device_map="auto", max_new_tokens=2048)
    elif model_name == "gpt4":
        model_id = 'gpt-4-01250-preview'
        generator = get_gpt4_response

    print("loaded model!")
    filename = input.name
    data = open(input).readlines()
    data = [json.loads(d) for d in data][:50]

    method = prompt_set[dataset]

    refinement_methods = {
        "improve_simple": {
            "feedback": None,
            "refinement": method["simple_refinement_prompt"]
        },
        "feedback_simple": {
            "feedback": method["simple_feedback_prompt"],
            "refinement": method["refinement_with_feedback_prompt"]
        },
        "feedback_categories": {
            "feedback": method["category_feedback_prompt"],
            "refinement": method["refinement_with_feedback_prompt"]
        },
        "feedback_json": {
            "feedback": method["json_category_feedback_prompt"],
            "refinement": method["refinement_with_feedback_prompt"]
        },
    }

    for r in refinement_methods:
        print(r)
        feedback_method = refinement_methods[r]['feedback']
        refinement_method = refinement_methods[r]['refinement']
        output_path = f"{output_dir}/{model_name}_{r}_{filename}"
        print(output_path)
        with open(output_path,"w") as f_:
            for i, d in enumerate(data):
                print(i)
                if isinstance(d, str):
                    d = json.loads(d)
                print(d.keys())
                if dataset == "ultrachat":
                    input = d['instruction']
                    summary = d['completions'][0]
                    aspect = ""
                if dataset == "tofueval":
                    input = d['source_doc']
                    summary = d['summary']
                    aspect = d['topic']
                values = {
                    "input": input,
                    "summary": summary,
                    "aspect": aspect
                }

                feedback_response, feedback_prompt = None, None
                if feedback_method:
                    feedback_response, feedback_prompt = get_feedback(feedback_method, values, method=r, generator=generator, model=model_name)

                values = {
                    "instruction": input,
                    "summary": summary,
                    "aspect": aspect,
                    "feedback":feedback_response
                }
                refinement_response, refinement_prompt = get_refinement_response(refinement_method, values, generator=generator, model=model_name)
                d['feedback_prompt'] = feedback_prompt
                d['feedback_response'] = feedback_response
                d['refinement_prompt'] = refinement_prompt
                d['refinement'] = refinement_response
                f_.write(json.dumps(d))
                f_.write("\n")

if __name__ == "__main__":
    main()
