import os
import torch

from baselines.MetricClass import MetricClass
from vllm import LLM, SamplingParams

from utils.mqm_gemba_utils import parse_mqm_answer, apply_template_shared_task_like_mt


class LocalGembaMQM(MetricClass):
    name = 'localGembaMqm'

    def __init__(self, model, cache_dir="/work/<REDACTED>/.cache", *args, **kwargs):
        self.sampling_params = SamplingParams(temperature=0, max_tokens=180)
        if "gptq" in model.lower():
            self.llm = LLM(model=model,quantization="gptq", download_dir=cache_dir)
        else:
            self.llm = LLM(model=model, download_dir=cache_dir)

    def __call__(self, gt, hyp, src_lang, hyp_lang):
        prompt_list = [apply_template_shared_task_like_mt(src_lang, hyp_lang, g, h) for g, h in zip(gt, hyp)]
        outputs = self.llm.generate(prompt_list, self.sampling_params)
        processed_results = [m.outputs[0].text for i, m in enumerate(outputs)]
        processed_results += ['MISSING'] * (len(prompt_list) - len(processed_results))
        return ([parse_mqm_answer(p) for p in processed_results], processed_results)

    def evaluate_df(self, df, src_lang, tgt_lang):
        return self.__call__(df['SRC'], df['HYP'], src_lang, tgt_lang)



if __name__ == '__main__':
    b = localGembaMqm(model="Open-Orca/OpenOrca-Platypus2-13B")
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"], "English", "German"))

    del b
    torch.cuda.empty_cache()

    b = localGembaMqm(model="TheBloke/Platypus2-70B-Instruct-GPTQ")
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"], "English", "German"))

    del b
    torch.cuda.empty_cache()

    b = localGembaMqm(model="NousResearch/Nous-Hermes-13b")
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"], "English", "German"))
