from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, default_data_collator
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import csv
import pandas as pd
import os
os.chdir('/user/workspace')
from torch.utils.data import DataLoader
from utils import *
from natsort import natsorted

def count_num(input):
    h1n=0
    h2n=0
    h3n=0
    h4n=0
    h5n=0
    h6n=0
    for ind in range(0,len(input)):
        stat = input[ind] % 6
        if stat ==0:
            h1n=h1n+1
        elif stat ==1:
            h2n=h2n+1
        elif stat ==2:
            h3n=h3n+1
        elif stat ==3:
            h4n=h4n+1
        elif stat ==4:
            h5n=h5n+1
        elif stat ==5:
            h6n=h6n+1
        
    return h1n,h2n,h3n,h4n,h5n,h6n


def plot_h(h1,h2,h3,h4,h5,h6):
    h1_c = Counter(h1)
    h2_c = Counter(h2)
    h3_c = Counter(h3)
    h4_c = Counter(h4)
    h5_c = Counter(h5)
    h6_c = Counter(h6)

    r1=[]
    r2=[]
    r3=[]
    r4=[]
    r5=[]
    r6=[]

    r1.append(h1_c['entailment'])
    r1.append(h1_c['neutral'])
    r1.append(h1_c['contradiction'])

    r2.append(h2_c['entailment'])
    r2.append(h2_c['neutral'])
    r2.append(h2_c['contradiction'])

    r3.append(h3_c['entailment'])
    r3.append(h3_c['neutral'])
    r3.append(h3_c['contradiction'])

    r4.append(h4_c['entailment'])
    r4.append(h4_c['neutral'])
    r4.append(h4_c['contradiction'])

    r5.append(h5_c['entailment'])
    r5.append(h5_c['neutral'])
    r5.append(h5_c['contradiction'])

    r6.append(h6_c['entailment'])
    r6.append(h6_c['neutral'])
    r6.append(h6_c['contradiction'])

    h1n = r1[0] + r1[2]
    h2n = r2[0] + r2[2]
    h3n = r3[0] + r3[2]
    h4n = r4[0] + r4[2]
    h5n = r5[0] + r5[2]
    h6n = r6[0] + r6[2]

    
    print('Total',(r1[1]+r2[1]+r3[1]+r4[1]+r5[1]+r6[1])/6000)
    
    print('h1:N2 is A1,wrong:',h1n)
    print('[E N C]=:',r1)
    print('[E N C]=:',[i/1000 for i in r1 ],'\n')
    print('h2:N1 is A2,wrong:',h2n)
    print('[E N C]=:',r2)
    print('[E N C]=:',[i/1000 for i in r2 ],'\n')
    print('h3:N2 is A2,wrong:',h3n)
    print('[E N C]=:',r3)
    print('[E N C]=:',[i/1000 for i in r3 ],'\n')
    print('h4:N2 is not A1,wrong:',h4n)
    print('[E N C]=:',r4)
    print('[E N C]=:',[i/1000 for i in r4 ],'\n')
    print('h5:N1 is not A2,wrong:',h5n)
    print('[E N C]=:',r5)
    print('[E N C]=:',[i/1000 for i in r5 ],'\n')
    print('h6:N2 is not A2,wrong:',h6n)
    print('[E N C]=:',r6)
    print('[E N C]=:',[i/1000 for i in r6 ],'\n')



