# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from torch.optim import AdamW


def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name:
            if len(w.size()) < 2:
                continue
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass


def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()

    model.train()

    start_epoch = 0
    total_batch = 0
    dev_best_macro_f1 = 0
    last_improve = 0
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and p.requires_grad],
             'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad],
             'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters,
                          lr=config.learning_rate,
                          weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

    flag = False
    model.train()
    for epoch in range(start_epoch, config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))

        for i, (trains, labels) in enumerate(train_iter):
            outputs = model(trains)
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)

            loss.backward()
            optimizer.step()
            scheduler.step()

            if total_batch % 20 == 0:
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_macro_f1, micro_f1, dev_loss = evaluate(config, model, dev_iter)
                if dev_macro_f1 > dev_best_macro_f1:
                    dev_best_macro_f1 = dev_macro_f1
                    improve = '*'
                    last_improve = total_batch
                    checkpoint = {
                        'epoch': epoch,
                        'total_batch': total_batch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': loss,
                        'dev_best_macro_f1': dev_best_macro_f1,
                        'last_improve': last_improve,

                    }

                    torch.save(checkpoint, config.save_path)

                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = (
                    'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%}, '
                    'Val Macro F1: {5:>6.2%}, Time: {6} {7} ')
                print(
                    msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, dev_macro_f1,
                               time_dif, improve,
                               ))
                model.train()

            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break

        if flag:
            break
    test(config, model, test_iter)


def test(config, model, test_iter):

    checkpoint = torch.load(config.save_path)

    model_state_dict = checkpoint['model_state_dict']

    model.load_state_dict(model_state_dict)
    model.eval()
    start_time = time.time()

    test_acc, test_macro_f1, test_micro_f1, test_loss, test_report, test_confusion= evaluate(config, model,
                                                                                                      test_iter,
                                                                                                      test=True)

    msg = 'Test Loss: {0:>5.2f}, Test Acc: {1:>6.2%}, Test Macro F1: {2:>6.2%}, Test Micro F1: {3:>6.2%}'
    print(msg.format(test_loss, test_acc, test_macro_f1, test_micro_f1))

    print("Precision, Recall and F1-Score...")
    print(test_report)

    print("Confusion Matrix...")
    print(test_confusion)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)

    with torch.no_grad():
        for texts, labels in data_iter:
            outputs= model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    macro_f1 = metrics.f1_score(labels_all, predict_all, average='macro')
    micro_f1 = metrics.f1_score(labels_all, predict_all, average='micro')

    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, macro_f1, micro_f1, loss_total / len(data_iter), report, confusion

    return acc, macro_f1, micro_f1, loss_total / len(data_iter)
