# -*- coding: utf-8 -*-
import torch
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast as autocast
from transformers import AutoConfig, AutoModelForTokenClassification
from tqdm import tqdm

import time
import hanlp
import faiss
import jieba
import random
import copy
import math

import os
import dill as pickle
import numpy as np
import copy


class MyIter:
    def __init__(self, iters):
        self.iters = iters

    def __iter__(self):
        for i in tqdm(self.iters):
            yield i


class Utils(object):
    """tokenzing the words to chunkes"""

    def __init__(self, language):
        assert language in ["zh", "en"]
        self.language = language
        self.seg = None
        if self.language == "zh":
            self.seg = hanlp.load(hanlp.pretrained.tok.COARSE_ELECTRA_SMALL_ZH)
        else:
            self.seg = hanlp.load(
                hanlp.pretrained.mtl.UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_XLMR_BASE
            )

    def alignTok(self, pos, tok, text):
        mid_pos, mid_tok = list(), list()
        for i in range(len(tok)):
            th = tok[i].split(" ")
            mid_pos.extend([pos[i]] * len(th))
            mid_tok.extend(th)
        pos, tok = mid_pos, mid_tok
        res_pos = list()
        i, j = 0, 0
        while i < len(tok) and j < len(text):
            if tok[i] != text[j]:
                th = ""
                res_pos.append(pos[j])
                while th != text[j]:
                    th += tok[i]
                    i += 1
            else:
                res_pos.append(pos[j])
                i += 1
            j += 1
        return res_pos

    def en2Word(self, pos, text):
        """
        transform pos to chunks using the pos-tagging of the hanlp
        """
        idxs = []
        i = 0
        while i < len(pos):
            if pos[i] not in ["ADJ", "NOUN", "NUM", "PROPN"]:
                i += 1
                continue
            j = i
            while True:
                j += 1
                if (j >= len(pos)) or (
                    (pos[j] not in ["PROPN", "NOUN", "NUM"])
                    and not (pos[j] == "PUNCT" and text[j] == "-")
                ):
                    idxs.append([i, j - 1])
                    break
            i = j
        l = 0
        text, label = [], []
        for idx in idxs:
            if idx[0] > l:
                text.append("O" * (idx[0] - l))
                label.append("O")
            text.append("O" * (idx[1] - idx[0] + 1))
            label.append("O")
            l = idx[1] + 1
        if l < len(pos):
            text.append("O" * (len(pos) - l))
            label.append("O")
        return text, label

    def bio2Word(self, bio_dir, word_dir):
        """tokenizer：
        devword/testword/trainword: hanlp分词
        """
        bio, new_bio = "", ""
        all_texts, all_labels = list(), list()
        with open(bio_dir, "r", encoding="utf-8") as f:
            bio = f.read().split("\n")
        text, label = list(), list()
        for t in tqdm(bio):
            if t == "":
                all_texts.append(text)
                all_labels.append(label)
                new_bio = new_bio + "\n"
                text, label = [], []
            else:
                a, b = "", ""
                if t[0] == " ":
                    a = " "
                    b = t.split("\t")[-1]
                else:
                    a, b = t.split("\t")
                new_bio = new_bio + a + "\t" + b + "\n"
                if b != "O":
                    b = b.split("-")[-1]
                text.append(a)
                label.append(b)
        new_bio = new_bio[:-1]

        word_texts, word_labels = list(), list()
        word_bio = ""
        for t, l in zip(MyIter(all_texts), all_labels):
            text, label = [], []
            if self.language == "zh":
                text = self.seg("".join(t), coarse=True)
                label = ["O"] * len(text)
                if sum(len(item) for item in text) != len(t):
                    text = list(
                        jieba.cut("".join(t), cut_all=False)
                    )  # To handle garbled words
                    label = ["O"] * len(text)
                    if sum(len(item) for item in text) != len(t):
                        return
            else:
                seg_res = self.seg(" ".join(t))
                pos, tok = seg_res["pos"], seg_res["tok"]
                # print(len(pos), self.seg(' '.join(t))['tok'], t, len(l))
                if len(pos) != len(l):  # pos and tok don't align to the text and label
                    pos = self.alignTok(pos, tok, t)
                    if len(pos) != len(t):
                        return
                text, label = self.en2Word(pos, t)
            idx = [i for i in range(len(l)) if i == 0 or l[i] != l[i - 1]]
            idx_l = [l[i] for i in idx]
            res = [
                "".join(j) if language == "zh" else len(j) * "O"
                for j in [
                    t[idx[i] : idx[i + 1]] if i < len(idx) - 1 else t[idx[i] :]
                    for i in range(len(idx))
                ]
            ]
            text = [
                item
                for sublist in [
                    list(jieba.cut(g, cut_all=False))  # To handle garbled words
                    if self.language == "zh" and x == "O"
                    else [g]
                    for g, x in zip(res, idx_l)
                ]
                for item in sublist
            ]
            label = ["O"] * len(text)
            word_texts.append(text)
            word_labels.append(label)

        for t, l in zip(MyIter(word_texts), word_labels):
            word_bio = (
                word_bio + "\n".join(["\t".join([i, j]) for i, j in zip(t, l)]) + "\n\n"
            )
        word_bio = word_bio[:-1]

        with open(word_dir, "w", encoding="utf-8") as f:
            f.write(word_bio)


