import os
import json
import argparse
from tqdm.auto import tqdm
import numpy as np
import torch
from collections import namedtuple
from models import ABBAScorer, distinct, composite_one_instance
from data_utils import encode_truncate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Scorer:

    def __init__(self, args):
        self.args = args
        ABBA_path = os.path.join(args.weight_dir, 'ABBA.ckpt')

        self.nup_model = ABBAScorer.load_from_checkpoint(checkpoint_path=ABBA_path).to(device)
        # load normalize score
        norm_score_path = os.path.join(args.weight_dir, 'mlm_minmax_score.json')
        self.norm_scores = None
        with open(norm_score_path) as f:
            self.norm_scores = json.load(f)
            f.close()



        print ('[!] loading models comlete')

    def get_scores(self, contexts, responses, normalize=False):
        scores = []
        for c, r in tqdm(zip(contexts, responses)):
            if c.strip() == "" or r.strip() == "":
                continue
            score = self.get_score(c, r, normalize=normalize)
            scores.append(score)

        keys = scores[0].keys()
        avg_scores = {}
        for k in keys:
            arr = []
            for score in scores:
                arr.append(score[k])
            avg = sum(arr) / len(arr)
            avg_scores[k] = avg


        return avg_scores, scores




