from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
import json
import re
from torch.utils.data import Dataset, DataLoader

from huggingface_hub import login

mapping = {
    1: 'science/technology', 
    2: 'travel', 
    3: 'politics', 
    4: 'sports', 
    5: 'health', 
    6: 'entertainment', 
    7: 'geography'
}

PATTERN_1 = re.compile(r"Translate the following text from (?P<src_lang>[A-Za-z\s]+) to (?P<tar_lang>[A-Za-z\s]+)[.:]\s{0,2}\n{0,2}(?:\w+: )?(?P<txt>.*?)(\s*\n\n\w+:)?")
PATTERN_2 = re.compile(r"(?P<txt>.*) the previous text is in (?P<src_lang>[A-Za-z\s]+). Here is a translation to (?P<tar_lang>[A-Za-z\s]+)")

class InferenceDataset(Dataset):
    def __init__(self, data):
        self.all_samples = data

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, index):
        sample = self.all_samples[index]

        instruction = sample["instruction"]
        type = sample["task"]
            
        return instruction, type

def preprocess_translation(instruction: str):
    match = PATTERN_2.match(instruction)

    if match:
        res_dict = match.groupdict()
        source, dest, text = res_dict["src_lang"].strip(), res_dict["tar_lang"].strip(), res_dict["txt"].strip()
        return f"Translate the following text from {source} to {dest}. \n\n{source}: {text}\" \n\n{dest}:"

    return instruction

def extract_info(instruction: str):
    match_1 = PATTERN_1.match(instruction)
    match_2 = PATTERN_2.match(instruction)
    match = match_1 if bool(match_1) else match_2

    if match:
        res_dict = match.groupdict()
        return res_dict["txt"]

def truncate_QA(answer, symbol="Answer: "):
    return answer[answer.find(symbol) + len(symbol):].strip()

def truncate_translation(answer, instruction):
    assert "\n" not in extract_info(instruction)

    truncate_instr = answer[answer.find(instruction) + len(instruction):].strip()

    return truncate_instr[:truncate_instr.find("\n")].strip()

def truncate_classify(answer):
    if answer.find("Topic: ") != -1:
        truncate_topic = answer[answer.find("Topic:") + 6:]
    else:
        truncate_topic = answer[answer.find("Best:") + 5:]

    return truncate_topic[:truncate_topic.find("\n")].strip()

def run_inference(model_name, model, tokenizer, device):
    # all_files = [
    #     "benchmark_hau_806.json",
    #     "benchmark_ibo_806.json", 
    #     "benchmark_zul_602.json",
    #     "benchmark.json"
    # ]
    all_files = [ 
        "benchmark_swa_602.json",
        "benchmark_yor_602.json",
        "benchmark_zul_602.json", 
        "benchmark.json"
    ]

    for file in all_files:
        if os.path.getsize(file) == 0:
            continue

        all_samples = []
        
        print(f"Start inferencing on {file}")

        with open(file, "r") as f:
            data = json.load(f)

            for i, sample in enumerate(data):
                instruction = sample["instruction"]
                type = sample["task"]

                if type == "translation":
                    instruction = preprocess_translation(instruction)

                if not type == "topic-classification":
                    all_samples.append(sample)
                    continue
                else:
                    prefix = "You are an African Language Expert who only answer the one best option from ['science/technology', 'travel', 'politics', 'sports', 'health', 'entertainment', 'geography']"
                    instruction = "Explain and " + instruction
                    instruction = prefix + instruction

                inputs = tokenizer(instruction, return_tensors="pt").to(device)

                outputs = model.generate(
                    **inputs,
                    max_new_tokens=512,
                    eos_token_id=[
                            tokenizer.eos_token_id,
                            tokenizer.convert_tokens_to_ids("<|eot_id|>")
                        ],
                    do_sample=True,
                    num_return_sequences=1,
                )  

                response = outputs[0]
                answer = tokenizer.decode(response, skip_special_tokens=True)

                if type == "QA":
                    output = truncate_QA(answer)
                elif type == "topic-classification":
                    output = truncate_classify(answer)
                    tmp = output
                    try: 
                        output_pr = output.replace('.', '')
                        _ = int(output_pr)
                        idx = int(output_pr[0])

                        if idx == 0:
                            idx = 1

                        output = mapping[idx]
                    except ValueError:
                        try:
                            float(output)
                            if float(output) == 0:
                                output = 'science/technology'
                            else:
                                output = mapping[int(float(output))]
                        except (ValueError, KeyError):
                            if output == "0.0%":
                                output = 'science/technology'
                            else:
                                res = ""
                                earliest = len(output) + 1
                                for topic in mapping.values():
                                    loc = output.find(topic)
                                    if loc != -1 and loc < earliest:
                                        res = topic
                                        earliest = loc

                                output = res if res else tmp
                
                    sample['output'] = output
                else:
                    assert type == "translation"
                    output = truncate_translation(answer, instruction)

                print(f"Completed {i + 1} / {len(data)}")

                sample["output"] = output
                all_samples.append(sample)

        with open(f"{model_name}_{file}", "w", encoding='utf-8') as f:
            json.dump(all_samples, f, indent=6, ensure_ascii=False)

