import transformers
from llm_pretest.utils import parse_naturalness_args, get_prompt_info, get_key_name
from tqdm import tqdm
from typing import Dict
import json
import os
from utils import read_file
import torch
from llm_pretest.prompts_getter import get_prompt

NEED_DEVICE_MAP = ["facebook/opt-iml-30b", "facebook/opt-iml-max-30b", "bigscience/bloomz", "bigscience/bloomz-mt",
                   "bigscience/bloomz-7b1", "facebook/opt-13b"]

NEED_PARALELLIZE = ["bigscience/bloomz", "bigscience/bloomz-mt", "facebook/opt-iml-30b"]


BASE_ARGS = {"do_sample": True, "max_new_tokens": 50, "temperature": 0.7}


def get_hf_model(model_name: str, cache_dir: str):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_args = {'pretrained_model_name_or_path': model_name,
                  'cache_dir': cache_dir}
    if model_name in NEED_DEVICE_MAP:
        model_args['device_map'] = 'auto'
        model_args['offload_folder'] = "offload"

    model = transformers.AutoModelForCausalLM.from_pretrained(**model_args)
    tokenizer = transformers.AutoTokenizer.from_pretrained(**model_args)

    if model_name not in NEED_DEVICE_MAP:
        if model in NEED_PARALELLIZE:
            model.parallelize()
        else:
            model = model.to(device)

    return model, tokenizer


def parse_prediction(pred: str) -> int:

    naturalness_score = None
    found_num = False

    sentences = pred.split('\n')
    for sentence in sentences:
        if found_num:
            break
        words = sentence.split(' ')
        for word in words:
            if word.replace('.', '').isnumeric():
                naturalness_score = int(word.replace('.', ''))
                found_num = True
                break
    return naturalness_score


def retrieve_model_predictions(model, tokenizer, generation_args: Dict, prompt: str, num_predictions: int = 15):

    """
    Retrieves the model naturalness predictions.
    """

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    all_predictions = list()
    with torch.no_grad():
        for i in range(num_predictions):
            inputs_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            output = model.generate(inputs_ids, **generation_args)
            decoded_output = tokenizer.decode(output[i, inputs_ids.shape[1]:])
            all_predictions.append(parse_prediction(decoded_output))

    return all_predictions


def main(input_path: str, output_path: str, prompt_info: str, model_names: str, cache_path: str,
         num_predictions: int = 10) -> None:

    samples = read_file(input_path)

    model_names = model_names.split(',')
    prompts = prompt_info.split(',')

    if os.path.exists(output_path):
        done_samples = read_file(output_path)
        for sample in done_samples:
            for new_sample in samples:
                if sample['sample_id'] == new_sample['sample_id'] and 'hf_model_results' in sample:
                    new_sample['hf_model_results'] = sample['hf_model_results']
                    break

    for mod_name in model_names:
        model, tokenizer = get_hf_model(mod_name, cache_path)

        for prompt in prompts:
            add_ex, prompt_name, prompt_args = get_prompt_info(prompt)
            curr_key_name = get_key_name(mod_name, prompt.replace("__", "##"), BASE_ARGS)

            for i, sample in tqdm(enumerate(samples)):

                if "hf_model_results" in sample and curr_key_name in sample["hf_model_results"]:
                    continue

                prompt = get_prompt('completion', prompt_name, prompt_args, sample["sentence"], add_ex)
                scores = retrieve_model_predictions(model, tokenizer, BASE_ARGS, prompt, num_predictions)

                if "hf_model_results" not in sample:
                    sample["hf_model_results"] = dict()

                sample["hf_model_results"][curr_key_name] = scores

                with open(output_path, 'w') as f:
                    for new_sample in samples:
                        f.write(json.dumps(new_sample) + '\n')

    with open(output_path, 'w') as f:
        for new_sample in samples:
            f.write(json.dumps(new_sample) + '\n')


if __name__ == "__main__":

    args = parse_naturalness_args()
    main(**vars(args))