"""
generate exe, input, output for each function
"""
import torch
import os
import time
import json
import struct
import re
import subprocess
import logging
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
from bleu import list_bleu


def fill_template(c_src):
    prompt_template = f"""### Instruction: Compile the following C code into x86 assembly without optimization.
### Human:
```c
{c_src}
```
###Assistant:
```asm"""
    return prompt_template


def grab_assembly(raw_output: str, special_token="```asm"):
    # grab the ```asm{code}``` from raw_output
    start = raw_output.find(special_token)
    end = raw_output.find("```", start + 1)
    asm = raw_output[start + len(special_token) : end]
    return asm

def numeric_handler(asm: str):
    int_to_float = lambda n: struct.unpack('@f', struct.pack('@I', n))[0]
    float_to_int = lambda f: struct.unpack('@I', struct.pack('@f', f))[0]
    int_to_double = lambda n: struct.unpack('@d', struct.pack('@Q', n))[0]
    double_to_int = lambda f: struct.unpack('@Q', struct.pack('@d', f))[0]
    # find all float_to_int(float_value.f) pattern
    matches = re.findall(r"float_to_int\((\d+\.\d+)\)", asm)
    print(matches)
    for match in matches:
        value: int = float_to_int(float(match))
        asm = asm.replace("float_to_int("+match+")", str(value))
    
    matches = re.findall(r"double_to_quad\((\d+\.\d+)\)", asm)
    print(matches)
    for match in matches:
        value: int = double_to_int(float(match))
        asm = asm.replace("double_to_quad("+match+")", str(value))
    return asm

def test_bleu(ref_path, hyp_path):
    # read file 'crc16.s' and 'ai_crc16.s' in the same directory
    ref = "\n"
    ref += open(ref_path, "r").read()
    hyp = open(hyp_path, "r").read()
    ref = ref[ref.find(".text") :]
    ref = ref[ref.find("\n") + 1 :]
    ref = ref[: ref.find(".cfi_endproc") + len(".cfi_endproc")]
    # ref += "\n\n"
    # normalize
    ref = ref.replace("\n", "\\n")
    hyp = hyp.replace("\n", "\\n")
    ref = " ".join(ref.split())
    hyp = " ".join(hyp.split())
    logging.info(f"hyp: {hyp}")
    logging.info(f"ref: {ref}")
    score = list_bleu([ref], [hyp])
    logging.info(f"bleu score: {score}")
    return score


def run_inference(
    model: LlamaForCausalLM,
    tokenizer: AutoTokenizer,
    eval_script: str,
    in_len: list = None,
    out_len: list = None,
):
    model_input = tokenizer(eval_script, return_tensors="pt").to("cuda")
    in_len.append(model_input.input_ids.shape[1])
    model.eval()
    with torch.no_grad():
        generated = model.generate(**model_input, max_new_tokens=2048)[0]
        out_len.append(generated.shape[0])
        output = tokenizer.decode(generated)

        return output


def cleanup():
    # rm file in input/output/my_output
    for f in os.listdir("input"):
        os.remove(os.path.join("input", f))
    for f in os.listdir("output"):
        os.remove(os.path.join("output", f))
    for f in os.listdir("my_output"):
        os.remove(os.path.join("my_output", f))
    # rm tmp.s, tmp.o, libtmp.a, tmp_driver.cpp, tmp.exe
    if os.path.exists("tmp.s"):
        os.remove("tmp.s")
    if os.path.exists("tmp.o"):
        os.remove("tmp.o")
    if os.path.exists("libtmp.a"):
        os.remove("libtmp.a")
    if os.path.exists("tmp_driver.cpp"):
        os.remove("tmp_driver.cpp")
    if os.path.exists("tmp.exe"):
        os.remove("tmp.exe")