if __name__ == "__main__":
    bs_size = 128

    model_list = os.listdir(model_dir)
    task = 'mnli'
    num_labels = 3
    model_list = natsorted(model_list)
    for model_type in model_list:
        model_path = f'{model_dir}/{model_type}'
        name_idx = model_type.find('nli')
        model_name = model_type[:name_idx-2]
        dataset_name = model_type[name_idx-1:]
        save_path = f'{dataset_name}-model/core/{model_name}_N_is_A_core.csv'
        createDir(os.path.split(save_path)[0])

        config = AutoConfig.from_pretrained(model_path, num_labels=num_labels, finetuning_task=task)
        tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=True)
        model_mnli = AutoModelForSequenceClassification.from_pretrained(model_path, config=config)

        dev_path = 'result/data&code/DATA/NisA_coreference.csv'
        dev_premise,dev_hypo, dev_label = read_csv_data(dev_path,config.label2id)
        print('='*100)
        print(f'test model from {model_path}')
        print(f'test data from {dev_path}')
        print(f'save to {save_path}')

        dev_encoded = tokenizer(dev_premise, dev_hypo, truncation=True, padding='max_length', max_length=128)
        dev_dataset = BindDataset(dev_encoded,dev_label)
        data_loader = DataLoader(dev_dataset,batch_size=bs_size,shuffle=False)    
        classes = list(config.label2id.items())
        classes = sorted(classes,key=lambda x:x[1])
        classes = [pair[0] for pair in classes]
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        model_mnli.to(device)
        acc = 0
        pred = []
        erroL = []
        erroC = []
        accL = []
        accC = []

        h1_1 = []
        h1_2 = []
        h1_3 = []
        h1_4 = []
        h1_5 = []
        h1_6 = []

        h2_1 = []
        h2_2 = []
        h2_3 = []
        h2_4 = []
        h2_5 = []
        h2_6 = []


        for i, batch in enumerate(data_loader):
            with torch.no_grad():

                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                if 'token_type_ids' in batch:
                    token_type_ids = batch['token_type_ids'].to(device)
                    outputs = model_mnli(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
                else:
                    outputs = model_mnli(input_ids, attention_mask=attention_mask)
                logits = outputs['logits']
                output = logits.detach().cpu()
                
                poss = torch.softmax(output,dim=1).tolist()
                pred_batch = np.argmax(poss,axis=1)
                for idx in range(len(pred_batch)):
                    pred_index = pred_batch[idx]
                    pred_class = classes[pred_index]
                    if pred_index == dev_label[i*bs_size+idx]:
                        acc = acc+1
                        accL.append(i*bs_size+idx)
                        accC.append(pred_class)
                    else:
                        erroL.append(i*bs_size+idx)
                        erroC.append(pred_index)

                    pred.append(pred_index)

                    if (i*bs_size+idx) < 6000:
                        if (i*bs_size+idx)%6 == 0:
                            h1_1.append(pred_class)
                        elif (i*bs_size+idx)%6 == 1:
                            h1_2.append(pred_class)
                        elif (i*bs_size+idx)%6 == 2:
                            h1_3.append(pred_class)
                        elif (i*bs_size+idx)%6 == 3:
                            h1_4.append(pred_class)
                        elif (i*bs_size+idx)%6 == 4:
                            h1_5.append(pred_class)
                        elif (i*bs_size+idx)%6 == 5:
                            h1_6.append(pred_class)
                            
                    else:
                        if (i*bs_size+idx)%6 == 0:
                            h2_1.append(pred_class)
                        elif (i*bs_size+idx)%6 == 1:
                            h2_2.append(pred_class)
                        elif (i*bs_size+idx)%6 == 2:
                            h2_3.append(pred_class)
                        elif (i*bs_size+idx)%6 == 3:
                            h2_4.append(pred_class)
                        elif (i*bs_size+idx)%6 == 4:
                            h2_5.append(pred_class)
                        elif (i*bs_size+idx)%6 == 5:
                            h2_6.append(pred_class)

        dataframe = pd.DataFrame({'pred:':pred})
        dataframe.to_csv(save_path,sep=',')


        testL = [c[0].upper() for c in classes]
        print('total:', len(dev_premise))
        print('correct:', acc)
        print('wrong:',len(erroL),'\n')

        tri_C = 0
        tri_N = 0
        tri_E = 0

        Q_C = []
        Q_N = []
        Q_E = []

        for ind in range(0,len(erroL)):
            if testL[erroC[ind]] == 'C':
                tri_C = tri_C + 1
                Q_C.append(erroL[ind])
            elif testL[erroC[ind]] == 'E':
                tri_E = tri_E + 1
                Q_E.append(erroL[ind])
            else:
                tri_N = tri_N + 1
                

        print('Choose C:',tri_C)
        print(tri_C/len(dev_premise),'\n')
        print('Choose E:',tri_E)
        print(tri_E/len(dev_premise),'\n')
        print('Choose N:',acc)
        print(acc/len(dev_premise),'\n')


        print('P:N1 is A1.')
        plot_h(h1_1,h1_2,h1_3,h1_4,h1_5,h1_6)

        print('P:N3 is A3.')
        plot_h(h2_1,h2_2,h2_3,h2_4,h2_5,h2_6)
