"""
"""

import os
import copy
import argparse

_ARG_PARSER = argparse.ArgumentParser(description="我的实验，需要指定配置文件")
_ARG_PARSER.add_argument('--yaml', '-y', type=str, default='ada-pg', help='configuration file path.')
_ARG_PARSER.add_argument('--cuda', '-c', type=str, default='0', help='gpu ids, like: 1,2,3')
_ARG_PARSER.add_argument('--test', '-t', type=bool, default=False, help='进行测试输出')
_ARG_PARSER.add_argument('--name', '-n', type=str, default=None, help='save name.')
_ARG_PARSER.add_argument('--seed', '-s', type=int, default=123, help='random seed')
_ARG_PARSER.add_argument('--all', '-a', type=bool, default=False, help='all seed?')
_ARG_PARSER.add_argument('--debug', '-d', default=False, action="store_true")
_ARG_PARSER.add_argument('--timing', type=int, default=None)
_ARG_PARSER.add_argument('--mode', type=int, default=None)
_ARG_PARSER.add_argument('--case', type=bool, default=False, help='case study')

_ARG_PARSER.add_argument('--adapter_size', type=int, default=None)
_ARG_PARSER.add_argument('--lstm_size', type=int, default=None)
_ARG_PARSER.add_argument('--num_adapters', type=int, default=None)
_ARGS = _ARG_PARSER.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = _ARGS.cuda

if _ARGS:
    import random

    import numpy as np
    import torch

    from transformers import BertTokenizer

    import mynlp
    from mynlp.common.config import load_yaml
    from mynlp.common.util import output, cache_path, load_cache, dump_cache
    from mynlp.common.writer import Writer
    from mynlp.core import Trainer, Vocabulary
    from mynlp.core.optim import build_optimizer

    from dataset import OpinionDataset, group_data, group_mixup
    from models import build_model
else:
    raise Exception('Argument error.')

SEEDS = (123, 456, 789, 686, 666, 233, 1024, 2080, 3080, 3090)

mynlp.core.trainer.EARLY_STOP_THRESHOLD = 5


