import re

import torch
import torch.nn.functional as F
from typing import *
import os
import signal
import subprocess
import platform
import json
import numpy as np
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from codebleu import calc_codebleu

SUCCESS = 0
RUN_ERROR = 1
TIME_OUT = 2
UNKNOWN_ERR = 3


def run_cmd(cmd_string, timeout=5):
    p = subprocess.Popen(cmd_string, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, shell=True, close_fds=True,
                         start_new_session=True)

    format = 'utf-8'
    if platform.system() == "Windows":
        format = 'gbk'

    try:
        (msg, errs) = p.communicate(timeout=timeout)
        ret_code = p.poll()
        if ret_code:
            code = RUN_ERROR
        else:
            code = SUCCESS
    except subprocess.TimeoutExpired:
        p.kill()
        p.terminate()
        os.killpg(p.pid, signal.SIGTERM)
        # subprocess.Popen("taskkill /F /T /PID %i" % p.pid, shell=True)
        code = TIME_OUT
    except Exception as e:
        code = UNKNOWN_ERR

    return code


def rename_result(res_dir):
    for filename in os.listdir(res_dir):
        file_path = os.path.join(res_dir, filename)

        if os.path.isfile(file_path):
            new_filename = filename.replace(".json.json", ".json")
            new_file_path = os.path.join(res_dir, new_filename)
            os.rename(file_path, new_file_path)

            print(f"rename {file_path} to {new_file_path}")


def delete_empty_testcases(res_dir):
    input_file_list = [path for path in os.listdir(res_dir) if
                       'out_codeforces_problems_data' in path]
    for input_file in input_file_list:
        with open(f"{res_dir}/{input_file}", 'r', encoding='utf-8') as f:
            data = json.load(f)
        for i in range(len(data) - 1, -1, -1):
            problem = data[i]
            pid = problem['pid']
            test_cases: List = problem['test_case']
            if len(test_cases) == 0:
                print(f"pid: {pid} deleted")
                data.pop(i)
        with open(f"{res_dir}/{input_file}", 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False)


def pass_at_k(num_samples: int, num_correct: int, k: List[int]):

    def estimator(n: int, c: int, k: int) -> float:
        """
        Calculates 1 - comb(n - c, k) / comb(n, k).
        """
        if n - c < k:
            return 1.0
        return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

    return [estimator(num_samples, num_correct, i) for i in k]


def avg_pass_at_k(all_pass_at_k: List[Dict]) -> Dict:
    final_dict = {f"avg_{key}": [d[key]for d in all_pass_at_k]
                  for key in all_pass_at_k[0] if key != "pid"}
    # print(final_dict)
    res = {key: np.mean(value_list) for key, value_list in final_dict.items()}
    return res


def avg_test_case_pass_rate(all_res: List[Dict]) -> float:
    pass_rate_list = [d["pass_rate"] for d in all_res]
    return np.mean(pass_rate_list)


def get_bleu(prediction: str, reference: str, tokenizer=None, bleu=None):
    try:
        if tokenizer is None:
            results = bleu.compute(predictions=[prediction], references=[reference])
        else:
            results = bleu.compute(predictions=[prediction], references=[reference], tokenizer=tokenizer)
    except ZeroDivisionError as e:
        print(e)
        return 0.0
    else:
        # print(results)
        return results['bleu']


def get_rouge(prediction: str, reference: str, tokenizer=None, rouge=None):
    if tokenizer is None:
        results = rouge.compute(predictions=[prediction], references=[reference])
    else:
        results = rouge.compute(predictions=[prediction], references=[reference], tokenizer=tokenizer)
    # print(results)
    return results['rouge1'], results['rouge2'], results['rougeL']


def get_code_bleu(prediction: str, reference: str):
    result = calc_codebleu([reference], [prediction], lang="cpp", weights=(
        0.25, 0.25, 0.25, 0.25), tokenizer=None)
    # print(result)
    return result['codebleu']


