from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from num2words import num2words
import numpy as np
from random import shuffle, random, randint, sample, choices
import copy
from collections import defaultdict
from langchain.llms import HuggingFacePipeline
from torch import cuda, bfloat16
import transformers
from transformers import StoppingCriteria, StoppingCriteriaList
import torch

class StopOnTokens(StoppingCriteria):

    def __init__(self, stop_list, tokenizer):
        self.stop_list = stop_list
        self.tokenizer = tokenizer
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                return True
        return False


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT):
    SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
    prompt_template = "<s>" + B_INST + " " +SYSTEM_PROMPT + instruction + E_INST
    return prompt_template

def compare_function(src, tgt):
    lev = NormalizedLevenshtein()
    if lev.distance(src, tgt) <= 0.1:
        print("Correct")
        return True
    else:
        print("False")
        return False
def str2bool(v):
    if v == "True":
        return True
    elif v == "False":
        return False

def contains_all_alphabets(s: str) -> bool:
    s = s.lower()
    return all(chr(c) in s for c in range(ord('a'), ord('z')+1))

def alphabet_position(c: str) -> int:
    position = ord(c.lower()) - ord('a')
    return position

def unique_character(text):
    s = text.lower()
    counter = {}
    for c in s:
        if c.isalpha():
            if c not in counter:
                counter[c] = 1
            else:
                counter[c] += 1
    return counter

def few_shot_rule_correction(rules, few_shot_examples, determinstic=False):
    few_shot = []
    if determinstic == True:
        determined_positive = int(len(few_shot_examples) / 2)
    for idx, example in enumerate(few_shot_examples):
        counter = unique_character(example)
        occured_characters = list(counter.keys())
        wrong_rules = copy.deepcopy(rules)
        wrong_rule_flag = False
        if random() < 0.5 and determinstic == False:
            wrong_rule_flag = True
            target_noise_character = sample(occured_characters, 2)
            temp = wrong_rules[alphabet_position(target_noise_character[0])]["Altered"]
            wrong_rules[alphabet_position(target_noise_character[0])]["Altered"] = wrong_rules[alphabet_position(target_noise_character[1])]["Altered"].upper()
            wrong_rules[alphabet_position(target_noise_character[1])]["Altered"] = temp
        elif determinstic == True:
            if idx in range(determined_positive):
                wrong_rule_flag = True
                target_noise_character = sample(occured_characters, 2)
                temp = wrong_rules[alphabet_position(target_noise_character[0])]["Altered"]
                wrong_rules[alphabet_position(target_noise_character[0])]["Altered"] = \
                wrong_rules[alphabet_position(target_noise_character[1])]["Altered"].upper()
                wrong_rules[alphabet_position(target_noise_character[1])]["Altered"] = temp
            else:
                wrong_rule_flag = False
        few_shot.append([wrong_rules, example, wrong_rule_flag])
    return few_shot


def remove_elements(num_list, remove_list):
    res = []
    for i in range(len(num_list)):
        if i not in remove_list:
            res.append(num_list[i])
    return res

def few_shot_rule_incorporation(rules, few_shot_examples, added_num=8, determinstic=False):
    few_shot = []
    if determinstic == True:
        determined_positive = int(len(few_shot_examples)/2)
    for idx, example in enumerate(few_shot_examples):
        counter = unique_character(example)
        occured_characters = list(counter.keys())
        added_idx = [alphabet_position(char) for char in occured_characters]
        whole_erased = []
        whole_erased.extend(added_idx)
        incomplete_rules = copy.deepcopy(rules)
        incorporate_rule_flag = False
        if random() < 0.5 and determinstic == False:
            incorporate_rule_flag = True
            incomplete_rules = remove_elements(incomplete_rules, whole_erased)
        elif determinstic==True:
            if idx < determined_positive:
                incorporate_rule_flag = True
                incomplete_rules = remove_elements(incomplete_rules, whole_erased)
            else:
                incorporate_rule_flag = False
        few_shot.append([incomplete_rules, occured_characters, example, incorporate_rule_flag])
    return few_shot

def result_dump(results, new_prompt, method, shot_num, shift_num, shift_mention, cipher_table_mention,encoder_decode, prompt_hyper, args):
    shift_mention_string = "shift known" if shift_mention else "shift unknown"
    cipher_table_string = "cipher table known" if cipher_table_mention else "cipher table unknown"
    with open(f"result/{method}/{encoder_decode}/{prompt_hyper.split()[0].rstrip()}/{shift_mention_string} {cipher_table_string}.txt", "a+") as f:
        f.write("\n================Begin of Experiment=====================\n")
        for idx, result in enumerate(results):
            f.write("Result of "+num2words(idx+1, to="ordinal_num")+" Run: \n"
                    "Accuracy:"+str(result[0]*100) + "% \n"
                    "Correct Count: " + str(result[1]) + "\n"
                    "Total Count:" + str(result[2]))
            f.write("\n====================================\n")
        f.write("Avg Accuracy over 5 runs: " + str(np.mean([x[0] for x in results]))+"\n")
        f.write("Standard Deviation over 5 runs: " + str(np.std([x[0] for x in results]))+"\n")
        f.write("Method: "+method+"\n")
        f.write("Shot Number: "+str(shot_num)+"\n")
        f.write("Shift Number: "+str(shift_num)+"\n")
        f.write("Shift Mention: "+str(shift_mention)+"\n")
        f.write("Table Mention: "+str(cipher_table_mention)+"\n")
        f.write("Arguments:" + str(args)+"\n")
        f.write("\n========== The prompt used is ====================\n")
        f.write(new_prompt+"\n")
        f.write("\n========== End of this experiment ====================\n")
        f.write("\n======================================================\n")

def load_llama(model_id):
    model_id = model_id
    hf_auth = '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'
    model_config = transformers.AutoConfig.from_pretrained(
        model_id,
        use_auth_token=hf_auth,
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_id,
        use_auth_token=hf_auth
    )

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        device_map="auto",
        use_auth_token=hf_auth,
    )
    model.eval()
    stop_list = ['\nHuman:', '\n```\n']
    generate_text = transformers.pipeline(
        model=model, tokenizer=tokenizer,
        return_full_text=True,  # langchain expects the full text
        task='text-generation',
        # we pass model parameters here too
        # stopping_criteria=stopping_criteria,  # without this model rambles during chat
        temperature=0.01,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
        max_new_tokens=8192,  # mex number of tokens to generate in the output
        repetition_penalty=1.1  # without this output begins repeating
    )
    llm = HuggingFacePipeline(pipeline=generate_text, model_kwargs={'temperature': 0.01})
    return llm