"""faiss"""


class CharWordKnnNer(object):
    def __init__(
        self,
        tokenizer,
        model,
        dataloader,
        label_to_index,
        index_to_label,
        k=128,
        t=1,
        λ=0.5,
        language="zh",
    ):
        """使用训练集的数据与标签"""
        self.tokenizer = tokenizer
        self.model = model
        self.texts = None  # list test bio 信息
        self.labels = None
        self.dataloader = dataloader
        self.word_to_char = None  # test word信息
        self.label_to_index = label_to_index
        self.index_to_label = index_to_label
        self.k = k
        self.t = t
        self.λ = λ
        self.language = language
        self.entitys = None  # numpy [len * vec_size]
        self.entitys_label = None  # numpy [len]
        self.faiss = None

    def load_bio(self, dev_dir):
        """en需要扩展前后缀，重新调整word_to_char"""
        all_text = list()
        with open(dev_dir, "r", encoding="utf-8") as f:
            alls = f.read().split("\n")
        text = list()
        for t in alls:
            if t == "":
                all_text.append(text)
                text = []
            else:
                th = t.split("\t")[0]
                text.append(th)
        self.texts = copy.deepcopy(all_text)

    def load_word(self, dev_dir, word_dir):
        """加载词与字符的对应关系"""
        if self.language == "zh":
            word_to_char = list()
            with open(word_dir, "r", encoding="utf-8") as f:
                alls = f.read().split("\n")
            word = list()
            index = 0
            for t in alls:
                if t == "":
                    word_to_char.append(word)
                    word = []
                    index = 0
                else:
                    th = t.split("\t")[0]
                    word.append([index, len(th) + index])
                    index = index + len(th)
            self.word_to_char = copy.deepcopy(word_to_char)
        elif self.language == "en_roberta":
            self.load_bio(dev_dir)
            word_to_char = list()
            with open(word_dir, "r", encoding="utf-8") as f:
                alls = f.read().split("\n")
            word = list()
            index = 0  # 子词的索引
            l, r = 0, 0  # 分别是句子、原单词索引
            for t in alls:
                if t == "":
                    word_to_char.append(word)
                    word = []
                    l = l + 1
                    r = 0
                    index = 0
                else:
                    th = t.split("\t")[0]
                    lens = 0
                    for i in range(len(th)):
                        lens = lens + len(
                            self.tokenizer.encode(
                                self.texts[l][r], add_special_tokens=False
                            )
                        )
                        r = r + 1
                    word.append([index, lens + index])
                    index = index + lens
            self.word_to_char = copy.deepcopy(word_to_char)
        else:
            self.load_bio(dev_dir)
            word_to_char = list()
            with open(word_dir, "r", encoding="utf-8") as f:
                alls = f.read().split("\n")
            word = list()
            index = 0  # 子词的索引
            l, r = 0, 0  # 分别是句子、原单词索引
            for t in alls:
                if t == "":
                    word_to_char.append(word)
                    word = []
                    l = l + 1
                    r = 0
                    index = 0
                else:
                    th = t.split("\t")[0]
                    lens = 0
                    for i in range(len(th)):
                        lens = lens + len(
                            self.tokenizer.encode(
                                self.texts[l][r], add_special_tokens=False
                            ).ids
                        )
                        r = r + 1
                    word.append([index, lens + index])
                    index = index + lens
            self.word_to_char = copy.deepcopy(word_to_char)

    def get_entitys(self):
        """抽取出训练集中的实体加载faiss model"""
        print("Selecting Entitys...")
        entitys, entitys_label = [], []
        O, sep = 0, 40000
        datastore_dir = os.path.join(".", "datastores", "weibo", "100O_30.pkl")
        """create the new file of datastore"""
        if not os.path.exists(datastore_dir):
            """sep, batch_size contral the speed"""
            dataloader = DataLoader(self.dataloader, batch_size=1)
            self.model.eval()
            with torch.no_grad(), open(datastore_dir, "ab") as f:
                for index, (text, label, _) in enumerate(MyIter(dataloader)):
                    text, label = (
                        text[0].cuda(),
                        text[1].cuda(),
                    ), label.numpy().tolist()
                    """bert"""
                    # text = F.normalize(self.model(text)[0], dim=-1).cpu().numpy().tolist()
                    """finetune bert"""
                    text = F.normalize(self.model(text), dim=-1).cpu().numpy().tolist()
                    """random O"""
                    for t, l in zip(text, label):
                        for i, j in zip(t, l):
                            if len(entitys) == sep:
                                pickle.dump((entitys, entitys_label), f)
                                entitys, entitys_label = [], []
                            if self.index_to_label[j] != "O":
                                if random.uniform(0, 1) < 1:
                                    entitys.append(i)
                                    entitys_label.append(j)
                            elif random.uniform(0, 1) < 0.3:
                                entitys.append(i)
                                entitys_label.append(j)
                                O = O + 1
                    """not add O"""
                    # entitys.extend([i for t, l in zip(text, label) for i, j in zip(t, l) if self.index_to_label[j]!='O'])
                    # entitys_label.extend([j for l in label for j in l if self.index_to_label[j]!='O'])
                if len(entitys) != 0:
                    pickle.dump((entitys, entitys_label), f)
                    entitys, entitys_label = [], []
            self.model.train()
        with open(datastore_dir, "rb") as f:
            while True:
                try:
                    e, l = pickle.load(f)
                    entitys.extend(e)
                    entitys_label.extend(l)
                except EOFError:
                    break
        entitys = np.array(entitys, dtype=np.float32)
        entitys_label = np.array(entitys_label)
        print("!O_size:{} O_size: {}".format(len(entitys) - O, O))

        start = time.time()
        index = faiss.index_factory(
            entitys.shape[-1], "IVF9784_HNSW64,Flat", faiss.METRIC_INNER_PRODUCT
        )

        if not index.is_trained:
            """gpu"""
            index_ivf = faiss.extract_index_ivf(index)
            clustering_index = faiss.index_cpu_to_all_gpus(
                faiss.IndexFlatIP(index_ivf.d)
            )
            index_ivf.clustering_index = clustering_index
            index.train(entitys)
        index.add(entitys)
        print(
            "Building datastore {:.2f}s storesize: {}".format(
                time.time() - start, len(entitys)
            )
        )

        self.faiss = index
        self.entitys = entitys
        self.entitys_label = entitys_label

    def get_maxword(self, labels):
        """获取出现类型最多的作为词信息"""
        if labels == []:
            return labels
        distribution = [0] * len(self.index_to_label)
        label = max(labels, key=labels.count)
        if label != "O":
            distribution[self.label_to_index["B-" + label]] = 0.5
            distribution[self.label_to_index["I-" + label]] = 0.5
        else:
            distribution[self.label_to_index[label]] = 1
        return [distribution for i in range(len(labels))]

    def get_averageword(self, dis):
        """平均池化分布作为词信息"""
        if dis == []:
            return dis
        distribution = np.mean(np.array(dis), axis=0).tolist()
        return [distribution for i in range(len(dis))]

    def caculate_word_distribution(self, line, dis):
        """dis not add [CLS],[SEP]"""
        distributions = [[0] * len(self.index_to_label)]
        label = [self.index_to_label[i.index(max(i))] for i in dis]
        label = [i if i == "O" else i.split("-")[-1] for i in label]
        for w in self.word_to_char[line]:
            distributions.extend(self.get_maxword(label[w[0] : w[-1]]))
            # distributions.extend(self.get_averageword(dis[w[0]:w[-1]]))
        distributions.append([0] * len(self.index_to_label))

        return F.softmax(torch.tensor(distributions, dtype=torch.float).cuda(), dim=-1)

    def caculate_knn_distribution(self, query, line, sequence_mask):
        distributions = list()
        query = (
            F.normalize(query.cpu(), dim=-1).numpy().astype("float32")
        )  # [sequence_size, dim_size]
        sim, index = self.faiss.search(query, self.k)
        """get sim to probality"""
        for idx, (s, x) in enumerate(zip(sim, index)):
            distribution = [0] * len(self.index_to_label)
            if (
                self.index_to_label[sequence_mask[idx]] != "O"
            ):  # when model predict is not 'O'
                for i, j in zip(s, x):
                    if j != -1:
                        distribution[self.entitys_label[j]] += math.exp((-i / self.t))
            distributions.append(distribution)
        distributions = F.softmax(
            torch.tensor(distributions, dtype=torch.float).cuda(), dim=-1
        ) + self.caculate_word_distribution(line, distributions[1:-1])

        """normalize"""
        return F.softmax(torch.unsqueeze(distributions, dim=0), dim=-1)


