from models.ABACScorer import ABACScorer
from datasets import ABACDataset
import argparse
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def predict(args, X_test, metric):
    if metric == "ABAC":  # sub metric-LTR just use ABAC model to score the score----last turn relevance
        model = ABACScorer.load_from_checkpoint(checkpoint_path=args.weight_path)
    else:
        raise Exception('Please select model :ABAC metric ')

    model = model.to(device)
    model.eval()
    with torch.no_grad():
        scores = []
        for x in X_test:
            if isinstance(x, str):  # single string
                score = model.predict(x, normalize=args.normalize)
            else: # a tuple of (c,r)
                score = model.predict(*x)
            scores.append(score)
        return scores

def read_dataset(path):
    arr = []
    with open(path) as f:
        for line in f:
            ctx_len = len(line)
            for idx, char in enumerate(reversed(line)):
                if idx == 0:
                    continue
                else:
                    if char == '.':
                        ctx_len = idx
                        break
                    if char == '?':
                        ctx_len = idx
                        break
                    if char == ':':
                        ctx_len = idx
                        break
                    if char == '!':
                        ctx_len = idx
                        break
                continue
            line = reversed(line)
            line = reversed( line[0:ctx_len]  )
            sents = [ x.strip() for x in line.split('\t') ]
            if len(sents) == 1:
                arr.append(sents[0])
            else:
                arr.append(sents)
        f.close()
    return arr

if __name__ == "__main__":    # sub metric-LTR just use ABAC model to score the score----last turn relevance
    parser = argparse.ArgumentParser()
    parser.add_argument('--metric', type=str, required=True)
    parser.add_argument('--weight-path', type=str, default='./checkpoints', help='Path to directory that stores the weight')
    parser.add_argument('--normalize', action='store_true', help='option for MLM whether to do normalization or not')
    parser.add_argument('--test-path', type=str, required=True, help='Path to the directory of testing set')
    args = parser.parse_args()
    test_data = read_dataset(args.test_path)
    scores = predict(args, test_data, args.metric)
    print (scores)
