from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, default_data_collator
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import csv
import pandas as pd
import os
os.chdir('/user/workspace')
from torch.utils.data import DataLoader
from utils import *
from natsort import natsorted
def count_num(input):
    h1n=0
    h2n=0
    h3n=0
    h4n=0
    h5n=0
    h6n=0
    for ind in range(0,len(input)):
        stat = input[ind] % 6
        if stat ==0:
            h1n=h1n+1
        elif stat ==1:
            h2n=h2n+1
        elif stat ==2:
            h3n=h3n+1
        elif stat ==3:
            h4n=h4n+1
        elif stat ==4:
            h5n=h5n+1
        elif stat ==5:
            h6n=h6n+1
        
    return h1n,h2n,h3n,h4n,h5n,h6n



if __name__ == "__main__":
    task = 'mnli'
    num_labels = 3
    bs_size = 128
    model_list = os.listdir(model_dir)
    model_list = natsorted(model_list)

    for model_type in model_list:
        model_path = f'{model_dir}/{model_type}'
        name_idx = model_type.find('nli')
        model_name = model_type[:name_idx-2]
        dataset_name = model_type[name_idx-1:]
        save_path = f'result/{dataset_name}-model/Extended/{model_name}_extended_conjun.csv'
        createDir(os.path.split(save_path)[0])

        config = AutoConfig.from_pretrained(model_path, num_labels=num_labels, finetuning_task=task)
        tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=True)
        model_mnli = AutoModelForSequenceClassification.from_pretrained(model_path, config=config)

        dev_path = 'data&code/DATA/extended_conjun.csv'
        dev_premise,dev_hypo, dev_label = read_csv_data(dev_path,config.label2id)
        print('='*100)
        print(f'test model from {model_path}')
        print(f'test data from {dev_path}')
        print(f'save to {save_path}')

        dev_encoded = tokenizer(dev_premise, dev_hypo, truncation=True, padding='max_length', max_length=128)
        dev_dataset = BindDataset(dev_encoded,dev_label)
        data_loader = DataLoader(dev_dataset,batch_size=bs_size,shuffle=False)    
        classes = list(config.label2id.items())
        classes = sorted(classes,key=lambda x:x[1])
        classes = [pair[0] for pair in classes]
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model_mnli.to(device)
        acc = 0
        pred = []
        erroL = []
        erroC = []
        accL = []
        accC = []
        h = []

        for i, batch in enumerate(data_loader):
            with torch.no_grad():

                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                if 'token_type_ids' in batch:
                    token_type_ids = batch['token_type_ids'].to(device)
                    outputs = model_mnli(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
                else:
                    outputs = model_mnli(input_ids, attention_mask=attention_mask)
                logits = outputs['logits']
                output = logits.detach().cpu()
                
                poss = torch.softmax(output,dim=1).tolist()
                pred_batch = np.argmax(poss,axis=1)
                for idx in range(len(pred_batch)):
                    pred_index = pred_batch[idx]
                    pred_class = classes[pred_index]
                    if pred_index == dev_label[i*bs_size+idx]:
                        acc = acc+1
                        accL.append(i*bs_size+idx)
                        accC.append(pred_class)
                    else:
                        erroL.append(i*bs_size+idx)
                        erroC.append(pred_index)

                    pred.append(pred_index)

                    h.append(pred_class)



        dataframe = pd.DataFrame({'pred:':pred})
        dataframe.to_csv(save_path,sep=',')

        print('acc:', acc/len(dev_premise))
        print('\t')

        testL = [c[0].upper() for c in classes]


        print('total:', len(dev_premise))
        print('correct:', acc)
        print('wrong:',len(erroL),'\n')

        tri_C = 0
        tri_N = 0
        tri_E = 0

        Q_C = []
        Q_N = []
        Q_E = []

        for ind in range(0,len(erroL)):
            if testL[erroC[ind]] == 'C':
                tri_C = tri_C + 1
                Q_C.append(erroL[ind])
            elif testL[erroC[ind]] == 'E':
                tri_E = tri_E + 1
                Q_E.append(erroL[ind])
            else:
                tri_N = tri_N + 1
                

        print('Choose C:',tri_C)
        print(tri_C/len(dev_premise),'\n')
        print('Choose E:',tri_E)
        print(tri_E/len(dev_premise),'\n')
        print('Choose N:',acc)
        print(acc/len(dev_premise),'\n')

        h1_c = Counter(h)

        r1=[]

        r1.append(h1_c['entailment'])
        r1.append(h1_c['neutral'])
        r1.append(h1_c['contradiction'])
        h1n=r1[0]+r1[2]

        print('Extended Conjunction,wrong:',h1n)
        print('[E N C]=:',r1)
        print('[E N C]=:',[i/6000 for i in r1 ],'\n')