# -*- encoding:utf8 -*-
from init import MyIter, NerDataLoader, KnnNer, CharWordKnnNer, Evaluater
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from common import seed_everything
import numpy as np
import dill as pickle
import os
import time
import torch

seed = 42
seed_everything(seed=seed)

device_ids = [0, 1, 2]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))


def evaluate_bert_softmax():
    max_length = 510
    batch_size = 16
    data_name = "msra"
    language = "zh"
    assert language in ["zh"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    vec_dir = "iioSnail/ChineseBERT-large"
    cache_dir = os.path.join(".", "models", "pretrained_models", "chinesebert-large")
    dev_model = os.path.join(
        ".",
        "models",
        "chinesebertNer",
        "{}_model".format(data_name),
        "chinesebert-largener_23236_7_1_100.pkl",
    )

    """chinesebert"""
    tokenizer = AutoTokenizer.from_pretrained(
        vec_dir, trust_remote_code=True, cache_dir=cache_dir
    )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    _, label_to_index, index_to_label = dataloader.get_soups()
    print(label_to_index, len(index_to_label))

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")
    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )

    evaluate = Evaluater(dev_dataloader, dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()


def evaluate_bert_softmax_knn():
    max_length = 510
    batch_size = 16
    language = "zh"
    data_name = "weibo"
    data_size = 1  # size of the datastore

    k = 512
    t = 0.1520
    λ = 0.0044

    assert language in ["zh"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    word_dir = os.path.join(
        ".", "data", "ner_data", data_name, "testword.ner.bio"
    )  # testword.ner.bmes: hanlp

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")
    vec_dir = "iioSnail/ChineseBERT-large"
    cache_dir = os.path.join(".", "models", "pretrained_models", "chinesebert-large")
    dev_model = os.path.join(
        ".",
        "models",
        "chinesebertNer",
        "{}_model".format(data_name),
        "chinesebert-largener_23236_17_1_100.pkl",
    )

    """chinesebert"""
    tokenizer = AutoTokenizer.from_pretrained(
        vec_dir, trust_remote_code=True, cache_dir=cache_dir
    )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    dataloader.load()
    label_to_index, index_to_label = (
        dataloader.label_to_index,
        dataloader.index_to_label,
    )
    dataloader.tokenize(dev=True)
    # _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    knn = CharWordKnnNer(
        tokenizer,
        model,
        dataloader=dataloader,
        label_to_index=label_to_index,
        index_to_label=index_to_label,
        k=k,
        t=t,
        λ=λ,
        language=language,
    )
    knn.get_entitys()
    knn.load_word(dev_dir=dev_dir, word_dir=word_dir)
    print(knn.entitys.shape, knn.entitys_label.shape)

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, model_dir=dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()
    evaluate.evaluate_knnner(knn)

    # for λ in np.linspace(0.01, 0.95, 30):
    #     for t in np.linspace(0.01, 0.95, 30):
    #         knn.t = t
    #         knn.λ = λ
    #         print('t:{:.4f} λ: {:.4f}'.format(t, λ))
    #         evaluate.evaluate_knnner(knn)

    knn.λ = 0.0027
    for t in MyIter(np.linspace(0.001, 0.3, 60)):
        knn.t = t
        print("t: {:.4f}".format(t))
        evaluate.evaluate_knnner(knn)

    # knn.t = 0.0512
    # for k in MyIter(range(100, 1000, 100)):
    #     knn.k = k
    #     print('k: {}'.format(k))
    #     evaluate.evaluate_knnner(knn)

    knn.t = 0.0577
    for λ in np.linspace(0.001, 0.1, 60):
        knn.λ = λ
        print("λ: {:.4f}".format(λ))
        evaluate.evaluate_knnner(knn)


def evaluate_bert_softmax_onlyknn():
    max_length = 510
    batch_size = 16
    language = "zh"
    data_name = "ontonote4"
    data_size = 1  # size of the datastore

    k = 256
    t = 0.2155
    λ = 0.9056

    assert language in ["zh"]

    data_dir = os.path.join(".", "data", "ner_data", data_name, "train.ner.bio")
    word_dir = os.path.join(
        ".", "data", "ner_data", data_name, "testword.ner.bio"
    )  # testword.ner.bmes: hanlp

    dev_dir = os.path.join(".", "data", "ner_data", data_name, "test.ner.bio")
    vec_dir = "iioSnail/ChineseBERT-large"
    cache_dir = os.path.join(".", "models", "pretrained_models", "chinesebert-large")
    dev_model = os.path.join(
        ".",
        "models",
        "chinesebertNer",
        "{}_model".format(data_name),
        "chinesebert-largener_23236_9_1_100.pkl",
    )

    """chinesebert"""
    tokenizer = AutoTokenizer.from_pretrained(
        vec_dir, trust_remote_code=True, cache_dir=cache_dir
    )

    dataloader = NerDataLoader(
        tokenizer,
        in_dir=data_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        data_size=data_size,
        language=language,
    )
    dataloader.load()
    label_to_index, index_to_label = (
        dataloader.label_to_index,
        dataloader.index_to_label,
    )
    dataloader.tokenize(dev=True)
    # _, label_to_index, index_to_label = dataloader.get_soups()

    print(label_to_index, len(index_to_label))
    model = pickle.load(open(dev_model, "rb"))

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs")
        model = DataParallel(model)
    model = model.cuda()

    knn = KnnNer(
        tokenizer,
        model,
        dataloader=dataloader,
        label_to_index=label_to_index,
        index_to_label=index_to_label,
        k=k,
        t=t,
        λ=λ,
    )
    knn.get_entitys()
    print(knn.entitys.shape, knn.entitys_label.shape)

    dataloader = DataLoader(dataloader, batch_size=batch_size, shuffle=True)

    dev_dataloader = NerDataLoader(
        tokenizer,
        in_dir=dev_dir,
        vec_dir=vec_dir,
        max_length=max_length,
        language=language,
    )
    dev_dataloader.get_soups(
        label_to_index=label_to_index, index_to_label=index_to_label, dev=True
    )
    evaluate = Evaluater(dev_dataloader, model_dir=dev_model, max_length=max_length)
    evaluate.model = model
    evaluate.evaluate_ner()
    evaluate.evaluate_knnner(knn)

    best_f1, best_t, best_λ = 0, 0, 0
    for t in np.linspace(0.0001, 0.5, 10):
        for λ in np.linspace(0.001, 0.95, 10):
            knn.t = t
            knn.λ = λ
            print("t:{:.4f} λ: {:.4f}".format(t, λ))
            dev_f1 = evaluate.evaluate_knnner(knn)
            print(
                "t:{:.4f} λ: {:.4f} f1: {:.4f} best_f1: {:.4f}".format(
                    best_t, best_λ, dev_f1, best_f1
                )
            )

            if best_f1 <= dev_f1:
                best_f1 = dev_f1
                best_t = t
                best_λ = λ
    print("t:{:.4f} λ: {:.4f} f1: {:.4f}".format(best_t, best_λ, best_f1))
    knn.t = best_t
    knn.λ = best_λ
    evaluate.evaluate_knnner(knn)

    # knn.λ = 0.8892
    # for t in MyIter(np.linspace(0.001, 0.6, 30)):
    #     knn.t = t
    #     print('t: {:.4f}'.format(t))
    #     evaluate.evaluate_knnner(knn)

    # knn.t = 0.0512
    # for k in MyIter(range(100, 1000, 100)):
    #     knn.k = k
    #     print('k: {}'.format(k))
    #     evaluate.evaluate_knnner(knn)

    # knn.t = 0.0577
    # for λ in np.linspace(0.001, 1, 10):
    #     knn.λ = λ
    #     print('λ: {:.4f}'.format(λ))
    #     evaluate.evaluate_knnner(knn)


if __name__ == "__main__":
    # evaluate_bert_softmax()
    evaluate_bert_softmax_knn()
    # evaluate_bert_softmax_onlyknn()
