# -*- encoding:utf8 -*-
from init import (
    MyIter,
    NerDataLoader,
    NerModel,
    Evaluater,
)
from transformers import (
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
from tokenizers import BertWordPieceTokenizer
from torch.cuda.amp import autocast as autocast
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from common import seed_everything
import torch
import torch.nn as nn
import os
import argparse

seed = 42
seed_everything(seed=seed)

device_ids = [0, 1, 2, 3]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))


def get_parseargs():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--max_length", type=int, default=510)
    parser.add_argument("--data_size", type=int, default=1)
    parser.add_argument("--hidden_size", type=int, default=1024)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=3e-5)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--hidden_dropout_prob", type=float, default=0.1)
    parser.add_argument("--warmup_proportion", type=float, default=1e-1)
    parser.add_argument("--accumulation_steps", type=int, default=2)
    parser.add_argument(
        "--language",
        type=str,
        default="en_roberta",
        choices=["zh", "en_roberta", "en_bert"],
    )
    parser.add_argument("--data_name", type=str, default="weibo")
    parser.add_argument("--pre_model", type=str, default="bert-base-cased")
    parser.add_argument(
        "--model_type",
        type=str,
        default="roberta",
        choices=["roberta", "bert"],
    )

    return parser.parse_args()


def bert_softmax(args):
    epochs = args.epochs
    max_length = args.max_length
    data_size = args.data_size
    hidden_size = args.hidden_size
    batch_size = args.batch_size
    bi = data_size * 100
    lr = args.lr
    weight_decay = args.weight_decay
    hidden_dropout_prob = args.hidden_dropout_prob
    warmup_proportion = args.warmup_proportion
    accumulation_steps = args.accumulation_steps
    language = args.language
    data_name = args.data_name
    pre_model = args.pre_model
    model_type = args.model_type

    assert language in ["zh", "en_roberta", "en_bert"]

    # data_dir = os.path.join('.','data','person_data','trainDataBIO.txt')
    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    vec_dir = os.path.join(".", "models", "pretrained_models", pre_model)
    dev_model = os.path.join(
        ".",
        "models",
        "bertNer",
        "{}_model".format(data_name),
        "bertner_768_37_1_90.pkl",
    )

    if language == "zh" or language == "en_roberta":
        """zh"""
        tokenizer = AutoTokenizer.from_pretrained(vec_dir)
    else:
        """en_bert tokenzier"""
        tokenizer = BertWordPieceTokenizer(
            os.path.join(vec_dir, "vocab.txt"), lowercase=False
        )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    model = NerModel(
        vec_dir=vec_dir,
        hidden_size=hidden_size,
        bi=bi,
        label_size=len(index_to_label),
        hidden_dropout_prob=hidden_dropout_prob,
    )

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = len(dataloader) * epochs
    warmup_steps = int(warmup_proportion * total_steps)
    scheduler = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    loss = nn.CrossEntropyLoss()

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "dev.ner.bio")
    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, dev_model, max_length=max_length)

    best_f1, best_epoch = 0, 0
    for e in MyIter(range(epochs)):
        print("epoch:{}".format(e))
        Loss = 0
        for index, (text, label, line) in enumerate(dataloader):
            text, label = text.cuda(), label.cuda()
            with autocast():
                attention_mask = (text != 0).long()
                loss_mask = attention_mask.view(-1).bool()
                t = model(text, label)
                t = t.reshape(-1, t.size()[-1])
                # t = t.reshape(-1, t.size(-1))
                active_text = t.masked_select(loss_mask.unsqueeze(1)).view(
                    -1, t.size(-1)
                )
                active_label = label.view(-1).masked_select(loss_mask)
                l = loss(active_text, active_label) / accumulation_steps
                l.backward()
            Loss = Loss + l

            if (index + 1) % accumulation_steps == 0:
                opt.step()
                opt.zero_grad(True)
            scheduler.step()

            if index % (int(len(dataloader) / 10)) == 0:
                print("loss:{:.5f}".format(Loss))
                Loss = 0

        evaluate.model = model
        dev_f1 = evaluate.evaluate_ner()

        if best_f1 <= dev_f1:
            best_f1 = dev_f1
            best_epoch = e
            save_dir = os.path.join(
                ".", "models", "{}Ner".format(model_type), "{}_model".format(data_name)
            )

            origin_model = model
            if torch.cuda.device_count() > 1:
                origin_model = model.module
            origin_model.save(model_type, save_dir, origin_model)

    print("best epoch: {} f1_score: {:.4f}".format(best_epoch, best_f1))


