import os
import argparse
import logging
import ujson

import torch
from accelerate.utils import set_seed
from transformers import BertConfig, BertTokenizer
from sklearn.metrics import classification_report

from util.dataloader import load_data
from util.encoder import BertForSequenceClassification, RobertaForSequenceClassification

MODEL_CLASS_MAPPING = {
    'bert-base-chinese': BertForSequenceClassification,
    'hfl/chinese-roberta-wwm-ext-large': RobertaForSequenceClassification
}



def parse_args():
    parser = argparse.ArgumentParser(
        'Finetune a transformers model on a text classification task.')
    parser.add_argument(
        '--max_length',
        type=int,
        default=512,
        help=
        ('The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,'
         ' sequences shorter will be padded if `--pad_to_max_length` is passed.'
         ))
    parser.add_argument(
        '--pad_to_max_length',
        action='store_true',
        help=
        'If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.'
    )
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        help=
        'Path to pretrained model or model identifier from huggingface.co/models.',
        required=True)
    parser.add_argument(
        '--per_device_test_batch_size',
        type=int,
        default=16,
        help='Batch size (per device) for the training dataloader.')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='A seed for reproducible training.')
    parser.add_argument('--speaker_role', type=str, choices=['client', 'counselor'])

    args = parser.parse_args()
    return args


def main():
    args = parse_args()


    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
    )

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    datasets = load_data(filenames=['test'])
    test_dataset = datasets['test']
    print(test_dataset)
    print(len(datasets))
    for item in test_dataset:
        if len(item['text']) > 512:
            print(item)

    config = BertConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=11,
        problem_type="multi_label_classification",
        finetuning_task='text classification')
    tokenizer = BertTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=False)

    MODEL_CLASS = MODEL_CLASS_MAPPING[args.model_name_or_path]

    model = MODEL_CLASS(config, args.model_name_or_path)
    model_name_or_path: str = args.model_name_or_path
    model_name_or_path = model_name_or_path.replace('/', '-')
    PATH = f'out/{model_name_or_path}/{args.seed}/pytorch_model.bin'
    model.load_state_dict(torch.load(PATH))
    model.cuda()

    total_length = len(test_dataset)
    print('total_length:', total_length)

    test_preds = []
    test_trues = []

    model.eval()
    test_loss = 0
    accuracy = 0
    new_data = []
    for item in test_dataset:
        text = item['text']
        label = item['label']
        result = tokenizer(text=text,
                           padding='max_length',
                           max_length=512,
                           truncation=False,
                           add_special_tokens=True,
                           return_token_type_ids=True,
                           return_tensors='pt')

        result = result.to('cuda')
        
        labels = torch.tensor([label]).cuda()
        result['labels'] = labels
        with torch.no_grad():
            outputs = model(**result)
            loss = outputs.loss
            test_loss += loss.detach().float()
        predictions = torch.sigmoid(outputs.logits).ge(0.5).int()
        print(torch.sigmoid(outputs.logits))
        golden_labels = labels.int()
        print(predictions)
        print(golden_labels)
        

        batch_accuracy = sum(row.all().int().item() for row in (predictions == golden_labels))
        accuracy += batch_accuracy

        for row in (predictions == golden_labels):
            item['predict_label'] = predictions[0].detach().cpu().tolist()
            if row.all().int().item() == 0:
                # 预测错误
                item['flag'] = False
            else:
                item['flag'] = True
            print(item)
            new_data.append(item)

        print('--------')

        test_preds.extend(list(predictions.detach().cpu().numpy()))
        test_trues.extend(list(golden_labels.detach().cpu().numpy()))

    print(f'accuracy: {accuracy/total_length}')
    print(f'validation loss: {test_loss}')

    report = classification_report(test_trues, test_preds, digits=5)
    print(f'report: \n{report}')
    target_dir = f'./statistics/{model_name_or_path}'
    os.makedirs(target_dir, exist_ok=True)
    print(new_data[0])
    with open(f'{target_dir}/{args.seed}.json', 'w', encoding='utf-8') as f:
        ujson.dump(new_data, f, ensure_ascii=False, indent=2)


if __name__ == '__main__':
    main()
