import torch 
import random 
import numpy as np 
import time 
import os 
import csv 
import argparse 
import sys 
from tqdm import tqdm 
from torch.utils.data import TensorDataset, DataLoader, random_split 
from transformers import BertTokenizer, BertConfig 
from transformers import BertForSequenceClassification, AdamW 
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from transformers import DebertaTokenizer, DebertaModel, DebertaConfig, DebertaForSequenceClassification
from transformers import get_linear_schedule_with_warmup 
from adversarial_train import FreeLB, PGD, FGM 

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

# from sklearn.metrics import f1_score, accuracy_score

def flat_accuracy(preds, labels):
    
    """A function for calculating accuracy scores"""
    
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    acc = sum(int(t) for t in pred_flat==labels_flat) / len(pred_flat)
    # return accuracy_score(labels_flat, pred_flat)
    return acc 
# RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer


def get_data(type):
    '''SST-2 GLUE version
    '''
    print('Getting Data...')
    text = [] 
    label = [] 
    path = './data/SST-2/' + type + '.tsv'
    with open(path, 'r', encoding='utf8') as fin:
        for line in fin.readlines()[1:]:
            line = line.strip().split('\t')
            text.append(line[0])
            label.append(1. if line[1]=='1' else 0.)

    print('Done...')
    return text, label 

def encode_fn(tokenizer, text_list):
    all_input_ids = []    
    for text in text_list:
        input_ids = tokenizer.encode(
                        text,
                        truncation=True,                       
                        add_special_tokens = True,  # special tokens， CLS SEP
                        max_length = 40,           # 
                        # pad_to_max_length = True,   #   
                        padding = 'max_length',
                        return_tensors = 'pt'       # 
                   )
        all_input_ids.append(input_ids)    
    all_input_ids = torch.cat(all_input_ids, dim=0)
    return all_input_ids

def build_inputs(batch):
    '''
    Sent all model inputs to the appropriate device (GPU on CPU)
    rreturn:
     The inputs are in a dictionary format
    '''
    input_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'labels']
    batch = (batch[0].to(device), (batch[0]>0).to(device), None, batch[1].long().to(device))
    # batch = tuple(t.to(device) for t in batch)
    inputs = {key: value for key, value in zip(input_keys, batch)}
    return inputs

def s(args):
    train_texts, train_labels = get_data('train')
    np.random.seed(2021)
    np.random.shuffle(train_texts)
    np.random.seed(2021)
    np.random.shuffle(train_labels)
    with open('train.tsv', 'w', encoding='utf8') as f:
        for text, label in zip(train_texts[:60614], train_labels[:60614]):
            f.write('{}\t{}\n'.format(text, int(float(label))))

    with open('dev.tsv', 'w', encoding='utf8') as f:
        for text, label in zip(train_texts[60614:], train_labels[60614:]):
            f.write('{}\t{}\n'.format(text, int(float(label))))



