import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from importlib import import_module
from tqdm import tqdm
from data_helper import DataHelper
import json

# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

parser = argparse.ArgumentParser()
parser.add_argument('--max_epochs', type=int, default=60)
parser.add_argument('--max_steps', type=int, default=10000000)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--dim_embd', type=int, default=300)
parser.add_argument('--dim_model', type=int, default=300)
parser.add_argument('--dim_inner', type=int, default=1024)
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--seed', type=int, default=666)
parser.add_argument('--num_class',type=int,default=6)
parser.add_argument('--l2_reg', type=float, default=0.0000001)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--norm', type=float, default=1e-5)

parser.add_argument('--embedding_pretrained', default=True, type=bool)
parser.add_argument('--num_head', default=4, type=int)
parser.add_argument('--num_encoder', default=2, type=int)
parser.add_argument('--max_length', default=-1, type=int)
parser.add_argument('--freeze', default=False, type=bool)

parser.add_argument('--n_centroid', default=None, type=int)
parser.add_argument('--commitment1', default=0, type=float)
parser.add_argument('--commitment2', default=0, type=float)

parser.add_argument('--modelName', type=str, default='Clusterformer', choices=['Transformer', 'Clusterformer', 'Reformer', 'Routing'])
parser.add_argument('--dataName', type=str, default='20NEWS')

args = parser.parse_args()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

