import os
import json
import argparse
from tqdm.auto import tqdm
import numpy as np
import torch
from models import NUFScorer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Scorer:

    def __init__(self, args):
        self.args = args

        NUF_path = os.path.join(args.weight_dir, 'NUF-ClASS.ckpt')

        self.NUF_model = NUFScorer.load_from_checkpoint(checkpoint_path=NUF_path).to(device)

        # load normalize score
        norm_score_path = os.path.join(args.weight_dir, 'mlm_minmax_score.json')
        self.norm_scores = None
        with open(norm_score_path) as f:
            self.norm_scores = json.load(f)
            f.close()



    def get_scores(self, contexts, responses, normalize=False):
        scores = []
        for c, r in tqdm(zip(contexts, responses)):
            if c.strip() == "" or r.strip() == "":
                continue
            score = self.get_score(c, r, normalize=normalize)
            scores.append(score)

        keys = scores[0].keys()
        avg_scores = {}
        for k in keys:
            arr = []
            for score in scores:
                arr.append(score[k])
            avg = sum(arr) / len(arr)
            avg_scores[k] = avg


        distinct_score = self.get_distinct(responses)
        for k,v in distinct_score.items():
            avg_scores[k] = v

        return avg_scores, scores

    def get_NUF(self, context, response):
        return self.NUF_model.predict(context, response)