def run(args):
    # _cls = '[CLS]'
    # _sep = '[SEP]'
    # _pad = '[PAD]'
    fada = args.fada 
    fada_path = args.fada_path 
    ada = args.ada 
    ada_path = args.ada_path 
    freelb = args.freelb 
    fgm = args.fgm 
    model_type = None
    # if args.base_model == 'bert':
    #     model_type = 'bert-base-uncased' 
    # elif args.base_model == 'roberta':
    #     model_type = 'roberta-base'
    # model_type = 'roberta-base'
    # model_type = 'bert-base-uncased'
    tokenizer, config, model = None, None, None  
    if args.base_model == 'roberta':
        model_type = 'roberta-base'
        tokenizer = RobertaTokenizer.from_pretrained(model_type)
        config = RobertaConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = RobertaForSequenceClassification.from_pretrained(model_type, config=config) 
    elif args.base_model == 'bert':
        model_type = 'bert-base-uncased' 
        tokenizer = BertTokenizer.from_pretrained(model_type, do_lower_case=True)
        # Load the pretrained BERT model
        config = BertConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = BertForSequenceClassification.from_pretrained(model_type, config=config)
    elif args.base_model == 'deberta':
        model_type = 'microsoft/deberta-base' 
        tokenizer = DebertaTokenizer.from_pretrained(model_type, do_lower_case=True)
        # Load the pretrained BERT model
        config = DebertaConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = DebertaForSequenceClassification.from_pretrained(model_type, config=config)
    model.cuda()



    train_texts, train_labels = get_data('train')
    # # GLUE version no test label
    # test_texts, test_labels = get_data('dev')
    if args.aug_only:
        train_texts, train_labels = [], [] 
    if fada:
        with open(fada_path, 'r', encoding='utf8') as fin:
            reader = csv.reader(fin)
            reader = list(reader)[1:]
            for line in reader:
                train_texts.append(str(line[7]))
                train_labels.append(int(float(line[5])))
    elif ada:
        with open(ada_path, 'r', encoding='utf8') as fin:
            reader = csv.reader(fin)
            reader = list(reader)[1:]
            for line in reader:
                if line[8] == 'Successful':
                    train_texts.append(str(line[7]))
                    train_labels.append(int(float(line[0])))
    print('Encoding Data...')
    all_train_ids = encode_fn(tokenizer, train_texts)
    labels = torch.tensor(train_labels)
    print('Done...')

    epochs = args.epochs
    batch_size = args.batch_size

    # Split data into train and validation
    dataset = TensorDataset(all_train_ids, labels)
    train_size = int(0.90 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create train and validation dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
    val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)

    
    # create optimizer and learning rate schedule
    optimizer = AdamW(model.parameters(), lr=2e-5)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    CELoss = torch.nn.CrossEntropyLoss()

    save_path = args.save_path 
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    best_epoch = 0
    best_val = 0.0

    if freelb:
        adv_trainer = FreeLB(adv_K=args.adv_K, adv_lr=args.adv_lr,
                                adv_init_mag=args.adv_init_mag,
                                adv_norm_type=args.adv_norm_type,
                                base_model=args.base_model,
                                adv_max_norm=args.adv_max_norm)
    elif fgm:
        adv_trainer = FGM(model, emb_name='word_embeddings', epsilon=args.fgm_epsilon)
    # keys = ['input_ids', 'attention_mask', 'token_type_ids', 'labels']
    print(save_path)
    print('Start Trainging...')
    for epoch in range(epochs):
        model.train()
        total_loss, total_val_loss = 0, 0
        total_eval_accuracy = 0
        for step, batch in tqdm(enumerate(train_dataloader)):
            model.zero_grad()

            inputs = build_inputs(batch)
            if freelb:
                loss, logits = adv_trainer.attack(model, inputs, gradient_accumulation_steps=args.gradient_accumulation_steps)
            elif fgm:
                outputs = model(**inputs)
                loss, logits = outputs[:2]
                loss.backward(retain_graph=True)

                adv_trainer.attack()
                adv_outputs = model(**inputs)
                loss_adv, logits_adv = outputs[:2]
                loss_adv.backward()
                adv_trainer.restore()
            else:
                outputs = model(**inputs)
                loss, logits = outputs[:2]
                loss.backward()


            # loss = CELoss(logits, batch[1].long().to(device))
            
            total_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step() 
            scheduler.step()

            if (step+1) % 100 == 0:
                logits = logits.detach().cpu().numpy()
                label_ids = batch[1].to('cpu').numpy()
                training_acc = flat_accuracy(logits, label_ids)
                print('\033[0;31;40m{}\033[0m,step:{},training_loss:{},training_acc:{}'.format(time.asctime(time.localtime(time.time())), step, loss.item(), training_acc))
            
        model.eval()
        for i, batch in enumerate(val_dataloader):
            with torch.no_grad():
                inputs = build_inputs(batch)
                outputs = model(**inputs)
                loss, logits = outputs[:2]
                # loss = CELoss(logits, batch[1].long().to(device))    
                total_val_loss += loss.item()
                
                logits = logits.detach().cpu().numpy()
                label_ids = batch[1].to('cpu').numpy()
                total_eval_accuracy += flat_accuracy(logits, label_ids)
        
        avg_train_loss = total_loss / len(train_dataloader)
        avg_val_loss = total_val_loss / len(val_dataloader)
        avg_val_accuracy = total_eval_accuracy / len(val_dataloader)
        if best_val < avg_val_accuracy:
            best_val = avg_val_accuracy
            best_epoch = epoch
        
        print(f'Train loss     : {avg_train_loss}')
        print(f'Validation loss: {avg_val_loss}')
        print(f'Val Accuracy: {avg_val_accuracy}')
        print(f'Best Val Accuracy: {best_val}')
        print('Best Epoch:', best_epoch)

        print('Save model...')
        torch.save(model.state_dict(), save_path+str(epoch)+'.pt')
        print('Done...')