def get_param_numbers(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.device = device
    if args.dataName == "TREC":
        vocab, sequence_max_length, train_len, text_len = 9593, 37, 5452, 500
        args.save="checkpoints/TREC/{}".format(args.modelName)
        args.data_file = "data/TREC/"
        args.num_class = 6
    elif args.dataName == "SUBJ":
        vocab, sequence_max_length, train_len, text_len = 21323, 120, 9000, 1000
        args.save="checkpoints/SUBJ/{}".format(args.modelName)
        args.data_file = "data/SUBJ/"
        args.num_class = 2
    elif args.dataName == "CR":
        vocab, sequence_max_length, train_len, text_len = 5340, 105, 3393, 377
        args.save="checkpoints/CR/{}".format(args.modelName)
        args.data_file = "data/CR/"
        args.num_class = 2
    elif args.dataName == "MPQA":
        vocab, sequence_max_length, train_len, text_len = 6247, 36, 9540, 1063
        args.save="checkpoints/MPQA/{}".format(args.modelName)
        args.data_file = "data/MPQA/"
        args.num_class = 2
    elif args.dataName == "MR":
        vocab, sequence_max_length, train_len, text_len = 18765, 56, 9594, 1068
        args.save="checkpoints/MR/{}".format(args.modelName)
        args.data_file = "data/MR/"
        args.num_class = 2
    elif args.dataName == "20NEWS":
        vocab, sequence_max_length, train_len, text_len = 100177, 400, 11314, 7532
        args.save="checkpoints/20NEWS/{}".format(args.modelName)
        args.data_file = "data/20NEWS/"
        args.num_class = 20
    else:
        raise NotImplementedError

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    args.num_vocab = vocab
    
    if args.max_length <= 0:
        args.max_length = sequence_max_length

    if args.n_centroid is None:
        args.n_centroid = int(args.max_length ** 0.5)

    if args.modelName in ['Clusterformer', 'Routing']:
        args.max_length = args.n_centroid - args.max_length % args.n_centroid + args.max_length
    if args.modelName in ['Reformer']:
        args.max_length = args.n_centroid * 2 - args.max_length % (args.n_centroid * 2) + args.max_length

    print(args.modelName, args.dataName, args.max_length, args.batch_size, args.max_steps)

    batch_size = args.batch_size
    max_epochs = args.max_epochs
    database_path = args.data_file
    with open('{}wordmap.json'.format(database_path), 'r') as j:
        vocab_dict = json.load(j)

    data_helper = DataHelper(vocab_dict, sequence_max_length=args.max_length)

    x = import_module('models.' + args.modelName)
    model = x.Model(args).to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    train_loss_all = []

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)#
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_reg)
    epoch_time_all = []
    train_loss_all = []
    accuracy_all = []
    
    # print('Trained model has {} parameters'.format(get_param_numbers(model)))
    max_acc = 0
    best_epoch = 0

    forward_times = []
    backward_times = []

    for epoch in range(max_epochs):
        model.train()
        # print('epoch', epoch+1)

        train_losses = []
        t_start = time.time()

        train_data, train_label, test_data, test_label = data_helper.load_dataset(database_path,train_len,text_len)
        train_batches = data_helper.batch_iter(np.column_stack((train_data, train_label)), batch_size, max_epochs)

        train_size = train_data.shape[0] // batch_size
        test_size = test_data.shape[0] // batch_size
        train_accuracy = 0
        for j, batch in enumerate(train_batches):

            train_data_b,label = batch
            train_data_b = torch.from_numpy(train_data_b).to(device)
            label = torch.from_numpy(label).squeeze().to(device)
            model.zero_grad()

            s_time = time.time()
            output = model(train_data_b)
            forward_times.append(time.time() - s_time)

            s_time = time.time()
            logits = output[0]
            loss = criterion(logits, label)

            if len(output) > 1:
                cc_loss = output[1]
                if args.modelName in ['Clusterformer']:
                    # print(cc_loss[-1][0].item(), cc_loss[-1][1].item(), loss.item())
                    loss1 = cc_loss[-1][0] * args.commitment1
                    loss2 = cc_loss[-1][1] * args.commitment2
                    loss = loss + loss1 + loss2
                elif args.modelName == 'Routing':
                    # print(cc_loss[-1].item(), loss.item())
                    loss = loss + cc_loss[-1] * args.commitment1
                else:
                    raise NotImplementedError

            loss.backward()
            optimizer.step()
            backward_times.append(time.time() - s_time)

            if len(backward_times) == args.max_steps:
                aver_forward_time = sum(forward_times) / args.max_steps * 1000
                aver_backward_time = sum(backward_times) / args.max_steps * 1000
                aver_all_time = (sum(forward_times) + sum(backward_times)) / args.max_steps * 1000
                print("forward time %.1f \t backward time %.1f \t all time %.1f \n" \
                      % (aver_forward_time, aver_backward_time, aver_all_time))
                exit()

            train_losses += [loss.item()]
            train_accuracy += torch.sum((torch.argmax(logits, dim=1) == label).long()).item()

        epoch_time_all.append(time.time() - t_start)
        train_loss_all.append(train_losses)
        train_accuracy = train_accuracy * 100 / train_data.shape[0]
        # print("Train loss: {:.4f}      Train accuracy: {}%".format(np.mean(train_losses), np.round(train_accuracy, 1)))

        test_loader = data_helper.batch_iter(np.column_stack((test_data, test_label)), batch_size, max_epochs, shuffle=False)

        model.eval()
        # print("Testing...")
        num_items = 0
        accuracy = 0

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                data = torch.from_numpy(data).to(device)
                target = torch.from_numpy(target).squeeze().to(device)
                logits = model(data)[0]
                accuracy += torch.sum((torch.argmax(logits, dim=1) == target).long()).item()
        accuracy = accuracy * 100 / test_data.shape[0]
        accuracy_all.append(np.round(accuracy, 1))
        if accuracy > max_acc:
            max_acc = accuracy
            best_epoch = epoch + 1
            torch.save({'state_dict': model.state_dict()}, os.path.join(args.save, 'model.pth'))

            torch.save({'accuracy': accuracy_all,
                        'train_loss': train_loss_all,
                        'epoch_time': epoch_time_all}, os.path.join(args.save, 'log.pkl'))
        # print("Test Accuracy: {}%      Best Accuracy: {}%      Best epoch :{}".format(np.round(accuracy, 1), np.round(max_acc,1), best_epoch))
        # print()

    print("Best Accuracy: {}%      Best epoch :{}".format(np.round(max_acc, 1), best_epoch))