import argparse
import json
import os

import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm

def fluency_score(responses, arg):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = GPT2Tokenizer.from_pretrained(arg.pretrained_model_path)
    model = GPT2LMHeadModel.from_pretrained(arg.pretrained_model_path)
    model.to(device)

    model.eval()
    nb_steps, eval_loss, exp_average_loss = 0, 0, None
    score_list = []

    with torch.no_grad():
        for utterence in tqdm(responses):
            # Put model in training mode.
            if not utterence:
                print('space sentence')
                score_list.append(1e6)
                continue
            utterence = encoder.encode(utterence)
            batch = torch.tensor([utterence]).to(device)
            output = model(batch, labels=batch)
            loss = output[0].item()
            # print
            # (loss*len(utterence))
            eval_loss += loss
            #nb_steps += 1

            score_list.append(loss)

    cutoff = np.quantile([-t for t in score_list], 0.05)
    modified_rating = np.array([cutoff if -t < cutoff else -t for t in score_list])
    normed_rating = (modified_rating - cutoff) / np.abs(cutoff)
    return normed_rating


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrained-model-path', default='../../ckpt/dialogpt-medium', help='model path of pretrained gpt2 finetuned on dataset')
    parser.add_argument('--file-path', default='../../dataset/dstc10-split-by-dialog-score/', help='input .csv file')

    parse = parser.parse_args()

    response = []

    datasets = os.listdir("../../dataset/dstc10-split-by-dialog-score")
    for each_dataset in tqdm(datasets):
        f = open("../../dataset/dstc10-split-by-dialog-score/{}/{}_all_res.txt".format(each_dataset, each_dataset), "r", encoding="utf-8")
        for line in f:
            response.append(line.strip())
        f.close()
        score_list = fluency_score(response, parse)
        #print(score_list)
        print(len(score_list))
        print(type(score_list))
        f = open("fluency_s/{}_score.json".format(each_dataset), "w", encoding='utf-8')
        json.dump(list(score_list), f)
        print("{}....score finished!".format(each_dataset))