def get_similarity(prediction: str, reference: str, tokenizer, model):
    def mean_pooling(model_output, attention_mask):
        # First element of model_output contains all token embeddings
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Tokenize sentences
    sentences = [reference, prediction]
    encoded_input = tokenizer(sentences, padding=True,
                              truncation=True, return_tensors='pt')
    encoded_input["input_ids"] = encoded_input["input_ids"].cuda()
    encoded_input["attention_mask"] = encoded_input["attention_mask"].cuda()
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
    # Perform pooling
    sentence_embeddings = mean_pooling(
        model_output, encoded_input['attention_mask'])
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    cos_sim = F.cosine_similarity(sentence_embeddings[0], sentence_embeddings[1], dim=0)
    return cos_sim.item()


def get_SWS_and_F1(prediction: str, reference: str, threshold: float, sim_tokenizer, sim_model):
    prediction = prediction.strip('\n')
    reference = reference.strip('\n')
    prediction_steps = prediction.split('\n')
    reference_steps = reference.split('\n')
    M, N = len(reference_steps), len(prediction_steps)
    j, correct = 0, 0
    sim_sum = 0.0
    for i in range(M):
        # print(f"i={i}")
        while j < N:
            # print(f"j={j}")
            sim = get_similarity(prediction=prediction_steps[j], reference=reference_steps[i],
                                 tokenizer=sim_tokenizer, model=sim_model)
            if sim >= threshold:
                j += 1
                sim_sum += sim
                correct += 1
                break
            j += 1
    sws = sim_sum / M
    p, r = correct / N, correct / M
    f1 = 0.0 if correct == 0 else 2 * (p * r) / (p + r)
    return sws, f1


def process_result_text(result_text: str, stop_words=['<|endoftext|>']):
    """Manual processing of stop_words. result_text is the generated part (excluding the original prompt)"""
    stop_index = []
    for word in stop_words:
        stop_index.append(result_text.find(word))

    find_index = [i for i in stop_index if i > -1]
    if len(find_index) == 0:
        return result_text
    else:
        end_index = min(find_index)
        return result_text[:end_index]


