import numpy as np
import random
from tensorflow.python.ops.gen_random_ops import TruncatedNormal
from transformers import AutoTokenizer, AutoConfig, AdamW
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer,EvalPrediction
import torch
import csv
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
import copy
from itertools import cycle
from transformers.file_utils import BACKENDS_MAPPING
import os 
os.chdir('/user/workspace')
from utils import *
def flat_accuracy(preds, labels):
    
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return accuracy_score(labels_flat, pred_flat)


if __name__ == "__main__":
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    classes = ['entailment','neutral','contradiction']

    dataset_name = "MNLI"
    train_path = f'{dataset_name}/train.tsv'
    dev_path = f'{dataset_name}/dev.tsv'
    
    # training data and dev data path

    model_path = 'model/bert-base-cased'
    model_name = os.path.split(model_path)[-1]
    # model path

    is_regression=False


    num_labels = 3
    max_len = 128
    weight_decay = 0.01
    lr_rate = 2e-5
    batch_size = 16
    epochs = 3
    # hyperparameter for traning, see Appendix Table 1

    aug_training = False
    # set the aug_training flag to True if performing augmented training

    tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=True)
    config = AutoConfig.from_pretrained(model_path, num_labels=num_labels, finetuning_task='mnli',do_lower_case=False)
    model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config)

    train_premise,train_hypo, train_label = read_csv_data(train_path,config.label2id,p_id=8,h_id=9,l_id=-1)
    dev_premise,dev_hypo, dev_label = read_csv_data(dev_path,config.label2id,p_id=8,h_id=9,l_id=-1)

    if aug_training:
        augmented_path = 'data&code/DATA/aug_train.csv'
        aug_premise,aug_hypo, aug_label = read_csv_data(augmented_path,config.label2id)

        train_premise.extend(aug_premise)
        train_hypo.extend(aug_hypo)
        train_label.extend(aug_label)

 
    train_encoded = tokenizer(train_premise, train_hypo, truncation=True, padding='max_length',max_length=max_len)
    dev_encoded = tokenizer(dev_premise, dev_hypo, truncation=True, padding='max_length',max_length=max_len)


    train_dataset = BindDataset(train_encoded,train_label)
    dev_dataset = BindDataset(dev_encoded,dev_label)

    model.to(device)


    train_loader = DataLoader(train_dataset, batch_size=int(batch_size/2), shuffle=True)
    dev_loader = DataLoader(dev_dataset,batch_size=batch_size,shuffle=True)
    optim = AdamW(model.parameters(), lr=lr_rate,weight_decay=weight_decay)
    total_steps = len(train_loader) * epochs

    metric_name = 'accuracy'

    min_loss = 10
    best_model = None
    for epoch in range(epochs):
        model.train()
        epoch_dev_loss = 0
        epoch_dev_acc = 0
        step_in_epoch = 0
        for step, batch in enumerate(train_loader):
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            if 'token_type_ids' in batch:
                token_type_ids = batch['token_type_ids'].to(device)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels,token_type_ids=token_type_ids)
            else:
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs[0]
            logits = outputs[1]

            logits = logits.detach().cpu().numpy()
            label_ids = batch['labels'].to('cpu').numpy()

            if step%50 ==0:
                print(f'steps: {step:.1f} loss: ', loss.item())
                step_in_epoch = step_in_epoch+1
            loss.backward()
            optim.step()
            
            if step_in_epoch>10:
                
                model.eval()
                for i, batch in enumerate(dev_loader):
                    with torch.no_grad():
                        input_ids = batch['input_ids'].to(device)
                        attention_mask = batch['attention_mask'].to(device)
                        labels = batch['labels'].to(device)
                        if 'token_type_ids' in batch:
                            token_type_ids = batch['token_type_ids'].to(device)
                            outputs = model(input_ids, attention_mask=attention_mask, labels=labels,token_type_ids=token_type_ids)
                        else:
                            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                        loss = outputs[0]
                        logits = outputs[1]
                
                        logits = logits.detach().cpu().numpy()
                        label_ids = batch['labels'].to('cpu').numpy()
                        epoch_dev_acc += flat_accuracy(logits, label_ids)
                        epoch_dev_loss += loss.item()

                step_in_epoch = 0

                epoch_dev_loss = epoch_dev_loss / len(dev_loader)
                epoch_dev_acc = epoch_dev_acc / len(dev_loader)

                if min_loss>epoch_dev_loss:
                    best_model = copy.deepcopy(model)
                    min_loss = epoch_dev_loss
                # save the best model based on the loss of dev set

                print(f'Validation loss: {epoch_dev_loss:.2f}')
                print(f'Validation Accuracy: {epoch_dev_acc:.2f}')
                print('\n')
                epoch_dev_loss = 0
                epoch_dev_acc = 0
                step_in_epoch = 0

    best_model.save_pretrained(f'model/{model_name}-{dataset_name.lower()}-{"-aug" if aug_training else ""}')
    tokenizer.save_pretrained(f'model/{model_name}-{dataset_name.lower()}-{"-aug" if aug_training else ""}')
    config.save_pretrained(f'model/{model_name}-{dataset_name.lower()}-{"-aug" if aug_training else ""}')
    # output path








