import os, torch, re

import numpy as np

from baselines.MetricClass import MetricClass
from vllm import LLM, SamplingParams

from utils.dsba import HG_SUMMARY_RELEVANCE, HG_SUMMARY_FACTUALITY, HG_SUMMARY_FLUENCY, HG_SUMMARY_COHERENCE


class DSBA(MetricClass):
    name = 'DSBA'

    def __init__(self, model, cache_dir="/work/<REDACTED>/.cache", *args, **kwargs):
        self.sampling_params = SamplingParams(temperature=0, max_tokens=180)
        if "gptq" in model.lower():
            self.llm = LLM(model=model,quantization="gptq",download_dir=cache_dir, enforce_eager=True)
        else:
            self.llm = LLM(model=model,download_dir =cache_dir)

    def get_score(self, text):
        # Return last number as score
        try:
            found = re.findall("[-+]?(?:\d*\.*\d+)", text)
            return float(found[-1])

        except Exception as e:
             print(e)
             return np.NaN

    def order_scores(self, scores):
        ordered_scores = []
        for i in range(0, len(scores), 4):
            ordered_scores.append({
                "relevance": scores[i],
                "factuality": scores[i+1],
                "fluency": scores[i+2],
                "coherence": scores[i+3]
            })
            ordered_scores[-1]["avg"] = sum(ordered_scores[-1].values()) / 4
        return ordered_scores


    def __call__(self, gt, hyp):
        prompt_list = []
        for g, h in zip(gt, hyp):
            prompt_list.append(HG_SUMMARY_RELEVANCE.format(src=g, hyp=h))
            prompt_list.append(HG_SUMMARY_FACTUALITY.format(src=g, hyp=h))
            prompt_list.append(HG_SUMMARY_FLUENCY.format(src=g, hyp=h))
            prompt_list.append(HG_SUMMARY_COHERENCE.format(src=g, hyp=h))

        outputs = self.llm.generate(prompt_list, self.sampling_params)
        processed_results = [m.outputs[0].text for i, m in enumerate(outputs)]
        processed_results += ['MISSING'] * (len(prompt_list) - len(processed_results))
        scores = [self.get_score(p) for p in processed_results]
        o = self.order_scores(scores)

        return {"scores": [a["avg"] for a in o], "all_scores": o, "texts": processed_results}



if __name__ == '__main__':
    #b = DSBA(model="Open-Orca/OpenOrca-Platypus2-13B")
    #print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"]))#

    #del b.llm
    #del b
    #torch.cuda.empty_cache()

    b = DSBA(model="TheBloke/Platypus2-70B-Instruct-GPTQ")
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"]))

    del b.llm
    del b
    torch.cuda.empty_cache()

    b = DSBA(model="NousResearch/Nous-Hermes-13b")
    print(b(["A test sentence", "Sentence B"],["So Cummings was told that these units must be preserved in their entirety.", "Satz B"]))