import json
import os
from tqdm import tqdm
import argparse
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def predict(args, X_test, metric):
    scores = []
    norm_scores = []
    for context, response in X_test:
        ctx_len = len(context)
        res_len = len(response)
        for idx, char in enumerate(reversed(context)):
            if idx == 0:
                continue
            else:
                if char == '.':
                    ctx_len = idx
                    break
                if char == '?':
                    ctx_len = idx
                    break
                if char == ':':
                    ctx_len = idx
                    break
                if char == '!':
                    ctx_len = idx
                    break
            continue
        ratio = min(float(res_len / ctx_len), 5.0)
        scores.append(ratio)

        x = float(x - min(scores)) / (max(scores) - min(scores))
        norm_scores.append(x)

    return norm_scores


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Length Ratio inference script')
    parser.add_argument('--metric', type=str, default='LR')
    # parser.add_argument('--test-path', type=str, required=True, help='Path to the directory of testing set')
    parser.add_argument('--weight-path', type=str, default='./checkpoints/IES-CLASS.ckpt',
                        help='Path to directory that stores the weight')
    args = parser.parse_args()

    datasets = os.listdir("../../dataset/dstc10-split-by-dialog-score")

    for each_dataset in tqdm(datasets):
        ctx = []
        res = []
        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_ctx.txt".format(each_dataset, each_dataset),
                  encoding="utf-8") as f:
            for line in f:
                ctx.append(str(line.strip()))

        with open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_res.txt".format(each_dataset, each_dataset),
                  encoding="utf-8") as f:
            for line in f:
                res.append(str(line.strip()))
        ctx_res = zip(ctx, res)

        scores = predict(args, ctx_res, args.metric)

        with open("LengthRatio_s/{}_score.json".format(each_dataset), "w", encoding='utf-8') as f:
            json.dump(scores, f)
        print("{}....score finished!".format(each_dataset))
