import numpy as np
import os
import random
import argparse
import csv

import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences

import transformers
from transformers import AlbertTokenizer

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AlbertTokenizer
from transformers import AlbertConfig, AlbertModel, AlbertForSequenceClassification
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
import sklearn

import torch

parser = argparse.ArgumentParser()

parser.add_argument('--input_dir',default=None,type=str)
parser.add_argument('--spm_model_path',default=None,type=str)
parser.add_argument('--init_checkpoint',default=None,type=str)
parser.add_argument('--albert_config_file',default=None,type=str)
parser.add_argument('--output_dir',default=None,type=str)
parser.add_argument('--max_seq_len',default=256,type=int)
parser.add_argument('--batch_size',default=10,type=int)
parser.add_argument('--epochs',default=5,type=int)
parser.add_argument('--lr',default=2e-5,type=float)
parser.add_argument('--do_train',default=True,type=bool)
parser.add_argument('--do_eval',default=True,type=bool)
parser.add_argument('--do_predict',default=True,type=bool)

args = parser.parse_args()

# If there's a GPU available...
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

def preprocessing(tokenizer, sentences,labels,datatype):
    assert datatype in ['train','dev','test'], "Wrong datatype, only train,dev,test allowed"
    input_ids = []
    for sent in sentences:
        encoded_sent = tokenizer.encode(sent, add_special_tokens = True)

        # Add the encoded sentence to the list.
        input_ids.append(encoded_sent)
    print('Original: ', sentences[0])
    print('Token IDs:', input_ids[0])
        
    MAX_LEN = args.max_seq_len
    print('\nPadding/truncating all sentences to %d values...' % MAX_LEN)
    print('\nPadding token: "{:}", ID: {:}'.format(tokenizer.pad_token, tokenizer.pad_token_id))

    input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", 
                              value=0, truncating="post", padding="post")
    print('Done.')
    
    attention_masks = []
    for sent in input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]

        attention_masks.append(att_mask)
    
    # Label processing for HoC dataset
    label_list = []
    aspect_value_list = [0,1]
    for i in range(10):
        for value in aspect_value_list:
            label_list.append(str(i) + "_" + str(value))

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    
    label_ids = []

    for label in labels:
        # get list of label
        label_id_list = []
        label_list = label.split(",")
        for label_ in label_list:
            label_id_list.append(label_map[label_])
        # convert to multi-hot style
        label_id = [0 for l in range(len(label_map))]
        for j, label_index in enumerate(label_id_list):
            label_id[label_index] = 1
        label_ids.append(label_id)

    pp_inputs = torch.tensor(input_ids)
    pp_labels = torch.tensor(label_ids)
    pp_masks = torch.tensor(attention_masks)

    pp_data = TensorDataset(pp_inputs, pp_masks, pp_labels)
    pp_sampler = RandomSampler(pp_data)
    if datatype == 'test' or datatype == 'dev':
        pp_dataloader = DataLoader(pp_data, sampler=None, batch_size=args.batch_size)
    else:
        pp_dataloader = DataLoader(pp_data, sampler=pp_sampler, batch_size=args.batch_size)
    
    return pp_data,pp_sampler,pp_dataloader

import time
import datetime
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

with open(os.path.join(args.input_dir,"train.tsv"),'r',encoding='utf-8') as f:
    rdr = csv.reader(f,delimiter='\t')
    rdr = list(rdr)[1:]
    sentences_train = [item[1] for item in rdr]
    labels_train = [item[0] for item in rdr]

with open(os.path.join(args.input_dir,"dev.tsv"),'r',encoding='utf-8') as f:
    rdr = csv.reader(f,delimiter='\t')
    rdr = list(rdr)[1:]
    sentences_dev = [item[1] for item in rdr]
    labels_dev = [item[0] for item in rdr]
    
with open(os.path.join(args.input_dir,"test.tsv"),'r',encoding='utf-8') as f:
    rdr = csv.reader(f,delimiter='\t')
    rdr = list(rdr)[1:]
    sentences_test = [item[1] for item in rdr]
    labels_test = [item[0] for item in rdr]

tokenizer = AlbertTokenizer(args.spm_model_path)

train_data,train_sampler,train_dataloader = preprocessing(tokenizer,sentences_train,labels_train,'train')
dev_data,dev_sampler,dev_dataloader = preprocessing(tokenizer,sentences_dev,labels_dev,'dev')
test_data,test_sampler,test_dataloader = preprocessing(tokenizer,sentences_test,labels_test,'test')

tf.io.gfile.makedirs(args.output_dir)

