import argparse
import torch
from torch import nn
from mask_lstm_model import *
from getvectors import getVectors
import os
import numpy as np
import random
import dill
import torch.optim as optim
import random
import pickle


os.environ["CUDA_VISIBLE_DEVICES"]="1"
parser = argparse.ArgumentParser(description='MASK_LSTM text classificer')
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate')
parser.add_argument('-beta', type=float, default=1, help='beta')
parser.add_argument('--weight_decay', default=0, type=float, help='adding l2 regularization')
parser.add_argument('--clip', type=float, default=1, help='gradient clipping')
parser.add_argument('-epochs', type=int, default=400, help='number of epochs for training')
parser.add_argument('-batch-size', type=int, default=32, help='batch size for training')
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
parser.add_argument('-dropout', type=float, default=0.2, help='the probability for dropout')
parser.add_argument('-embed-dim', type=int, default=300, help='original number of embedding dimension')
parser.add_argument('-lstm-hidden-dim', type=int, default=100, help='number of hidden dimension')
parser.add_argument('-lstm-hidden-layer', type=int, default=1, help='number of hidden layers')
parser.add_argument('-mask-hidden-dim', type=int, default=300, help='number of hidden dimension')
parser.add_argument("--max_sent_len", type=int, dest="max_sent_len", default=10000, help='max sentence length')
parser.add_argument("--activation", type=str, dest="activation", default="tanh", help='the choice of \
        non-linearity transfer function')
parser.add_argument('--save', type=str, default='masklstm_20210806.pt', help='path to save the final model')
parser.add_argument('--mode', type=str, default='static', help='available models: static, non-static')
parser.add_argument('--gpu', default=0, type=int, help='0:gpu, -1:cpu')
parser.add_argument('--gpu_id', default='0', type=str, help='gpu id')
parser.add_argument('--seed', type=int, default=1111, help='random seed')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id

dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)


def random_seed():
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    return


if args.gpu > -1:
    args.device = "cuda"
else:
    args.device = "cpu"


args.save = "masklstm_20210806.pt"


#########################################################################
#########################################################################
#########################################################################
max_length = 10000
train_num = 300
train_unlabeled_num = 2170
dev_num = 99
test_num = 99
threshoud_num = 0.5

#word_each_sentence_max_num = 25
#sentencece_max_length = 400

train_text_numpy, train_unlabeled_text_numpy, train_label_numpy, dev_text_numpy, dev_label_numpy, test_text_numpy, test_label_numpy = pickle.load(open("lstm_data_real_sentencec.pckl", 'rb'))

train_text_list = list(train_text_numpy)
for i in range(train_num):
    train_text_list[i] = torch.from_numpy(train_text_list[i])
train_label_list = list(train_label_numpy)
for i in range(train_num):
    temp_int_value = train_label_list[i]
    if temp_int_value == 1:
        train_label_list[i] = torch.tensor([1])
    else:
        train_label_list[i] = torch.tensor([0])


train_unlabeled_text_list = list(train_unlabeled_text_numpy)
for i in range(train_unlabeled_num):
    train_unlabeled_text_list[i] = torch.from_numpy(train_unlabeled_text_list[i])


dev_text_list = list(dev_text_numpy)
for i in range(dev_num):
    dev_text_list[i] = torch.from_numpy(dev_text_list[i])    
dev_label_list = list(dev_label_numpy)
for i in range(dev_num):
    temp_int_value = dev_label_list[i]
    if temp_int_value == 1:
        dev_label_list[i] = torch.tensor([1])
    else:
        dev_label_list[i] = torch.tensor([0])


test_text_list = list(test_text_numpy)
for i in range(test_num):
    test_text_list[i] = torch.from_numpy(test_text_list[i])
test_label_list = list(test_label_numpy)
for i in range(test_num):
    temp_int_value = test_label_list[i]
    if temp_int_value == 1:
        test_label_list[i] = torch.tensor([1])
    else:
        test_label_list[i] = torch.tensor([0])

wordvocab = data.wordvocab

dic_index2word = {}
dic_word2index = {}
temp_index = -1
for each_key in wordvocab:
    temp_index += 1
    dic_index2word[temp_index] = each_key
    dic_word2index[each_key] = temp_index

vectors = getVectors(args, wordvocab)



train_text = train_text_list
train_unlabeled_text = train_unlabeled_text_list
train_label = train_label_list
dev_text = dev_text_list
dev_label = dev_label_list
test_text = test_text_list
test_label = test_label_list



args.embed_num = len(wordvocab)
args.class_num = 2


class B:
    text = torch.zeros(1).to(args.device)
    label = torch.zeros(1).to(args.device)