"""topk"""


class KnnNer(object):
    def __init__(
        self,
        tokenizer,
        model,
        dataloader,
        label_to_index,
        index_to_label,
        k=128,
        t=1,
        λ=0.5,
    ):
        """使用训练集的数据与标签"""
        self.tokenizer = tokenizer
        self.model = model
        self.texts = None  # list
        self.labels = None
        self.dataloader = dataloader
        self.label_to_index = label_to_index
        self.index_to_label = index_to_label
        self.k = k
        self.t = t
        self.λ = λ
        self.entitys = None  # numpy [len * vec_size]
        self.entitys_label = None  # numpy [len]
        self.faiss = None
        self.norm_2 = None

    def get_entitys(self):
        """抽取出训练集中的实体加载faiss model"""
        print("Selecting Entitys...")
        entitys, entitys_label = [], []
        O, sep = 0, 40000
        datastore_dir = os.path.join(
            ".", "datastores", "ontonote4", "chinese-100O_100.pkl"
        )
        """create the new file of datastore"""
        if not os.path.exists(datastore_dir):
            """sep, batch_size contral the speed"""
            dataloader = DataLoader(self.dataloader, batch_size=1)
            self.model.eval()
            with torch.no_grad(), open(datastore_dir, "ab") as f:
                for index, (text, label, _) in enumerate(MyIter(dataloader)):
                    text, label = (
                        text[0].cuda(),
                        text[1].cuda(),
                    ), label.numpy().tolist()
                    """bert"""
                    # text = F.normalize(self.model(text)[0], dim=-1).cpu().numpy().tolist()
                    """finetune bert"""
                    text = F.normalize(self.model(text), dim=-1).cpu().numpy().tolist()
                    """random O"""
                    for t, l in zip(text, label):
                        for i, j in zip(t, l):
                            if len(entitys) == sep:
                                pickle.dump((entitys, entitys_label), f)
                                entitys, entitys_label = [], []
                            if self.index_to_label[j] != "O":
                                if random.uniform(0, 1) < 1:
                                    entitys.append(i)
                                    entitys_label.append(j)
                            elif random.uniform(0, 1) < 1:
                                entitys.append(i)
                                entitys_label.append(j)
                                O = O + 1
                    """not add O"""
                    # entitys.extend([i for t, l in zip(text, label) for i, j in zip(t, l) if self.index_to_label[j]!='O'])
                    # entitys_label.extend([j for l in label for j in l if self.index_to_label[j]!='O'])
                if len(entitys) != 0:
                    pickle.dump((entitys, entitys_label), f)
                    entitys, entitys_label = [], []
            self.model.train()
        with open(datastore_dir, "rb") as f:
            while True:
                try:
                    e, l = pickle.load(f)
                    entitys.extend(e)
                    entitys_label.extend(l)
                except EOFError:
                    break
        print("!O_size:{} O_size: {}".format(len(entitys) - O, O))
        entitys = torch.tensor(entitys, dtype=torch.float32).cuda()
        entitys_label = torch.tensor(entitys_label).cuda()

        self.entitys = entitys  # [token_num, feature_size]
        self.entitys_label = entitys_label  # [token_num]
        self.norm_1 = (
            (entitys.t() ** 2).sum(dim=0, keepdim=True).sqrt()
        )  # [1, token_num]

    def caculate_knn_distribution(self, query, line, sequence_mask):
        """query should in the shape [batch_size, seq_len, num_labels]"""
        query = query.unsqueeze(0)  # [batch_size, seq_len, num_labels]

        batch_size = query.shape[0]
        sent_len = query.shape[1]
        hidden_size = query.shape[-1]
        token_num = self.entitys.shape[0]
        # print(batch_size, sent_len, hidden_size, token_num)

        query = query.view(-1, hidden_size)  # [bsz*sent_len, feature_size]
        sim = torch.mm(query, self.entitys.t())  # [bsz*sent_len, token_num]
        norm_2 = (query**2).sum(dim=1, keepdim=True).sqrt()  # [bsz*sent_len, 1]
        scores = (sim / (self.norm_1 + 1e-10) / (norm_2 + 1e-10)).view(
            batch_size, sent_len, -1
        )  # [bsz, sent_len, token_num]
        knn_labels = self.entitys_label.view(1, 1, token_num).expand(
            batch_size, sent_len, token_num
        )  # [bsz, sent_len, token_num]

        if self.k != -1 and scores.shape[-1] > self.k:
            topk_scores, topk_idxs = torch.topk(
                scores, dim=-1, k=self.k
            )  # [bsz, sent_len, topk]
            scores = topk_scores
            knn_labels = knn_labels.gather(
                dim=-1, index=topk_idxs
            )  # [bsz, sent_len, topk]

        sim_probs = torch.softmax(scores / self.t, dim=-1)  # [bsz, sent_len, token_num]

        knn_probabilities = (
            torch.zeros_like(sim_probs[:, :, 0])
            .unsqueeze(-1)
            .repeat([1, 1, len(self.index_to_label)])
        )  # [bsz, sent_len, num_labels]
        knn_probabilities = knn_probabilities.scatter_add(
            dim=2, index=knn_labels, src=sim_probs
        )  # [bsz, sent_len, num_labels]

        return knn_probabilities


