import ujson
import os
import argparse

import torch
from sklearn.metrics import classification_report

from model import MODE

def set_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--times', type=str, required=True, choices=['iter1', 'iter2', 'iter3'])
    parser.add_argument('--seed', type=str, required=True)
    parser.add_argument('--checkpoint_dir', type=str, required=True)

    return parser.parse_args()

args = set_args()


model_path = f'./output-chatglm2-2-{args.seed}/{args.checkpoint_dir}'
glm2_tokenizer = MODE['glm2']['tokenizer'].from_pretrained(model_path)
glm2_model = MODE['glm2']['model'].from_pretrained(model_path)
glm2_model.cuda()
print('done')

from uni_user_prompt import fixed_instruction
def template(text: str):
    return f'''{fixed_instruction}
用户输入：{text}
类别：'''

_label_map = {
        '自杀未遂': 0,
        '自杀准备行为': 1,
        '自杀计划': 2,
        '主动自杀意图': 3,
        '被动自杀意图': 4,
        '用户攻击行为': 5,
        '他人攻击行为': 6,
        '自伤行为': 7,
        '自伤意图': 8,
        '关于自杀的探索': 9,
        '与自杀/自伤/攻击行为无关': 10
    }

def convert_labels_into_list(labels: list):
    label = [0] * 11
    output_flag = True
    try:
        for el in labels:
            idx = _label_map[el]
            label[idx] = 1
    except:
        output_flag = False
    return label, output_flag

if __name__ == '__main__':
    print('begin')
    correct = 0
    total = 0
    test_preds = []
    test_trues = []
    with open(f'./data/test.json', 'r', encoding='utf-8') as f:
        data = ujson.load(f)
    for item in data:
        text = item['text']
        labels = item['labels']
        label, golden_flag = convert_labels_into_list(labels=labels)
        item['label'] = label
        item['golden_flag'] = golden_flag
        instruction = template(text=text)

        response, history = glm2_model.chat(
            glm2_tokenizer,
            instruction,
            history=[])
        item['predicted_desc'] = response

        predict_labels = response.split('，')
        predict_label, predict_flag = convert_labels_into_list(labels=predict_labels)
        item['predict_label'] = predict_label
        item['predict_flag'] = predict_flag
        test_preds.append(predict_label)
        test_trues.append(label)
        if label == predict_label:
            item['flag'] = True
            correct += 1
        else:
            item['flag'] = False
            print('********')
            print(item)
            print(response)
            print('********')
        total += 1
    print('total:', total)
    print('correct:', correct)
    print('accuracy:', correct/total)

    report = classification_report(test_trues, test_preds, digits=5)
    print(f'report: \n{report}')

    os.makedirs(f'./statistics/chatglm2/{args.checkpoint_dir}/{args.seed}', exist_ok=True)
    with open(f'./statistics/chatglm2/{args.checkpoint_dir}/{args.seed}/{args.times}.json', 'w', encoding='utf-8') as f:
        ujson.dump(data, f, ensure_ascii=False, indent=2)

    print('done')