def batch_from_list(textlist, labellist):
    batch = B()
    batch.text = textlist[0]
    batch.label = labellist[0]
    #batch.label = torch.tensor([labellist[0]])
    for txt, la in zip(textlist[1:], labellist[1:]):
        batch.text = torch.cat((batch.text, txt), 0)
        # you may need to change the type of "la" to torch.tensor for different datasets, sorry for the inconvenience
        #batch.label = torch.cat((batch.label, la), 0) # for SST and IMDB dataset, you do not need to change "la" type
        #la = torch.tensor([la])
        batch.label = torch.cat((batch.label, la), 0) # for SST and IMDB dataset, you do not need to change "la" type
    batch.text = batch.text.to(args.device)
    batch.label = batch.label.to(args.device)
    return batch

def batch_from_list_unlabeled(textlist, labellist):
    batch = B()
    batch.text = textlist[0]
    batch.label = labellist[0]
    #batch.label = torch.tensor([labellist[0]])
    for txt, la in zip(textlist[1:], labellist[1:]):
        batch.text = torch.cat((batch.text, txt), 0)
        # you may need to change the type of "la" to torch.tensor for different datasets, sorry for the inconvenience
        #batch.label = torch.cat((batch.label, la), 0) # for SST and IMDB dataset, you do not need to change "la" type
        #la = torch.tensor([la])
        batch.label = torch.cat((batch.label, la), 0) # for SST and IMDB dataset, you do not need to change "la" type
    batch.text = batch.text.to(args.device)
    batch.label = batch.label.to(args.device)
    return batch


# evaluate
def evaluation(model, data_text, data_label):
    model.eval()
    acc, loss, size = 0, 0, 0
    count = 0
    for stidx in range(0, len(data_label), args.batch_size):
        count += 1
        batch = batch_from_list(data_text[stidx:stidx + args.batch_size],
                                data_label[stidx:stidx + args.batch_size])
        pred = model(batch, 'eval')

        batch_loss = criterion(pred, batch.label)
        loss += batch_loss.item()

        _, pred = pred.max(dim=1)
        acc += (pred == batch.label).sum().float()
        size += len(pred)

    acc /= size
    loss /= count
    return loss, acc

# evaluate aopc
def evaluation_aopc(model, data_text, data_label):
    model.eval()
    acc, loss, size = 0, 0, 0
    count = 0
    total_pred_pro = []
    for stidx in range(0, len(data_label), args.batch_size):
        count += 1
        batch = batch_from_list(data_text[stidx:stidx + args.batch_size],
                                data_label[stidx:stidx + args.batch_size])
        pred = model(batch, 'eval')
        #print("pred:" + str(pred))
        
        temp_pred = pred.cpu().detach().numpy()
        if count == 1:
            total_pred_pro = temp_pred
        else:
            total_pred_pro = np.concatenate((total_pred_pro, temp_pred) , axis=0)
        
        
        batch_loss = criterion(pred, batch.label)
        loss += batch_loss.item()

        _, pred = pred.max(dim=1)
    
        acc += (pred == batch.label).sum().float()
        size += len(pred)

    acc /= size
    loss /= count
    return loss, acc, total_pred_pro