if __name__ == "__main__":
    login("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")
    # login("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")

    # model_name = "meta-llama/Meta-Llama-3-8B"
    model_name = "meta-llama/Llama-2-7b-hf"
    device="cuda:0"

    torch.cuda.empty_cache()

    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
    )

    benchmark_folder = "/mnt/disk/yfsun/africadata/inference_llama_2_base"
    model_name = "Llama2"

    run_inference(model_name, model, tokenizer, device)

    # Playground, used to observer Patterns

    # instruction_QA = "Use the following pieces of context to answer the provided question.\nThomas Joseph Odhiambo Mboya (15August 19305July 1969) was a Kenyan trade unionist, educator, Pan-Africanist, author, independence activist, and statesman. He was one of the founding fathers of the Republic of Kenya. He led the negotiations for independence at the Lancaster House Conferences and was instrumental in the formation of Kenya's independence party – the Kenya African National Union (KANU) – where he served as its first Secretary-General. He laid the foundation for Kenya's capitalist and mixed economy policies at the height of the Cold War and set up several of the country's key labour institutions. Mboya's intelligence, charm, leadership, and oratory skills won him admiration from all over the world. He gave speeches, participated in debates and interviews across the world in favour of Kenya's independence from British colonial rule. He also spoke at several rallies in the goodwill of the civil rights movement in the United States. In 1958, at the age of 28, Mboya was elected Conference Chairman at the All-African Peoples' Conference convened by Kwame Nkrumah of Ghana. He helped build to the Trade Union Movement in Kenya, Uganda and Tanzania, as well as across Africa. He also served as the Africa Representative to the International Confederation of Free Trade Unions (ICFTU). In 1959, Mboya called a conference in Lagos, Nigeria, to form the first All-Africa ICFTU labour organization.. \n Question: Eneo wanako ishi wakamba nchini Kenya linaitwaje?\nProvide the answer in (English) based on the context available."
    # instruction_tr = "Translate the following text from English to Chichewa. \n\nEnglish: I didn't do it.\" \n\nChichewa:"
    # instruction_tr = "Translate the following text from French to Hausa. \n\nFrench: Malheureusement, nous l’avons perdu ce matin avant que l’hélicoptère arrive à Palu. \n\nHausa:"
    # instruction_tr = "ሌላ ጥቁር ሰዉ?! the previous text is in Amharic. Here is a translation to English "
    # instruction_processed = preprocess_translation(instruction_tr)
    prefix = "You are an African Language Expert who only answer the one best option from ['science/technology', 'travel', 'politics', 'sports', 'health', 'entertainment', 'geography']"
    instruction_classify = "Explain and " + "Classify the text \"Isjijiidka sagxadda biyaha ayaa wuxuu u dhacaa qaybaha biyaha ayaa isjiidanaayo in ka badan inta ay jiidanayaan hawada korkooda ah.\" into the following topics:\n- science/technology\n- travel\n- politics\n- sports\n- health\n- entertainment\n- geography\nTopic: "

    # # sequences_QA = pipe(
    #     instruction_QA,
    #     do_sample=True,
    #     top_k=1,
    #     num_return_sequences=1,
    #     eos_token_id=tokenizer.eos_token_id,
    #     max_length=1000
    # )[0]["generated_text"]

    # # print(sequences_QA)

    # print(truncate_QA(sequences_QA))

    # sequences_tr = pipe(
    #     instruction_processed,
    #     do_sample=False,
    #     top_k=1,
    #     num_return_sequences=1,
    #     eos_token_id=tokenizer.eos_token_id,
    #     max_length=200,
    # )[0]["generated_text"]

    # print("1 " + instruction_processed)
    # print("2" + sequences_tr)
    # prefix = "You are an African Language Expert who only answer the one best option from ['science/technology', 'travel', 'politics', 'sports', 'health', 'entertainment', 'geography']"
    instruction_classify = "Explain and Classify the text \"Mkpụrụ ndụ dị oke mkpa na ọmụmụ gbasara ndụ, n’ezie, a na-akpọkarị ha “ntọala obibi nke ndụ”.\" into the following topics:\n- science/technology\n- travel\n- politics\n- sports\n- health\n- entertainment\n- geography\nTopic: "

    # messages = [
    #     {"role": "system", "content": "You are an African Language Expert who only answer the one best option from ['science/technology', 'travel', 'politics', 'sports', 'health', 'entertainment', 'geography']"},
    #     {"role": "user", "content": instruction_classify},
    # ]

    # input_ids = tokenizer.apply_chat_template(
    #     messages,
    #     add_generation_prompt=True,
    #     return_tensors="pt"
    # ).to(model.device)
    inputs = tokenizer(instruction_classify, return_tensors="pt").to(device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        eos_token_id=[
                tokenizer.eos_token_id,
                tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ],
        do_sample=True,
        num_return_sequences=1,
    )  

    response = outputs[0]
    # response = outputs[0][input_ids.shape[-1]:]
    result = tokenizer.decode(response, skip_special_tokens=True)

    print(result)

    # print(answer)

    print("\n")

    print(truncate_classify(result))