def test(args):
    fada = args.fada
    freelb = args.freelb
    ada = args.ada 
    fgm = args.fgm 
    model_type = None
    if args.base_model == 'bert':
        model_type = 'bert-base-uncased' 
    elif args.base_model == 'roberta':
        model_type = 'roberta-base' 

    save_path = args.save_path
    print('------------------------------------------------------------')
    print(save_path)
    print('------------------------------------------------------------')
    tokenizer, config, model = None, None, None  
    if args.base_model == 'roberta':
        model_type = 'roberta-base'
        tokenizer = RobertaTokenizer.from_pretrained(model_type)
        config = RobertaConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = RobertaForSequenceClassification.from_pretrained(model_type, config=config) 
    elif args.base_model == 'bert':
        model_type = 'bert-base-uncased' 
        tokenizer = BertTokenizer.from_pretrained(model_type, do_lower_case=True)
        # Load the pretrained BERT model
        config = BertConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = BertForSequenceClassification.from_pretrained(model_type, config=config)
    elif args.base_model == 'deberta':
        model_type = 'microsoft/deberta-base' 
        tokenizer = DebertaTokenizer.from_pretrained(model_type, do_lower_case=True)
        # Load the pretrained BERT model
        config = DebertaConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
                        attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
        model = DebertaForSequenceClassification.from_pretrained(model_type, config=config)
    
    batch_size = 64
    # tokenizer = BertTokenizer.from_pretrained(model_type, do_lower_case=True)
    # train_texts, train_labels, test_texts, test_labels = get_data()
    test_texts, test_labels = get_data('dev')
    all_test_ids = encode_fn(tokenizer, test_texts)
    test_labels = torch.tensor(test_labels)
    pred_data = TensorDataset(all_test_ids, test_labels)
    pred_dataloader = DataLoader(pred_data, batch_size=batch_size, shuffle=False)

    # config = BertConfig.from_pretrained(model_type, num_labels=2, output_attentions=False, output_hidden_states=False, \
    #                 attention_probs_dropout_prob=args.attention_probs_dropout_prob, hidden_dropout_prob=args.hidden_dropout_prob,)
    # model = BertForSequenceClassification.from_pretrained(model_type, config=config)
    model.cuda()

    # model.load_state_dict(torch.load(save_path+str(args.best_epoch)+'.pt'))
        
    # model.eval()
    # total_test_accuracy = 0
    # for i, batch in tqdm(enumerate(pred_dataloader)):
    #     with torch.no_grad():
    #         outputs = model(batch[0].to(device), token_type_ids=None, attention_mask=(batch[0]>0).to(device))
    #         logits = outputs[0]
            
    #         logits = logits.detach().cpu().numpy()
    #         label_ids = batch[1].to('cpu').numpy()
    #         total_test_accuracy += flat_accuracy(logits, label_ids)
    # avg_test_accuracy = total_test_accuracy / len(pred_dataloader)

    # print('test_acc:{}'.format(avg_test_accuracy))

    for i in range(10):
        model.load_state_dict(torch.load(save_path+str(i)+'.pt'))
        
        model.eval()
        total_test_accuracy = 0
        for i, batch in tqdm(enumerate(pred_dataloader)):
            with torch.no_grad():
                outputs = model(batch[0].to(device), token_type_ids=None, attention_mask=(batch[0]>0).to(device))
                logits = outputs[0]
                
                logits = logits.detach().cpu().numpy()
                label_ids = batch[1].to('cpu').numpy()
                total_test_accuracy += flat_accuracy(logits, label_ids)
        avg_test_accuracy = total_test_accuracy / len(pred_dataloader)

        print('test_acc:{}'.format(avg_test_accuracy))


if __name__ == '__main__':
    argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve')
    argparser.add_argument("--fada", action='store_true', help="whether to use fada")
    argparser.add_argument("--fada_path", type=str, default='./data/FADA_sst.csv')
    argparser.add_argument("--ada", action='store_true', help="whether to use ada")
    argparser.add_argument("--ada_path", type=str, default='./data/ADA_sst.csv')
    argparser.add_argument("--freelb", action='store_true', help="whether to use freelb")
    argparser.add_argument("--fgm", action='store_true', help="whether to use fgm")
    argparser.add_argument("--fgm_epsilon", type=float, default=1.0, help="fgm epsilon")
    argparser.add_argument('--adv_lr', type=float, default=1e-2)
    argparser.add_argument('--adv_K', type=int, default=3, help="should be at least 1")
    argparser.add_argument('--adv_init_mag', type=float, default=2e-2)
    argparser.add_argument('--adv_norm_type', type=str, default="l2", choices=["l2", "linf"])
    argparser.add_argument('--adv_max_norm', type=float, default=0, help="set to 0 to be unlimited")
    argparser.add_argument('--base_model', default='bert')
    argparser.add_argument('--hidden_dropout_prob', type=float, default=0.1)
    argparser.add_argument('--attention_probs_dropout_prob', type=float, default=0)
    argparser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    argparser.add_argument('--best_epoch', type=int, default=0)
    argparser.add_argument('--epochs', type=int, default=10)
    argparser.add_argument('--batch_size', type=int, default=64)
    argparser.add_argument('--do_train', action='store_true')
    argparser.add_argument('--do_test', action='store_true')
    argparser.add_argument('--aug_only', action='store_true')
    argparser.add_argument('--freelb_o', action='store_true')
    argparser.add_argument('--save_path', type=str, default="")
    args = argparser.parse_args()
    set_seed(2021)
    device = torch.device('cuda')
    print(args)
    # s(args)
    if args.do_train:
        run(args)
    elif args.do_test:
        test(args)

