import os
import sys
import argparse
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from dataloader import EFPDatasetforLLM
import torch
from openai import OpenAI

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default="Mistral-7B-Instruct-v0.2")
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--max_length', type=int, default=256)
    parser.add_argument('--data_dir', type=str, default='test.jsonl')
    parser.add_argument("--num_shot", type=int, default=5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_labels", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--prompt_path", type=str, default="prompt.json")
    parser.add_argument("--shot_path", type=str, default="shot.json")
    parser.add_argument("--chat_mode", action="store_true")
    parser.add_argument("--relation", action="store_true")
    parser.add_argument("--arg", action="store_true")
    parser.add_argument("--cot", action="store_true")
    return parser.parse_args()


def generate_input(prompt_path, shot_path, text, num_shot, cot=None, cause=None, precondition=None, argument=None):
    with open(prompt_path, "r", encoding="utf-8") as f:
        prompt = json.load(f)
    output = prompt["prompt"]
    if cot:
        output += prompt["rules"]
    with open(shot_path, "r", encoding="utf-8") as f:
        shot = json.load(f)
    output += "Here are some examples:\n\n"
    for i in range(num_shot):
        output += "TEXT: " + shot[i]["text"] + "\n"
        output += "LABEL: " + shot[i]["label"] + "\n"
        output += "\n"
    system_output = output
    
    output += "For your reference,\n"
    if (cause and cause != []) or (precondition and precondition != []):
        output += prompt["relation"]
    if cause and cause != []:
        output += "Cause Relations: "
        for i in range(len(cause)):
            output += cause[i] + "\n"
        output += "\n"
    if precondition and precondition != []:
        output += "Precondition Relations: "
        for i in range(len(precondition)):
            output += precondition[i] + "\n"
        output += "\n"
    if argument and argument != "":
        output += prompt["argument"]
        output += "Arguments:"
        output += argument
        output += "\n\n"
    output += "Here is the text you need to generate the label for, please do not output other information other than the label.\n"
    output += "TEXT: " + text + "\n"
    output += f"LABEL: "
    remain_output = output[len(system_output):]
    return system_output, output, remain_output
  

def generate_output():
    args = parse_args()
    print("num_shot: ", args.num_shot)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    output_name = f"result_{os.path.basename(args.model_path)}_{args.num_shot}shot"
    if args.cot:
        output_name += "_cot"
    if args.relation and args.arg:
        output_name += "_relation_arg.jsonl"
    elif args.relation:
        output_name += "_relation.jsonl"
    elif args.arg:
        output_name += "_arg.jsonl"
    else:
        output_name += ".jsonl"
    output_path = os.path.join(args.output_dir, output_name)
            
    if "Llama" in args.model_path:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        tokenizer.add_special_tokens({'additional_special_tokens': ['<e>', '</e>', '<c>', '</c>', '<p>', '</p>']})
        model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, device_map="auto").to(device)
    elif "Mistral" in args.model_path:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        tokenizer.add_special_tokens({'additional_special_tokens': ['<e>', '</e>', '<c>', '</c>', '<p>', '</p>']})
        model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16).to(device)
        model.resize_token_embeddings(len(tokenizer))
    dataset = EFPDatasetforLLM(args.data_dir)
    label_set = dataset.labels
    texts = dataset.texts
    causes = dataset.causes
    preconditions = dataset.preconditions
    arguments = dataset.arguments
    id2label = {0: "CT+", 1: "CT-", 2: "PS+", 3: "PS-", 4: "Uu"}
    label_set = [id2label[i] for i in label_set]
    for i in tqdm(range(len(dataset)), desc=f"Generating outputs with {args.model_path}"):
        if args.relation and args.arg:
            system_prompt, prompt, remain_prompt = generate_input(args.prompt_path, args.shot_path, texts[i], args.num_shot, cot=args.cot, cause=causes[i], precondition=preconditions[i], argument=arguments[i])
        elif args.arg:
            system_prompt, prompt, remain_prompt = generate_input(args.prompt_path, args.shot_path, texts[i], args.num_shot, cot=args.cot, argument=arguments[i])
        elif args.relation:
            system_prompt, prompt, remain_prompt = generate_input(args.prompt_path, args.shot_path, texts[i], args.num_shot, cot=args.cot, cause=causes[i], precondition=preconditions[i])
        else:
            system_prompt, prompt, remain_prompt = generate_input(args.prompt_path, args.shot_path, texts[i], args.num_shot, cot=args.cot)
        if "gpt" in args.model_path:
            client = OpenAI(api_key="")
            for i in tqdm(range(len(dataset)), desc=f"Generating outputs with {args.model_path}"):
                output = client.chat.completions.create(
                    model=args.model_path,
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": remain_prompt}
                    ],
                    max_tokens=16,
                    temperature=0.0
                )
                output = output.choices[0].message.content
                res = {"text": texts[i], "label": label_set[i], "output": output}
                with open(os.path.join(args.output_dir, output_path), "a+", encoding="utf-8") as f:
                    f.write(str(json.dumps(res, ensure_ascii=False)) + "\n")
        elif "Llama" in args.model_path:
            if args.chat_mode:
                messages = [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": remain_prompt}
                ]
                input_ids = tokenizer.apply_chat_template(
                        messages,
                        add_generation_prompt=True,
                        return_tensors="pt"
                ).to(device)

            else:
                input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

            terminators = [
                    tokenizer.eos_token_id,
                    tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]
            outputs = model.generate(
                    input_ids=input_ids,
                    max_new_tokens=32,
                    eos_token_id=terminators,
                    pad_token_id=tokenizer.pad_token_id,
                    do_sample=True,
                    temperature=0.6,
                    top_p=0.9,
            )
            response = outputs[0][input_ids.shape[-1]:]
            res = {"text": texts[i], "label": label_set[i], "output": tokenizer.decode(response, skip_special_tokens=True)}
        else:
            if args.chat_mode:
                messages = [
                    {"role": "user", "content": prompt}
                ]
                model_inputs = tokenizer.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    return_tensors="pt"
                ).to(device)
                outputs = model.generate(
                        model_inputs,
                        max_new_tokens=16,
                        do_sample=True
                )
            else:
                model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
                outputs = model.generate(
                    **model_inputs,
                    max_new_tokens=32,
                    do_sample=True
                )
            res = {"text": texts[i], "label": label_set[i], "output": tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].split("[/INST]")[1]}
        with open(output_path, "a+", encoding="utf-8") as f:
            f.write(str(json.dumps(res, ensure_ascii=False)) + "\n")


if __name__ == "__main__":
    generate_output()               