import json
from collections import namedtuple

from models.ABBAScorer import ABBAScorer

import argparse
import torch
import os
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def predict(args, X_test, metric):

    if metric == "lsc":# sub metric-LSC just use ABBA model to score the it---- response logic self consistency
        model = ABBAScorer.load_from_checkpoint(checkpoint_path=args.weight_path)

    else:
        raise Exception('Please select model from the following. lsc')

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        scores = []
        for x in X_test:
            if isinstance(x, str): # has a single string
                score = model.predict(x, normalize=args.normalize)
            else: # otherwise, a tuple of (c,r)
                score = model.predict(*x)
            scores.append(score)
        return scores

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='lsc inference script')
    parser.add_argument('--metric', type=str, default='lsc') # sub metric-LSC just use ABBA model to score the it---- response logic self consistency
    #parser.add_argument('--test-path', type=str, required=True, help='Path to the directory of testing set')
    parser.add_argument('--weight-path', type=str, default='./checkpoints/lsc.ckpt', help='Path to directory that stores the weight')
    args = parser.parse_args()

    datasets = os.listdir("../../dataset/dstc10-split-by-dialog-score")
    for each_dataset in tqdm(datasets):
        ctx = []
        res = []
        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_ctx.txt".format(each_dataset, each_dataset)) as f:
            for line in f:
                ctx.append(str(line.strip()))

        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_res.txt".format(each_dataset, each_dataset)) as f:
            for line in f:
               res.append(str(line.strip()))
        ctx_res = zip(ctx, res)

        scores = predict(args, ctx_res, args.metric)
        with open("abba_s/{}_score.json".format(each_dataset), "w", encoding='utf-8') as f:
            json.dump(scores, f)
        print("{}....score finished!".format(each_dataset))