class NerDataLoader(Dataset):
    def __init__(
        self,
        tokenizer,
        in_dir,
        vec_dir,
        out_dir=None,
        max_length=100,
        data_size=1,
        language="zh",
    ):
        self.tokenizer = tokenizer
        self.in_dir = in_dir
        self.out_dir = out_dir
        self.max_length = max_length
        self.vec_dir = vec_dir
        self.data_size = data_size
        self.language = language
        self.soups = None  # 后续不转化为索引的已分词的原始数据
        self.soups_label = None  # 不转化为索引的标签
        self.texts = None
        self.labels = None
        """源文件、标签对应"""
        self.label_to_index = None
        self.index_to_label = None

    def get_soups(self, label_to_index=None, index_to_label=None, dev=False):
        self.load()
        if dev:
            self.label_to_index, self.index_to_label = label_to_index, index_to_label
        self.tokenize(dev)
        return self.soups, self.label_to_index, self.index_to_label

    def load(self):
        """ner数据加载,文件最后应有且只有一个\n"""
        print("Loading WordBIO files...")
        all_text = list()
        all_label = list()
        with open(self.in_dir, "r", encoding="utf-8") as f:
            alls = f.read().split("\n")
        label_to_index = {"O": 0}
        text = list()
        label = list()
        for t in alls:
            if t == "":
                """调整训练规模比例"""
                if random.uniform(0, 1) <= self.data_size:
                    all_text.append(text)
                    all_label.append(label)
                text = []
                label = []
            else:
                text.append(t.split("\t")[0])
                label.append(t.split("\t")[1])
                if t.split("\t")[1] not in label_to_index:
                    label_to_index[t.split("\t")[1]] = len(label_to_index.keys())

        self.texts = all_text
        self.labels = all_label
        self.soups = copy.deepcopy(all_text)
        self.soups_label = copy.deepcopy(all_label)
        self.label_to_index = label_to_index
        self.index_to_label = list(label_to_index)
        return self.label_to_index, self.index_to_label

    def tokenize(self, dev=False):
        """转化为索引"""
        print("Tokenize start...")
        text = []
        label = []
        if self.language == "zh":
            for index in MyIter(range(len(self.texts))):
                if dev:
                    text = (
                        ["[CLS]"] + self.texts[index][:510] + ["[SEP]"]
                    )  # warning burn out the length of 510
                else:
                    text = (
                        ["[CLS]"]
                        + self.texts[index][: self.max_length]
                        + ["[SEP]"]
                        + ["[PAD]"]
                        * (self.max_length - len(self.texts[index][: self.max_length]))
                    )
                input_ids = self.tokenizer.convert_tokens_to_ids(text)
                self.texts[index] = {
                    "input_ids": torch.tensor(input_ids),
                    "pinyin_ids": torch.tensor(
                        self.tokenizer.convert_ids_to_pinyin_ids(input_ids)
                    ),
                }

            for index in MyIter(range(len(self.labels))):
                if dev:
                    label = self.labels[index][:510]
                    self.labels[index] = torch.tensor(
                        [0] + [self.label_to_index[i] for i in label] + [0]
                    )
                else:
                    # 标签需要与实体对齐，由于添加了特殊字符，多2
                    label = self.labels[index][: self.max_length]
                    self.labels[index] = torch.tensor(
                        [0]
                        + [self.label_to_index[i] for i in label]
                        + [0]
                        + [0] * (self.max_length - len(label))
                    )
        else:
            pass

    def __getitem__(self, index):
        return (
            (self.texts[index]["input_ids"], self.texts[index]["pinyin_ids"]),
            self.labels[index],
            index,
        )

    def __len__(self):
        return len(self.texts)


