# -*- coding: utf-8 -*-
import torch
import numpy as np
from data.vocabulary import Vocabulary, LabelVocabulary
from data.dataset import NERDataset, ner_shuffle
from data.data_iterator import DataIterator
from utils.ner_param import HyperParam
from utils.common_utils import *
from metric.F1_metric import ExactMatch, f1_score
from driver.ner_helper import NERHelper, Statistics
from model.base import BaseModel
from model.SE_base import SEModel
from model.adapter import CWSPOSModel
from model.domain_adapter import DomainModel
from model.share_adapter import ShareModel
from model.lebert import LexiconModel
from optim import Optimizer
from optim.lr_scheduler import ReduceOnPlateauScheduler, NoamScheduler

import subprocess
import argparse
import random
import ntpath
import time
import os
import re


os.environ["OMP_NUM_THREADS"] = '10'


def set_seed(seed):
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    random.seed(seed)

    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True


def train(restore=False,
          mode='transformer',
          gpu_use=-1,
          name='base',
          model_name='base'):
    hp = HyperParam(mode=mode)
    hp._print_items()

    gpu = torch.cuda.is_available()
    print('begin with mode {}'.format(mode))
    print("GPU available: ", gpu)
    print("CuDNN: \n", torch.backends.cudnn.enabled)

    global_step = 0

    use_cuda = False
    if gpu and gpu_use >= 0:
        use_cuda = True
        torch.cuda.set_device(gpu_use)
        print("GPU ID: ", gpu_use)

    set_seed(1234)

    label_vocab = LabelVocabulary(hp.vocabulary_type, hp.label_vocab)

    if model_name == 'base':
        model = BaseModel(hp.bert_path, label_vocab, d_model=hp.bert_size, for_ner=True, use_feature=hp.use_feature)
    elif model_name == 'se':
        model = SEModel(hp.bert_path, label_vocab, d_model=hp.bert_size, for_ner=True, use_feature=hp.use_feature)
    elif model_name == 'lexicon':
        model = LexiconModel(hp.bert_path, label_vocab, d_model=hp.bert_size, use_feature=hp.use_feature)
    elif model_name == 'adapter':
        model = DomainModel(hp.bert_path, label_vocab, d_model=hp.bert_size, trainsets_len=hp.trainsets_len)
    elif model_name == 'share':
        model = ShareModel(hp.bert_path, label_vocab, d_model=hp.bert_size, trainsets_len=hp.trainsets_len)
    else:
        raise NameError(f'no model named {model_name}')
    # model.init_model(param_path=hp.pos_path, device=gpu_use)

    if hp.shuffle is True:
        if model_name in ['share', 'base', 'se', 'lexicon']:
            helper = NERHelper(model,
                    label_vocab,
                    hp,
                    use_cuda=use_cuda,
                    shuffle=ner_shuffle)
        else:
            helper = NERHelper(model,
                            label_vocab,
                            hp,
                            use_cuda=use_cuda,
                            shuffle=ner_shuffle)
        
    else:
        if model_name in ['share', 'base', 'se', 'lexicon']:
            helper = NERHelper(model,
                    label_vocab,
                    hp,
                    use_cuda=use_cuda,
                    shuffle=None)
        else:
            helper = NERHelper(model,
                            label_vocab,
                            hp,
                            use_cuda=use_cuda,
                            shuffle=None)

    optim = Optimizer(name=hp.optim,
                      model=model,
                      lr=hp.lr,
                      grad_clip=-1.0,
                      optim_args=None)

    if hp.schedule_method == 'noam':
        scheduler = NoamScheduler(optimizer=optim,
                                  d_model=512,
                                  warmup_steps=hp.warmup_steps)
    else:
        scheduler = None

    print('begin training:')

    if not os.path.exists('./save/ner/' + name):
        os.mkdir('./save/ner/' + name)

    best_f1 = -1
    best_wF = -1
    bert_test_f1 = -1
    best_test_step = -1
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(
        os.path.join('./save/ner/' + name, name)),
                             num_max_keeping=20)

    if restore is True:
        checkpoint_saver.load_latest(device=gpu_use, model=model)
        print('restore successful')

    if hp.general_path is not None:
        model._load_param(hp.general_path, gpu_use)

    for epoch in range(hp.epoch_num):
        if model_name in ['share', 'base', 'se', 'lexicon']:
            total_stats = Statistics()
        else:
            total_stats = Statistics()
        training_iter = helper.training_iterator.build_generator()
        batch_iter, total_iters = 0, len(helper.training_iterator)

        for batch in training_iter:
            global_step += 1
            if hp.schedule_method is not None \
                    and hp.schedule_method != "loss":
                scheduler.step(global_step=global_step)

            seqs, label, pos = batch
            # print(seqs)
            # print(label)
            # exit()

            n_samples_t = len(seqs)
            batch_iter += n_samples_t
            n_words_t = sum(len(s) for s in seqs)

            lrate = list(optim.get_lrate())[0]
            optim.zero_grad()

            try:
                for seqs_txt_t, seqs_label_t, pos_t in split_shard(seqs, label, pos, split_size=hp.update_cycle):
                    stat = helper.train_batch(
                        seqs_txt_t,
                        seqs_label_t,
                        pos_t,
                        n_samples_t,
                        global_step=global_step,
                        finetune=hp.finetune
                    )
                    total_stats.update(stat)

                total_stats.print_out(global_step - 1, epoch, batch_iter,
                                      total_iters, lrate, n_words_t, best_wF, best_f1)
                optim.step()
            except RuntimeError as e:
                print('seqs_txt_t is:{}\nshape is: {}'.format(
                    seqs_txt_t, np.shape(seqs_txt_t)))
                print('seqs_label_t is:{}\nshape is: {}'.format(
                    seqs_label_t, np.shape(seqs_label_t)))
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    optim.zero_grad()
                elif 'cuda runtime error' in str(e):
                    print(e)
                    print('| WARNING: unknow cuda error:{}, skipping batch'.
                          format(str(e)))
                    optim.zero_grad()
                    raise e
                elif 'CUDA error' in str(e):
                    print(e)
                    print('| WARNING: unknow cuda error:{}, skipping batch'.
                          format(str(e)))
                    optim.zero_grad()
                    raise e
                raise e

            if global_step % hp.valid_freq == 0:
                dev_start_time = time.time()
                f1, precision, recall = evaluate(helper, hp, global_step, name, label_vocab)
                test_f1, test_precision, test_recall = evaluate(helper, hp, global_step, name, label_vocab, is_test=True)
                during_time = float(time.time() - dev_start_time)
                print("step %d, epoch %d: dev ner p: %.4f, r: %.4f, f1: %.4f, time %.2f" % (global_step, epoch, precision, recall, f1, during_time))
                print("step %d, epoch %d: test ner p: %.4f, r: %.4f, f1: %.4f, time %.2f" % (global_step, epoch, test_precision, test_recall, test_f1, best_test_step))

                if test_f1 > bert_test_f1:
                    bert_test_f1 = test_f1
                    best_test_step = global_step
                    print("exceed best test f1: history = %.2f, current = %.2f, at = %d" % (bert_test_f1, test_f1, global_step))

                if f1 > best_f1: # - 0.0005:
                    print("exceed best dev f1: history = %.2f, current = %.2f, lr_ratio = %.6f" % (best_f1, f1, lrate))
                    best_f1 = f1
                    checkpoint_saver.save(
                        global_step=global_step,
                        model=helper.model,
                        optim=optim,
                        lr_scheduler=scheduler
                    )


