# -*- coding: utf-8 -*-
import torch
import numpy as np
from data.vocabulary import Vocabulary, LabelVocabulary
from data.dataset import NERDataset, ner_shuffle
from data.data_iterator import DataIterator
from utils.ner_param import HyperParam
from utils.common_utils import *
from metric.F1_metric import ExactMatch
from driver.ner_helper import NERHelper, Statistics
from model.base import BaseModel
from model.adapter import CWSPOSModel
from model.domain_adapter import DomainModel
from model.share_adapter import ShareModel
from optim import Optimizer
from optim.lr_scheduler import ReduceOnPlateauScheduler, NoamScheduler

import subprocess
import argparse
import random
import ntpath
import time
import os
import re


os.environ["OMP_NUM_THREADS"] = '10'


def set_seed(seed):
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    random.seed(seed)

    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True


def train(
    restore: bool = False,
    mode: str = 'transformer',
    gpu_use: int = -1,
    name: str = 'base',
    model_name: str = 'base'
) -> None:
    hp = HyperParam(mode=mode)
    hp._print_items()

    gpu = torch.cuda.is_available()
    print('begin with mode {}'.format(mode))
    print("GPU available: ", gpu)
    print("CuDNN: \n", torch.backends.cudnn.enabled)

    global_step = 0

    use_cuda = False
    if gpu and gpu_use >= 0:
        use_cuda = True
        torch.cuda.set_device(gpu_use)
        print("GPU ID: ", gpu_use)

    set_seed(1234)

    label_vocab = LabelVocabulary(hp.vocabulary_type, hp.label_vocab)

    if model_name == 'base':
        model = BaseModel(hp.bert_path, label_vocab, d_model=hp.bert_size, for_ner=True)
    elif model_name == 'adapter':
        model = DomainModel(hp.bert_path, label_vocab, d_model=hp.bert_size, trainsets_len=hp.trainsets_len)
    elif model_name == 'share':
        model = ShareModel(hp.bert_path, label_vocab, d_model=hp.bert_size, trainsets_len=hp.trainsets_len)
    else:
        raise NameError(f'no model named {model_name}')
    # model.init_model(param_path=hp.pos_path, device=gpu_use)

    if hp.shuffle is True:
        if model_name in ['share', 'base']:
            helper = NERHelper(model,
                    label_vocab,
                    hp,
                    use_cuda=use_cuda,
                    shuffle=ner_shuffle)
        else:
            helper = NERHelper(model,
                            label_vocab,
                            hp,
                            use_cuda=use_cuda,
                            shuffle=ner_shuffle)
        
    else:
        if model_name == 'share':
            helper = NERHelper(
                model,
                label_vocab,
                hp,
                use_cuda=use_cuda,
                shuffle=None
            )
        else:
            helper = NERHelper(
                model,
                label_vocab,
                hp,
                use_cuda=use_cuda,
                shuffle=None
            )

    optim = Optimizer(name=hp.optim,
                      model=model,
                      lr=hp.lr,
                      grad_clip=-1.0,
                      optim_args=None)

    if hp.schedule_method == 'noam':
        scheduler = NoamScheduler(optimizer=optim,
                                  d_model=512,
                                  warmup_steps=hp.warmup_steps)
    else:
        scheduler = None

    print('begin training:')

    if not os.path.exists('./save/ner/' + name):
        os.mkdir('./save/ner/' + name)

    best_f1 = -1
    best_wF = -1
    bert_test_f1 = -1
    best_test_step = -1
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(
        os.path.join('./save/ner/' + name, name)),
                             num_max_keeping=20)

    if restore is True:
        checkpoint_saver.load_latest(device=gpu_use, model=helper.model)
        print('restore successful')

    f1, precision, recall = evaluate(helper, hp, global_step, name, label_vocab, is_test=True)

    print('Precision: %.4f, Recall: %.4f, F1: %.4f' % (precision, recall, f1))


def evaluate(helper: NERHelper, hp: HyperParam, global_step, name, label_vocab, is_test=False):
    batch_size = hp.batch_size

    if is_test is False:
        dev_dataset = NERDataset(data_paths=[hp.dev_data], max_len=500)
    else:
        print('------------- doing test -----------')
        dev_dataset = NERDataset(data_paths=[hp.test_data], max_len=500)

    dev_iterator = DataIterator(
        dataset=dev_dataset,
        batch_size=10,
        use_bucket=True,
        buffer_size=100,
        numbering=True
    )

    helper.model.eval()
    metric = ExactMatch(label_vocab.token2id('O'), label_vocab._token2id_feq)

    numbers = []
    trans = []

    dev_iter = dev_iterator.build_generator(batch_size=batch_size)
    for batch in dev_iter:
        seq_nums, seqs, labels, pos = batch
        numbers += seq_nums

        sub_trans, labels_out, labels_all, length = helper.translate_batch(seqs, labels, pos, is_test=is_test)
        trans += sub_trans
        metric(labels_out, labels_all, length)

    origin_order = np.argsort(numbers).tolist()
    trans = [trans[ii] for ii in origin_order]

    if is_test is False:
        head, tail = ntpath.split(hp.dev_data)
        hyp_path = os.path.join('./save/evaluation/ner/' + name + '-' +
                                tail + "." + str(global_step))
        # hyp_path = data_configs['valid_data'][0] + "." + str(global_step)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            for line in trans:
                f.write('%s\n' % re.sub('@@ ', '', line))

    res_dic = metric.get_metric()

    return res_dic['F1'], res_dic['precision'], res_dic['recall']


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--restore',
                        default='False',
                        action='store_true',
                        help="to restore the last ckpt.")
    parser.add_argument('--trans_flow',
                        default=None,
                        type=str,
                        help='the flow of multi_seq2seq')
    parser.add_argument('--mode',
                        default='transformer',
                        type=str,
                        help='the flow of multi_seq2seq')
    parser.add_argument('--GPU',
                        '-g',
                        default=0,
                        type=int,
                        help='choose the gpu to use')
    parser.add_argument('--name',
                        '-n',
                        default='defalut',
                        type=str,
                        help='the name of model')
    parser.add_argument('--model',
                        '-m',
                        default='base',
                        type=str,
                        help='choose the model to use')

    args = parser.parse_args()

    train(restore=args.restore,
          mode=args.mode,
          gpu_use=args.GPU,
          name=args.name,
          model_name=args.model)
