import os
import signal
import subprocess
import platform
import json
from typing import *
import numpy as np
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
from codebleu import calc_codebleu

SUCCESS = 0
RUN_ERROR = 1
TIME_OUT = 2
UNKNOWN_ERR = 3


def run_cmd(cmd_string, timeout=5):
    p = subprocess.Popen(cmd_string, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, close_fds=True,
                         start_new_session=True)

    format = 'utf-8'
    if platform.system() == "Windows":
        format = 'gbk'

    try:
        (msg, errs) = p.communicate(timeout=timeout)
        ret_code = p.poll()
        if ret_code:
            code = RUN_ERROR
        else:
            code = SUCCESS
    except subprocess.TimeoutExpired:
        p.kill()
        p.terminate()
        os.killpg(p.pid, signal.SIGTERM)
        # subprocess.Popen("taskkill /F /T /PID %i" % p.pid, shell=True)
        code = TIME_OUT
    except Exception as e:
        code = UNKNOWN_ERR

    return code


def rename_result(res_dir):

    for filename in os.listdir(res_dir):

        file_path = os.path.join(res_dir, filename)
        if os.path.isfile(file_path):
            new_filename = filename.replace(".json.json", ".json")
            new_file_path = os.path.join(res_dir, new_filename)
            os.rename(file_path, new_file_path)
            print(f"rename {file_path} to {new_file_path}")


def delete_empty_testcases(res_dir):
    input_file_list = [path for path in os.listdir(res_dir) if
                       'out_codeforces_problems_data' in path]
    for input_file in input_file_list:
        with open(f"{res_dir}/{input_file}", 'r', encoding='utf-8') as f:
            data = json.load(f)
        for i in range(len(data) - 1, -1, -1):
            problem = data[i]
            pid = problem['pid']
            test_cases: List = problem['test_case']
            if len(test_cases) == 0:
                print(f"pid: {pid} deleted")
                data.pop(i)
        with open(f"{res_dir}/{input_file}", 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False)