def main():
    # load model
    model = MASK_LSTM(args, vectors)
    model.to(torch.device(args.device))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.train()
    best_val_acc = None
    beta = args.beta
    
    #flag = "train"
    flag = "train"
    
    total_num_train_examples = 0
    
    if flag == "train":
    
        for epoch in range(1, args.epochs+1):
            total_num_train_examples = 0
            model.train()
            print("\n## The {} Epoch, All {} Epochs ! ##".format(epoch, args.epochs))
            lstm_count = 0
            lstm_count_unlabeled = 0
            lstm_count_unlabeled_reverse = 0
            trn_lstm_size, trn_lstm_corrects, trn_lstm_loss, trn_lstm_loss_unlabeled, trn_lstm_loss_unlabeled_reverse = 0, 0, 0, 0, 0
    
            # shuffle
            textlist1 = train_text.copy()
            labellist1 = train_label.copy()
            listpack = list(zip(textlist1, labellist1))
            random.shuffle(listpack)
            textlist1[:], labellist1[:] = zip(*listpack)
            
            #unlabeled part to test
            textlist2 = train_unlabeled_text.copy()
            #listpack2 = list(textlist2)
            #random.shuffle(listpack2)
            #textlist2[:] = listpack2
    
            for stidx in range(0, len(labellist1), args.batch_size):
                lstm_count += 1
                batch = batch_from_list(textlist1[stidx:stidx + args.batch_size],
                                        labellist1[stidx:stidx + args.batch_size])
                #batch_unlabeled_pair = batch_from_list_unlabeled(textlist2[stidx:stidx + args.batch_size],
                #                        labellist1[stidx:stidx + args.batch_size])
                pred = model(batch, 'train')
                optimizer.zero_grad()
                model_loss = criterion(pred, batch.label)
                 
                 
                #pred_batch_unlabeled_pair = model(batch_unlabeled_pair, 'train_unmask')
                #pred_batch_unlabeled_pair_mask = model(batch_unlabeled_pair, 'train')
                #unlabeled_loss = torch.norm(pred_batch_unlabeled_pair - pred_batch_unlabeled_pair_mask)
                 
                batch_loss = model_loss + beta * model.infor_loss# + unlabeled_loss
                trn_lstm_loss += batch_loss.item()
                batch_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip)
                optimizer.step()
     
                _, pred = pred.max(dim=1)
                trn_lstm_corrects += (pred == batch.label).sum().float()
                trn_lstm_size += len(pred)
                 
            for stidx in range(0, len(textlist2), args.batch_size):
                lstm_count_unlabeled += 1
                batch_unlabeled_pair = batch_from_list_unlabeled(textlist2[stidx:stidx + args.batch_size],
                                        labellist1[0:0 + args.batch_size])
                pred_batch_unlabeled_pair = model(batch_unlabeled_pair, 'train_unmask')
                pred_batch_unlabeled_pair_mask = model(batch_unlabeled_pair, 'train')
                  
                optimizer.zero_grad()
                  
                unlabeled_loss = torch.norm(pred_batch_unlabeled_pair - pred_batch_unlabeled_pair_mask)
                  
                batch_loss = unlabeled_loss * (1.0 * 300 / 2170) * (1.0 / 10.0)
                trn_lstm_loss_unlabeled += batch_loss.item()
                batch_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip)
                optimizer.step()
                
                
            for stidx in range(0, len(textlist2), args.batch_size):
                total_num_train_examples += args.batch_size
                if total_num_train_examples >= 500:
                    break
                lstm_count_unlabeled_reverse += 1
                batch_unlabeled_pair_reverse = batch_from_list_unlabeled(textlist2[stidx:stidx + args.batch_size],
                                        labellist1[0:0 + args.batch_size])
                pred_batch_unlabeled_pair_reverse = model(batch_unlabeled_pair_reverse, 'train_unmask')
                pred_batch_unlabeled_pair_mask_reverse = model(batch_unlabeled_pair_reverse, 'train_mask_important_words')
                
                optimizer.zero_grad()
                
                unlabeled_loss_reverse = torch.norm(pred_batch_unlabeled_pair_reverse - pred_batch_unlabeled_pair_mask_reverse)
                
                batch_loss = - unlabeled_loss_reverse * (1.0 * 300 / 2170) * (1.0 / 10.0)
                trn_lstm_loss_unlabeled_reverse += batch_loss.item()
                batch_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip)
                optimizer.step()
    
    
            dev_lstm_loss, dev_lstm_acc = evaluation(model, dev_text.copy(), dev_label.copy())
            #dev_lstm_loss, dev_lstm_acc = evaluation_dev(model, dev_text.copy(), dev_label.copy())
            if not best_val_acc or dev_lstm_acc > best_val_acc:
                args.save = "masklstm_20210828_" + str(epoch) + "iterations_" + str(dev_lstm_acc) + "_ac.pt"
                with open(args.save, 'wb') as f:
                    torch.save(model, f)
                best_val_acc = dev_lstm_acc
    
            train_lstm_acc = trn_lstm_corrects / trn_lstm_size
            train_lstm_loss = trn_lstm_loss / lstm_count
            train_lstm_loss_unlabeled = trn_lstm_loss_unlabeled / lstm_count_unlabeled
            train_lstm_loss_unlabeled_reverse = trn_lstm_loss_unlabeled_reverse / lstm_count_unlabeled_reverse
            print('local_epoch {} | train_lstm_loss {:.6f} | train_lstm_loss_unlabeled {:.6f} | train_lstm_loss_unlabeled_reverse {:.6f} | train_lstm_acc {:.6f} | dev_lstm_loss {:.6f} | '
                  'dev_lstm_acc {:.6f} | best_dev_acc {:.6f}'.format(epoch, train_lstm_loss, train_lstm_loss_unlabeled, train_lstm_loss_unlabeled_reverse, train_lstm_acc,
                                                                     dev_lstm_loss, dev_lstm_acc, best_val_acc))
    
            # annealing
            if epoch % 10 == 0:
                if beta > 0.01:
                    beta -= 0.099


    # load best model and test
    print("args.save: " + str(args.save))
    #if flag == "train":
    #    del model
    with open(args.save, 'rb') as f:
        model = torch.load(f)
    model.to(torch.device(args.device))
    #_, test_acc = evaluation(model, test_text.copy(), test_label.copy())
    _, test_acc, pred_total = evaluation_aopc(model, test_text.copy(), test_label.copy())
    print(pred_total) #1821 * 2
    #numpy_test_label = test_label.cpu().detach().numpy() #need to modify
    #print(numpy_test_label) #1821 * 2
    print('\nfinal_test_acc {:.6f}'.format(test_acc))


if __name__ == "__main__":
    random_seed()
    criterion = nn.CrossEntropyLoss()
    main()
