from tqdm import tqdm

from transformers import AutoTokenizer

from dualdec import dualdec
from dualdec.models import LlamaForCausalLM
from dualdec.cache_engine import CacheEngine

import time, torch, re

import argparse, random

from datasets import load_dataset, load_from_disk



def filter_code(completion: str) -> str:
    pos = completion.find("[DONE]")
    if pos != -1:
        return completion[:pos]
    else:
        return completion


def count_indent(text: str) -> int:
    count = 0
    for char in text:
        if char == " ":
            count += 1
        else:
            break
    return count


def fix_indents(text: str, multiple: int = 2):
    outputs = []
    for line in text.split("\n"):
        while count_indent(line) % multiple != 0:
            line = " " + line
        outputs.append(line)
    return "\n".join(outputs)


def test_fix_indents():
    text = "   # TODO: Implement separate_paren_groups\nreturn []"
    print(fix_indents(text))

def gen_context_prompt(entry):
    return f"You are an expert Python programmer, and here is your task: {entry['text']} Your code should pass these tests:\n\n{entry['test_list']}\n[BEGIN]\n{entry['code']}\n[DONE]\n"

def gen_context(dataset, example_num):
    dataset = list(dataset)[:example_num]
    output = ""
    for entry in dataset:
        output += gen_context_prompt(entry)
    return output
        

def format_test_example(q, tests, code: str=None):
    prompt = ">>> Problem:\n{}\n>>> Test Cases:\n{}\n".format(q.strip(), "\n".join(tests))
    if code:
        code = code.replace("\r", "").replace("\t", "    ")
        prompt += "\n>>> Code:\n```python\n{}\n```".format(code)
    return prompt

def deepseek_gen_context(dataset, example_num):
    examples_str = []
    dataset = list(dataset)
    for i in range(example_num):
        ex = dataset[i]
        q, test, code = ex['text'], ex['test_list'], ex['code']
        ex_prompt = format_test_example(q, test, code)
        example_prompt = '- Example {}:\n{}'.format(i, ex_prompt)
        examples_str += [example_prompt]
    return examples_str

def deepseek_convert_for_evaluation(example):
    gpt_completion = example
    generation = gpt_completion
    try:
        code_block: str = re.findall(f'```python\n(.*?)```', gpt_completion, re.DOTALL | re.IGNORECASE)[0]
        generation = code_block
    except Exception as ex:
        print("Failed to extract codeblock:\n{}".format(gpt_completion))

    example = generation
    return example


ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

N_SHOT = 8
COT_FLAG = True
DEBUG = False
ANSWER_TRIGGER = "The answer is"

def create_demo_text(n_shot=8, cot_flag=True):

    question, chain, answer = [], [], []
    question.append(
        "There are 15 trees in the grove. "
        "Grove workers will plant trees in the grove today. "
        "After they are done, there will be 21 trees. "
        "How many trees did the grove workers plant today?"
    )
    chain.append(
        "There are 15 trees originally. "
        "Then there were 21 trees after some more were planted. "
        "So there must have been 21 - 15 = 6."
    )
    answer.append("6")

    question.append(
        "If there are 3 cars in the parking lot and 2 more cars arrive, "
        "how many cars are in the parking lot?"
    )
    chain.append("There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.")
    answer.append("5")

    question.append(
        "Leah had 32 chocolates and her sister had 42. If they ate 35, "
        "how many pieces do they have left in total?"
    )
    chain.append(
        "Originally, Leah had 32 chocolates. "
        "Her sister had 42. So in total they had 32 + 42 = 74. "
        "After eating 35, they had 74 - 35 = 39."
    )
    answer.append("39")

    question.append(
        "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason "
        "has 12 lollipops. How many lollipops did Jason give to Denny?"
    )
    chain.append(
        "Jason started with 20 lollipops. Then he had 12 after giving some "
        "to Denny. So he gave Denny 20 - 12 = 8."
    )
    answer.append("8")

    question.append(
        "Shawn has five toys. For Christmas, he got two toys each from his "
        "mom and dad. How many toys does he have now?"
    )
    chain.append(
        "Shawn started with 5 toys. If he got 2 toys each from his mom and "
        "dad, then that is 4 more toys. 5 + 4 = 9."
    )
    answer.append("9")

    question.append(
        "There were nine computers in the server room. Five more computers "
        "were installed each day, from monday to thursday. "
        "How many computers are now in the server room?"
    )
    chain.append(
        "There were originally 9 computers. For each of 4 days, 5 more "
        "computers were added. So 5 * 4 = 20 computers were added. "
        "9 + 20 is 29."
    )
    answer.append("29")

    question.append(
        "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On "
        "wednesday, he lost 2 more. "
        "How many golf balls did he have at the end of wednesday?"
    )
    chain.append(
        "Michael started with 58 golf balls. After losing 23 on tuesday, "
        "he had 58 - 23 = 35. After losing 2 more, "
        "he had 35 - 2 = 33 golf balls."
    )
    answer.append("33")

    question.append(
        "Olivia has $23. She bought five bagels for $3 each. "
        "How much money does she have left?"
    )
    chain.append(
        "Olivia had 23 dollars. "
        "5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. "
        "So she has 23 - 15 dollars left. 23 - 15 is 8."
    )
    answer.append("8")

    index_list = list(range(len(question)))

    # Concatenate demonstration examples ...
    demo_text = ""
    for i in index_list[:n_shot]:
        if cot_flag:
            demo_text += (
                "Q: "
                + question[i]
                + "\nA: "
                + chain[i]
                + " "
                + ANSWER_TRIGGER
                + " "
                + answer[i]
                + ".\n\n"
            )
        else:
            demo_text += (
                "Question: "
                + question[i]
                + "\nAnswer: "
                + ANSWER_TRIGGER
                + " "
                + answer[i]
                + ".\n\n"
            )
    return demo_text

