from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, default_data_collator
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import csv
import pandas as pd
import os
os.chdir('/user/workspace')
from torch.utils.data import DataLoader
from utils import *
   
if __name__ == '__main__':
    task = 'mnli'
    num_labels = 3

    #dir_path = "/data/liu/binding/model"
    model_list = os.listdir(model_dir)
    torch.cuda.set_device(4)
    for model_name in model_list:
        #model_name 
        model_path = f'{model_dir}/{model_name}'
        print("%"*100)
        print(f"now turn to {model_path}\n")
        config = AutoConfig.from_pretrained(model_path, num_labels=num_labels, finetuning_task=task)
        print(config.label2id)

        tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=True)
        model_mnli = AutoModelForSequenceClassification.from_pretrained(model_path, config=config)
        bs_size = 128
        test_list = []
        if 'mnli' in model_name.lower():
            test_dir = "MNLI"
        else:
            test_dir = "SNLI"
        test_list = os.listdir(test_dir)
        for test_type in test_list:
            if 'test' not in test_type:
                continue
            dev_path = f'{test_dir}/{test_type}'
            if "mnli" in model_name:
                dev_premise,dev_hypo, dev_label = read_csv_data(dev_path,config.label2id,p_id=8,h_id=9,l_id=-1)
            else:
                dev_premise,dev_hypo, dev_label = read_csv_data(dev_path,config.label2id,p_id=7,h_id=8,l_id=-1)

            print("="*100)
            print(f"test dataset: {test_type}, from {dev_path}")
            print(f"test model from: {model_path}")
            dev_encoded = tokenizer(dev_premise, dev_hypo, truncation=True, padding='max_length', max_length=128)
            dev_dataset = BindDataset(dev_encoded,dev_label)
            data_loader = DataLoader(dev_dataset,batch_size=bs_size,shuffle=False)    
            classes = list(config.label2id.keys())

            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

            model_mnli.to(device)
            acc = 0
            pred = []
            error = 0
            n2c_error = 0
            n2e_error = 0
            e2c_error = 0
            e2n_error = 0
            c2e_error = 0
            c2n_error = 0


            for i, batch in enumerate(data_loader):
                with torch.no_grad():
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    # token_type_ids = batch['token_type_ids'].to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = batch['token_type_ids'].to(device)
                        outputs = model_mnli(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
                    else:
                        outputs = model_mnli(input_ids, attention_mask=attention_mask)
                    logits = outputs['logits']
                    output = logits.detach().cpu()
                    
                    poss = torch.softmax(output,dim=1).tolist()
                    pred_batch = np.argmax(poss,axis=1)

                    for idx in range(len(pred_batch)):
                        pred_index = pred_batch[idx]
                        pred_class = classes[pred_index]
                        if pred_index == dev_label[i*bs_size+idx]:
                            acc = acc+1
                        else:
                            error = error +1
                            if dev_label[i*bs_size+idx] == 1:
                                if pred_index == 0:
                                    n2e_error = n2e_error +1
                                elif pred_index == 2:
                                    n2c_error = n2c_error +1
                            elif dev_label[i*bs_size+idx] == 0:
                                if pred_index == 1:
                                    e2n_error = e2n_error+1
                                elif pred_index == 2:
                                    e2c_error = e2c_error+1
                            elif dev_label[i*bs_size+idx] == 2:
                                if pred_index == 0:
                                    c2e_error = c2e_error+1
                                elif pred_index == 1:
                                    c2n_error = c2n_error+1

                        pred.append(pred_class)

            print('acc:', acc/len(dev_premise))
            print('false entailment rate: ', (e2c_error+e2n_error)/error)
            print('entailment to neutral error rate:', e2n_error/error)
            print('entailment to contradiction error rate: ', e2c_error/error)

            print('false neutral rate: ', (n2e_error+n2c_error)/error)
            print('neutral to entailment error rate:', n2e_error/error)
            print('neutral to contradiction error rate: ', n2c_error/error)

            print('false contradiction rate: ', (c2e_error+c2n_error)/error)
            print('contradiction to entailment error rate:', c2e_error/error)
            print('contradiction to neutral error rate: ', c2n_error/error)