def set_seed(seed: int = 123):
    output(f"Process id: {os.getpid()}, cuda: {_ARGS.cuda}, set seed {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_num_threads(4)  # CPU占用过高，但训练速度没快，还没找到问题所在


def case_study(pg, vocab, device):
    all = argparse.Namespace(**load_yaml("./dev/config/ada-all.yml"))
    mv = argparse.Namespace(**load_yaml("./dev/config/ada-vote.yml"))
    silver = argparse.Namespace(**load_yaml("./dev/config/ada.yml"))
    mix = argparse.Namespace(**load_yaml("./dev/config/ada-mix.yml"))
    lstm = argparse.Namespace(**load_yaml("./dev/config/lstm-crowd.yml"))
    cfgs = dict(silver=silver, vanila=pg, annmix=mix, lstm_c=lstm, all___=all, mv____=mv)
    for v in cfgs.values():
        v.model['allowed'] = allowed_transition(vocab)

    models = dict()
    for k, v in cfgs.items():
        models[k] = build_model(vocab=vocab, **v.model)
        models[k].load(v.trainer['pre_train_path'], device)
        models[k].to(device)
        models[k].eval()
        print(f"loaded <{v.trainer['pre_train_path']}>")

    def text_to_batch(text):
        input_ids = vocab.indices_of(list(text), 'words')
        input_ids = [101] + input_ids + [102]
        words = torch.tensor(input_ids, device=device).unsqueeze(0)
        mask = torch.ones_like(words)
        # lengths = torch.tensor([words.size(1)], device=device).unsqueeze(0)
        return dict(words=words, mask=mask, lengths=None)

    def format(tags, text):
        string = ""
        for la, t in zip(tags, text):
            if la in (4, 5):
                ft = f"\033[34m{t}\033[0m"
            elif la in (2, 3):
                ft = f"\033[31m{t}\033[0m"
            else:
                ft = t
            string += ft
        return string

    with torch.no_grad():
        text = input("\n测试句子: ")
        while text.strip():
            batch = text_to_batch(text)
            for k, m in models.items():
                pred = m.forward(**batch)['predicted_tags']
                print(k, ": ", format(pred[0][1:-1], text))
            else:
                text = input("\n测试句子: ")

    return models


def run_once(cfg, dataset, vocab, device, writer=None, seed=123):
    model = build_model(vocab=vocab, **cfg.model)
    setattr(model, 'seed', seed)
    para_num = sum([p.numel() for p in model.parameters()])
    output(f'param num: {para_num}, {para_num / 1000000:4f}M')
    model.to(device=device)

    optimizer = build_optimizer(model, **cfg.optim)
    trainer = Trainer(vars(cfg), dataset, vocab, model, optimizer, None, None,
                      writer, device, **cfg.trainer)

    if _ARGS.yaml in ('pg-ada', 'ada-pg-batch', 'ada-pgo'):
        group_data(dataset.train, by_ins=False)
    elif 'mix' in cfg.model['name']:
        mixup, ann = copy.deepcopy(dataset.train), copy.deepcopy(dataset.train)
        group_mixup(mixup)
        group_data(ann)
        setattr(dataset, "mixup", mixup)
        setattr(dataset, "ann", ann)
        setattr(dataset, "raw", dataset.train)

    if not _ARGS.test:
        # 训练过程
        trainer.train()
        output(model.metric.data_info)

    trainer.load()
    test_metric = trainer.test(dataset.test)
    return model.metric.best, test_metric


def allowed_transition(vocab):
    def idx(token: str) -> int:
        return vocab.index_of(token, 'labels')

    allowed = [
        (idx('O'), idx('O')),
        (idx('O'), idx('B-POS')),
        (idx('O'), idx('B-NEG')),
        (idx('B-POS'), idx('O')),
        (idx('B-POS'), idx('I-POS')),
        (idx('B-POS'), idx('B-NEG')),
        (idx('B-NEG'), idx('O')),
        (idx('B-NEG'), idx('I-NEG')),
        (idx('B-NEG'), idx('B-POS')),
        (idx('I-POS'), idx('O')),
        (idx('I-POS'), idx('I-POS')),
        (idx('I-POS'), idx('B-NEG')),
        (idx('I-NEG'), idx('O')),
        (idx('I-NEG'), idx('I-NEG')),
        (idx('I-NEG'), idx('B-POS')),
    ]
    return allowed


def main():
    cfg = argparse.Namespace(**load_yaml(f"./dev/config/{_ARGS.yaml}.yml"))

    device = torch.device("cuda:0")
    data_kwargs, vocab_kwargs = dict(cfg.data), dict(cfg.vocab)
    use_bert = 'bert' in cfg.model['word_embedding']['name_or_path']

    # 如果用了BERT，要加载tokenizer
    if use_bert:
        tokenizer = BertTokenizer.from_pretrained(
            cfg.model['word_embedding']['name_or_path'],
            do_lower_case=False)
        print("I'm batman!  ",
              tokenizer.tokenize("I'm batman!"))  # [CLS] [SEP]
        data_kwargs['tokenizer'] = tokenizer
        vocab_kwargs['oov_token'] = tokenizer.unk_token
        vocab_kwargs['padding_token'] = tokenizer.pad_token
    else:
        tokenizer = None

    cache_name = _ARGS.yaml
    if not os.path.exists(cache_path(cache_name)):
        cfg.data['data_dir'] = 'data/'
        dataset = argparse.Namespace(**OpinionDataset.build(**cfg.data))
        vocab = Vocabulary.from_data(dataset, **vocab_kwargs)
        vocab.set_field(['[PAD]', 'O', 'B-POS', 'I-POS', 'B-NEG', 'I-NEG'], 'labels')

        if use_bert:
            # 若用BERT，则把words词表替换为BERT的
            vocab.token_to_index['words'] = tokenizer.vocab
            vocab.index_to_token['words'] = tokenizer.ids_to_tokens
        dump_cache((dataset, vocab), cache_name)
    else:
        dataset, vocab = load_cache(cache_name)

    dataset.train.index_with(vocab)
    dataset.dev.index_with(vocab)
    dataset.test.index_with(vocab)

    cfg.model['allowed'] = allowed_transition(vocab)

    if _ARGS.case:
        set_seed(_ARGS.seed)
        case_study(cfg, vocab, device)
        return

    if _ARGS.timing is not None:
        cfg.model['timing'] = _ARGS.timing
    if _ARGS.mode is not None:
        cfg.model['mode'] = _ARGS.mode
        prefix = _ARGS.name if _ARGS.name else f"{_ARGS.yaml}-m{_ARGS.mode}"
    else:
        prefix = _ARGS.name if _ARGS.name else _ARGS.yaml

    if isinstance(_ARGS.lstm_size, int):
        cfg.model['lstm_size'] = _ARGS.lstm_size
        prefix += f'-l{_ARGS.lstm_size}'
    if isinstance(_ARGS.adapter_size, int):
        cfg.model['adapter_size'] = _ARGS.adapter_size
        prefix += f'-a{_ARGS.adapter_size}'
    if isinstance(_ARGS.num_adapters, int):
        cfg.model['num_adapters'] = _ARGS.num_adapters
        prefix += f'-n{_ARGS.num_adapters}'
    print(cfg.model)

    info = list()
    if _ARGS.debug:
        log_dir = None
        cfg.trainer['save_strategy'] = 'no'
    else:
        # log_dir = f"./dev/tblog/{prefix}"
        # if not os.path.exists(log_dir):
        #     os.mkdir(log_dir)
        log_dir = None

    seeds = SEEDS if _ARGS.all else [_ARGS.seed]
    for seed in seeds:
        print('\n')
        set_seed(seed)
        cfg.trainer['prefix'] = f"{prefix}_{seed}"
        if 'pre_train_path' not in cfg.trainer:
            cfg.trainer['pre_train_path'] = os.path.normpath(
                f"./dev/model/{cfg.trainer['prefix']}_best.pth")
        writer = Writer(log_dir, str(seed), 'tensorboard') if log_dir else None
        info.append(run_once(cfg, dataset, vocab, device, writer, seed))

    # print('\nAVG DEV: ', merge_dicts([i[0] for i in info], avg=True))
    # print('AVG TEST: ', merge_dicts([i[1] for i in info], avg=True))


if __name__ == "__main__":
    main()
