import time

from transformers import AutoTokenizer, OPTForCausalLM
from llm_serving.model.wrapper import get_model
import Evaluation.resources_usage as resources_usage
import torch

#-eval_model opt -model_variant opt-175b -templates 0,1,2,3,4,5,6,7,8,9,10 -evaluation_aspects style_transfer_accuracy,content_preservation,naturalness -sentiments positive,negative -tst_models  ARAE_lambda_1,CAAE_rho_0_5,DAR_gamma_15

class OPT():
    #opt-13b is not deployable with alpa, too big for 32GB VRAM
    model_sizes = {'small': ['opt-125m', 'opt-350m', 'opt-1.3b', 'opt-2.7b', 'opt-6.7b','opt-13b'],
                   'large': ['opt-30b', 'opt-66b', 'opt-175b']}

    def __init__(self, model_variant, batch_size):
        self.model_variant = model_variant
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        t0 = time.time()
        if model_variant in OPT.model_sizes['small']:
            tokenizer, model = self.__get_small__(model_variant)
        elif model_variant in OPT.model_sizes['large']:
            tokenizer, model = self.__get_large__(model_variant, batch_size)
        print(f"Done with Load the model: {(time.time() - t0) / 60} minutes")
        resources_usage.print_cpu_usage()
        resources_usage.print_memory_usage()
        resources_usage.print_GPU_usage()
        self.model = model
        self.tokenizer = tokenizer

    def __get_small__(self, model_variant):
        tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_variant, use_fast=False, padding_side='left')
        tokenizer.add_bos_token = False
        model = OPTForCausalLM.from_pretrained("facebook/" + model_variant).to(self.device)
        return tokenizer, model

    def __get_large__(self, model_variant, batch_size):
        model_variant_path = {'opt-6.7b': '/path/to/opt/opt-6.7b',
                              'opt-30b': '/path/to/opt/opt-30b',
                              'opt-66b': '/path/to/opt/opt-66b', 'opt-175b': '/path/to/OPT175B'}
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False, padding_side='left')
        tokenizer.add_bos_token = False
        model = get_model(model_name="alpa/" + model_variant, path=model_variant_path[model_variant],
                          batch_size=batch_size)
        return tokenizer, model

    def compute(self, prompts, sample_decimals=False):
        if self.model_variant in OPT.model_sizes['small']:
            input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids.to(self.device)
        elif self.model_variant in OPT.model_sizes['large']:
            input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids
        outputs = self.model.generate(
            input_ids=input_ids,
            max_length=int(len(input_ids[0])) + 64,
            temperature=1.0,
            top_p=0.6,
            do_sample=True
        )
        outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return outputs