config = AlbertConfig.from_json_file(args.albert_config_file)
config.num_labels = 20 # len(processor.get_labels)
config.problem_type = "multi_label_classification"

model = AlbertForSequenceClassification.from_pretrained(args.init_checkpoint,config=config)
model.cuda()


# Get all of the model's parameters as a list of tuples.
params = list(model.named_parameters())
print('The ALBERT model has {:} different named parameters.\n'.format(len(params)))
print('==== Embedding Layer ====\n')
for p in params[0:5]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== First Transformer ====\n')
for p in params[5:21]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== Output Layer ====\n')
for p in params[-4:]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))


# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
# I believe the 'W' stands for 'Weight Decay fix"
optimizer = AdamW(model.parameters(),
                  lr = args.lr, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )
# Number of training epochs (authors recommend between 2 and 4)
epochs = args.epochs
# Total number of training steps is number of batches * number of epochs.

total_steps = len(train_dataloader) * args.epochs
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

# seed_val = 42
# random.seed(seed_val)
# np.random.seed(seed_val)
# torch.manual_seed(seed_val)
# torch.cuda.manual_seed_all(seed_val)

if args.do_train:
    # Store the average loss after each epoch so we can plot them.
    loss_values = []

    for epoch_i in range(0, epochs):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        print('Training...')
        
        # Measure how long the training epoch takes.
        t0 = time.time()
        
        # Reset the total loss for this epoch.
        total_loss = 0
        
        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()

        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):

            # Progress update every 40 batches.
            if step % 40 == 0 and not step == 0:

                # Calculate elapsed time in minutes.
                elapsed = format_time(time.time() - t0)

                # Report progress.
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))


            b_input_ids = batch[0].to(device).long()
            b_input_mask = batch[1].to(device).long()
            b_labels = batch[2].to(device).float()

            model.zero_grad()        

            outputs = model(input_ids = b_input_ids,
                            attention_mask = b_input_mask,
                            return_dict=True,
                            labels = b_labels
                           )

            loss = outputs['loss']

            total_loss += loss.item()

            # Perform a backward pass to calculate the gradients.
            loss.backward()

            # Clip the norm of the gradients to 1.0.
            # This is to help prevent the "exploding gradients" problem.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters and take a step using the computed gradient.
            # The optimizer dictates the "update rule"--how the parameters are
            # modified based on their gradients, the learning rate, etc.
            optimizer.step()

            # Update the learning rate.
            scheduler.step()

        # Calculate the average loss over the training data.
        avg_train_loss = total_loss / len(train_dataloader)

        # Store the loss value for plotting the learning curve.
        loss_values.append(avg_train_loss)
        print("")
        print("  Average training loss: {0:.5f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))
    print('Saving model ...')
    model.save_pretrained(os.path.join(args.output_dir))
    print("Done")

if args.do_eval:
    print("")
    print("Running Validation...")

    t0 = time.time()
    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    model.eval()
    # Tracking variables 
    eval_accuracy = 0
    nb_eval_steps = 0
    logits_all = []

    for batch in dev_dataloader:

        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        b_input_ids = batch[0].to(device).long()
        b_input_mask = batch[1].to(device).long()
        b_labels = batch[2].to(device).float()

        with torch.no_grad():        
            outputs = model(input_ids = b_input_ids,
                            attention_mask = b_input_mask,
                            return_dict=True,
                            labels = b_labels)
            
        logits = outputs['logits']
        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        logits_all += list(logits)

        # Track the number of batches
        nb_eval_steps += 1
    
    with open(os.path.join(args.output_dir,'eval_results.tsv'),'w',newline='') as f_prd:
        wr = csv.writer(f_prd,delimiter='\t')
        for logit in logits_all:
            wr.writerow(logit)
    print("  Validation took: {:}".format(format_time(time.time() - t0)))

if args.do_predict:
    print("")
    print("Running Predictiontion...")
    t0 = time.time()
    model.eval()

    logits_all = []

    for batch in test_dataloader:

        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        b_input_ids = batch[0].to(device).long()
        b_input_mask = batch[1].to(device).long()
        b_labels = batch[2].to(device).float()

        with torch.no_grad():        
            outputs = model(input_ids = b_input_ids,
                            attention_mask = b_input_mask,
                            return_dict=True,
                            labels = b_labels)
            
        logits = outputs['logits']
        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()

        logits_all += list(logits)


    with open(os.path.join(args.output_dir,'test_results.tsv'),'w',newline='') as f_prd:
        wr = csv.writer(f_prd,delimiter='\t')
        for logit in logits_all:
            wr.writerow(logit)
    print("  Prediction took: {:}".format(format_time(time.time() - t0)))