def evaluate(helper: NERHelper, hp: HyperParam, global_step, name, label_vocab, is_test=False):
    batch_size = hp.batch_size

    if is_test is False:
        dev_dataset = NERDataset(data_paths=[hp.dev_data], max_len=500)
    else:
        print('------------- doing test -----------')
        dev_dataset = NERDataset(data_paths=[hp.test_data], max_len=500)

    dev_iterator = DataIterator(dataset=dev_dataset,
                                batch_size=10,
                                use_bucket=True,
                                buffer_size=100,
                                numbering=True)

    helper.model.eval()
    metric = ExactMatch(label_vocab.token2id('O'), label_vocab._token2id_feq)

    numbers = []
    trans = []

    dev_iter = dev_iterator.build_generator(batch_size=batch_size)
    for batch in dev_iter:
        seq_nums, seqs, labels, pos = batch
        numbers += seq_nums

        sub_trans, labels_out, labels_all, length = helper.translate_batch(seqs, labels, pos, is_test=is_test)
        trans += sub_trans
        metric(labels_out, labels_all, length)

    origin_order = np.argsort(numbers).tolist()
    trans = [trans[ii] for ii in origin_order]

    if is_test is False:
        head, tail = ntpath.split(hp.dev_data)
        hyp_path = os.path.join('./save/ner/' + name + '/' +
                                tail + "." + str(global_step))
        # hyp_path = data_configs['valid_data'][0] + "." + str(global_step)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            for line in trans:
                f.write('%s\n' % re.sub('@@ ', '', line))

    res_dic = metric.get_metric()

    return res_dic['F1'], res_dic['precision'], res_dic['recall']


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--restore',
                        default='False',
                        action='store_true',
                        help="to restore the last ckpt.")
    parser.add_argument('--trans_flow',
                        default=None,
                        type=str,
                        help='the flow of multi_seq2seq')
    parser.add_argument('--mode',
                        default='transformer',
                        type=str,
                        help='the flow of multi_seq2seq')
    parser.add_argument('--GPU',
                        '-g',
                        default=0,
                        type=int,
                        help='choose the gpu to use')
    parser.add_argument('--name',
                        '-n',
                        default='defalut',
                        type=str,
                        help='the name of model')
    parser.add_argument('--model',
                        '-m',
                        default='base',
                        type=str,
                        help='choose the model to use')

    args = parser.parse_args()

    train(restore=args.restore,
          mode=args.mode,
          gpu_use=args.GPU,
          name=args.name,
          model_name=args.model)