def pass_at_k(num_samples: int, num_correct: int, k: List[int]):

    def estimator(n: int, c: int, k: int) -> float:
        """
        Calculates 1 - comb(n - c, k) / comb(n, k).
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    return [estimator(num_samples, num_correct, i) for i in k]


def avg_pass_at_k(all_pass_at_k: List[Dict]) -> Dict:
    final_dict = {f"avg_{key}": [d[key]for d in all_pass_at_k]
                  for key in all_pass_at_k[0] if key != "pid"}
    # print(final_dict)
    res = {key: np.mean(value_list) for key, value_list in final_dict.items()}
    return res


def get_code_bleu(prediction: str, reference: str):
    result = calc_codebleu([reference], [prediction], lang="cpp",
                           weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None)
    # print(result)
    return result['codebleu']


def avg_test_case_pass_rate(all_res: List[Dict]) -> float:
    pass_rate_list = [d["pass_rate"] for d in all_res]
    return np.mean(pass_rate_list)


def incoder_generate_code(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, max_to_generate=512, temperature=0.2):

    # signals the start of a document
    BOS = "<|endoftext|>"
    # signals the end of a generated infill
    EOM = "<|endofmask|>"

    overlen = False

    def generate(input: str):
        global overlen
        """
        Do standard left-to-right completion of the prefix `input` by sampling from the model
        """
        input_ids = tokenizer(input, return_tensors="pt",
                              truncation=True, max_length=2048).input_ids
        # if CUDA:
        input_ids = input_ids.cuda()
        max_length = max_to_generate + input_ids.flatten().size(0)
        if max_length > 2048:
            max_length = 2048
            overlen = True
        with torch.no_grad():
            output = model.generate(input_ids=input_ids, do_sample=True,
                                    top_p=0.95, temperature=temperature, max_length=max_length)
        detok_hypo_str = tokenizer.decode(
            output.flatten(), clean_up_tokenization_spaces=False)
        if detok_hypo_str.startswith(BOS):
            detok_hypo_str = detok_hypo_str[len(BOS):]
        return detok_hypo_str

    def make_sentinel(i):
        # signals (1) a location to insert an infill and (2) the start of the infill generation
        return f"<|mask:{i}|>"

    def infill(parts: List[str], extra_sentinel: bool = False, max_retries: int = 1):
        retries_attempted = 0
        done = False

        while (not done) and (retries_attempted < max_retries):
            retries_attempted += 1

            # (1) build the prompt
            if len(parts) == 1:
                prompt = parts[0]
            else:
                prompt = ""
                # encode parts separated by sentinel
                for sentinel_ix, part in enumerate(parts):
                    prompt += part
                    if extra_sentinel or (sentinel_ix < len(parts) - 1):
                        prompt += make_sentinel(sentinel_ix)

            infills = []
            complete = []

            done = True

            # (2) generate infills
            for sentinel_ix, part in enumerate(parts[:-1]):
                complete.append(part)
                prompt += make_sentinel(sentinel_ix)
                # print(prompt)
                completion = generate(prompt)
                completion = completion[len(prompt):]
                if EOM not in completion:
                    completion += EOM
                    done = False
                completion = completion[:completion.index(EOM) + len(EOM)]
                infilled = completion[:-len(EOM)]
                infills.append(infilled)
                complete.append(infilled)
                prompt += completion
            complete.append(parts[-1])
            text = ''.join(complete)

        return {
            # str, the completed document (with infills inserted)
            'text': text,
            # List[str], length N-1. The list of infills generated
            'infills': infills
        }

    def process_result_text(result_text: str):
        # print(f"in process:\n {result_text}")
        begin_index = result_text.find("#include")
        if begin_index == -1:
            begin_index = 0

        stop_words = ["<|/", "</code>"]
        stop_index = []
        for word in stop_words:
            stop_index.append(result_text.find(word))

        find_index = [i for i in stop_index if i > -1]
        if len(find_index) == 0:
            return result_text[begin_index:]
        else:
            end_index = min(find_index)
            return result_text[begin_index:end_index]

    example = prompt + "<insert>"
    parts = example.split("<insert>")
    result = infill(parts)
    # print(
    #     f"################## result_text before process ##################\n{result['text']}")
    # print(len(result['infills']))
    result_text = process_result_text(result['infills'][0])

    return result_text, overlen


def codegen_generate_code(model, tokenizer, prompts: List[str], max_to_generate=512, top_p=0.95, temperature=0.2):

    def process_result_text(batch_result_text: List[str]):
        batch_without_prompt = []
        for result_text in batch_result_text:
            begin_idx = result_text.find("#include")
            batch_without_prompt.append(result_text[begin_idx:])
        return batch_without_prompt

    inputs = tokenizer(prompts, return_tensors="pt", padding=True,
                       truncation=True, max_length=2048)
    # print(input["input_ids"].shape)
    inputs["input_ids"] = inputs["input_ids"].cuda()
    inputs["attention_mask"] = inputs["attention_mask"].cuda()
    max_length = max_to_generate + inputs["input_ids"].shape[1]
    if max_length > 2048:
        max_length = 2048

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_length=max_length, temperature=temperature,
                                       top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id)
    batch_result_text = tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True)
    batch_res_final = process_result_text(batch_result_text)

    return batch_res_final


def starcoderbase_generate_code(model, tokenizer, prompts: List[str], max_to_generate=512, top_p=0.95, temperature=0.2):
    def process_result_text(batch_result_text: List[str]):
        batch_without_prompt = []
        for result_text in batch_result_text:
            begin_idx = result_text.find("#include")
            batch_without_prompt.append(result_text[begin_idx:])
        return batch_without_prompt

    inputs = tokenizer(prompts, return_tensors="pt", padding=True,
                       truncation=True, max_length=2048)
    # print(input["input_ids"].shape)
    inputs["input_ids"] = inputs["input_ids"].cuda()
    inputs["attention_mask"] = inputs["attention_mask"].cuda()
    max_length = max_to_generate + inputs["input_ids"].shape[1]
    if max_length > 2048:
        max_length = 2048

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_length=max_length, temperature=temperature,
                                       top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id)
    batch_result_text = tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True)
    batch_res_final = process_result_text(batch_result_text)

    return batch_res_final


def generate_one_batch(model, tokenizer, prompts: List[str], max_to_generate=512, top_p=0.95, temperature=0.2):

    ### incoder model ###
    # return incoder_generate_code(model=model, tokenizer=tokenizer, prompt=prompt, max_to_generate=max_to_generate, top_p=top_p, temperature=temperature)

    ### Codegen ###
    # return codegen_generate_code(model=model, tokenizer=tokenizer, prompts=prompts, max_to_generate=max_to_generate, top_p=top_p, temperature=temperature)

    ### starcoderbase ###
    return starcoderbase_generate_code(model=model, tokenizer=tokenizer, prompts=prompts, max_to_generate=max_to_generate, top_p=top_p, temperature=temperature)