class ChineseNerModel(Module):
    """bert, wobert"""

    def __init__(
        self,
        vec_dir,
        cache_dir,
        hidden_size,
        label_size,
        num_layers=1,
        bi=False,
        hidden_dropout_prob=0.1,
        drop_out=0,
    ):
        super(ChineseNerModel, self).__init__()
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.num_layers = num_layers
        self.bi = bi
        # self.bert_config = AutoConfig.from_pretrained(vec_dir, hidden_dropout_prob=hidden_dropout_prob)
        self.bert_config = AutoConfig.from_pretrained(
            vec_dir,
            num_labels=self.label_size,
            hidden_dropout_prob=hidden_dropout_prob,
            output_hidden_states=True,
            cache_dir=cache_dir,
        )
        self.bert = AutoModelForTokenClassification.from_pretrained(
            vec_dir,
            trust_remote_code=True,
            config=self.bert_config,
            cache_dir=cache_dir,
        )

    def forward(self, text, label=None):
        if label != None:
            attention_mask = (text[0] != 0).long()
            text = self.bert(
                text[0], attention_mask=attention_mask, pinyin_ids=text[1]
            ).logits
            return text
        else:
            attention_mask = (text[0] != 0).long()
            text = self.bert(
                text[0], attention_mask=attention_mask, pinyin_ids=text[1]
            ).hidden_states[-1]
            return text

    def save(self, model_name, model_dir, model):
        model_name = "{}ner_{}_{}_{}_{}.pkl".format(
            model_name, self.hidden_size, self.label_size, self.num_layers, int(self.bi)
        )
        pickle.dump(model, open(os.path.join(model_dir, model_name), "wb"))