def clean_answer(model_pred):

    model_pred = model_pred.lower()
    preds = model_pred.split(ANSWER_TRIGGER.lower())
    answer_flag = True if len(preds) > 1 else False
    if answer_flag:
        # Pick first answer with flag
        pred = preds[1]
    else:
        # Pick last number without flag
        pred = preds[-1]

    pred = pred.replace(",", "")
    pred = [s for s in re.findall(r"-?\d+\.?\d*", pred)]

    if len(pred) == 0:
        return "[invalid]"

    if answer_flag:
        # choose the first element in list
        pred = pred[0]
    else:
        # choose the last element in list
        pred = pred[-1]

    # (For arithmetic tasks) if a word ends with period, it will be omitted ...
    if pred[-1] == ".":
        pred = pred[:-1]

    return pred


def build_prompt(input_text, n_shot=8, cot_flag=True):

    demo = create_demo_text(n_shot, cot_flag)
    input_text_prompt = demo + "Q: " + input_text + "\n" + "A:"
    return input_text_prompt


cnndm_prompt = "Please summarize the following document within one sentence. Document: {article}  Summary: "

wmt_prompt = "Translate this German sentence to English. German: {de} English: "

def evaluate(model, mbpp, gsm, cnndm, wmt, round = 20, each_num = 1, **kwargs):
    progress_bar = tqdm(total=round * each_num * 4, desc="Generating samples")
    # make mbpp context
    example_num = 3
    context = gen_context(mbpp['prompt'], example_num)
    mbpp = list(mbpp['validation'])

    gsm = list(gsm)
    cnndm = list(cnndm)
    wmt = list(wmt)

    if each_num < 0:
        seq = [0] * round + [1] * round + [2] * round + [3] * round
        random.shuffle(seq)
        cnt_0 = cnt_1 = cnt_2 = cnt_3 = 0
        for r in seq:
            if r == 0:
                entry = mbpp[cnt_0]
                prompt = context + f"You are an expert Python programmer, and here is your task: {entry['text']} Your code should pass these tests:\n\n{entry['test_list']}\n[BEGIN]\n"
                model.run(prompt)
                # progress_bar.update(1)
                cnt_0 += 1
            elif r == 1:
                entry = gsm[cnt_1]
                prompt = build_prompt(entry['question'])
                model.run(prompt)
                # progress_bar.update(1)
                cnt_1 += 1
            elif r == 2:
                entry = cnndm[cnt_2]
                prompt = cnndm_prompt.format(**entry)
                model.run(prompt)
                # progress_bar.update(1)
                cnt_2 += 1
            else:
                entry = wmt[cnt_3]['translation']
                prompt = wmt_prompt.format(**entry)
                model.run(prompt)
                # progress_bar.update(1)
                cnt_3 += 1
    else:
        for r in range(round):
            # mbpp
            for i in range(each_num):
                entry = mbpp[r * each_num + i]
                prompt = context + f"You are an expert Python programmer, and here is your task: {entry['text']} Your code should pass these tests:\n\n{entry['test_list']}\n[BEGIN]\n"
                model.run(prompt)
                progress_bar.update(1)
            
            print("mbpp" * 50)
            token_map = model.ngram_cache.token_map
            vocab_size = model.tokenizer.vocab_size
            for i in range(vocab_size):
                out = ""
                out += model.tokenizer.decode(i)
                out += ': '
                if i in token_map:
                    out += model.tokenizer.decode(token_map[i][-1])
                print(out)

            # gsm
            for i in range(each_num):
                entry = gsm[r * each_num + i]
                prompt = build_prompt(entry['question'])
                model.run(prompt)
                progress_bar.update(1)
            
            print("gsm" * 50)
            token_map = model.ngram_cache.token_map
            vocab_size = model.tokenizer.vocab_size
            for i in range(vocab_size):
                out = ""
                out += model.tokenizer.decode(i)
                out += ': '
                if i in token_map:
                    out += model.tokenizer.decode(token_map[i][-1])
                print(out)
            # cnndm
            for i in range(each_num):
                entry = cnndm[r * each_num + i]
                prompt = cnndm_prompt.format(**entry)
                model.run(prompt)
                progress_bar.update(1)
            
            print("cnndm" * 50)
            token_map = model.ngram_cache.token_map
            vocab_size = model.tokenizer.vocab_size
            for i in range(vocab_size):
                out = ""
                out += model.tokenizer.decode(i)
                out += ': '
                if i in token_map:
                    out += model.tokenizer.decode(token_map[i][-1])
                print(out)
            # wmt
            for i in range(each_num):
                entry = wmt[r * each_num + i]['translation']
                prompt = wmt_prompt.format(**entry)
                model.run(prompt)
                progress_bar.update(1)
            print("wmt" * 50)
            token_map = model.ngram_cache.token_map
            vocab_size = model.tokenizer.vocab_size
            for i in range(vocab_size):
                out = ""
                out += model.tokenizer.decode(i)
                out += ': '
                if i in token_map:
                    out += model.tokenizer.decode(token_map[i][-1])
                print(out)
    
    return model.latency()

