import os
import json
import argparse
from tqdm.auto import tqdm
import numpy as np
import torch
from collections import namedtuple
from models import VUPScorer
from data_utils import encode_truncate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Scorer:

    def __init__(self, args):
        self.args = args
        vup_path = os.path.join(args.weight_dir, 'BERT-VUP.ckpt')

        self.vup_model = VUPScorer.load_from_checkpoint(checkpoint_path=vup_path).to(device)


        print ('[!] loading models comlete')

    def get_scores(self, contexts, responses, normalize=False):
        scores = []
        for c, r in tqdm(zip(contexts, responses)):
            if c.strip() == "" or r.strip() == "":
                continue
            score = self.get_score(c, r, normalize=normalize)
            scores.append(score)

        keys = scores[0].keys()
        avg_scores = {}
        for k in keys:
            arr = []
            for score in scores:
                arr.append(score[k])
            avg = sum(arr) / len(arr)
            avg_scores[k] = avg


        return avg_scores, scores



    def get_vup(self, response):
        return self.vup_model.predict(response)