from metric import SpanEntityScore, print_table


class Evaluater:
    def __init__(
        self, dev_dataloader, model_dir, out_dir=None, e=0.7, max_length=100, eps=0.5
    ):
        self.dev_dataloader = dev_dataloader
        self.model_dir = model_dir
        self.out_dir = out_dir  # 处理后的数据（加载）路径,wobertembedding和wordembedding需要,ner不需要
        self.model = None
        self.matrix = None
        self.e = e
        self.max_length = max_length
        self.eps = eps

    def load_model(self):
        print("Loading evaluate model...")
        start = time.time()
        self.model = pickle.load(open(self.model_dir, "rb"))
        end = time.time()
        print("Loading evaluate model time: {} s".format(end - start))
        return self.model

    def evaluate_ner(self):
        dev_dataloader = DataLoader(self.dev_dataloader, shuffle=False)
        pre_labels = []
        real_labels = []

        self.model.eval()
        with torch.no_grad():
            for index, (text, label, line) in enumerate(MyIter(dev_dataloader)):
                text, label = (text[0].cuda(), text[1].cuda()), label.cuda()
                with autocast():
                    l = self.model(text, label)
                line = int(line)
                l, label = (
                    torch.argmax(l, -1).reshape(-1).cpu().numpy().tolist(),
                    label.reshape(-1).cpu().numpy().tolist(),
                )
                # only add 2 pad
                label = [self.dev_dataloader.index_to_label[i] for i in label[1:-1]]
                l = [
                    self.dev_dataloader.index_to_label[i]
                    for i in l[1 : (len(label) + 1)]
                ]
                if len(label) != len(l):
                    print("label!=real_label！", label, l)
                    break
                pre_labels.append(l)
                real_labels.append(label)
            self.model.train()
            """span metric"""
            span = SpanEntityScore(
                id2label=self.dev_dataloader.label_to_index, markup="bio"
            )
            span.update(real_labels, pre_labels)
            score, label = span.result()
            print_table(label)
            f1, precision, recall = (
                score["f1"],
                score["precision"],
                score["recall"],
            )
            print(
                "f1_score:{:.4f} precision_score:{:.4f} recall_score:{:.4f}".format(
                    f1, precision, recall
                )
            )
            return f1

    def evaluate_knnner(self, knn):
        dev_dataloader = DataLoader(self.dev_dataloader, shuffle=False)
        pre_labels = []
        real_labels = []

        self.model.eval()
        with torch.no_grad():
            for index, (text, label, line) in enumerate(MyIter(dev_dataloader)):
                text, label = (text[0].cuda(), text[1].cuda()), label.cuda()
                with autocast():
                    l = self.model(text, label)
                """bert"""
                # querys = knn.model(text)[0][0]
                """fine tune"""
                querys = knn.model(text)[0]
                sequence_mask = (
                    torch.argmax(l, dim=-1).reshape(-1).cpu().numpy().tolist()
                )
                l = knn.λ * F.softmax(l, dim=-1) + (
                    1 - knn.λ
                ) * knn.caculate_knn_distribution(
                    querys, int(line), sequence_mask=sequence_mask
                )

                l, label = (
                    torch.argmax(l, -1).reshape(-1).cpu().numpy().tolist(),
                    label.reshape(-1).cpu().numpy().tolist(),
                )
                # 处理PAD标签
                label = [self.dev_dataloader.index_to_label[i] for i in label[1:-1]]
                l = [
                    self.dev_dataloader.index_to_label[i]
                    for i in l[1 : (len(label) + 1)]
                ]
                if len(label) != len(l):
                    print("label!=real_label！", label, l)
                    break
                pre_labels.append(l)
                real_labels.append(label)
        self.model.train()

        span = SpanEntityScore(
            id2label=self.dev_dataloader.label_to_index, markup="bio"
        )
        span.update(real_labels, pre_labels)
        score, label = span.result()
        print_table(label)
        f1, precision, recall = (
            score["f1"],
            score["precision"],
            score["recall"],
        )
        print(
            "f1_score:{:.4f} precision_score:{:.4f} recall_score:{:.4f}".format(
                f1, precision, recall
            )
        )
        return f1


if __name__ == "__main__":
    language = "en"
    util = Utils(language=language)

    """en"""
    util.bio2Word(
        "./data/ner_data/en_conll03/test.ner.bio",
        "./data/ner_data/en_conll03/testword.ner.bio",
    )
    util.bio2Word(
        "./data/ner_data/ontonote5/test.ner.bio",
        "./data/ner_data/ontonote5/testword.ner.bio",
    )

    """zh"""
    language = "zh"
    util = Utils(language=language)
    util.bio2Word(
        "./data/ner_data/ontonote4/test.ner.bio",
        "./data/ner_data/ontonote4/testword.ner.bio",
    )
    util.bio2Word(
        "./data/ner_data/msra/test.ner.bio", "./data/ner_data/msra/testword.ner.bio"
    )
    util.bio2Word(
        "./data/ner_data/weibo/test.ner.bio", "./data/ner_data/weibo/testword.ner.bio"
    )
