import sys
import json
from tqdm import tqdm
import re
import time
import os
import pandas as pd

from prompts import PromptTemplate
from style_lexicon.mark_style import generate_style_modified_texts

from absl import app
from absl import flags

# Output Directory
OUTPUT_DIR = "Results/Evaluation"

# Templates
D0 = "{"
D1 = "}"

timestamp = str(time.time()).split(".")[0]


# Helper methods
def create_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def add_zeros(x):
    while (len(x)) < 5:
        x = "0" + x
    return x


def read_dataset(tst_model, sentiment, evaluation_aspect, style_mask):
    if sentiment == "positive":
        sentiment_path = "1"
    else:
        sentiment_path = "0"

    input_path = (
        "Data/model_evaluation_paper_data/data/sentiment.test." + sentiment_path
    )
    transferred_path = (
        "Data/model_evaluation_paper_data/transfer_model_outputs/"
        + tst_model.partition("_")[0]
        + "/"
        + tst_model.partition("_")[2]
        + "/sentiment.test."
        + sentiment_path
        + ".tsf"
    )
    with open(input_path, "r") as f:
        input_lines = f.readlines()
    input_lines = [line.strip() for line in input_lines]

    style_mask_obj = {"unmask": 0, "remove": 1, "mask": 2}
    if evaluation_aspect.startswith("content_preservation"):
        input_lines = generate_style_modified_texts(input_lines)[
            style_mask_obj[style_mask]
        ]

    with open(transferred_path, "r") as f:
        transferred_lines = f.readlines()
    transferred_lines = [line.strip() for line in transferred_lines]
    if evaluation_aspect.startswith("content_preservation"):
        transferred_lines = generate_style_modified_texts(transferred_lines)[
            style_mask_obj[style_mask]
        ]

    lines = []
    for i, (input_line, transferred_line) in enumerate(
        zip(input_lines, transferred_lines)
    ):
        lines.append([i, input_line, transferred_line])
    return lines, input_path, transferred_path


def read_rerun_dataset(rerun_path, tst_model, sentiment, evaluation_aspect, style_mask):
    input_file = os.path.join(rerun_path, tst_model + ".tsv")
    lines = pd.read_csv(input_file, sep="\t", header=None)
    lines = lines[lines[3] == sentiment]

    # add line number form file, not enumerate

    style_mask_obj = {"unmask": 0, "remove": 1, "mask": 2}
    input_lines = list(lines.iloc[:, 1])
    if evaluation_aspect.startswith("content_preservation"):
        input_lines = generate_style_modified_texts(input_lines)[
            style_mask_obj[style_mask]
        ]
        lines.iloc[:, 1] = input_lines

    transferred_lines = list(lines.iloc[:, 2])
    if evaluation_aspect.startswith("content_preservation"):
        transferred_lines = generate_style_modified_texts(transferred_lines)[
            style_mask_obj[style_mask]
        ]
        lines.iloc[:, 2] = transferred_lines

    lines = lines.drop([3], axis=1)
    lines = lines.values.tolist()
    return lines, rerun_path, rerun_path


def get_input(template, line, evaluation_aspect, d0=D0, d1=D1):
    if evaluation_aspect == "naturalness":
        prompt = template.format(input=(d0 + line[2] + d1))
    elif evaluation_aspect.startswith("content_preservation"):
        prompt = template.format(
            input=(d0 + line[1] + d1), transferred=(d0 + line[2] + d1)
        )
    else:
        prompt = template.format(
            input=(d0 + line[1] + d1), transferred=(d0 + line[2] + d1)
        )
    return prompt


def make_output_path(
    eval_model: str,
    model_variant: str,
    tst_model: str,
    evaluation_aspect: str,
    template_index: str,
    style_mask: str,
    sample_decimals: bool,
):
    sample_decimals = "decimal_sampling" if sample_decimals else "zeroshot"
    output_dir = os.path.join(
        OUTPUT_DIR,
        eval_model,
        model_variant,
        sample_decimals,
        evaluation_aspect,
        "template_" + str(template_index),
        timestamp,
    )
    if evaluation_aspect.startswith("content_preservation"):
        output_dir = os.path.join(output_dir, style_mask)
    output_dir = os.path.join(output_dir, tst_model)

    create_dir(output_dir)
    return output_dir