class SyldModel:
    def __init__(self, draft_model, target_model, tokenizer, test_func, max_len, gamma, window_size, guess_set_size, lookahead_level, eos_token_id):
        self.draft_model = draft_model
        self.target_model = target_model
        self.tokenizer = tokenizer
        self.tot_time = 0
        self.tot_tokens = 0
        self.test_func = test_func
        self.gamma = gamma
        self.window_size = window_size
        self.guess_set_size = guess_set_size
        self.lookahead_level = lookahead_level
        self.max_len = max_len
        self.eos_token_id = eos_token_id
        self.ngram_cache = CacheEngine(lookahead_level, guess_set_size)
    
    def run(self, prompt, temperature = 1.0):
        # self.ngram_cache = CacheEngine(self.lookahead_level, self.guess_set_size)
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to("cuda").view(1, -1)
        prompt_len = input_ids.shape[-1]
        beg_time = time.time()
        output = self.test_func(input_ids, self.draft_model, self.target_model, self.ngram_cache, self.max_len, self.gamma, self.window_size, self.guess_set_size, self.lookahead_level, self.eos_token_id)
        end_time = time.time()
        output_len = output.shape[-1]
        output = output[:,prompt_len:]
        self.tot_time += end_time - beg_time
        # print(f"{end_time - beg_time : .4f}")
        self.tot_tokens += output_len - prompt_len
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return generated_text

    def latency(self):
        return self.tot_tokens / (self.tot_time)


def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--target_model', type=str, help='Model name or path of target model in both greedy mode or dualdec mode.')
    parser.add_argument('--draft_model', type=str, help='Model name or path of draft model only in dualdec mode.')
    parser.add_argument('--mbpp', type=str, help="Data path of the dataset", default=None)
    parser.add_argument('--gsm8k', type=str, help="Data path of the dataset", default=None)
    parser.add_argument('--cnndm', type=str, help="Data path of the dataset", default=None)
    parser.add_argument('--wmt16', type=str, help="Data path of the dataset", default=None)
    parser.add_argument('--generate_len', type=int, help='Generate length during testing', default=512) 
    parser.add_argument('--gamma', type=int, default=6)
    parser.add_argument('--window_size', type=int, default=18)
    parser.add_argument('--guess_set_size', type=int, default=18)
    parser.add_argument('--lookahead_level', type=int, default=5)
    parser.add_argument('--round', type=int, default=20)
    parser.add_argument('--each_num', type=int, default=1)
    args = parser.parse_args()  
    return args


def main():
    args = parse()
    small_model = LlamaForCausalLM.from_pretrained(args.draft_model, torch_dtype=torch.float16, device_map='auto')
    target_model = LlamaForCausalLM.from_pretrained(args.target_model, torch_dtype=torch.float16, device_map='auto')
    torch.cuda.empty_cache()

    tokenizer = AutoTokenizer.from_pretrained(args.target_model)
    model = SyldModel(draft_model=small_model, target_model=target_model, tokenizer=tokenizer, test_func=dualdec, max_len=args.generate_len, gamma=args.gamma, window_size=args.window_size, guess_set_size=args.guess_set_size, lookahead_level=args.lookahead_level, eos_token_id=tokenizer.eos_token_id)
    
    gsm8k = load_from_disk(args.gsm8k)
    cnndm = load_from_disk(args.cnndm)
    wmt16 = load_from_disk(args.wmt16)
    mbpp = load_from_disk(args.mbpp)

    print("warm up...")
    for i in range(5):
        model.run("warm up")
    print("start")

    
    latency = evaluate(model, mbpp, gsm8k, cnndm, wmt16, args.round, args.each_num)
    print(f"total speed: {latency:.2f} tok / s")


if __name__ == "__main__":
    main()