def bert_softmax_lowersource(args):
    args = get_parseargs()
    epochs = args.epochs
    max_length = args.max_length
    data_size = args.data_size
    hidden_size = args.hidden_size
    batch_size = args.batch_size
    bi = data_size * 100
    lr = args.lr
    weight_decay = args.weight_decay
    hidden_dropout_prob = args.hidden_dropout_prob
    warmup_proportion = args.warmup_proportion
    accumulation_steps = args.accumulation_steps
    language = args.language
    data_name = args.data_name
    pre_model = args.pre_model
    model_type = args.model_type

    assert language in ["zh", "en_roberta", "en_bert"]

    # data_dir = os.path.join('.','data','person_data','trainDataBIO.txt')
    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    vec_dir = os.path.join(".", "models", "pretrained_models", pre_model)
    dev_model = os.path.join(
        ".",
        "models",
        "bertNer",
        "{}_model".format(data_name),
        "bertner_768_37_1_90.pkl",
    )

    if language == "zh" or language == "en_roberta":
        """zh"""
        tokenizer = AutoTokenizer.from_pretrained(vec_dir)
    else:
        """en_bert tokenzier"""
        tokenizer = BertWordPieceTokenizer(
            os.path.join(vec_dir, "vocab.txt"), lowercase=False
        )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    _, label_to_index, index_to_label = dataloader.get_soups()
    print(label_to_index, len(index_to_label))

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    model = NerModel(
        vec_dir=vec_dir,
        hidden_size=hidden_size,
        bi=bi,
        label_size=len(index_to_label),
        hidden_dropout_prob=hidden_dropout_prob,
    )

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = len(dataloader) * epochs
    warmup_steps = int(warmup_proportion * total_steps)
    scheduler = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    loss = nn.CrossEntropyLoss()

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "dev.ner.bio")
    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, dev_model, max_length=max_length)

    best_f1, best_epoch = 0, 0
    for e in MyIter(range(epochs)):
        print("epoch:{}".format(e))
        Loss = 0
        for index, (text, label, line) in enumerate(dataloader):
            text, label = text.cuda(), label.cuda()
            with autocast():
                """origin train"""
                # t = model(text, label)
                # l = loss(t.reshape(-1,t.size()[-1]), label.reshape(-1))
                # l = l / accumulation_steps
                """mask train"""
                attention_mask = (text != 0).long()
                loss_mask = attention_mask.view(-1).bool()  # 创建用于损失计算的掩码
                t = model(text, label)
                t = t.reshape(-1, t.size()[-1])
                # t = t.reshape(-1, t.size(-1))
                active_text = t.masked_select(loss_mask.unsqueeze(1)).view(
                    -1, t.size(-1)
                )
                active_label = label.view(-1).masked_select(loss_mask)
                l = loss(active_text, active_label) / accumulation_steps
                l.backward()
            Loss = Loss + l

            if (index + 1) % accumulation_steps == 0:
                opt.step()
                opt.zero_grad(True)
            scheduler.step()

            if index % (int(len(dataloader) / 10)) == 0:
                print("loss:{:.5f}".format(Loss))
                Loss = 0

        evaluate.model = model
        dev_f1 = evaluate.evaluate_ner()

        if best_f1 <= dev_f1:
            best_f1 = dev_f1
            best_epoch = e
            save_dir = os.path.join(
                ".", "models", "{}Ner".format(model_type), "{}_model".format(data_name)
            )

            origin_model = model
            if torch.cuda.device_count() > 1:
                origin_model = model.module
            origin_model.save(model_type, save_dir, origin_model)

    print("best epoch: {} f1_score: {:.4f}".format(best_epoch, best_f1))


if __name__ == "__main__":
    args = get_parseargs()
    bert_softmax(args)
    # bert_softmax_lowersource(args)