class ModelRunner:
    def __init__(self, eval_model, model_key, model_variant):
        if eval_model == "opt":
            from OPT.OPT import OPT

            self.model = OPT(model_variant, batch_size=FLAGS.batch_size)
        if eval_model == "gpt3":
            from GPT3.GPT3 import GPT3

            self.model = GPT3(model_key, model_variant)
        if eval_model == "bloom" or eval_model == "bloomAPI":
            from BLOOM.BLOOM import BLOOM

            self.model = BLOOM(model_key, batch_size=FLAGS.batch_size)

        if FLAGS.pipeline_generation and (eval_model == "falcon" or eval_model == "llama2"):
            from pipeline_generation import PipelineGeneration

            self.model = PipelineGeneration(eval_model, model_variant)
        
        elif (not FLAGS.pipeline_generation) and eval_model == "llama2":
            from llama_2.llama_2 import Llama2

            self.model = Llama2(model_variant, batch_size=FLAGS.batch_size)

    def run(self, lines, evaluation_aspect: str, template_index: str):
        template = PromptTemplate.get_prompts(evaluation_aspect)[int(template_index)]
        template_str = template.template

        answered_lines = []
        batch_indices = range(0, len(lines), FLAGS.batch_size)
        for i in tqdm(batch_indices):
            batch_lines = lines[i : i + FLAGS.batch_size]
            prompts = [
                get_input(template_str, line, evaluation_aspect) for line in batch_lines
            ]
            answers = self.model.compute(prompts, FLAGS.sample_decimals)
            for j, answer in enumerate(answers):
                answered_lines.append([*batch_lines[j], answer])
            # retry if what?

        return answered_lines


def save_answered_lines(answered_lines, output_file):
    df = pd.DataFrame(answered_lines)
    with open(output_file, "w", newline="") as f:
        df.to_csv(f, index=False, header=False, sep="\t")


FLAGS = flags.FLAGS
flags.DEFINE_list("sentiments", ["positive", "negative"], "sentiment")
flags.DEFINE_enum(
    "eval_model",
    "gpt3",
    ["gpt3", "opt", "bloom", "bloomAPI", "llama2", "falcon"],
    "LLM eval model",
)
flags.DEFINE_list(
    "evaluation_aspects",
    ["style_transfer_accuracy", "content_preservation", "naturalness"],
    "Evaluation aspect",
)
flags.DEFINE_list(
    "templates",
    ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
    "Prompt template index",
)
flags.DEFINE_enum(
    "style_mask",
    "unmask",
    ["unmask", "mask", "remove"],
    "Content preservation text style mask",
)
flags.DEFINE_list(
    "tst_models", ["ARAE_lambda_1", "CAAE_rho_0_5", "DAR_gamma_15"], "TST model"
)
flags.DEFINE_string("model_key", None, "Key for GPT3 and Bloom model")
flags.DEFINE_string("model_variant", "text-davinci-003", "Which gpt3 model to use")
flags.DEFINE_integer("batch_size", 1, "Batch size")
flags.DEFINE_string("rerun_path", None, "rerun for these sentences")
flags.DEFINE_bool("sample_decimals", False, "only sample decimal numbers")
flags.DEFINE_bool("pipeline_generation", False, "use huggingface pipeline for text generation")


def main(argv):
    evaluate()


def evaluate():
    runner = ModelRunner(FLAGS.eval_model, FLAGS.model_key, FLAGS.model_variant)
    for evaluation_aspect in FLAGS.evaluation_aspects:
        for template_index in FLAGS.templates:
            for tst_model in FLAGS.tst_models:
                for sentiment in FLAGS.sentiments:
                    dataset = "yelp-" + sentiment
                    if FLAGS.rerun_path == None:
                        lines, input_path, transferred_path = read_dataset(
                            tst_model, sentiment, evaluation_aspect, FLAGS.style_mask
                        )
                    else:
                        lines, input_path, transferred_path = read_rerun_dataset(
                            FLAGS.rerun_path,
                            tst_model,
                            sentiment,
                            evaluation_aspect,
                            FLAGS.style_mask,
                        )

                    print("----------------------------------------")
                    print("EXPERIMENT CONFIGURATION")
                    print("----------------------------------------")
                    print("eval_model: ", FLAGS.eval_model)
                    print("tst_model: ", tst_model)
                    print("sentiment: ", sentiment)
                    print("input_path: ", input_path)
                    print("transferred_path: ", transferred_path)
                    print("evaluation_aspect: ", evaluation_aspect)
                    print("template: ", template_index)
                    print("sample_decimals: ", FLAGS.sample_decimals)
                    if evaluation_aspect == "content_preservation":
                        print("style_mask: ", FLAGS.style_mask)
                    print("----------------------------------------")

                    answered_lines = runner.run(
                        lines,
                        evaluation_aspect=evaluation_aspect,
                        template_index=template_index,
                    )
                    if FLAGS.eval_model != "llama2" or FLAGS.pipeline_generation or (
                        FLAGS.eval_model == "llama2"
                        and int(os.environ["LOCAL_RANK"]) == 0
                    ):
                        save_output(
                            evaluation_aspect,
                            template_index,
                            tst_model,
                            dataset,
                            answered_lines,
                        )


def save_output(evaluation_aspect, template_index, tst_model, dataset, answered_lines):
    output_dir = make_output_path(
        FLAGS.eval_model,
        FLAGS.model_variant,
        tst_model,
        evaluation_aspect,
        template_index,
        FLAGS.style_mask,
        FLAGS.sample_decimals,
    )
    output_file = f"{output_dir}/{dataset}.tsv"
    save_answered_lines(answered_lines, output_file)


if __name__ == "__main__":
    argument_list_str = " ".join(sys.argv)
    print("argv: " + argument_list_str)
    app.run(main)
