# -*- encoding:utf8 -*-
from init import (
    MyIter,
    NerDataLoader,
    KnnNer,
    CharWordKnnNer,
    Evaluater,
)
from transformers import AutoModel, AutoTokenizer
from tokenizers import BertWordPieceTokenizer
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from common import seed_everything
import numpy as np
import dill as pickle
import torch
import os
import time

seed = 42
seed_everything(seed=seed)
"""gpu"""
# CUDA_VISIBLE_DEVICES='1' python classes/ner_tester.py
device_ids = [1, 2, 3]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))


def evaluate_bert_softmax():
    max_length = 510
    batch_size = 16
    language = "en_bert"
    data_name = "en_conll03"
    assert language in ["zh", "en_roberta", "en_bert"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    vec_dir = os.path.join(".", "models", "pretrained_models", "bert-base-cased")
    dev_model = os.path.join(
        ".",
        "models",
        "bertNer",
        "{}_model".format(data_name),
        "bertner_768_9_1_50.pkl",
    )

    if language == "zh" or language == "en_roberta":
        """zh, en_robert tokenizer"""
        tokenizer = AutoTokenizer.from_pretrained(vec_dir)
    else:
        """en_bert tokenzier"""
        tokenizer = BertWordPieceTokenizer(
            os.path.join(vec_dir, "vocab.txt"), lowercase=False
        )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")
    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )

    evaluate = Evaluater(dev_dataloader, dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()


def evaluate_bert_softmax_knn():
    max_length = 510
    batch_size = 16
    language = "zh"
    data_size = 1  # size of the datastore
    data_name = "weibo"

    k = 512
    t = 0.1526
    λ = 0.0027

    assert language in ["zh", "en_roberta", "en_bert"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    word_dir = os.path.join(
        ".", "data", "ner_data", data_name, "testword.ner.bio"
    )  # testword.ner.bio: hanlp

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")

    vec_dir = os.path.join(".", "models", "pretrained_models", "bert-base-chinese")
    dev_model = os.path.join(
        ".",
        "models",
        "bertNer",
        "{}_model".format(data_name),
        "bertner_768_17_1_100.pkl",
    )

    if language == "zh" or language == "en_roberta":
        """zh,en_robert tokenizer"""
        tokenizer = AutoTokenizer.from_pretrained(vec_dir)
    else:
        """en_bert tokenzier"""
        tokenizer = BertWordPieceTokenizer(
            os.path.join(vec_dir, "vocab.txt"), lowercase=False
        )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    dataloader.load()
    label_to_index, index_to_label = (
        dataloader.label_to_index,
        dataloader.index_to_label,
    )
    dataloader.tokenize(dev=True)
    # _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    knn = CharWordKnnNer(
        tokenizer,
        model,
        dataloader=dataloader,
        label_to_index=label_to_index,
        index_to_label=index_to_label,
        k=k,
        t=t,
        λ=λ,
        language=language,
    )
    knn.get_entitys()
    knn.load_word(dev_dir=dev_dir, word_dir=word_dir)
    print(knn.entitys.shape, knn.entitys_label.shape)

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, model_dir=dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()
    evaluate.evaluate_knnner(knn)

    # for λ in np.linspace(0.01, 0.95, 30):
    #     for t in np.linspace(0.01, 0.95, 30):
    #         knn.t = t
    #         knn.λ = λ
    #         print('t:{:.4f} λ: {:.4f}'.format(t, λ))
    #         evaluate.evaluate_knnner(knn)

    knn.λ = 0.0027
    for t in MyIter(np.linspace(0.0001, 1, 60)):
        knn.t = t
        print("t: {:.4f}".format(t))
        dev_f1 = evaluate.evaluate_knnner(knn)
        # with open(os.path.join('.', 'ablation.txt'), 'a+', encoding='utf-8') as f:
        #     f.write('t:{:.4f} λ: {:.4f} f1: {:.4f}\n'.format(knn.t, knn.λ, dev_f1))

    # knn.t = 0.0512
    # for k in MyIter(range(100, 1000, 100)):
    #     knn.k = k
    #     print('k: {}'.format(k))
    #     evaluate.evaluate_knnner(knn)

    knn.t = 0.0559
    for λ in np.linspace(0, 0.2, 60):
        knn.λ = λ
        print("λ: {:.4f}".format(λ))
        evaluate.evaluate_knnner(knn)


def evaluate_bert_softmax_onlyknn():
    max_length = 510
    batch_size = 16
    language = "zh"
    data_size = 1  # size of the datastore
    data_name = "weibo"

    k = 256
    t = 0.0001
    λ = 0.6107

    assert language in ["zh", "en_roberta", "en_bert"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    word_dir = os.path.join(
        ".", "data", "ner_data", data_name, "testword.ner.bio"
    )  # testword.ner.bmes: hanlp

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")

    vec_dir = os.path.join(".", "models", "pretrained_models", "bert-base-chinese")
    dev_model = os.path.join(
        ".",
        "models",
        "bertNer",
        "{}_model".format(data_name),
        "bertner_768_17_1_100.pkl",
    )

    if language == "zh" or language == "en_roberta":
        """zh,en_robert tokenizer"""
        tokenizer = AutoTokenizer.from_pretrained(vec_dir)
    else:
        """en_bert tokenzier"""
        tokenizer = BertWordPieceTokenizer(
            os.path.join(vec_dir, "vocab.txt"), lowercase=False
        )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    dataloader.load()
    label_to_index, index_to_label = (
        dataloader.label_to_index,
        dataloader.index_to_label,
    )
    dataloader.tokenize(dev=True)
    # _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    knn = KnnNer(
        tokenizer,
        model,
        dataloader=dataloader,
        label_to_index=label_to_index,
        index_to_label=index_to_label,
        k=k,
        t=t,
        λ=λ,
    )
    knn.get_entitys()
    print(knn.entitys.shape, knn.entitys_label.shape)

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, model_dir=dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()
    evaluate.evaluate_knnner(knn)

    best_f1, best_t, best_λ = 0, 0, 0
    for t in np.linspace(0.0001, 1, 10):
        for λ in np.linspace(0.0001, 1, 10):
            knn.t = t
            knn.λ = λ
            print("t:{:.4f} λ: {:.4f}".format(t, λ))
            dev_f1 = evaluate.evaluate_knnner(knn)
            print(
                "t:{:.4f} λ: {:.4f} f1: {:.4f} best_f1: {:.4f}".format(
                    best_t, best_λ, dev_f1, best_f1
                )
            )

            if best_f1 <= dev_f1:
                best_f1 = dev_f1
                best_t = t
                best_λ = λ
    print("t:{:.4f} λ: {:.4f} f1: {:.4f}".format(best_t, best_λ, best_f1))
    with open(os.path.join(".", "output.txt"), "w", encoding="utf-8") as f:
        f.write("t:{:.4f} λ: {:.4f} f1: {:.4f}".format(best_t, best_λ, best_f1))
    knn.t = best_t
    knn.λ = best_λ
    evaluate.evaluate_knnner(knn)

    # knn.λ = 0.9655
    # for t in MyIter(np.linspace(0.0001, 0.2, 65)):
    #     knn.t = t
    #     print('t: {:.4f}'.format(t))
    #     evaluate.evaluate_knnner(knn)

    # knn.t = 0.0512
    # for k in MyIter(range(100, 1000, 100)):
    #     knn.k = k
    #     print('k: {}'.format(k))
    #     evaluate.evaluate_knnner(knn)

    # knn.t = 0.00131
    # for λ in np.linspace(0.1, 0.5, 30):
    #     knn.λ = λ
    #     print('λ: {:.4f}'.format(λ))
    #     evaluate.evaluate_knnner(knn)


if __name__ == "__main__":
    # evaluate_bert_softmax()
    # evaluate_bert_softmax_knn()
    evaluate_bert_softmax_onlyknn()
