import json
import os
from collections import namedtuple

from tqdm import tqdm

from models.IESScorer import IESScorer

import argparse
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def predict(args, X_test, metric):

    if metric == "IES-CLASS":
        model = IESScorer.load_from_checkpoint(checkpoint_path=args.weight_path)

    else:
        raise Exception('Please select model from the following. IES-CLASS')

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        scores = []
        norm_scores = []
        for x in tqdm(X_test):
            if isinstance(x, str): # has a single string
                score = model.predict(x, normalize=args.normalize)
            else: # otherwise, a tuple of (c,r)
                score = model.predict(*x)
            scores.append(score)

        for x in scores:
            x = float(x - min(scores)) / (max(scores) - min(scores))
            norm_scores.append(x)
        return norm_scores

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='IES-CLASS inference script')
    parser.add_argument('--metric', type=str, default='IES-CLASS')
    #parser.add_argument('--test-path', type=str, required=True, help='Path to the directory of testing set')
    parser.add_argument('--weight-path', type=str, default='./checkpoints/IES-CLASS.ckpt', help='Path to directory that stores the weight')
    args = parser.parse_args()



    datasets = os.listdir("../../dataset/dstc10-split-by-dialog-score")
    
     #IES-CLASS Metirc 在DSTC10的打分
    for each_dataset in tqdm(datasets):
        ctx = []
        res = []
        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_ctx.txt".format(each_dataset, each_dataset)) as f:
            for line in f:
                ctx.append(str(line.strip()))

        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_res.txt".format(each_dataset, each_dataset)) as f:
            for line in f:
               res.append(str(line.strip()))
        ctx_res = zip(ctx, res)

        scores = predict(args, ctx_res, args.metric)
        with open("IES-CLASS-s/{}_score.json".format(each_dataset), "w", encoding='utf-8') as f:
            json.dump(scores, f)
        print("{}....score finished!".format(each_dataset))


    '''IES-CLASS Metirc 在IES-Dataset 的打分'''



    '''ctx = []
    res = []
    with open("../../dataset/dstc10-split-by-quality-for-train/IES/ctx.txt") as f:
        for line in f:
            ctx.append(str(line.strip()))

    with open("../../dataset/dstc10-split-by-quality-for-train/IES/res.txt") as f:
        for line in f:
            res.append(str(line.strip()))
    ctx_res = zip(ctx, res)

    scores = predict(args, ctx_res, args.metric)
    with open("IES-CLASS-on-IESDataset-s.json", "w", encoding='utf-8') as f:
        json.dump(scores, f)
    print("score finished!")'''