if __name__ == "__main__":
    # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"
    # 1. load eval dataset
    log_dir = "log_dir" # anonymous data
    time_str = time.strftime("%Y:%m:%d:%H:%M:%S")
    time_str = time_str.replace(":", "_")
    log_file = os.path.join(log_dir, f"eval_exebench_run_{time_str}.log")
    logging.basicConfig(filename=log_file, level=logging.INFO)

    # make dataset from exebench
    dataset = load_dataset("jordiae/exebench")
    dataset = dataset["train_real_simple_io"]
    small_dataset = dataset.train_test_split(None, 1000, True)["train"]

    resume = 0
    max_succ = 10000  # unused
    workspace_dir = "workspace" # anonymous data
    index = resume
    non_compilable_asm = []
    compilable_non_executable_skip_asm = (
        []
    )  # lacking some header/library, very few has internal error
    executable_skip_asm = []  # lacking json, wrong driver
    executable_io_correct_asm = []
    executable_io_error_asm = []
    long_input_skip_asm = []

    base_model = "finetuned_model" # anonymous data
    need_numeric = False
    
    model = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.float16, device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    os.chdir(workspace_dir)

    in_c_len = []
    out_asm_len = []
    csv_name = "csv_name.csv" # anonymous data
    bleu_map = {}

    start_time = time.time()
    for func in dataset:
        inputs = func["real_io_pairs"]["input"]
        outputs = func["real_io_pairs"]["output"]
        var_values_dict: dict = {}
        out_dict: dict = {}
        out_count = 0
        input_count = 0
        succ = True
        for pair in inputs:
            input_count += 1
            for i in range(len(pair["var"])):
                var = pair["var"][i]
                value = pair["value"][i]
                if var not in var_values_dict:
                    var_values_dict[var] = []
                cur_values = var_values_dict[var]
                cur_values.append(value)
        for pair in outputs:
            out_count += 1
            for i in range(len(pair["var"])):
                var = pair["var"][i]
                value = pair["value"][i]
                if var not in out_dict:
                    out_dict[var] = []
                out_dict[var].append(value)
        ### 1. compile to generate assmebly
        c_src = func["func_def"]
        prompt = fill_template(c_src)
        if len(prompt) > 2048:
            logging.info(f"SKIP very long input: {index}")
            long_input_skip_asm.append(index)
            cleanup()
            index += 1
            continue
        output = run_inference(model, tokenizer, prompt, in_c_len, out_asm_len)
        asm = grab_assembly(output).strip()
        if need_numeric:
            asm = numeric_handler(asm)
        if os.path.exists("tmp.s"):
            os.remove("tmp.s")
        asm_file = open(f"tmp.s", "w")
        asm_file.write(asm)
        asm_file.close()
        # generate tmp.c
        c_src_file = open(f"tmp.c", "w")
        predefined_headers = """#include <stdio.h>
        #include <stdlib.h>
        #include <stdbool.h>
        #include <string.h>
        #include <math.h>
        #include <time.h>
        """
        c_src_file.write(predefined_headers)
        c_src_file.write(c_src)
        c_src_file.close()
        if os.path.exists("tmp_ref.s"):
            os.remove("tmp_ref.s")
        try:
            ret = subprocess.run(
                ["gcc-9", "-S", "-fno-jump-tables", f"tmp.c", "-o", f"tmp_ref.s"], check=True, timeout=20
            )
            ref_asm_file = open(f"tmp_ref.s", "r")
            ref_asm = ref_asm_file.read()
            ref_asm_file.close()
            # 
        except Exception:
            logging.warning(f"{index} ref cannot compile, using default assembly as reference, which has:\n{c_src}")
            asm_targets = func['asm']['target']
            asm_id = asm_targets.index('real_gcc_x86_O0')
            asm = func['asm']['code'][asm_id]
            ref_asm = asm
            if ref_asm == None:
                ref_asm = ""
                logging.warning(f"ref_asm not found in {index}, skip bleu test")
            else:
                ref_asm_file = open(f"tmp_ref.s", "w")
                ref_asm_file.write(ref_asm)
                ref_asm_file.close()
        # if tmp.s and tmp_ref.s exists, compute bleu
        if os.path.exists("tmp.s") and os.path.exists("tmp_ref.s"):
            bleu_score = test_bleu("tmp_ref.s", "tmp.s")
            bleu_map[index] = bleu_score

        ### 2. assemble into static library
        try:
            ret = subprocess.run(
                ["gcc-9", "-c", f"tmp.s", "-o", f"tmp.o"], check=True, timeout=20
            )
            ret = subprocess.run(["ar", "rcs", f"libtmp.a", f"tmp.o"], check=True)
        except Exception:
            logging.warning(f"non_compilable: {index}, which has c:\n{c_src}and asm:\n{asm}")
            non_compilable_asm.append(index)
            cleanup()
            index += 1
            continue
        ### 2. generate driver function from bench
        c_exe = func["real_exe_wrapper"]
        c_exe_cleaned = c_exe
        start = c_exe.find("""extern "C" {""")
        end = c_exe.find("}", start + 1)
        func_decl = func["func_def"]
        end2 = func_decl.find("{")
        func_decl = func_decl[:end2]
        c_exe_cleaned = (
            c_exe[:start] + 'extern "C" {\n' + func_decl + ";\n}\n" + c_exe[end + 1 :]
        )
        # remove the following part from c_exe_cleaned(unknown/unused header)
        """#include <clib/synthesizer.h>"""
        start = c_exe_cleaned.find("#include <clib/synthesizer.h")
        end = c_exe_cleaned.find(">", start + 1)
        c_exe_cleaned2 = c_exe_cleaned[:start] + c_exe_cleaned[end + 1 :]
        # write to file
        c_exe_file = open(f"tmp_driver.cpp", "w")
        c_exe_file.write(c_exe_cleaned2)
        c_exe_file.close()
        try:
            ret = subprocess.run(
                ["g++-9", f"tmp_driver.cpp", "-L.", f"-ltmp", "-o", f"tmp.exe"],
                check=True,
                timeout=20,
            )
        except Exception:
            logging.warning(f"compilable_non_executable_skip: {index}")
            compilable_non_executable_skip_asm.append(index)
            cleanup()
            index += 1
            continue
        ### 3. assemble input jsons
        if not os.path.exists("input"):
            os.mkdir("input")
        os.chdir("input")
        for i in range(input_count):
            data_file = open(f"in{i}.json", "w")
            data_file.write("{\n")
            for j, (k, v) in enumerate(var_values_dict.items()):
                value = v[i]
                data_file.write(f'    "{k}": {value}')
                if j != len(var_values_dict) - 1:
                    data_file.write(",")
                data_file.write("\n")
            data_file.write("}")
            data_file.close()
        os.chdir("..")

        ### 4. assemble ref output jsons
        if not os.path.exists("output"):
            os.mkdir("output")
        os.chdir("output")
        for i in range(out_count):
            data_file = open(f"out{i}.json", "w")
            data_file.write("{\n")

            for j, (k, v) in enumerate(out_dict.items()):
                value = v[i]
                data_file.write(f'    "{k}": {value}')
                if j != len(out_dict) - 1:
                    data_file.write(",")
                data_file.write("\n")
            data_file.write("}\n")
            data_file.close()
        os.chdir("..")

        ### 5. run the exe
        time_out_limit = 20.0
        if not os.path.exists("my_output"):
            os.mkdir("my_output")
        input_path = os.path.join(workspace_dir, "input")
        my_path = os.path.join(workspace_dir, "my_output")
        ref_path = os.path.join(workspace_dir, "output")
        exe_path = os.path.join(workspace_dir, "tmp.exe")
        json_count = len([f for f in os.listdir(input_path) if f.endswith(".json")])
        local_succ = 0
        local_err = 0

        for i in range(json_count):
            in_file = os.path.join(input_path, f"in{i}.json")
            my_out = os.path.join(my_path, f"myout{i}.json")
            ref_out = os.path.join(ref_path, f"out{i}.json")
            # catch timeout exception
            try:
                ret = subprocess.run(
                    [exe_path, in_file, my_out], timeout=time_out_limit
                )
            except subprocess.TimeoutExpired:
                logging.warning(
                    f"executable_io_error: timeout exceed {time_out_limit}s in {index}"
                )
                local_err += 1
                break
            if not os.path.exists(my_out):
                logging.warning(
                    f"executable_io_error: no execution output found: {index}"
                )
                local_err += 1
                break
            # ret = subprocess.run(["diff", my_out, ref_out])
            json_my = open(my_out, "r")
            json_ref = open(ref_out, "r")
            json_my_str = json_my.read()
            json_ref_str = json_ref.read()
            json_my.close()
            json_ref.close()
            try:
                json_my = json.loads(json_my_str)
                json_ref = json.loads(json_ref_str)
            except json.JSONDecodeError:
                logging.warning(
                    f"SKIP: json decode failed in file {index} for input {i}"
                )
                executable_skip_asm.append(index)
                local_err += 1
                break
            if json_my == json_ref:
                local_succ += 1

            else:
                if json_my == None and json_ref == {}:
                    local_succ += 1
                else:
                    logging.warning(
                        f"executable_io_error: json IO not equal in {index} for input {i}, my: {json_my}, ref: {json_ref}"
                    )
                    local_err += 1
                    break
        if local_err == 0:
            executable_io_correct_asm.append(index)
            logging.info(f"executable_io_correct: {index}")
        else:
            executable_io_error_asm.append(index)
        index += 1
        cleanup()
        if index >= max_succ:
            logging.info(f"reach max succ: {max_succ}")
            break
    end_time = time.time()
    # print result
    compile_failed = len(non_compilable_asm)
    skip0 = len(long_input_skip_asm)
    skip1 = len(compilable_non_executable_skip_asm)
    skip2 = len(executable_skip_asm)
    io_correct = len(executable_io_correct_asm)
    io_error = len(executable_io_error_asm)
    total = skip0 + compile_failed + skip1 + skip2 + io_correct + io_error
    non_skip = total - skip0
    compilable = skip1 + skip2 + io_correct + io_error
    executable = skip2 + io_correct + io_error
    logging.info("result:\n-----------------------------------------------")
    logging.info(
        f"""
            {compile_failed}(FAIL COMPILE)
            {skip1}(SKIP)
            {skip2}(SKIP)
            {io_correct}(SUCCESS)
            {io_error}(FAIL EXECUTION)
            {skip0}(SKIP)
            {total}(TOTAL)
    """
    )
    logging.info(
        f"non_compilable_asm: total {compile_failed}, include <{non_compilable_asm}>"
    )
    logging.info(
        f"compilable_non_executable_skip_asm: total {skip1}, include <{compilable_non_executable_skip_asm}>"
    )
    logging.info(f"executable_skip_asm: total {skip2}, include <{executable_skip_asm}>")
    logging.info(
        f"executable_io_correct_asm: total {io_correct}, include <{executable_io_correct_asm}>"
    )
    logging.info(
        f"executable_io_error_asm: total {io_error}, include <{executable_io_error_asm}>"
    )
    logging.info(f"long_input_skip_asm: total {skip0}, include <{long_input_skip_asm}>")
    logging.info("rate:\n-----------------------------------------------")
    logging.info(f"compilable rate: {compilable*100/non_skip}%")
    logging.info(f"executable rate: {executable*100/(compilable-skip1)}%")
    logging.info(f"io correct rate: {io_correct*100/(executable-skip2)}%")

    # extra datas
    logging.info("extra datas:\n-----------------------------------------------")
    logging.info(f"average input length: {sum(in_c_len)/len(in_c_len)}")
    logging.info(f"average output length: {sum(out_asm_len)/len(out_asm_len)}")
    logging.info(f"input length: {in_c_len}")
    logging.info(f"output length: {out_asm_len}")
    logging.info(f"bleu map: {bleu_map}")
    logging.info(
        f"avg bleu score on {len(bleu_map)} pairs: {sum(bleu_map.values())/len(bleu_map)}"
    )

    # write to csv
    csv_file = open(csv_name, "w")
    csv_file.write("index,bleu\n")
    for k, v in bleu_map.items():
        csv_file.write(f"{k},       {v}\n")
    csv_file.write(f"total,         {len(bleu_map)}\n")
    csv_file.write(f"average,       {sum(bleu_map.values())/len(bleu_map)}\n")
    csv_file.close()

    # log execution time in %H:%M:%S
    logging.info(
        "execution statistics:\n-----------------------------------------------"
    )
    logging.info(
        f"total time: {time.strftime('%H:%M:%S', time.gmtime(end_time - start_time))}"
    )
    logging.info(
        f"average time per function: {time.strftime('%H:%M:%S', time.gmtime((end_time - start_time)/index))}"
    )
    logging.info(
        f"final IO Accuracy: {io_correct*100/(io_correct+io_error+compile_failed)}%"
    )