def codegen_generate(tokenizer, model, texts, max_to_generate, top_p, temperature, device: str) -> List[str]:
    def codegen_family(inputs, max_length, top_p, temperature):
        with torch.no_grad():
            samples = model.generate(**inputs, max_length=max_length, top_p=top_p,
                                     temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        trunked_len = inputs['input_ids'].shape[1]
        samples = samples[:, trunked_len:]
        result = tokenizer.batch_decode(samples, skip_special_tokens=True)
        return result

    inputs = tokenizer(texts, return_tensors="pt", padding=True,
                       truncation=True, max_length=2048)
    inputs["input_ids"] = inputs["input_ids"].cuda(device)
    # print(inputs["input_ids"].shape)
    inputs["attention_mask"] = inputs["attention_mask"].cuda(device)
    max_length = max_to_generate + inputs["input_ids"].shape[1]
    if max_length > 2048:
        max_length = 2048
    return codegen_family(inputs, max_length=max_length, top_p=top_p, temperature=temperature)


def incoder_generate(tokenizer, model, text, max_to_generate, top_p, temperature):
    def incoder_family(top_p, temperature):
        # signals the start of a document
        BOS = "<|endoftext|>"
        # signals the end of a generated infill
        EOM = "<|endofmask|>"

        def generate(input: str):
            """
            Do standard left-to-right completion of the prefix `input` by sampling from the model
            """
            input_ids = tokenizer(input, return_tensors="pt",
                                  truncation=True, max_length=2048).input_ids
            # if CUDA:
            input_ids = input_ids.cuda()
            max_length = max_to_generate + input_ids.flatten().size(0)
            with torch.no_grad():
                output = model.generate(input_ids=input_ids, do_sample=True,
                                        max_length=max_length, top_p=top_p, temperature=temperature)
            detok_hypo_str = tokenizer.decode(
                output.flatten(), clean_up_tokenization_spaces=False)
            if detok_hypo_str.startswith(BOS):
                detok_hypo_str = detok_hypo_str[len(BOS):]
            return detok_hypo_str

        def make_sentinel(i):
            # signals (1) a location to insert an infill and (2) the start of the infill generation
            return f"<|mask:{i}|>"

        def infill(parts: List[str], extra_sentinel: bool = False, max_retries: int = 1):
            retries_attempted = 0
            done = False

            while (not done) and (retries_attempted < max_retries):
                retries_attempted += 1

                # (1) build the prompt
                if len(parts) == 1:
                    prompt = parts[0]
                else:
                    prompt = ""
                    # encode parts separated by sentinel
                    for sentinel_ix, part in enumerate(parts):
                        prompt += part
                        if extra_sentinel or (sentinel_ix < len(parts) - 1):
                            prompt += make_sentinel(sentinel_ix)

                infills = []
                complete = []

                done = True

                # (2) generate infills
                for sentinel_ix, part in enumerate(parts[:-1]):
                    complete.append(part)
                    prompt += make_sentinel(sentinel_ix)
                    # print(prompt)
                    completion = generate(prompt)
                    completion = completion[len(prompt):]
                    if EOM not in completion:
                        completion += EOM
                        done = False
                    completion = completion[:completion.index(EOM) + len(EOM)]
                    infilled = completion[:-len(EOM)]
                    infills.append(infilled)
                    complete.append(infilled)
                    prompt += completion
                complete.append(parts[-1])
                text = ''.join(complete)

            return {
                # str, the completed document (with infills inserted)
                'text': text,
                # List[str], length N-1. The list of infills generated
                'infills': infills
            }

        example = text + "\n<insert>"
        parts = example.split("<insert>")
        result = infill(parts)
        # print(
        #     f"################## result_text before process ##################\n{result['text']}")
        # print(len(result['infills']))
        result_text = process_result_text(
            result['infills'][0], stop_words=["<|/", "<|endoftext|>"])
        return result_text
    return incoder_family(top_p=top_p, temperature=temperature)


def starcoderbase_generate(tokenizer, model, texts, max_to_generate, top_p, temperature, device: str) -> List[str]:
    def starcoderbase_family(inputs, max_length, top_p, temperature):
        with torch.no_grad():
            samples = model.generate(**inputs, max_length=max_length, top_p=top_p,
                                     temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        trunked_len = inputs['input_ids'].shape[1]
        samples = samples[:, trunked_len:]
        result = tokenizer.batch_decode(samples, skip_special_tokens=True)
        return result

    inputs = tokenizer(texts, return_tensors="pt", padding=True,
                       truncation=True, max_length=2048)
    inputs["input_ids"] = inputs["input_ids"].cuda(device)
    # print(inputs["input_ids"].shape)
    inputs["attention_mask"] = inputs["attention_mask"].cuda(device)
    max_length = max_to_generate + inputs["input_ids"].shape[1]
    if max_length > 2048:
        max_length = 2048
    return starcoderbase_family(inputs, max_length=max_length, top_p=top_p, temperature=temperature)


def gpt2_generate(tokenizer, model, texts, max_to_generate, top_p, temperature, device: str) -> List[str]:
    def gpt2_family(inputs, max_length, top_p, temperature):
        with torch.no_grad():
            samples = model.generate(**inputs, max_length=max_length, top_p=top_p,
                                     temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        trunked_len = inputs['input_ids'].shape[1]
        samples = samples[:, trunked_len:]
        result = tokenizer.batch_decode(samples, skip_special_tokens=True)
        return result

    inputs = tokenizer(texts, return_tensors="pt", padding=True,
                       truncation=True, max_length=1024)
    inputs["input_ids"] = inputs["input_ids"].cuda(device)
    # print(inputs["input_ids"].shape)
    inputs["attention_mask"] = inputs["attention_mask"].cuda(device)
    max_length = max_to_generate + inputs["input_ids"].shape[1]
    if max_length > 1024:
        max_length = 1024
    return gpt2_family(inputs, max_length=max_length, top_p=top_p, temperature=temperature)
