from __future__ import absolute_import, division, print_function

import argparse
from tqdm import tqdm, trange
import time


import torch
from torch.utils.data import Dataset
#from torcheval.metrics.functional import multiclass_f1_score
import random
import numpy as np
import pickle 
import os
import gzip
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from tokenization import BertTokenizer
from model import BertForSequenceClassification
from optimization import BertAdam
from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from train_utils import create_dataset, EarlyStopping, warmup_linear
import logging

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_eval', action='store_true')
    # Required parameters
    parser.add_argument("--train_dir",
                        default=None,
                        type=str,   
                        required=True,  
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--eval_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--attacked_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    # BERT model
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                                "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                                "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
     # Output Directory
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    # Max sequence length
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    # Uncased?
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    # Set batch size
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    # Batch size for evaluation
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    # Learning Rate for Adam
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    # Training epochs
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    
    parser.add_argument('--early_stop',
                        type=int,
                        default=5,
                        help="Number of epochs to wait before early stop")
    # ??
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--sample_seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--print_params',
                        action='store_true')
    parser.add_argument('--num_samples', type=int, default=1000, help='number of samples to train')
    #implicit_ode
    parser.add_argument('--method', type=str, default='gradientbased', help='set the desired ODE solver')


    args = parser.parse_args()

    # cuda or cpu
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        print('device: ', device, 'n_gpu: ', n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        # torch.distributed.init_process_group(backend='nccl')

    logger.info("device: {} n_gpu: {}".format(
        device, n_gpu))

    # Check for valid args
    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    # Set train batch size
    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    # Seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    num_labels_task = {
    "rte": 2,
    "sst2": 2,
    "qnli": 2,
    "qqp": 2,
    "mnli": 3,
    "mnli_mm": 3,
    }
    num_labels=num_labels_task[args.task_name]

    if args.do_train:
        train_dataset, dev_dataset=create_dataset(args.train_dir, args, type='train')    
        
        num_train_steps = int(
                len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
        
        model = BertForSequenceClassification.from_pretrained(args.bert_model, 
                                                            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE,
                                                            num_labels=num_labels)
        if args.fp16:
            model.half()

        model.to(device)
        
        #print the num of params
        if args.print_params:
            params=list(model.named_parameters())
            print('The BERT model has {:} different named parameters.\n'.format(len(params)))

            print('==== Embedding Layer ====\n')

            for p in params[0:5]:
                print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

            print('\n==== First Transformer ====\n')

            for p in params[5:21]:
                print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

            print('\n==== Output Layer ====\n')

            for p in params[-4:]:
                print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
        
        param_optimizer = list(model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        t_total = num_train_steps

        # Create Optimizer
        optimizer = BertAdam(optimizer_grouped_parameters,
                                lr=args.learning_rate,
                                warmup=args.warmup_proportion,
                                t_total=t_total)
        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        if args.num_samples != None:
            train_sampler = RandomSampler(train_dataset, num_samples=args.num_samples)
            dev_sampler = RandomSampler(dev_dataset, num_samples=args.num_samples)
        else:
            train_sampler = RandomSampler(train_dataset)
            dev_sampler = SequentialSampler(dev_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
        dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=args.eval_batch_size)

        # model train (train + eval)
        total_accu=0.6
        train_losses=[]
        valid_losses=[]
        avg_train_losses=[]
        avg_valid_losses=[]
        valid_accus=[]
        global_step=0
        early_stopper=EarlyStopping(args,path='checkpoint_best.bin', verbose=True)
        
        for epoch in range(int(args.num_train_epochs)):
            print('epoch: ', epoch+1)
            total_acc,total_loss, total_count = 0, 0,0
            bar_train = tqdm(train_dataloader, desc='Train')
            model.train()

            for step, batch in enumerate(bar_train):    
                b= tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = b
                loss, logits = model(input_ids, segment_ids, input_mask, label_ids) #pooled output ()
                total_loss += loss.item()
                total_count += label_ids.size(0)
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                optimizer.step()
                optimizer.zero_grad()
                bar_train.set_postfix(loss=loss.item())
                train_losses.append(loss.item())
                global_step += 1

            avg_train_losses.append(np.mean(train_losses))
            print('train loss: {:8.4f}'.format(avg_train_losses[-1]))

            model.eval()

            total_acc, total_count = 0, 0

            bar_eval = tqdm(dev_dataloader, desc='Eval')
            for step, batch in enumerate(bar_eval):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss, logits= model(input_ids, segment_ids, input_mask, label_ids)
                total_acc += (logits.argmax(1) == label_ids).sum().item()
                total_count += label_ids.size(0)
                bar_eval.set_postfix(acc=total_acc/total_count)
                valid_losses.append(loss.item())
            accu_val=total_acc/total_count
            avg_valid_losses.append(np.mean(valid_losses))
            valid_accus.append(accu_val)
            print('valid loss: {:8.4f}, valid acc: {:8.4f}'.format(avg_valid_losses[-1], accu_val))

            train_losses=[]
            valid_losses=[]   
            model_to_save = model.module if hasattr(model, 'module') else model

            if args.early_stop > 0:
                early_stopper(args, avg_valid_losses[-1], model)
                if early_stopper.early_stop:
                    print("Early stopping")
                    break
            else:
                model_to_save = model.module if hasattr(model, 'module') else model
                output_model_file = os.path.join(args.output_dir, 'checkpoint_best.bin')
                torch.save(model_to_save.state_dict(), output_model_file)
                
        
        del model
        del train_dataloader
        del dev_dataloader
        torch.cuda.empty_cache()       

    if args.do_eval:
        #attack test (best model -> attacked test+eval)
        model_state_dict=torch.load(os.path.join(args.output_dir, 'checkpoint_best.bin'))

        model=BertForSequenceClassification.from_pretrained(args.bert_model,
                                                            state_dict=model_state_dict,
                                                            num_labels=num_labels)
        
        model.to(device)
        model.eval()

        logger.info("***** Running Testing for attack *****")
        attack_list=[file for file in os.listdir(args.attacked_dir)]

        for att in attack_list:
            attacked=create_dataset(os.path.join(args.attacked_dir, att), args,type='test')
            att_sampler = SequentialSampler(attacked)
            at_dataloader=DataLoader(attacked, sampler=att_sampler, batch_size=8)
            total_acc, total_count = 0, 0
            total_recall, total_precision, total_f1=0,0,0
            at_losses=[]
            at_eval = tqdm(at_dataloader, desc='attacked')
            for step, batch in enumerate(at_eval): 
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss, logits= model(input_ids, segment_ids, input_mask, label_ids)
                pred=logits.argmax(1)
                total_acc += (pred == label_ids).sum().item()
                tp=(label_ids *pred).sum().to(torch.float32)
                tn=((1-label_ids)*(1-pred)).sum().to(torch.float32)
                fp=((1-label_ids)*pred).sum().to(torch.float32)
                fn=(label_ids*(1-pred)).sum().to(torch.float32)

                epsilon=1e-7
                precision=tp/(tp+fp+epsilon)
                recall=tp/(tp+fn+epsilon)
                f1=2*precision*recall/(precision+recall+epsilon)
                total_precision+=precision
                total_recall+=recall
                total_f1+=f1
                total_count += label_ids.size(0)
                at_eval.set_postfix(acc=total_acc/total_count)
                at_losses.append(loss.item())
            del attacked
            accu_val=total_acc/total_count
            precision_val=total_precision/total_count
            recall_val=total_recall/total_count
            f1_val=total_f1/total_count
            print('attacked method:{:s} attacked acc: {:8.4f}, attacked f1: {:8.4f}, attacked precision: {:8.4f}, attacked recall: {:8.4f}, attacked loss: {:8.4f}'.format(att,accu_val, f1_val, precision_val, recall_val ,np.mean(at_losses)))

if __name__ == "__main__":
    main()