import time
import sys
sys.path.append("..")
from MassUtils.MassUtil import *
from MassUtils.Productor import *
import numpy as np
import torch
# from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from collections import Counter
from sklearn.cluster import KMeans, DBSCAN, MiniBatchKMeans
from sklearn.mixture import GaussianMixture as GMM
from MassUtils.spc import SupervisedContrastiveLoss
from MassUtils.supcontrast import SupConLoss

def product_trainer(hp):
    # other_dict = get_trainer_params(hp)
    return eval(hp.trainer_name)(hp)

class Similarity_Evaluator(object):
    def __init__(self, hp):
        # initialize the experiment
        self.identify_name = get_identify_name(hp)
        self.record_dirs = {
            'plot_dir': '../MassPlotPng/',
            'tboard_dir': '../MassLogBoard/',
            'result_dir': '../MassPredictResult/',
            'pt_dir': '../MassPT/',
            'logger_dir': '../MassLogInfo/',
            'chk_dir': '../checkpoints/'
        }
        self.record_dir_dict = create_all_record_dir(self.record_dirs, hp)
        self.record_result_dict = {
            'train_loss_list': [],
            'dev_F1_list': [],
            'test_F1_list': [],
            'VALID_MAX_PREC': 0,
            'VALID_MAX_REC': 0,
            'VALID_MAX_F1': 0,
            'VALID_MAX_EPOCH': 0,
            'TEST_MAX_PREC': 0,
            'TEST_MAX_REC': 0,
            'TEST_MAX_F1': 0,
            'TEST_MAX_EPOCH': 0
        }
        self.logger = get_logger('info', self.record_dir_dict['logger_dir'] + 'log')
        # self.tboard_writer = SummaryWriter(self.record_dir_dict['tboard_dir'])
        self.batch_size = hp.batch_size
        self.feature_dim = 768

        # build model, reload potential checkpoints
        self.model_name = 'MyBertSiameseCLNER'
        self.model = cpu_2_gpu(eval(self.model_name)(hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM, self.feature_dim))
        self.optimizer = product_optimizer(hp.optim_name, self.model.parameters(), hp.lr)
        self.criterion = product_criterion(hp.criterion_name)

        # load data
        self.dataloader_name = "SiameseLoader"
        self.trainset_params = hp.trainset
        self.tgt_lang = hp.tgt_lang

        if(self.tgt_lang in ['es','nl','de']):
            data_dir = 'CoNLL2002'
        else:
            data_dir = 'WikiAnn'

        self.trainloader = self.init_dataloader(self.dataloader_name, hp.trainset, self.product_param_dict(self.trainset_params), self.logger)
        self.validloader = self.init_dataloader(self.dataloader_name, '../data/'+data_dir+'/'+self.tgt_lang+'/valid.txt', self.product_param_dict(self.trainset_params), self.logger)
        self.testloader = self.init_dataloader(self.dataloader_name, '../data/'+data_dir+'/'+self.tgt_lang+'/test.txt', self.product_param_dict(self.trainset_params), self.logger)

    def product_model(self, model_name, batch_size, tag_nums, HIDDEN_DIM):
        return eval(model_name)(batch_size, tag_nums, HIDDEN_DIM)

    def product_param_dict(self, param_name):
        values = param_name.split('_')
        param_dict = {
            'maxSamples': int(values[0]),
            'isequal': True if values[1] == 'T' else False,
            'shuffle': True if values[2] == 'T' else False,
            'istest': True if values[3] == 'T' else False
        }
        return param_dict

    def init_dataloader(self, dataloader_name, fpath, params, logger):
        dataset = eval(dataloader_name)(fpath, params, logger)
        data_loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=params['shuffle'], num_workers=4,
                                 collate_fn=dataset.pad)
        return data_loader

    def train(self, model, trainloader, optimizer, criterion):
        model.train()
        loss_train = []
        start_t = time.time()
        for idx, batch in enumerate(trainloader):
            words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
            x_1, x_2, y, tag_1, tag_2 = cpu_2_gpu(x_1_2+[y]+tag_1_2)
            optimizer.zero_grad()
            sim, _, _, _, _, _, _ = model([x_1, x_2], wordpiece_idx_1_2)
            loss = criterion(sim, y.float())
            loss.backward()
            optimizer.step()

            loss_train.append(loss.item())
            if idx % 1000 == 0:  # monitoring
                self.logger.info(
                    f"TRAIN STEP: {idx}\tLOSS_MAIN={round(np.mean(loss_train), 11)}\t\tTIME: {(time.time() - start_t) / 60}")
        return np.mean(loss_train)

    def evaluation(self, model, validloader, epoch_i, is_test=False):
        model.eval()
        start_t = time.time()
        Words_1, Is_heads_1, Tags_1, Word_Idx_1, Words_2, Is_heads_2, Tags_2, Word_Idx_2, Y, Y_hat = [], [], [], [], [], [], [], [], [], []
        embed_list_cls, tag_list = [], []
        with torch.no_grad():
            for idx, batch in enumerate(validloader):
                words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
                words_1, words_2 = words_1_2[0], words_1_2[1]
                word_idx_1, word_idx_2 = word_idx_1_2[0], word_idx_1_2[1]
                att_mask_1, att_mask_2 = att_mask_1_2[0], att_mask_1_2[1]

                x_1, x_2, y, tag_1, tag_2 = cpu_2_gpu(x_1_2 + [y] + tag_1_2)

                prediction, _, _, _, _, embeds_1, _ = model([x_1, x_2],
                                             wordpiece_idx_1_2)  # , [att_mask_1, att_mask_2], [seqlen_1, seqlen_2]

                embed_list_cls.extend(embeds_1.cpu().numpy().tolist())
                tag_list.extend(tag_1.cpu().numpy().tolist())

                Words_1.extend(words_1)
                Words_2.extend(words_2)
                Is_heads_1.extend(att_mask_1.cpu().numpy())
                Is_heads_2.extend(att_mask_2.cpu().numpy())
                Word_Idx_1.extend(word_idx_1)
                Word_Idx_2.extend(word_idx_2)
                Tags_1.extend(tag_1)
                Tags_2.extend(tag_2)
                Y.extend(y.cpu().numpy().tolist())
                Y_hat.extend(prediction.cpu().numpy().tolist())

                if idx % 1000 == 0:  # monitoring
                    self.logger.info(f"Siamese STEP: {idx}\t\ttime: {(time.time() - start_t) / 60}")


        Y_hat = [1 if y_pred >= 0.5 else 0 for y_pred in Y_hat]
        precision, recall, f1 = eval_F1(Y, Y_hat, 'sklearn_f1')


        return precision, recall, f1

    def train_epoch(self, i):
        # train set : labeled data: english
        self.logger.info(f"=========================TRAIN AT EPOCH={i}=========================")
        t_loss = self.train(self.model, self.trainloader, self.optimizer, self.criterion)
        self.record_result_dict['train_loss_list'].append(t_loss)

        # valid set : unlabeled data: es、nl、de、ch...
        # self.logger.info(f"=========================DEV AT EPOCH={i}===========================")
        # precision, recall, f1 = self.evaluation(self.model, self.validloader, i)
        # self.record_result_dict['dev_F1_list'].append(f1)

        # test set : unlabeled data: es、nl、de、ch...
        self.logger.info(f"=========================TEST AT EPOCH={i}==========================")
        precision, recall, f1 = self.evaluation(self.model, self.testloader, i, True)
        self.record_result_dict['test_F1_list'].append(f1)

        if (self.record_result_dict['TEST_MAX_F1'] < f1):
            self.record_result_dict['TEST_MAX_F1'], self.record_result_dict['TEST_MAX_EPOCH'] = f1, i
            torch.save(self.model.state_dict(), self.record_dir_dict['chk_dir']+'best_sim_fea.pt')

        self.logger.info(f"Best F1: EPOCH_NUM={np.argmax(self.record_result_dict['test_F1_list']) + 1}\tF1_MAX={np.max(self.record_result_dict['test_F1_list'])}")
        self.logger.info(f"Best Model Saved in {self.record_dir_dict['chk_dir']}best_sim_fea.pt")

    # end of epoch
    def record_result(self):
        name_list = ['train_loss', 'dev_F1', 'test_F1']
        data_list = [self.record_result_dict['train_loss_list'], self.record_result_dict['dev_F1_list'], self.record_result_dict['test_F1_list']]
        # write board
        write_board(self.tboard_writer, self.record_result_dict['train_loss_list'], '/loss/train')
        write_board(self.tboard_writer, self.record_result_dict['dev_F1_list'], '/F1/dev')
        write_board(self.tboard_writer, self.record_result_dict['test_F1_list'], '/F1/test')

        # save data to file
        write_plot(name_list, data_list, self.record_dir_dict['plot_dir'] + 'result.json', self.logger)
        # plot png
        file_path = self.record_dir_dict['plot_dir']
        plot_line(data_list[0], name_list[0], file_path + '_loss', title='loss', legend_loc=1)
        plot_line(data_list[1], name_list[1], file_path + '_f1', title='f1', legend_loc=2)
        plot_line(data_list[2], name_list[2], file_path + '_f1', title='f1', legend_loc=2)

class Recognizer(object):
    def __init__(self, hp):
        # initialize the experiment
        self.identify_name = get_identify_name(hp)
        self.record_dirs = {
            'plot_dir': '../MassPlotPng/',
            'tboard_dir': '../MassLogBoard/',
            'result_dir': '../MassPredictResult/',
            'pt_dir': '../MassPT/',
            'logger_dir': '../MassLogInfo/',
            'chk_dir': '../checkpoints/'
        }
        self.record_dir_dict = create_all_record_dir(self.record_dirs, hp)
        self.record_result_dict = {
            'train_loss_list': [],
            'dev_F1_list': [],
            'test_F1_list': [],
            'VALID_MAX_PREC': 0,
            'VALID_MAX_REC': 0,
            'VALID_MAX_F1': 0,
            'VALID_MAX_EPOCH': 0,
            'TEST_MAX_PREC': 0,
            'TEST_MAX_REC': 0,
            'TEST_MAX_F1': 0,
            'TEST_MAX_EPOCH': 0
        }
        self.logger = get_logger('info', self.record_dir_dict['logger_dir'] + self.identify_name)
        # self.tboard_writer = SummaryWriter(self.record_dir_dict['tboard_dir'])
        self.batch_size = hp.batch_size
        self.top_crf = hp.top_crf
        self.left_isner = hp.left_isner
        self.loss_alpha = hp.loss_alpha
        self.ix_to_tag = ix_to_tag
        self.add_sim = hp.add_sim
        self.add_spc = hp.add_spc

        # build model, reload potential checkpoints
        self.model_name, self.optim_name, self.criterion_name = 'MyBertRnnCrfIsner', "Adam", "CrossEntropyLoss"
        self.model = cpu_2_gpu(product_model(self.model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))

        # clip = 1
        # nn.utils.clip_grad_norm_(self.model.parameters(), clip)
        # 固定部分网络层
        # freeze_layers = ['bert.encoder.layer.0.', 'bert.encoder.layer.1.', 'bert.encoder.layer.2.',
        #                  'bert.embeddings.']  # 'bert.'
        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True
        #     for ele in freeze_layers:
        #         if ele in name:
        #             param.requires_grad = False
        #             break
        #
        # for name, param in self.model.named_parameters():
        #     print(name, param.requires_grad)
        self.optimizer = product_optimizer(self.optim_name, self.model.parameters(), hp.lr)
        self.criterion = product_criterion(self.criterion_name)
        if (self.add_spc):
            self.spc_tem = hp.spc_tem
            self.spc_criterion = SupervisedContrastiveLoss(temperature = self.spc_tem)

        # load data
        self.dataloader_name = "ConllIsentityLoader"
        self.tgt_lang = hp.tgt_lang

        if(self.tgt_lang in ['es','nl','de']):
            self.data_dir = '../data/CoNLL2002/'
        else:
            self.data_dir = '../data/WikiAnn/'

        self.trainloader = self.init_dataloader(self.dataloader_name, hp.trainset, {'shuffle': True})
        # self.validloader = self.init_dataloader(self.dataloader_name, '../data/CoNLL2002/'+self.tgt_lang+'/valid.txt', {'shuffle': False})
        # self.testloader = self.init_dataloader(self.dataloader_name, '../data/CoNLL2002/'+self.tgt_lang+'/test.txt', {'shuffle': False})
        self.validloader = self.init_dataloader(self.dataloader_name,
                                                self.data_dir + self.tgt_lang + '/valid.txt', {'shuffle': False})
        self.testloader = self.init_dataloader(self.dataloader_name, self.data_dir + self.tgt_lang + '/test.txt',
                                               {'shuffle': False})

        # self.SeqSimilarity = SeqSimilarity()
        # self.contrastive_loss = ContrastiveLoss()

    def init_dataloader(self, dataloader_name, fpath, params):
        dataset = eval(dataloader_name)(fpath)
        data_loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=params['shuffle'], num_workers=4,
                                 collate_fn=dataset.pad)
        return data_loader

    def train(self, model, trainloader, optimizer, criterion):
        model.train()
        loss_train = []
        supconloss_train, celoss_train = [], []
        start_t = time.time()

        for idx, batch in enumerate(trainloader):
            words, x, is_heads, tags, y, seqlens, is_entity = batch
            x, y, is_heads = cpu_2_gpu([x, y, is_heads])

            _y = y  # for monitoring
            optimizer.zero_grad()
            logits, prediction, embeds = model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)


            # print(words[0])
            # print([ix_to_tag[t] for head, t in zip(is_heads[0], _y.cpu().numpy().tolist()[0]) if head == 1][1:-1])
            # print([ix_to_tag[pred] for head, pred in zip(is_heads[0], prediction.cpu().numpy().tolist()[0]) if head == 1][1:-1])
            # crf包计算损失

            if (self.top_crf):
                # logits, prediction, _ = model(x, is_heads, seqlens)            # logits: (B, S, Tag), y: (N, T)
                log_likelihood = model.crf(logits, y, reduction='mean')
                loss_ner = -1 * log_likelihood
            else:
                logits = logits.view(-1, logits.shape[-1])  # (N*T, VOCAB)
                embeds = embeds.view(-1, embeds.shape[-1])
                _y = _y.view(-1)  # (N*T,)
                is_heads = is_heads.view(-1)  # (N*T,)
                logits_cons_list, _y_cons_list = [], []

                embeds = embeds[is_heads==1]
                len_word = int(is_heads[is_heads==1].size(0))
                logits_cons = torch.zeros((len_word, logits.shape[1]))
                logits_cons = cpu_2_gpu(logits_cons)
                l_idx = 0
                for h_idx, h in enumerate(is_heads.cpu().numpy().tolist()):
                    if(h==1):
                        _y_cons_list.append(_y[h_idx])
                        logits_cons[l_idx] = logits[h_idx]
                        l_idx += 1
                _y_cons = torch.tensor(_y_cons_list)
                _y_cons = cpu_2_gpu(_y_cons)

                celoss = criterion(logits_cons, _y_cons)  # crossentropy


                if(self.add_spc):
                    embeds = F.normalize(embeds)
                    supconloss = self.spc_criterion(embeds, _y_cons)

                if(self.add_sim):
                    embeds_1, embeds_2, y = self.SeqSimilarity.construct(embeds, y, is_heads)
                    loss_sim = self.contrastive_loss(embeds_1, embeds_2, y)

            if (self.left_isner):
                logits, prediction, loss_isentity = model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)
                loss_isentity = loss_isentity.view(-1, loss_isentity.shape[-1])  # (N*T, VOCAB)
                is_entity = is_entity.view(-1)  # (N*T,)
                loss_isentity = criterion(loss_isentity, is_entity)
                loss_ner = self.loss_alpha * loss_ner + (1 - self.loss_alpha) * loss_isentity

            if (self.add_sim):
                loss_ner += loss_sim

            if (self.add_spc):
                loss = self.loss_alpha * celoss + (1-self.loss_alpha) * supconloss
                supconloss_train.append(supconloss.item())

            else:
                loss = celoss

            loss_train.append(loss.item())
            celoss_train.append(celoss.item())

            celoss.backward()
            optimizer.step()

            if idx % 100 == 0:  # monitoring
                self.logger.info(
                    f"STEP: {idx}\tLOSS={round(np.mean(loss_train), 11)}\tCELOSS={round(np.mean(celoss_train), 11)}\tSupConLOSS={round(np.mean(supconloss_train), 11)}\t\ttime: {(time.time() - start_t) / 60}")

        return np.mean(loss_train)


    def evaluation(self, model, validloader, i, is_test=False):
        model.eval()
        Words, Is_heads, Tags, Y, Y_hat, Y_hat_List = [], [], [], [], [], []
        plot_embed_list, plot_tag_list, plot_words_list, plot_sents_list = [], [], [], []

        with torch.no_grad():
            for idx, batch in enumerate(validloader):
                words, x, is_heads, tags, y, seqlens, is_entity = batch
                x, y, is_heads = cpu_2_gpu([x, y, is_heads])
                _y = y
                # is_heads = is_heads.to(DEVICE)
                logits, prediction, embeds = model(x, is_heads, seqlens)
                Words.extend(words)
                Is_heads.extend(is_heads.cpu().numpy().tolist())
                Tags.extend(tags)
                Y_hat_List.extend(prediction.cpu().numpy().tolist())

                for t_i, pred, heads, embed, ws in zip(_y.cpu().numpy().tolist(), prediction.cpu().numpy().tolist(), is_heads.cpu().numpy().tolist(), embeds.cpu().numpy().tolist(), words):
                    Y.extend([t for head, t in zip(heads, t_i) if head == 1][1:-1])
                    pred = [0 if p == 9 else p for p in pred]
                    Y_hat.extend([p for head, p in zip(heads, pred) if head == 1][1:-1])
                    plot_embed_list.extend([e for head, e in zip(heads, embed) if head == 1][1:-1])
                    ws_list = ws.split()[1:-1]
                    plot_sents_list.extend([ws_list]*len(ws_list))
                    plot_words_list.extend(ws_list)
                    # print([ix_to_tag[t] for head, t in zip(heads, t_i) if head == 1][1:-1])
                    # print([ix_to_tag[pred] for head, pred in zip(heads, pred) if head == 1][1:-1])

        # self.logger.info(f"============seq classification Eval by sklearn_f1:============")
        # precision, recall, f1 = eval_F1([ix_to_tag[t_ix] for t_ix in Y], [ix_to_tag[t_ix] for t_ix in Y_hat],
        #                                 'sklearn_f1')
        # self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))

        self.logger.info(f"============seq classification Eval by conlleval.py:============")
        precision, recall, f1 = eval_F1([ix_to_tag[t_ix] for t_ix in Y], [ix_to_tag[t_ix] for t_ix in Y_hat], 'conlleval')
        self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))


        if(i%10==0):
            suffix = "test_" if is_test else "valid_"
            self.write_ner_result(self.record_dir_dict['result_dir'] + suffix + str(i),
                             data=(Words, Tags, Y_hat_List, Is_heads),
                             ix_to_tag=ix_to_tag)
            plot_num = 200
            tag_count = {tag: plot_num for tag in ALL_TAGS_}

            plot_embed_list_use, Y_plot, Y_hat_plot, plot_words, plot_sents = [], [], [], [], []

            for y_temp, pred_temp, embed_temp, words_temp, sents_temp in zip(Y, Y_hat, plot_embed_list, plot_words_list, plot_sents_list):
                if(tag_count[ix_to_tag[y_temp]]>0):
                    tag_count[ix_to_tag[y_temp]] = tag_count[ix_to_tag[y_temp]]-1
                    plot_embed_list_use.append(embed_temp)
                    Y_plot.append(y_temp)
                    Y_hat_plot.append(pred_temp)
                    plot_words.append(words_temp)
                    plot_sents.append(sents_temp)
                else:
                    continue

            plot_dim_embeds(plot_embed_list_use, Y_plot,
                            self.record_dir_dict['plot_dir'] + 'real_tag_' + str(i),
                            "real tag " + str(i))
            plot_dim_embeds(plot_embed_list_use, Y_hat_plot,
                            self.record_dir_dict['plot_dir'] + 'predict_tag_' + str(i),
                            "predict tag " + str(i))

            plot_dim_sentidx(plot_embed_list_use, plot_words, plot_sents, Y_plot, self.record_dir_dict['plot_dir'] + 'real_sent_' + str(i),
                            self.logger, "real sent " + str(i))

        return precision, recall, f1

    def train_epoch(self, i):
        # train set : labeled data: english
        self.logger.info(f"=========================TRAIN AT EPOCH={i}=========================")
        t_loss = self.train(self.model, self.trainloader, self.optimizer, self.criterion)
        self.record_result_dict['train_loss_list'].append(t_loss)

        # valid set : unlabeled data: es、nl、de、ch...
        # self.logger.info(f"=========================DEV AT EPOCH={i}===========================")
        # precision, recall, f1 = self.evaluation(self.model, self.validloader, i)
        f1 = 0
        self.record_result_dict['dev_F1_list'].append(f1)
        # if (self.record_result_dict['VALID_MAX_F1'] < f1):
        #     self.record_result_dict['VALID_MAX_F1'], self.record_result_dict['VALID_MAX_EPOCH'] = f1, i
        #     torch.save(self.model.state_dict(), self.record_dir_dict['chk_dir'] + 'best_valid.pt')

        # test set : unlabeled data: es、nl、de、ch...
        self.logger.info(f"=========================TEST AT EPOCH={i}==========================")
        precision, recall, f1 = self.evaluation(self.model, self.testloader, i, True)
        self.record_result_dict['test_F1_list'].append(f1)

        if (self.record_result_dict['TEST_MAX_F1'] < f1):
            self.record_result_dict['TEST_MAX_F1'], self.record_result_dict['TEST_MAX_EPOCH'] = f1, i
            torch.save(self.model.state_dict(), self.record_dir_dict['chk_dir'] + 'best_test.pt')

        # self.logger.info(f"Best Valid F1: EPOCH_NUM={np.argmax(self.record_result_dict['dev_F1_list']) + 1}\tF1_MAX={np.max(self.record_result_dict['dev_F1_list'])}")
        self.logger.info(f"Best Test F1: EPOCH_NUM={np.argmax(self.record_result_dict['test_F1_list']) + 1}\tF1_MAX={np.max(self.record_result_dict['test_F1_list'])}")
        self.logger.info(f"Best MODEL SAVED in {self.record_dir_dict['chk_dir']}best_test.pt")

    # end of epoch
    def record_result(self):
        name_list = ['train_loss', 'dev_F1', 'test_F1']
        # write board
        write_board(self.tboard_writer, self.record_result_dict['train_loss_list'], '/loss/train')
        write_board(self.tboard_writer, self.record_result_dict['dev_F1_list'], '/F1/dev')
        write_board(self.tboard_writer, self.record_result_dict['test_F1_list'], '/F1/test')

        # save data to file
        data_list = [self.record_result_dict['train_loss_list'], self.record_result_dict['dev_F1_list'], self.record_result_dict['test_F1_list']]
        write_plot(['train_loss', 'dev_F1', 'test_F1'], data_list, self.record_dir_dict['plot_dir'] + 'result.json', self.logger)

        # plot png
        file_path = self.record_dir_dict['plot_dir'] + self.identify_name
        plot_line(data_list[0], name_list[0], file_path + '_loss', title='loss', legend_loc=1)
        plot_line(data_list[1], name_list[1], file_path + '_f1', title='f1', legend_loc=2)
        plot_line(data_list[2], name_list[2], file_path + '_f1', title='f1', legend_loc=2)
        self.logger.info(f'Plot Png Saved in {file_path}.png')

    def write_ner_result(self, resultfile, data, ix_to_tag):
        # ep: word real_tag pred_tag
        Words, Tags, Y_hat, Is_heads = data
        with open(resultfile, 'w') as fout:
            for words, tags, y_hat, is_heads in zip(Words, Tags, Y_hat, Is_heads):
                y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
                preds = [ix_to_tag[hat] for hat in y_hat]
                assert len(preds) == len(words.split()) == len(tags.split()), print((preds), (words.split()), (tags.split()))
                for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                    fout.write(f"{w} {t} {p}\n")
                fout.write("\n")

class TSL(object):
    def __init__(self, hp):
        # initialize the experiment
        self.identify_name = get_identify_name(hp)
        self.record_dirs = {
            'plot_dir': '../MassPlotPng/',
            'tboard_dir': '../MassLogBoard/',
            'result_dir': '../MassPredictResult/',
            'pt_dir': '../MassPT/',
            'logger_dir': '../MassLogInfo/',
            'chk_dir': '../checkpoints/'
        }
        self.record_dir_dict = create_all_record_dir(self.record_dirs, hp)
        self.record_result_dict = {
            'train_loss_list': [],
            'test_F1_list': [],
            'SRC_F1': 0,
            'TRANS_F1': 0,
            'TEA_F1': 0,
            'TEST_MAX_PREC': 0,
            'TEST_MAX_REC': 0,
            'TEST_MAX_F1': 0,
            'TEST_MAX_EPOCH': 0
        }
        self.logger = get_logger('info', self.record_dir_dict['logger_dir'] + 'log')
        # self.tboard_writer = SummaryWriter(self.record_dir_dict['tboard_dir'])
        self.batch_size = hp.batch_size
        self.top_N = 400
        self.low_resource = hp.low_resource
        self.n_clusters = hp.n_clusters
        self.feature_dim = 768
        self.only_siamese = hp.only_siamese
        self.self_learn = hp.self_learn
        self.add_cluster = hp.add_cluster
        self.add_visual = hp.add_visual




        self.src_model_dict = {'es': '../checkpoints/2021-08-26 17_02_59/best_test.pt',
                               'nl': '../checkpoints/2021-08-26 21_41_38/best_test.pt',
                                'de': '../checkpoints/2021-08-27 17_10_14/best_test.pt'
                                }

        self.trans_model_dict = {
                                'es': '../checkpoints/2021-08-31 20_07_10/best_test.pt',
                               'nl': '../checkpoints/2021-08-31 15_22_09/best_test.pt',
                                'de': '../checkpoints/2021-08-31 10_51_57/best_test.pt'
                                }

        self.tea_model_dict = {'es': '../checkpoints/2021-08-26 17_02_59/best_test.pt',
                               'es': '../checkpoints/2021-09-13 21_04_04/best_test.pt',
                               'nl': '../checkpoints/2021-09-01 15_52_42/best_test.pt',

                                }

        self.tea_model_dict = {'es': '../checkpoints/2021-09-13 23_33_02/best_test.pt',
                              'nl': '../checkpoints/2021-09-02 16_14_40/best_test.pt',
                              'de': '../checkpoints/2021-09-13 23_38_55/best_test.pt',
                               }



        self.test_set_dict = {'es':'../data/CoNLL2002/es/test.txt',
                              'nl':'../data/CoNLL2002/nl/test.txt',
                              'de':'../data/CoNLL2002/de/IOB1/test.txt',}

        self.testset = self.test_set_dict[hp.tgt_lang]

        self.train_set_dict = {'es': '../data/CoNLL2002/es/train.txt',
                               'nl': '../data/CoNLL2002/nl/train.txt',
                               'de': '../data/CoNLL2002/de/IOB1/train.txt',
                               }

        self.trainset = self.train_set_dict[hp.tgt_lang]

        self.src_model_path = self.src_model_dict[hp.tgt_lang]
        self.src_state_dict = torch.load(self.src_model_path)
        self.trans_model_path = self.trans_model_dict[hp.tgt_lang]
        self.trans_state_dict = torch.load(self.trans_model_path)
        self.tea_model_path = self.tea_model_dict[hp.tgt_lang]
        self.tea_state_dict = torch.load(self.tea_model_path)

        # 构建一个网络结构 将模型参数加载到新模型中
        self.model_name = 'MyBertRnnCrfIsner'
        self.src_model = cpu_2_gpu(product_model(self.model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))
        self.src_model.load_state_dict(self.src_state_dict)

        self.trans_model = cpu_2_gpu(product_model(self.model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))
        self.trans_model.load_state_dict(self.trans_state_dict)

        self.tea_model = cpu_2_gpu(product_model(self.model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))
        self.tea_model.load_state_dict(self.tea_state_dict)

        self.stu_model = cpu_2_gpu(product_model(self.model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))

        self.model_name, self.optim_name, self.criterion_name = 'MyBertRnnCrfIsner', "Adam", "CrossEntropyLoss"
        # 固定部分网络层
        # freeze_layers = ['bert.encoder.layer.0.', 'bert.encoder.layer.1.', 'bert.encoder.layer.2.',
        #                  'bert.embeddings.']  # 'bert.'
        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True
        #     for ele in freeze_layers:
        #         if ele in name:
        #             param.requires_grad = False
        #             break
        #
        # for name, param in self.model.named_parameters():
        #     print(name, param.requires_grad)
        self.optimizer = product_optimizer(self.optim_name, self.stu_model.parameters(), hp.lr)
        self.hard_criterion = product_criterion('CrossEntropyLoss')
        self.soft_criterion = product_criterion('MSELoss')

        self.dataloader_name = "ConllIsentityLoader"
        self.testloader = self.init_cls_dataloader(self.dataloader_name, self.testset, {'shuffle': False})
        self.trainloader = self.init_cls_dataloader(self.dataloader_name, self.trainset, {'shuffle': True})
        self.validloader = self.init_cls_dataloader(self.dataloader_name, hp.valideset, {'shuffle': True})


    def train(self, src_model, trans_model, tea_model, stu_model, trainloader, tealoader, optimizer, hard_criterion, soft_criterion):
        stu_model.train()
        src_model.eval()
        trans_model.eval()
        tea_model.eval()
        loss_train = []
        start_t = time.time()

        for idx, batch in enumerate(zip(trainloader, tealoader)):
            words, x, is_heads, tags, y, seqlens, is_entity = batch[0]
            if(self.add_cluster):
                words, x, is_heads, tags, y_tea, seqlens, is_entity = batch[1]
                y_tea = cpu_2_gpu(y_tea)



            x, y, is_heads = cpu_2_gpu([x, y, is_heads])

            _y = y  # for monitoring
            optimizer.zero_grad()
            logits_src, prediction_src, embeds_src = src_model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)
            logits_trans, prediction_trans, embeds_trans = trans_model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)
            logits_tea, prediction_tea, embeds_tea = tea_model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)
            logits_stu, prediction_stu, embeds_stu = stu_model(x, is_heads, seqlens)  # logits: (B, S, Tag), y: (N, T)




            # print(words[0])
            # print([ix_to_tag[t] for head, t in zip(is_heads[0], _y.cpu().numpy().tolist()[0]) if head == 1][1:-1])
            # print([ix_to_tag[pred] for head, pred in zip(is_heads[0], prediction.cpu().numpy().tolist()[0]) if head == 1][1:-1])
            # crf包计算损失


            logits_src = logits_src.view(-1, logits_src.shape[-1])  # (N*T, VOCAB)
            logits_trans = logits_trans.view(-1, logits_trans.shape[-1])  # (N*T, VOCAB)
            logits_tea = logits_tea.view(-1, logits_tea.shape[-1])  # (N*T, VOCAB)
            logits_stu = logits_stu.view(-1, logits_stu.shape[-1])  # (N*T, VOCAB)
            prediction_src = prediction_src.view(-1)  # (N*T,)
            prediction_trans = prediction_trans.view(-1)  # (N*T,)
            prediction_tea = prediction_tea.view(-1)  # (N*T,)
            if(self.low_resource):
                y = y.view(-1)
            if(self.add_cluster):
                y_tea = y_tea.view(-1)
            is_heads = is_heads.view(-1)  # (N*T,)
            # logits_cons_list, _y_cons_list = [], []

            # len_word = int(is_heads[is_heads==1].size(0))
            #
            # logits_stu_cons = torch.zeros((len_word, logits_stu.shape[1]))
            # logits_stu_cons = cpu_2_gpu(logits_stu_cons)
            # l_idx = 0

            # logits_src_cons = logits_src[is_heads==1]
            # logits_trans_cons = logits_trans[is_heads==1]
            logits_tea_cons = logits_tea[is_heads==1]
            logits_stu_cons = logits_stu[is_heads==1]
            prediction_src_cons = prediction_src[is_heads==1]
            prediction_trans_cons = prediction_trans[is_heads==1]
            prediction_tea_cons = prediction_tea[is_heads==1]
            if(self.add_cluster):
                y_tea_cons = y_tea[is_heads==1]
                hard_index = (prediction_src_cons==prediction_trans_cons) & (prediction_trans_cons==prediction_tea_cons) & (prediction_tea_cons==y_tea_cons)
            elif(self.low_resource):
                y = y[is_heads == 1]
                hard_index = (prediction_src_cons != prediction_src_cons)
                if(idx% (100/self.batch_size) ==0):
                    hard_index[0] = True

                # if (idx % (100/self.batch_size) == 0):
                #     print(hard_index)
            else:
                hard_index = (prediction_src_cons==prediction_trans_cons) & (prediction_trans_cons==prediction_tea_cons) #& (prediction_tea_cons==y_tea_cons)

            # soft_index = ~hard_index

            # for h_idx, h in enumerate(is_heads.cpu().numpy().tolist()):
            #     if(h==1):
            #         # _y_cons_list.append(_y[h_idx])
            #         logits_cons[l_idx] = logits[h_idx]
            #         l_idx += 1
            # _y_cons = torch.tensor(_y_cons_list)
            # _y_cons = cpu_2_gpu(_y_cons)

            if (self.add_cluster):
                loss = hard_criterion(logits_stu_cons, y_tea_cons)
            elif(self.low_resource):
                if (idx % (100/self.batch_size) == 0):
                    # print(hard_index[:10])
                    loss_hard = hard_criterion(logits_stu_cons[hard_index], y[hard_index])  # crossentropy
                    # print(loss_hard)
                    # print(logits_stu_cons[hard_index], y[hard_index])

                else:
                    loss_hard = 0
                loss_soft = soft_criterion(logits_stu_cons, logits_tea_cons)
                loss = loss_hard + loss_soft
            else:
                loss_hard = hard_criterion(logits_stu_cons[hard_index], prediction_tea_cons[hard_index])  # crossentropy
                loss_soft = soft_criterion(logits_stu_cons, logits_tea_cons)
                loss = loss_hard + loss_soft

            loss_train.append(loss.item())
            loss.backward()
            optimizer.step()

            if idx % 100 == 0:  # monitoring
                # print(loss_hard.item(), loss_soft.item())
                self.logger.info(
                    f"STEP: {idx}\tLOSS={round(np.mean(loss_train), 11)}\t\ttime: {(time.time() - start_t) / 60}")

        return np.mean(loss_train)

    def cls_evaluation(self, model, validloader,  is_test=False):
        model.eval()
        Words, Is_heads, Tags, Y, Y_hat, Y_hat_List = [], [], [], [], [], []
        plot_embed_list, plot_tag_list, plot_words_list, plot_sents_list = [], [], [], []

        with torch.no_grad():
            for idx, batch in enumerate(validloader):
                words, x, is_heads, tags, y, seqlens, is_entity = batch
                x, y, is_heads = cpu_2_gpu([x, y, is_heads])
                _y = y
                # is_heads = is_heads.to(DEVICE)
                logits, prediction, embeds = model(x, is_heads, seqlens)
                Words.extend(words)
                Is_heads.extend(is_heads.cpu().numpy().tolist())
                Tags.extend(tags)
                Y_hat_List.extend(prediction.cpu().numpy().tolist())

                for t_i, pred, heads, embed, ws in zip(_y.cpu().numpy().tolist(), prediction.cpu().numpy().tolist(),
                                                       is_heads.cpu().numpy().tolist(), embeds.cpu().numpy().tolist(),
                                                       words):
                    Y.extend([t for head, t in zip(heads, t_i) if head == 1][1:-1])
                    pred = [0 if p == 9 else p for p in pred]
                    Y_hat.extend([p for head, p in zip(heads, pred) if head == 1][1:-1])
                    plot_embed_list.extend([e for head, e in zip(heads, embed) if head == 1][1:-1])
                    ws_list = ws.split()[1:-1]
                    plot_sents_list.extend([ws_list] * len(ws_list))
                    plot_words_list.extend(ws_list)
                    # print([ix_to_tag[t] for head, t in zip(heads, t_i) if head == 1][1:-1])
                    # print([ix_to_tag[pred] for head, pred in zip(heads, pred) if head == 1][1:-1])

        # self.logger.info(f"============seq classification Eval by sklearn_f1:============")
        # precision, recall, f1 = eval_F1([ix_to_tag[t_ix] for t_ix in Y], [ix_to_tag[t_ix] for t_ix in Y_hat],
        #                                 'sklearn_f1')
        # self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))

        self.logger.info(f"============seq classification Eval by conlleval.py:============")
        precision, recall, f1 = eval_F1([ix_to_tag[t_ix] for t_ix in Y], [ix_to_tag[t_ix] for t_ix in Y_hat],
                                        'conlleval')
        self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))

        suffix = "test_cls" if is_test else "valid_cls"
        self.write_cls_result(self.record_dir_dict['result_dir'] + suffix,
                         data=(Words, Tags, Y_hat_List, Is_heads),
                         ix_to_tag=ix_to_tag)


        return f1

    def train_epoch(self, i):
        # test set : unlabeled data: es、nl、de、ch...
        if(i==1):
            self.logger.info(f"=========================TEST SRC|TRANS|TEA AT EPOCH={i}==========================")
            self.record_result_dict['SRC_F1'] = self.cls_evaluation(self.src_model, self.testloader, True)
            self.record_result_dict['TRANS_F1'] = self.cls_evaluation(self.trans_model, self.testloader, True)
            self.record_result_dict['TEA_F1'] = self.cls_evaluation(self.tea_model, self.testloader, True)

        self.logger.info(f"=========================TEACHING AT EPOCH={i}==========================")
        t_loss = self.train(self.src_model, self.trans_model, self.tea_model, self.stu_model, self.trainloader,
                            self.tealoader, self.optimizer, self.hard_criterion, self.soft_criterion)
        self.record_result_dict['train_loss_list'].append(t_loss)

        # test set : unlabeled data: es、nl、de、ch...
        self.logger.info(f"=========================STU TEST AT EPOCH={i}==========================")
        f1 = self.cls_evaluation(self.stu_model, self.testloader, True)

        self.record_result_dict['test_F1_list'].append(f1)

        if (self.record_result_dict['TEST_MAX_F1'] < f1):
            self.record_result_dict['TEST_MAX_F1'], self.record_result_dict['TEST_MAX_EPOCH'] = f1, i
            torch.save(self.stu_model.state_dict(), self.record_dir_dict['chk_dir'] + 'best_test.pt')

        self.logger.info(f"SRC F1={self.record_result_dict['SRC_F1']}")
        self.logger.info(f"TRANS F1={self.record_result_dict['TRANS_F1']}")
        self.logger.info(f"TEA F1={self.record_result_dict['TEA_F1']}")
        self.logger.info(f"STU MAX F1={self.record_result_dict['TEST_MAX_F1']} AT EPOCH {self.record_result_dict['TEST_MAX_EPOCH']}\t")
        self.logger.info(f"STU MODEL STORED IN {self.record_dir_dict['chk_dir']}best_test.pt")

    def write_cls_result(self, resultfile, data, ix_to_tag):
        # ep: word real_tag pred_tag
        Words, Tags, Y_hat, Is_heads = data
        with open(resultfile, 'w') as fout:
            for words, tags, y_hat, is_heads in zip(Words, Tags, Y_hat, Is_heads):
                y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]
                preds = [ix_to_tag[hat] for hat in y_hat]
                assert len(preds) == len(words.split()) == len(tags.split()), print((preds), (words.split()), (tags.split()))
                for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):
                    fout.write(f"{w} {t} {p}\n")
                fout.write("\n")

    def record_result(self):
        pass

class MTST(object):
    def __init__(self, hp):
        # initialize the experiment
        self.identify_name = get_identify_name(hp)
        self.record_dirs = {
            'plot_dir': '../MassPlotPng/',
            'tboard_dir': '../MassLogBoard/',
            'result_dir': '../MassPredictResult/',
            'pt_dir': '../MassPT/',
            'logger_dir': '../MassLogInfo/',
            'chk_dir': '../checkpoints/'
        }
        self.record_dir_dict = create_all_record_dir(self.record_dirs, hp)
        self.record_result_dict = {
            'test_F1_n_clusters': [],
            'train_loss_list': [],
            'tgt_train_uncls_F1_list': [],
            'tgt_train_cls_F1_list': [],
            'dev_F1_sim_list': [],
            'test_F1_sim_list': [],
            'test_F1_main_list': [],
            'test_F1_assist_list': [],
            'test_F1_fine_list': [],
            'test_F1_vote_list': [],
            'test_F1_base_list': [],
            'VALID_MAX_PREC': 0,
            'VALID_MAX_REC': 0,
            'BASE_MAX_F1': 0,
            'VALID_MAX_EPOCH': 0,
            'TEST_MAX_PREC': 0,
            'TEST_MAX_REC': 0,
            'TEST_MAX_F1': 0,
            'TEST_MAX_EPOCH': 0,
            'TEST_Siamese_MAX_F1': 0,
            'TEST_Siamese_MAX_EPOCH': 0,
            'VALID_MAX_PREC_target': 0,
            'VALID_MAX_REC_target': 0,
            'VALID_MAX_F1_target': 0,
            'VALID_MAX_EPOCH_target': 0,
            'TEST_MAX_PREC_target': 0,
            'TEST_MAX_REC_target': 0,
            'TEST_MAX_F1_target': 0,
            'TEST_MAX_EPOCH_target': 0,
            'test_F1_knn_k': [],
            'test_F1_knn_max_sim': []
        }
        self.logger = get_logger('info', self.record_dir_dict['logger_dir'] + 'log')
        # self.tboard_writer = SummaryWriter(self.record_dir_dict['tboard_dir'])
        self.batch_size = hp.batch_size
        self.top_N = 400
        self.low_resource = hp.low_resource
        self.n_clusters = hp.n_clusters
        self.feature_dim = 768
        self.only_siamese = hp.only_siamese
        self.self_learn = hp.self_learn
        self.add_cluster = hp.add_cluster
        self.add_visual = hp.add_visual
        self.knn_k = hp.knn_k

        self.base_model_dict = {
                                'nl': '../checkpoints/2021-08-26 21_41_38/best_test.pt',  
                                'de': '../checkpoints/2021-08-28 13_21_29/best_test.pt',
                                'es': '../checkpoints/2021-08-31 10_35_48/best_test.pt',
                               
                                'ar': '../checkpoints/2021-10-27 15_15_37/best_test.pt', 
                                'hi': '../checkpoints/2021-10-28 11_30_28/best_test.pt',  
                                'zh': '../checkpoints/2021-10-28 11_30_40/best_test.pt',
                                }

        self.cls_model_dict = {'nl': '../checkpoints/2021-08-26 21_41_38/best_test.pt',
                                'de': '../checkpoints/2021-08-28 13_21_29/best_test.pt',
                                'es': '../checkpoints/2021-08-31 10_35_48/best_test.pt',

                                'ar': '../checkpoints/2021-10-27 15_15_37/best_test.pt',
                                'hi': '../checkpoints/2021-10-28 11_30_28/best_test.pt',
                                'zh': '../checkpoints/2021-10-28 11_30_40/best_test.pt',  }


        self.sim_model_dict = {'nl': '../checkpoints/2021-08-26 21_41_38/best_sim_fea.pt',
                                'de': '../checkpoints/2021-08-28 13_21_29/best_sim_fea.pt',
                                'es': '../checkpoints/2021-08-31 10_35_48/best_sim_fea.pt',

                                'ar': '../checkpoints/2021-10-27 15_15_37/best_sim_fea.pt',
                                'hi': '../checkpoints/2021-10-28 11_30_28/best_sim_fea.pt',
                                'zh': '../checkpoints/2021-10-28 11_30_40/best_sim_fea.pt',

                               }
        self.test_set_dict = {'es_new': '../data/CoNLL2002/es/test.txt',
                              'nl_new': '../data/CoNLL2002/nl/test.txt',
                              'de_new': '../data/CoNLL2002/de/test.txt',
                              'de_new': '../data/CoNLL2002/de/test.txt',
                              'ar_new': '../data/WikiAnn/ar/test.txt',
                              'hi_new': '../data/WikiAnn/hi/test.txt',
                              'zh_new': '../data/WikiAnn/zh/test.txt',}
        self.tgt_lang = hp.tgt_lang
        self.testset = self.test_set_dict[hp.tgt_lang]

        self.tgt_train_set_dict = {'es_new': '../data/CoNLL2002/es/train.txt',
                                   'nl_new': '../data/CoNLL2002/nl/train.txt',
                                   'de_new_IOB1': '../data/CoNLL2002/de/train.txt',
                                   'de_new': '../data/CoNLL2002/de/train.txt',
                                   'ar_new': '../data/WikiAnn/ar/train.txt',
                                   'hi_new': '../data/WikiAnn/hi/train.txt',
                                   'zh_new': '../data/WikiAnn/zh/train.txt',}

        self.target_trainset = '../data/CoNLL2002/en/train.txt'

        # self.base_model_path = self.base_model_dict[hp.tgt_lang]
        # self.base_state_dict = torch.load(self.base_model_path)

        self.sim_model_path = self.sim_model_dict[hp.tgt_lang]
        self.sim_state_dict = torch.load(self.sim_model_path)

        self.cls_model_path = self.unitrans_model_dict[hp.tgt_lang + hp.new_num]  # unitrans_model_dict  base_model_dict
        self.cls_state_dict = torch.load(self.cls_model_path)



        self.cls_model_name = 'MyBertRnnCrfIsner'
        self.cls_model = cpu_2_gpu(product_model(self.cls_model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))
        self.cls_model.load_state_dict(self.cls_state_dict)

        self.cls_model_stu = cpu_2_gpu(product_model(self.cls_model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))

        self.sim_model_name = 'MyBertSiameseCLNER'
        self.sim_model = cpu_2_gpu(
            (eval(self.sim_model_name)(hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM, self.feature_dim)))
        self.sim_model.load_state_dict(self.sim_state_dict)

        self.sim_model_stu = cpu_2_gpu(
            (eval(self.sim_model_name)(hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM, self.feature_dim)))
        # self.sim_model_stu.load_state_dict(self.sim_state_dict)
        self.alpha = hp.loss_alpha
        # freeze_layers = ["embeddings"] + ["layer." + str(layer_i) + "." for layer_i in range(12) if layer_i < 3]#['bert.encoder.layer.0.', 'bert.encoder.layer.1.', 'bert.encoder.layer.2.',
        #                  # 'bert.embeddings.']  # 'bert.'
        #
        # # freeze_layers = ["bert."]
        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True
        #     for ele in freeze_layers:
        #         if ele in name:
        #             param.requires_grad = False
        #             break
        # for name, param in self.model.named_parameters():
        #     print(name, param.requires_grad)
        self.dataloader_name = "ConllIsentityLoader"

        self.testloader = self.init_cls_dataloader(self.dataloader_name, self.testset, {'shuffle': False})
        linear_params = list(map(id, self.cls_model_stu.linear.parameters()))
        base_params = filter(lambda p: id(p) not in linear_params,
                             self.cls_model_stu.parameters())

        self.optimizer = optim.Adam([
            {'params': base_params},
            {'params': self.cls_model_stu.linear.parameters(), 'lr': hp.lr / 1000}], lr=hp.lr)


        self.weight_loss = hp.weight_loss
        if(self.weight_loss):
            self.criterion_BCE = nn.BCELoss(reduction='none')
            self.criterion_CE = nn.CrossEntropyLoss(reduction='none')
            self.criterion_MSE = nn.MSELoss(reduction='none')
        else:
            self.criterion_BCE = nn.BCELoss()
            self.criterion_CE = nn.CrossEntropyLoss()
            self.criterion_MSE = product_criterion('MSELoss')

        self.dataloader_name = 'SiameseLoader'

        self.trainset_params = hp.trainset_params
        self.testset_params =  hp.testset_params
        self.validloader_sim = self.init_dataloader(self.dataloader_name, hp.validset, self.product_param_dict(hp.validset_sim_params))
        self.testloader_sim = self.init_dataloader(self.dataloader_name, self.testset,
                                                   self.product_param_dict(self.testset_params), self.logger)
        self.trainloader_tgt = self.init_dataloader(self.dataloader_name, self.target_trainset, self.product_param_dict(self.trainset_params), self.logger)

    def train(self,  cls_model_stu, target_trainloader, i):
        # sim_model_stu.train()
        cls_model_stu.train()
        # sim_model.eval()
        # cls_model.eval()
        start_t = time.time()
        self.cossim = nn.CosineSimilarity(dim=1, eps=1e-6)

        loss_train, mseloss1_train, mseloss2_train, bceloss_train = [], [], [], []
        for idx, batch in enumerate(target_trainloader):
            # prepare input data
            words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
            # words_1, words_2 = words_1_2[0], words_1_2[1]
            seqlen_1, seqlen_2 = seqlen_1_2[0], seqlen_1_2[1]
            wordpiece_idx_1, wordpiece_idx_2 = wordpiece_idx_1_2[0], wordpiece_idx_1_2[1]
            tag_1, tag_2 = tag_1_2[0], tag_1_2[1]
            x_1, x_2, att_mask_1, att_mask_2, y = cpu_2_gpu(x_1_2 + att_mask_1_2 + [y])

            logits_1_stu, tag_pred_1_stu, embeds_cls_1_stu = cls_model_stu(x_1, att_mask_1, seqlen_1)
            logits_2_stu, tag_pred_2_stu, embeds_cls_2_stu = cls_model_stu(x_2, att_mask_2, seqlen_2)

            wordpiece_idx_1 = [wp_i + 1 for wp_i in wordpiece_idx_1]
            wordpiece_idx_2 = [wp_i + 1 for wp_i in wordpiece_idx_2]
            logits_1_stu = logits_1_stu[list(range(logits_1_stu.size()[0])),wordpiece_idx_1]
            logits_2_stu = logits_2_stu[list(range(logits_2_stu.size()[0])),wordpiece_idx_2]
            embeds_cls_1_stu = embeds_cls_1_stu[list(range(embeds_cls_1_stu.size()[0])), wordpiece_idx_1]
            embeds_cls_2_stu = embeds_cls_2_stu[list(range(embeds_cls_2_stu.size()[0])), wordpiece_idx_2]



            embeds_cls_1_stu, embeds_cls_2_stu = cpu_2_gpu([embeds_cls_1_stu, embeds_cls_2_stu])
            # prediction_stu = sim_model.similarity(embeds_cls_query_1_stu, embeds_cls_query_2_stu)
            prediction_stu = self.cossim(embeds_cls_1_stu, embeds_cls_2_stu)
            prediction_stu = torch.sigmoid(prediction_stu)


            self.optimizer.zero_grad()

            logits_1_stu, logits_2_stu, tag_1, tag_2, y = cpu_2_gpu([logits_1_stu,logits_2_stu,tag_1, tag_2,y])

            loss_ner_1 = self.criterion_CE(logits_1_stu.float(), tag_1)
            loss_ner_2 = self.criterion_CE(logits_2_stu.float(), tag_2)
            loss_sim = self.criterion_BCE(prediction_stu, y.float())

            loss = self.alpha*loss_sim + 1/2*loss_ner_1  + 1/2*loss_ner_2     #+ loss_ner_1   # #+ loss_ner_1   #+ loss_sim_2

            loss.backward()

            self.optimizer.step()
            # #
            # self.optimizer2.step()
            loss_train.append(loss.item())
            mseloss1_train.append(loss_ner_1.item())
            mseloss2_train.append(loss_ner_2.item())
            bceloss_train.append(loss_sim.item())
            if idx % 1000 == 0:  # monitoring
                self.logger.info(
                    f"STEP: {idx}\tLOSS={round(np.mean(loss_train), 11)}\tMSELOSS1={round(np.mean(mseloss1_train), 11)}\tMSELOSS2={round(np.mean(mseloss2_train), 11)}\tBCELOSS={round(np.mean(bceloss_train), 11)}\t\ttime: {(time.time() - start_t) / 60}")

        return np.mean(loss_train)

 
    def evaluation(self, cls_model, target_trainloader, i):
        cls_model.eval()
        start_t = time.time()
        Y, Y_HAT, Embeds_1 = [], [], []
        Conf_list = []
        Words,Word_Idx_list = [],[]
        pred_list, tag_list, embed_list_cls, embed_list_cls_stu, embed_list_sim = [], [], [], [], []
        embed_list_cls, embed_list_sim, psd_tag_list, real_tag_list, plot_word_list, plot_sent_list = [], [], [], [], [], []
        with torch.no_grad():
            for idx, batch in enumerate(target_trainloader):
                # prepare input data
                words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
                words_1, words_2 = words_1_2[0], words_1_2[1]


                seqlen_1, seqlen_2 = seqlen_1_2[0], seqlen_1_2[1]
                word_idx_1, word_idx_2 = word_idx_1_2[0], word_idx_1_2[1]
                wordpiece_idx_1, wordpiece_idx_2 = wordpiece_idx_1_2[0], wordpiece_idx_1_2[1]

                Words.extend(words_1)
                Word_Idx_list.extend(word_idx_1)

                tag_1, tag_2 = tag_1_2[0], tag_1_2[1]
                x_1, x_2, att_mask_1, att_mask_2 = cpu_2_gpu(x_1_2 + att_mask_1_2)

                # model output data
                logits_1, tag_pred_1, embeds_cls_1_stu = cls_model(x_1, att_mask_1, seqlen_1)

                logits_stu_1 = torch.zeros((logits_1.shape[0], logits_1.shape[-1]))
                # logits_stu_1 = cpu_2_gpu(logits_stu_1)
                logits_1_stu = [l_i[wp_idx + 1] for l_i, wp_idx in zip(logits_1, wordpiece_idx_1)]
                for k in range(logits_stu_1.shape[0]):
                    logits_stu_1[k] = logits_1_stu[k]
                logits_stu_1 = torch.FloatTensor(logits_stu_1)
                softmax = nn.Softmax(dim=-1)
                logits_stu_1 = softmax(logits_stu_1)

                confidence = torch.max(logits_stu_1.detach(), dim=1).values.numpy().tolist()
                Conf_list.extend(confidence)

                embeds_cls_1_stu = [emd[w_idx + 1] for emd, w_idx in
                                    zip(embeds_cls_1_stu.cpu().numpy().tolist(), wordpiece_idx_1)]
                embed_list_cls_stu.extend(embeds_cls_1_stu)


                _, _, embeds_cls = self.cls_model(x_1, att_mask_1, seqlen_1)
                _, _, _, _, _, embeds_1_sim, _ = self.sim_model([x_1, x_2],
                                                                wordpiece_idx_1_2)
                # logits_1_stu = self.cls_model.linear(embeds_1_sim)
                # tag_pred_1 = logits_1_stu.argmax(-1)

                embeds_cls = [emd[w_idx + 1] for emd, w_idx in
                              zip(embeds_cls.cpu().numpy().tolist(), wordpiece_idx_1)]
                embed_list_cls.extend(embeds_cls)

                embed_list_sim.extend(embeds_1_sim.cpu().numpy().tolist())


                tag_pred_1 = [t_i[wp_idx + 1] for t_i, wp_idx in
                              zip(tag_pred_1.cpu().numpy().tolist(), wordpiece_idx_1)]
                #
                tag_pred_1 = [ix_to_tag[0] if p == 9 else ix_to_tag[p] for p in tag_pred_1]

                pred_list.extend([tag_to_ix[t] for t in tag_pred_1])
                # pred_list.extend([t.item() for t in tag_pred_1])
                tag_list.extend([t.item() for t in tag_1])

                # prediction, _, _, _, _, embeds_1, _ = sim_model([x_1, x_2], wordpiece_idx_1_2)
                # embed_list_cls.extend(embeds_1.cpu().numpy().tolist())


                Y.extend([ix_to_tag[t.item()] for t in tag_1])
                Y_HAT.extend(tag_pred_1)
                if idx % 1000 == 0:  # monitoring
                    self.logger.info(f"TEST STEP: {idx}\t\tTIME: {(time.time() - start_t) / 60}")


        self.logger.info(f"============Eval by conlleval:============")
        precision, recall, f1 = eval_F1(Y, Y_HAT, 'conlleval')
        self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))

        plot_num = 50
        tag_count = {tag: plot_num for tag in ALL_TAGS_}

        embed_list_use_stu, embed_list_use_cls, embed_list_use_sim, tag_list_use, pred_list_use = [], [], [], [], []
        for embed_stu_temp, embed_cls_temp, embed_sim_temp, y_temp, pred_temp in zip(embed_list_cls_stu, embed_list_cls, embed_list_sim, tag_list, pred_list):

            if (tag_count[ix_to_tag[y_temp]] > 0):
                tag_count[ix_to_tag[y_temp]] = tag_count[ix_to_tag[y_temp]] - 1
                embed_list_use_stu.append(embed_stu_temp)
                embed_list_use_cls.append(embed_cls_temp)
                embed_list_use_sim.append(embed_sim_temp)
                tag_list_use.append(y_temp)
                pred_list_use.append(pred_temp)



        if(i==1):
            write_plot(['Y', 'Y_HAT', 'Conf_list'], [Y, Y_HAT, Conf_list], self.record_dir_dict['plot_dir'] + 'cls_weight_test.json', self.logger)
            write_plot(['embed_list_use_cls', 'tag_list_use'], [embed_list_use_cls, tag_list_use], self.record_dir_dict['plot_dir'] + 'cls_tea.json', self.logger)
            write_plot(['embed_list_use_sim', 'tag_list_use'], [embed_list_use_sim, tag_list_use], self.record_dir_dict['plot_dir'] + 'sim_tea.json', self.logger)

            self.plot_weight_analyse(Y, Y_HAT, Conf_list, self.record_dir_dict['plot_dir'] + 'cls_weight_test')
            # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_cls, tag_list_use, self.record_dir_dict['plot_dir'] + 'cls_tea_' + str(i), 'cls tea '+str(i))
            # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_sim, tag_list_use, self.record_dir_dict['plot_dir'] + 'sim_tea_' + str(i), 'sim tea '+str(i))
        write_plot(['embed_list_use_stu', 'tag_list_use'], [embed_list_use_stu, tag_list_use],
                   self.record_dir_dict['plot_dir'] + 'stu_' + str(i)+'.json', self.logger)
        # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_stu, tag_list_use, self.record_dir_dict['plot_dir'] + 'stu_' + str(i), 'stu '+str(i))

        suffix = "tea_" if i==1 else "stu_"
        self.write_ner_result_word_tag_pred(self.record_dir_dict['result_dir'] + suffix + str(i),
                              data=(Words, Y, Y_HAT, Word_Idx_list))
        self.logger.info(f"predict result stored in {self.record_dir_dict['result_dir'] + suffix + str(i)}")

        return f1


    def train_epoch_param(self, i):
        if(i==1):
            f1 = self.cls_evaluation(self.cls_model_stu, self.testloader_sim, i)


        # predict pseudo label
        self.logger.info(f"======================FINE TUNE BY JOINT LOSS={i}=====================")
        loss = self.train(self.cls_model_stu, self.trainloader_tgt, i)
        self.logger.info(f"======================TEST BY JOINT LOSS={i}=====================")
        f1 = self.evaluation(self.cls_model_stu, self.validloader_sim, i)



        # f1 = self.cls_evaluation(self.cls_model_stu,  self.testloader_sim, i)
        if(f1>self.record_result_dict['VALID_MAX_F1']):
            self.record_result_dict['VALID_MAX_F1'] = f1
            torch.save(self.cls_model_stu.state_dict(), self.record_dir_dict['chk_dir'] + 'best_cls_test.pt')


        f1 = self.evaluation(self.cls_model_stu, self.testloader_sim, i)
        if (f1 > self.record_result_dict['TEST_MAX_F1']):
            self.record_result_dict['TEST_MAX_F1'] = f1
        self.record_result_dict['tgt_train_cls_F1_list'].append(f1)

        self.logger.info(f"Teacher: {self.cls_model_path}")
        self.logger.info(f"Teacher F1: {self.record_result_dict['BASE_MAX_F1']}")

        self.logger.info(
            f"Best CLUSTER_F1: BEST EPOCH_NUM={np.argmax(self.record_result_dict['tgt_train_cls_F1_list']) + 1}\t F1={np.max(self.record_result_dict['tgt_train_cls_F1_list'])}")
        self.logger.info(f"MODEL SAVED IN: {self.record_dir_dict['chk_dir']}best_cls_test.pt & best_sim_test.pt")

    def write_ner_result(self, resultfile, data, ix_to_tag):
        # ep: word real_tag pred_tag
        Words, Word_Idx, Y_list, Tag_Pred_NER_list = data
        with open(resultfile + '.txt', 'w') as fout:
            for words_batch, word_idx_batch, tags_batch, tag_ner_batch \
                    in zip(Words, Word_Idx, Y_list, Tag_Pred_NER_list):
                for k, words in enumerate(words_batch):
                    assert (len(words[1:-1]) > word_idx_batch[
                        k]), f"sentence length: {len(words[1:-1])}, word idx: {word_idx_batch[k]}"
                    t = tags_batch[k]
                    p_ner = tag_ner_batch[k]
                    w_i = word_idx_batch[k]
                    if (w_i == 0):
                        fout.write("\n")
                    # print(f"{words[1:-1]} {[w_i]} \n")  # {ix_to_tag} {[t]} {[p]}
                    fout.write(
                        f"{words[1:-1][w_i]}\t\t{ix_to_tag[t]}\t\t{ix_to_tag[p_ner]}\n")

    def write_ner_result_word_tag_pred(self, resultfile, data):
        # ep: word real_tag pred_tag
        Words, Y_list, Tag_Pred_list, Word_Idx_list = data
        with open(resultfile + '.txt', 'w') as fout:
            for word, tag, pred, w_idx in zip(Words, Y_list, Tag_Pred_list, Word_Idx_list):
                if (w_idx == 0):
                    fout.write(f"\n")
                fout.write(f"{word[1:-1][w_idx]}\t\t{tag}\t\t{pred}\n")

    def train_epoch(self, i):
        self.train_epoch_param(i)

    def record_result(self):
        pass                #   #

class MTMT(object):
    def __init__(self, hp):
        # initialize the experiment
        self.identify_name = get_identify_name(hp)
        self.record_dirs = {
            'plot_dir': '../MassPlotPng/',
            'tboard_dir': '../MassLogBoard/',
            'result_dir': '../MassPredictResult/',
            'pt_dir': '../MassPT/',
            'logger_dir': '../MassLogInfo/',
            'chk_dir': '../checkpoints/'
        }
        self.record_dir_dict = create_all_record_dir(self.record_dirs, hp)
        self.record_result_dict = {
            'test_F1_n_clusters': [],
            'train_loss_list': [],
            'tgt_train_uncls_F1_list': [],
            'tgt_train_cls_F1_list': [],
            'dev_F1_sim_list': [],
            'test_F1_sim_list': [],
            'test_F1_main_list': [],
            'test_F1_assist_list': [],
            'test_F1_fine_list': [],
            'test_F1_vote_list': [],
            'test_F1_base_list': [],
            'VALID_MAX_PREC': 0,
            'VALID_MAX_REC': 0,
            'BASE_MAX_F1': 0,
            'VALID_MAX_EPOCH': 0,
            'TEST_MAX_PREC': 0,
            'TEST_MAX_REC': 0,
            'TEST_MAX_F1': 0,
            'TEST_MAX_EPOCH': 0,
            'TEST_Siamese_MAX_F1': 0,
            'TEST_Siamese_MAX_EPOCH': 0,
            'VALID_MAX_PREC_target': 0,
            'VALID_MAX_REC_target': 0,
            'VALID_MAX_F1_target': 0,
            'VALID_MAX_EPOCH_target': 0,
            'TEST_MAX_PREC_target': 0,
            'TEST_MAX_REC_target': 0,
            'TEST_MAX_F1_target': 0,
            'TEST_MAX_EPOCH_target': 0,
            'test_F1_knn_k': [],
            'test_F1_knn_max_sim': []
        }
        self.logger = get_logger('info', self.record_dir_dict['logger_dir'] + 'log')
        # self.tboard_writer = SummaryWriter(self.record_dir_dict['tboard_dir'])
        self.batch_size = hp.batch_size
        self.top_N = 400
        self.low_resource = hp.low_resource
        self.n_clusters = hp.n_clusters
        self.feature_dim = 768
        self.only_siamese = hp.only_siamese
        self.self_learn = hp.self_learn
        self.add_cluster = hp.add_cluster
        self.add_visual = hp.add_visual
        self.knn_k = hp.knn_k

        self.base_model_dict = {
            'nl': '../checkpoints/2021-08-26 21_41_38/best_test.pt',
            'de': '../checkpoints/2021-08-28 13_21_29/best_test.pt',
            'es': '../checkpoints/2021-08-31 10_35_48/best_test.pt',

            'ar': '../checkpoints/2021-10-27 15_15_37/best_test.pt',
            'hi': '../checkpoints/2021-10-28 11_30_28/best_test.pt',
            'zh': '../checkpoints/2021-10-28 11_30_40/best_test.pt',
        }

        self.cls_model_dict = {'nl': '../checkpoints/2021-08-26 21_41_38/best_test.pt',
                               'de': '../checkpoints/2021-08-28 13_21_29/best_test.pt',
                               'es': '../checkpoints/2021-08-31 10_35_48/best_test.pt',

                               'ar': '../checkpoints/2021-10-27 15_15_37/best_test.pt',
                               'hi': '../checkpoints/2021-10-28 11_30_28/best_test.pt',
                               'zh': '../checkpoints/2021-10-28 11_30_40/best_test.pt', }

        self.sim_model_dict = {'nl': '../checkpoints/2021-08-26 21_41_38/best_sim_fea.pt',
                               'de': '../checkpoints/2021-08-28 13_21_29/best_sim_fea.pt',
                               'es': '../checkpoints/2021-08-31 10_35_48/best_sim_fea.pt',

                               'ar': '../checkpoints/2021-10-27 15_15_37/best_sim_fea.pt',
                               'hi': '../checkpoints/2021-10-28 11_30_28/best_sim_fea.pt',
                               'zh': '../checkpoints/2021-10-28 11_30_40/best_sim_fea.pt',

                               }
        self.test_set_dict = {'es_new': '../data/CoNLL2002/es/test.txt',
                              'nl_new': '../data/CoNLL2002/nl/test.txt',
                              'de_new': '../data/CoNLL2002/de/test.txt',
                              'de_new': '../data/CoNLL2002/de/test.txt',
                              'ar_new': '../data/WikiAnn/ar/test.txt',
                              'hi_new': '../data/WikiAnn/hi/test.txt',
                              'zh_new': '../data/WikiAnn/zh/test.txt', }
        self.tgt_lang = hp.tgt_lang
        self.testset = self.test_set_dict[hp.tgt_lang]

        self.tgt_train_set_dict = {'es_new': '../data/CoNLL2002/es/train.txt',
                                   'nl_new': '../data/CoNLL2002/nl/train.txt',
                                   'de_new_IOB1': '../data/CoNLL2002/de/train.txt',
                                   'de_new': '../data/CoNLL2002/de/train.txt',
                                   'ar_new': '../data/WikiAnn/ar/train.txt',
                                   'hi_new': '../data/WikiAnn/hi/train.txt',
                                   'zh_new': '../data/WikiAnn/zh/train.txt', }

        self.target_trainset = '../data/CoNLL2002/en/train.txt'

        # self.base_model_path = self.base_model_dict[hp.tgt_lang]
        # self.base_state_dict = torch.load(self.base_model_path)

        self.sim_model_path = self.sim_model_dict[hp.tgt_lang]
        self.sim_state_dict = torch.load(self.sim_model_path)

        self.cls_model_path = self.unitrans_model_dict[hp.tgt_lang + hp.new_num]  # unitrans_model_dict  base_model_dict
        self.cls_state_dict = torch.load(self.cls_model_path)

        self.cls_model_name = 'MyBertRnnCrfIsner'
        self.cls_model = cpu_2_gpu(product_model(self.cls_model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))
        self.cls_model.load_state_dict(self.cls_state_dict)

        self.cls_model_stu = cpu_2_gpu(
            product_model(self.cls_model_name, hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM))

        self.sim_model_name = 'MyBertSiameseCLNER'
        self.sim_model = cpu_2_gpu(
            (eval(self.sim_model_name)(hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM, self.feature_dim)))
        self.sim_model.load_state_dict(self.sim_state_dict)

        self.sim_model_stu = cpu_2_gpu(
            (eval(self.sim_model_name)(hp.batch_size, hp.LEN_ALL_TAGS, hp.HIDDEN_DIM, self.feature_dim)))
        # self.sim_model_stu.load_state_dict(self.sim_state_dict)
        self.alpha = hp.loss_alpha
        # freeze_layers = ["embeddings"] + ["layer." + str(layer_i) + "." for layer_i in range(12) if layer_i < 3]#['bert.encoder.layer.0.', 'bert.encoder.layer.1.', 'bert.encoder.layer.2.',
        #                  # 'bert.embeddings.']  # 'bert.'
        #
        # # freeze_layers = ["bert."]
        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True
        #     for ele in freeze_layers:
        #         if ele in name:
        #             param.requires_grad = False
        #             break
        # for name, param in self.model.named_parameters():
        #     print(name, param.requires_grad)
        self.dataloader_name = "ConllIsentityLoader"

        self.testloader = self.init_cls_dataloader(self.dataloader_name, self.testset, {'shuffle': False})
        linear_params = list(map(id, self.cls_model_stu.linear.parameters()))
        base_params = filter(lambda p: id(p) not in linear_params,
                             self.cls_model_stu.parameters())

        self.optimizer = optim.Adam([
            {'params': base_params},
            {'params': self.cls_model_stu.linear.parameters(), 'lr': hp.lr / 1000}], lr=hp.lr)

        self.weight_loss = hp.weight_loss
        if (self.weight_loss):
            self.criterion_BCE = nn.BCELoss(reduction='none')
            self.criterion_CE = nn.CrossEntropyLoss(reduction='none')
            self.criterion_MSE = nn.MSELoss(reduction='none')
        else:
            self.criterion_BCE = nn.BCELoss()
            self.criterion_CE = nn.CrossEntropyLoss()
            self.criterion_MSE = product_criterion('MSELoss')

        self.dataloader_name = 'SiameseLoader'

        self.trainset_params = hp.trainset_params
        self.testset_params = hp.testset_params
        self.validloader_sim = self.init_dataloader(self.dataloader_name, hp.validset,
                                                    self.product_param_dict(hp.validset_sim_params))
        self.testloader_sim = self.init_dataloader(self.dataloader_name, self.testset,
                                                   self.product_param_dict(self.testset_params), self.logger)
        self.trainloader_tgt = self.init_dataloader(self.dataloader_name, self.target_trainset,
                                                    self.product_param_dict(self.trainset_params), self.logger)

    def train(self, sim_model, cls_model, sim_model_stu, cls_model_stu, target_trainloader, i):
        # sim_model_stu.train()
        cls_model_stu.train()
        sim_model.eval()
        cls_model.eval()
        start_t = time.time()
        loss_train, mseloss1_train, mseloss2_train, bceloss_train = [], [], [], []
        Conf_list, Sim_list, Gamma_list, Y, Y_hat, Y_sim, Y_sim_hat = [], [], [], [], [], [], []
        Y1_list, P1_list, Y2_list, P2_list = [], [], [], []
        embed_list_cls, embed_list_sim, psd_tag_list, real_tag_list, plot_word_list, plot_sent_list = [], [], [], [], [], []
        for idx, batch in enumerate(target_trainloader):
            # prepare input data
            words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
            # words_1, words_2 = words_1_2[0], words_1_2[1]
            seqlen_1, seqlen_2 = seqlen_1_2[0], seqlen_1_2[1]
            wordpiece_idx_1, wordpiece_idx_2 = wordpiece_idx_1_2[0], wordpiece_idx_1_2[1]
            tag_1, tag_2 = tag_1_2[0], tag_1_2[1]
            x_1, x_2, att_mask_1, att_mask_2, y = cpu_2_gpu(x_1_2 + att_mask_1_2 + [y])
            # tea model output data
            logits_1, tag_pred_1, embeds_cls_1 = cls_model(x_1, att_mask_1, seqlen_1)
            logits_2, tag_pred_2, embeds_cls_2 = cls_model(x_2, att_mask_2, seqlen_2)

            # prediction, _, _, _, _, embeds_1, _ = sim_model([x_1, x_2],
            #                                                 wordpiece_idx_1_2)
            prediction, _, _, _, _, embeds_cls_1_stu, embeds_cls_2_stu = sim_model([x_1, x_2],
                                                            wordpiece_idx_1_2)
            # stu model output data
            logits_1_stu, tag_pred_1_stu, embeds_cls_1_stu = cls_model_stu(x_1, att_mask_1, seqlen_1)
            logits_2_stu, tag_pred_2_stu, embeds_cls_2_stu = cls_model_stu(x_2, att_mask_2, seqlen_2)

            logits_1 = self.softmax(logits_1)
            logits_2 = self.softmax(logits_2)
            logits_1_stu = self.softmax(logits_1_stu)
            logits_2_stu = self.softmax(logits_2_stu)
            wordpiece_idx_1 = [wp_i+1 for wp_i in wordpiece_idx_1]
            wordpiece_idx_2 = [wp_i+1 for wp_i in wordpiece_idx_2]
            logits_tea_1 = logits_1[list(range(logits_1.size()[0])), wordpiece_idx_1]
            logits_tea_2 = logits_2[list(range(logits_2.size()[0])), wordpiece_idx_2]
            tag_pred_w1 = tag_pred_1[list(range(tag_pred_1.size()[0])), wordpiece_idx_1]
            tag_pred_w2 = tag_pred_2[list(range(tag_pred_2.size()[0])), wordpiece_idx_2]

            logits_stu_1 = logits_1_stu[list(range(logits_1_stu.size()[0])), wordpiece_idx_1]
            logits_stu_2 = logits_2_stu[list(range(logits_2_stu.size()[0])), wordpiece_idx_2]
            embeds_cls_1_stu = embeds_cls_1_stu[list(range(embeds_cls_1_stu.size()[0])), wordpiece_idx_1]
            embeds_cls_2_stu = embeds_cls_2_stu[list(range(embeds_cls_2_stu.size()[0])), wordpiece_idx_2]

            embeds_cls_1_stu, embeds_cls_2_stu = cpu_2_gpu([embeds_cls_1_stu, embeds_cls_2_stu])
            prediction_stu = sim_model.similarity(embeds_cls_1_stu, embeds_cls_2_stu)

            # embeds_sim_1 = sim_model.siamese(x_1, wordpiece_idx_1)
            # prediction_stu[tag_pred_w1 == tag_pred_w2] = prediction_stu_one[tag_pred_w1 == tag_pred_w2]

            self.optimizer.zero_grad()
            # self.optimizer2.zero_grad()
            # logits_tea_1 = torch.FloatTensor(logits_tea_1)
            # logits_tea_2 = torch.FloatTensor(logits_tea_2)
            # logits_stu_1 = torch.FloatTensor(logits_stu_1)
            # logits_stu_2 = torch.FloatTensor(logits_stu_2)
            logits_stu_1, logits_stu_2, logits_tea_1, logits_tea_2, tag_pred_w1, tag_pred_w2, prediction_stu = cpu_2_gpu([logits_stu_1, logits_stu_2, logits_tea_1, logits_tea_2, tag_pred_w1, tag_pred_w2, prediction_stu])
            # tag_1 = torch.FloatTensor(tag_1)

            # loss_ner_1 = self.criterion_CE(logits_stu_1.float(), tag_pred_w1.long())
            # loss_ner_2 = self.criterion_CE(logits_stu_2.float(), tag_pred_w2.long())
            loss_ner_1 = self.criterion_MSE(logits_stu_1.float(), logits_tea_1.float())
            loss_ner_2 = self.criterion_MSE(logits_stu_2.float(), logits_tea_2.float())

            # byte = prediction
            # byte[prediction>=0.5] = prediction[prediction>=0.5]
            # byte[prediction<0.5] = prediction[prediction<0.5] * (-1)
            # logits_stu_3 = logits_stu_2

            # prediction_logits = sim_model.similarity(logits_stu_1.float(), logits_stu_2.float())
            #
            # loss_sim = self.criterion_BCE(prediction_logits, prediction.detach())

            # print(byte.shape, F.pairwise_distance(logits_stu_1.float(),logits_stu_3.float(),p=2).shape)
            # loss_ner_3 = F.pairwise_distance(logits_stu_1.float(), logits_stu_3.detach().float(),p=2).mul(byte).sum()

            loss_sim = self.criterion_BCE(prediction_stu, prediction.detach())
            # alpha = 100
            if (self.weight_loss):
                loss_ner_1 = torch.sum(loss_ner_1, dim=-1)
                loss_ner_2 = torch.sum(loss_ner_2, dim=-1)

                confidence1 = torch.max(logits_tea_1.detach(), dim=1).values
                weight_cls1 = confidence1 ** 2
                confidence2 = torch.max(logits_tea_2.detach(), dim=1).values
                weight_cls2 = confidence2 ** 2

                if (i == 1):
                    # Conf_list.extend(confidence.cpu().numpy().tolist())
                    Sim_list.extend(prediction.detach().cpu().numpy().tolist())
                    # Y.extend([ix_to_tag[t.item()] for t in tag_1])
                    # tag_pred_1 = [ix_to_tag[t_i[wp_idx + 1]] for t_i, wp_idx in
                    #               zip(tag_pred_1.cpu().numpy().tolist(), wordpiece_idx_1)]
                    # Y_hat.extend(tag_pred_1)

                Y_sim.extend(y.detach().cpu().numpy().tolist())

                # print(prediction.detach().cpu().numpy().tolist())
                Y_sim_hat.extend(prediction.detach().cpu().numpy().tolist())


                one_like_weight = torch.ones_like(prediction.detach())
                weight_sim = (prediction.detach()*2 -one_like_weight)**2 + 0.5
                weight_sim = torch.where(weight_sim > 1, one_like_weight, weight_sim)

                prediction_sim = self.cossim(logits_tea_1.detach(), logits_tea_2.detach())
                # print(logits_tea_1.size(), logits_tea_2.size())
                # prediction_sim = self.sim_model.similarity(logits_tea_1.detach(), logits_tea_2.detach())
                prediction_sim = torch.sigmoid(prediction_sim)
                weight_L = one_like_weight - torch.abs(prediction_sim - prediction.detach())

                if (i == 1):
                    Gamma_list.extend(weight_L.cpu().numpy().tolist())
                    Y1_list.extend([ix_to_tag[t.item()] for t in tag_1])
                    P1_list.extend([ix_to_tag[0] if p == 9 else ix_to_tag[p] for p in tag_pred_w1.cpu().numpy().tolist()])
                    Y2_list.extend([ix_to_tag[t.item()] for t in tag_2])
                    P2_list.extend([ix_to_tag[0] if p == 9 else ix_to_tag[p] for p in tag_pred_w2.cpu().numpy().tolist()])

                # self.logger.info(f"{weight_cls1},{weight_cls2},{weight_sim}")
                # self.logger.info(f"{torch.mul(weight_L, weight_cls1)},{torch.mul(weight_L, weight_cls2)},{torch.mul(weight_L, weight_sim)},{weight_L}")

                loss_ner_1 = torch.mean(torch.mul(torch.mul(weight_L, weight_cls1), loss_ner_1))
                loss_ner_2 = torch.mean(torch.mul(torch.mul(weight_L, weight_cls2), loss_ner_2))
                loss_sim = torch.mean(torch.mul(torch.mul(weight_L, weight_sim), loss_sim))

                # self.logger.info(f"{loss_ner_1.item()},{loss_ner_2.item()},{loss_sim.item()},{weight_L}")

            loss = self.alpha*loss_sim + 1/2* loss_ner_1  + 1/2* loss_ner_2   #+ loss_ner_1   # #+ loss_ner_1   #+ loss_sim_2
            # loss = alpha*loss_ner_1 + (1-alpha)*loss_sim
            # loss_sim.backward()
            # loss = loss_ner_1 + loss_ner_2

            loss.backward()

            # for name, parms in cls_model_stu.named_parameters():
            #     if idx == 0 and ("layer.11." in name or "linear" in name):
            #         self.logger.info(f'-->name:{name} -->grad_requirs:{parms.requires_grad}-->grad_value:{parms.grad}')


            self.optimizer.step()
            # #
            # self.optimizer2.step()
            loss_train.append(loss.item())
            mseloss1_train.append(loss_ner_1.item())
            mseloss2_train.append(loss_ner_2.item())
            bceloss_train.append(loss_sim.item())
            if idx % 1000 == 0:  # monitoring
                self.logger.info(
                    f"STEP: {idx}\tLOSS={round(np.mean(loss_train), 11)}\tMSELOSS1={round(np.mean(mseloss1_train), 11)}\tMSELOSS2={round(np.mean(mseloss2_train), 11)}\tBCELOSS={round(np.mean(bceloss_train), 11)}\t\ttime: {(time.time() - start_t) / 60}")
        if (self.weight_loss and i == 1):
            # self.plot_weight_analyse(Y, Y_hat, Conf_list, self.record_dir_dict['plot_dir'] + 'cls_weight')
            Y_sim_hat = [1 if y_pred >= 0.5 else 0 for y_pred in Y_sim_hat]
            write_plot(['Y_sim', 'Y_sim_hat', 'Sim_list'], [Y_sim, Y_sim_hat, Sim_list], self.record_dir_dict['plot_dir'] + 'sim_weight.json', self.logger)
            write_plot(['Y_sim', 'Y_sim_hat', 'Gamma_list','Y1_list', 'P1_list', 'Y2_list', 'P2_list'], [Y_sim, Y_sim_hat, Gamma_list, Y1_list, P1_list, Y2_list, P2_list], self.record_dir_dict['plot_dir'] + 'gamma_weight.json', self.logger)
            self.plot_weight_analyse(Y_sim, Y_sim_hat, Sim_list, self.record_dir_dict['plot_dir'] + 'sim_weight', 'sim')

        # self.logger.info(f"============Pseudo Trainset Eval by conlleval:============")
        # precision, recall, f1 = eval_F1(Y, Y_HAT, 'conlleval')
        # self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))
        return np.mean(loss_train)

    def evaluation(self, cls_model, target_trainloader, i):
        cls_model.eval()
        start_t = time.time()
        Y, Y_HAT, Embeds_1 = [], [], []
        Conf_list = []
        Words, Word_Idx_list = [], []
        pred_list, tag_list, embed_list_cls, embed_list_cls_stu, embed_list_sim = [], [], [], [], []
        embed_list_cls, embed_list_sim, psd_tag_list, real_tag_list, plot_word_list, plot_sent_list = [], [], [], [], [], []
        with torch.no_grad():
            for idx, batch in enumerate(target_trainloader):
                # prepare input data
                words_1_2, wordpiece_idx_1_2, tag_1_2, x_1_2, y, att_mask_1_2, seqlen_1_2, word_idx_1_2 = batch
                words_1, words_2 = words_1_2[0], words_1_2[1]

                seqlen_1, seqlen_2 = seqlen_1_2[0], seqlen_1_2[1]
                word_idx_1, word_idx_2 = word_idx_1_2[0], word_idx_1_2[1]
                wordpiece_idx_1, wordpiece_idx_2 = wordpiece_idx_1_2[0], wordpiece_idx_1_2[1]

                Words.extend(words_1)
                Word_Idx_list.extend(word_idx_1)

                tag_1, tag_2 = tag_1_2[0], tag_1_2[1]
                x_1, x_2, att_mask_1, att_mask_2 = cpu_2_gpu(x_1_2 + att_mask_1_2)

                # model output data
                logits_1, tag_pred_1, embeds_cls_1_stu = cls_model(x_1, att_mask_1, seqlen_1)

                logits_stu_1 = torch.zeros((logits_1.shape[0], logits_1.shape[-1]))
                # logits_stu_1 = cpu_2_gpu(logits_stu_1)
                logits_1_stu = [l_i[wp_idx + 1] for l_i, wp_idx in zip(logits_1, wordpiece_idx_1)]
                for k in range(logits_stu_1.shape[0]):
                    logits_stu_1[k] = logits_1_stu[k]
                logits_stu_1 = torch.FloatTensor(logits_stu_1)
                softmax = nn.Softmax(dim=-1)
                logits_stu_1 = softmax(logits_stu_1)

                confidence = torch.max(logits_stu_1.detach(), dim=1).values.numpy().tolist()
                Conf_list.extend(confidence)

                embeds_cls_1_stu = [emd[w_idx + 1] for emd, w_idx in
                                    zip(embeds_cls_1_stu.cpu().numpy().tolist(), wordpiece_idx_1)]
                embed_list_cls_stu.extend(embeds_cls_1_stu)

                _, _, embeds_cls = self.cls_model(x_1, att_mask_1, seqlen_1)
                _, _, _, _, _, embeds_1_sim, _ = self.sim_model([x_1, x_2],
                                                                wordpiece_idx_1_2)
                # logits_1_stu = self.cls_model.linear(embeds_1_sim)
                # tag_pred_1 = logits_1_stu.argmax(-1)

                embeds_cls = [emd[w_idx + 1] for emd, w_idx in
                              zip(embeds_cls.cpu().numpy().tolist(), wordpiece_idx_1)]
                embed_list_cls.extend(embeds_cls)

                embed_list_sim.extend(embeds_1_sim.cpu().numpy().tolist())

                tag_pred_1 = [t_i[wp_idx + 1] for t_i, wp_idx in
                              zip(tag_pred_1.cpu().numpy().tolist(), wordpiece_idx_1)]
                #
                tag_pred_1 = [ix_to_tag[0] if p == 9 else ix_to_tag[p] for p in tag_pred_1]

                pred_list.extend([tag_to_ix[t] for t in tag_pred_1])
                # pred_list.extend([t.item() for t in tag_pred_1])
                tag_list.extend([t.item() for t in tag_1])

                # prediction, _, _, _, _, embeds_1, _ = sim_model([x_1, x_2], wordpiece_idx_1_2)
                # embed_list_cls.extend(embeds_1.cpu().numpy().tolist())

                Y.extend([ix_to_tag[t.item()] for t in tag_1])
                Y_HAT.extend(tag_pred_1)
                if idx % 1000 == 0:  # monitoring
                    self.logger.info(f"TEST STEP: {idx}\t\tTIME: {(time.time() - start_t) / 60}")

        self.logger.info(f"============Eval by conlleval:============")
        precision, recall, f1 = eval_F1(Y, Y_HAT, 'conlleval')
        self.logger.info("PRE=%.5f\t\tREC=%.5f\t\tF1=%.5f" % (precision, recall, f1))

        plot_num = 50
        tag_count = {tag: plot_num for tag in ALL_TAGS_}

        embed_list_use_stu, embed_list_use_cls, embed_list_use_sim, tag_list_use, pred_list_use = [], [], [], [], []
        for embed_stu_temp, embed_cls_temp, embed_sim_temp, y_temp, pred_temp in zip(embed_list_cls_stu, embed_list_cls,
                                                                                     embed_list_sim, tag_list,
                                                                                     pred_list):

            if (tag_count[ix_to_tag[y_temp]] > 0):
                tag_count[ix_to_tag[y_temp]] = tag_count[ix_to_tag[y_temp]] - 1
                embed_list_use_stu.append(embed_stu_temp)
                embed_list_use_cls.append(embed_cls_temp)
                embed_list_use_sim.append(embed_sim_temp)
                tag_list_use.append(y_temp)
                pred_list_use.append(pred_temp)

        if (i == 1):
            write_plot(['Y', 'Y_HAT', 'Conf_list'], [Y, Y_HAT, Conf_list],
                       self.record_dir_dict['plot_dir'] + 'cls_weight_test.json', self.logger)
            write_plot(['embed_list_use_cls', 'tag_list_use'], [embed_list_use_cls, tag_list_use],
                       self.record_dir_dict['plot_dir'] + 'cls_tea.json', self.logger)
            write_plot(['embed_list_use_sim', 'tag_list_use'], [embed_list_use_sim, tag_list_use],
                       self.record_dir_dict['plot_dir'] + 'sim_tea.json', self.logger)

            self.plot_weight_analyse(Y, Y_HAT, Conf_list, self.record_dir_dict['plot_dir'] + 'cls_weight_test')
            # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_cls, tag_list_use, self.record_dir_dict['plot_dir'] + 'cls_tea_' + str(i), 'cls tea '+str(i))
            # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_sim, tag_list_use, self.record_dir_dict['plot_dir'] + 'sim_tea_' + str(i), 'sim tea '+str(i))
        write_plot(['embed_list_use_stu', 'tag_list_use'], [embed_list_use_stu, tag_list_use],
                   self.record_dir_dict['plot_dir'] + 'stu_' + str(i) + '.json', self.logger)
        # embeds_plot_trainer.plot_dim_embeds(self, embed_list_use_stu, tag_list_use, self.record_dir_dict['plot_dir'] + 'stu_' + str(i), 'stu '+str(i))

        suffix = "tea_" if i == 1 else "stu_"
        self.write_ner_result_word_tag_pred(self.record_dir_dict['result_dir'] + suffix + str(i),
                                            data=(Words, Y, Y_HAT, Word_Idx_list))
        self.logger.info(f"predict result stored in {self.record_dir_dict['result_dir'] + suffix + str(i)}")

        return f1

    def train_epoch_param(self, i):
        if (i == 1):
            f1 = self.cls_evaluation(self.cls_model_stu, self.testloader_sim, i)

        #     self.record_result_dict['BASE_MAX_F1'] = self.evaluation(self.cls_model_stu, self.testloader_sim, i)
        # self.record_result_dict['BASE_MAX_F1'] = self.evaluation(self.cls_model, self.sim_model_stu, self.trainloader_tgt, i)

        # predict pseudo label
        self.logger.info(f"======================FINE TUNE BY JOINT LOSS={i}=====================")
        loss = self.train(self.cls_model_stu, self.trainloader_tgt, i)
        self.logger.info(f"======================TEST BY JOINT LOSS={i}=====================")
        f1 = self.evaluation(self.cls_model_stu, self.validloader_sim, i)

        # f1 = self.cls_evaluation(self.cls_model_stu,  self.testloader_sim, i)
        if (f1 > self.record_result_dict['VALID_MAX_F1']):
            self.record_result_dict['VALID_MAX_F1'] = f1
            torch.save(self.cls_model_stu.state_dict(), self.record_dir_dict['chk_dir'] + 'best_cls_test.pt')

        f1 = self.evaluation(self.cls_model_stu, self.testloader_sim, i)
        if (f1 > self.record_result_dict['TEST_MAX_F1']):
            self.record_result_dict['TEST_MAX_F1'] = f1
        self.record_result_dict['tgt_train_cls_F1_list'].append(f1)

        self.logger.info(f"Teacher: {self.cls_model_path}")
        self.logger.info(f"Teacher F1: {self.record_result_dict['BASE_MAX_F1']}")

        self.logger.info(
            f"Best CLUSTER_F1: BEST EPOCH_NUM={np.argmax(self.record_result_dict['tgt_train_cls_F1_list']) + 1}\t F1={np.max(self.record_result_dict['tgt_train_cls_F1_list'])}")
        self.logger.info(f"MODEL SAVED IN: {self.record_dir_dict['chk_dir']}best_cls_test.pt & best_sim_test.pt")



    def write_ner_result(self, resultfile, data, ix_to_tag):
        # ep: word real_tag pred_tag
        Words, Word_Idx, Y_list, Tag_Pred_NER_list = data
        with open(resultfile + '.txt', 'w') as fout:
            for words_batch, word_idx_batch, tags_batch, tag_ner_batch \
                    in zip(Words, Word_Idx, Y_list, Tag_Pred_NER_list):
                for k, words in enumerate(words_batch):
                    assert (len(words[1:-1]) > word_idx_batch[
                        k]), f"sentence length: {len(words[1:-1])}, word idx: {word_idx_batch[k]}"
                    t = tags_batch[k]
                    p_ner = tag_ner_batch[k]
                    w_i = word_idx_batch[k]
                    if (w_i == 0):
                        fout.write("\n")
                    # print(f"{words[1:-1]} {[w_i]} \n")  # {ix_to_tag} {[t]} {[p]}
                    fout.write(
                        f"{words[1:-1][w_i]}\t\t{ix_to_tag[t]}\t\t{ix_to_tag[p_ner]}\n")

    def write_ner_result_word_tag_pred(self, resultfile, data):
        # ep: word real_tag pred_tag
        Words, Y_list, Tag_Pred_list, Word_Idx_list = data
        with open(resultfile + '.txt', 'w') as fout:
            for word, tag, pred, w_idx in zip(Words, Y_list, Tag_Pred_list, Word_Idx_list):
                if (w_idx == 0):
                    fout.write(f"\n")
                fout.write(f"{word[1:-1][w_idx]}\t\t{tag}\t\t{pred}\n")

    def train_epoch(self, i):
        self.train_epoch_param(i)

    def record_result(self):
        pass  # #

from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from matplotlib.pyplot import MultipleLocator
import matplotlib.ticker as ticker
class analyse_exps(object):
    def __init__(self, hp):
        pass
    def record_result(self):
        pass
    def train_epoch(self,i):

        self.plot_weight_gamma('de', 3)     #
        self.plot_weight_analyse('nl','sim', 0) #

        self.plot_weight_analyse('es','cls', 1)

        self.plot_dim_embeds('nl','cls', 1)
        self.plot_dim_embeds('nl','sim', 1)
        self.plot_dim_embeds('nl','stu', 1, 3)
        self.loss_weight()

    def loss_weight(self):
        delta = 0.01
        x = np.arange(0,1+delta,delta)
        y = x**2

        plt.figure(dpi=300, figsize=(8, 6))

        plt.grid()
        plt.plot(x,y,linewidth=3,color='tab:green')
        plt.tick_params(labelsize=22)

        filename = '../MassPlotPng/exps/weight_alpha.pdf'
        filename = '../MassPlotPng/exps/weight_alpha.eps'
        plt.savefig(filename,bbox_inches='tight',pad_inches=0.0)

        plt.cla()
        plt.grid()
        y = [(2*a-1)**2 for a in x]
        plt.plot(x, y, linewidth=3,color='tab:green')
        plt.tick_params(labelsize=22)

        filename = '../MassPlotPng/exps/weight_beta.pdf'
        filename = '../MassPlotPng/exps/weight_beta.eps'
        plt.savefig(filename ,bbox_inches='tight',pad_inches=0.0)

    def plot_weight_gamma(self, lang, number=0):
        file_gamma_dict = {
            'es2': '../MassPlotPng/2021-11-13 17_28_59/gamma_weight.json',  #
            'nl2': '../MassPlotPng/2021-11-13 18_54_25/gamma_weight.json',  #
            'nl3': '../MassPlotPng/2021-11-13 19_59_02/gamma_weight.json',  #
            'de3': '../MassPlotPng/2021-11-14 00_16_23/gamma_weight.json',  #

        }
        filename = file_gamma_dict[lang + str(number)]

        f = open(filename, )
        data = json.load(f)



        YS, PS, X, P1, P2, Y1, Y2 = data['Y_sim'], data['Y_sim_hat'], data['Gamma_list'], data['P1_list'], data[
            'P2_list'], data['Y1_list'], data['Y2_list']

        delta = 0.1
        range_list = np.arange(delta, 1 + delta, delta)
        # delta = 0.1
        # range_list = np.arange(delta, 1 + delta, delta)
        precision, recall, overall_f1 = eval_F1(Y1, P1, 'conlleval')
        precision, recall, overall_f2 = eval_F1(Y2, P2, 'conlleval')
        precision, recall, overall_fs = eval_F1(YS, PS, 'sklearn_f1')
        # print(overall_fs)

        P1 = np.array(P1)
        P2 = np.array(P2)
        Y1 = np.array(Y1)
        Y2 = np.array(Y2)
        YS = np.array(YS)
        PS = np.array(PS)
        X = np.array(X)

        ran_min = 0
        X_range_data, Y_range_data, P_range_data = [], [], []
        X_index = []
        F1_range_p1, F1_range_p2, F1_range_s = [], [], []
        for ran_max in range_list:
            index_range = (X > ran_min) & (X <= ran_max)
            # X_range = X[index_range]
            YS_range = YS[index_range]
            PS_range = PS[index_range]
            Y1_range = Y1[index_range]
            P1_range = P1[index_range]
            Y2_range = Y2[index_range]
            P2_range = P2[index_range]
            if (len(YS_range) < 20):
                f1_p1, f1_p2, f1_s = 0,0,0
            else:
                precision, recall, f1_p1 = eval_F1(Y1_range, P1_range, 'conlleval')
                precision, recall, f1_p2 = eval_F1(Y2_range, P2_range, 'conlleval')
                precision, recall, f1_s = eval_F1(YS_range, PS_range, 'sklearn_f1')


            F1_range_p1.append(f1_p1 / 100)
            F1_range_p2.append(f1_p2 / 100)
            F1_range_s.append(f1_s / 100)
            # print(len(Y1_range), len(P1_range), ran_min, ran_max, f1_p1)
            # print(len(Y2_range), len(P2_range), ran_min, ran_max, f1_p2)
            # print(len(YS_range), len(PS_range), ran_min, ran_max, f1_s)

            X_index.append(ran_min + delta / 2)

            ran_min = ran_max


        F1_range_p1.sort()
        F1_range_p2.sort()
        print(F1_range_p1)
        print(F1_range_p2)


        print(F1_range_s)
        print(X_index)

        plt.figure(dpi=300, figsize=(8, 6))
        plt.bar(np.array(X_index)-delta / 4, F1_range_p1, color='tab:orange', width=delta / 4)
        plt.bar(np.array(X_index), F1_range_p2, color='tab:green', width=delta / 4)
        plt.bar(np.array(X_index)+delta / 4, F1_range_s, color='tab:blue', width=delta / 4)
        # plt.plot(X_index, F1_range, linewidth=3, color = 'orange')

        legend_elements = [ Patch(facecolor='tab:orange', edgecolor='tab:orange',
                                 label='y'),
                            Patch(facecolor='tab:green', edgecolor='tab:green',
                                  label='y\''),
                           Patch(facecolor='tab:blue', edgecolor='tab:blue',
                                 label='t')]
        plt.legend(handles=legend_elements, loc=2, fontsize=24)

        ax = plt.gca()
        # tick_spacing = 0.1
        # ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
        x_major_locator = MultipleLocator(0.1)
        y_major_locator = MultipleLocator(0.2)

        ax = plt.gca()
        ax.xaxis.set_major_locator(x_major_locator)

        ax.yaxis.set_major_locator(y_major_locator)

        plt.xlim(0.55, 1.05)
        plt.ylim(0, 1.05)
        plt.tick_params(labelsize=22)

        filename = '../MassPlotPng/exps/analyse_gamma.eps'
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.0)
        filename = '../MassPlotPng/exps/analyse_gamma.png'
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.0)
        plt.close()

    def plot_weight_analyse(self, lang, tea_name = 'sim', number=0):

        print(f"lang:{lang}, tea_name:{tea_name},number:{number} ")

        file_cls_dict = {
            'es0': '../MassPlotPng/2021-11-10 14_41_07/cls_weight_test.json',
            'de0': '../MassPlotPng/2021-11-10 14_41_19/cls_weight_test.json',
            'nl0': '../MassPlotPng/2021-11-10 14_40_40/cls_weight_test.json',
            'de1': '../MassPlotPng/2021-11-12 20_08_22/cls_weight_test.json',
            'nl1': '../MassPlotPng/2021-11-12 14_11_33/cls_weight_test.json',
            'es1': '../MassPlotPng/2021-11-12 20_08_56/cls_weight_test.json',

            # 负样本 完全
            'de2': '../MassPlotPng/2021-11-13 16_45_02/cls_weight_test.json',  #
            'nl2': '../MassPlotPng/2021-11-13 18_54_25/cls_weight_test.json',  #
            'es2': '../MassPlotPng/2021-11-13 17_28_59/cls_weight_test.json',  #
            'nl3': '../MassPlotPng/2021-11-13 19_59_02/cls_weight_test.json',  #

        }
        file_sim_dict = {
            'es0': '../MassPlotPng/2021-11-10 14_41_07/sim_weight.json',
            'de0': '../MassPlotPng/2021-11-10 14_41_19/sim_weight.json',
            'nl0': '../MassPlotPng/2021-11-10 14_40_40/sim_weight.json',
            'de1': '../MassPlotPng/2021-11-12 20_08_22/sim_weight.json',
            'nl1': '../MassPlotPng/2021-11-12 14_11_33/sim_weight.json',
            'es1': '../MassPlotPng/2021-11-12 20_08_56/sim_weight.json',

            # 负样本 完全
            'de2': '../MassPlotPng/2021-11-13 16_45_02/sim_weight.json',  #
            'es2': '../MassPlotPng/2021-11-13 17_28_59/sim_weight.json',  #
            'nl2': '../MassPlotPng/2021-11-13 18_54_25/sim_weight.json',  #
            # 'nl3': '../MassPlotPng/2021-11-13 19_59_02/sim_weight.json',  #

            # MTMT
            'de3': '../MassPlotPng/2021-11-15 00_55_54/sim_weight.json',  #
            'es3': '../MassPlotPng/2021-11-15 00_56_33/sim_weight.json',  #
            'nl3': '../MassPlotPng/2021-11-15 00_56_52/sim_weight.json',  #

        }
        file_gamma_dict = {
            'es2': '../MassPlotPng/2021-11-13 16_42_12/gamma_weight.json',
        }
        if(tea_name == 'sim'):
            filename = file_sim_dict[lang+str(number)]
        elif(tea_name == 'cls'):
            filename = file_cls_dict[lang+str(number)]


        f = open(filename, )
        data = json.load(f)
        if(tea_name=='cls'):
            Y, Y_pred, X = data['Y'], data['Y_HAT'], data['Conf_list']
            delta = 0.1
            range_list = np.arange(delta, 1 + delta, delta)

            x_major_locator = MultipleLocator(0.1)
            y_major_locator = MultipleLocator(0.1)
            ax = plt.gca()
            ax.xaxis.set_major_locator(x_major_locator)
            ax.yaxis.set_major_locator(y_major_locator)
            plt.xlim(-0.05, 1.05)
            plt.ylim(0, 1.05)
        elif(tea_name=='sim'):
            x_major_locator = MultipleLocator(0.05)
            y_major_locator = MultipleLocator(0.1)

            x_major_locator = MultipleLocator(0.1)
            y_major_locator = MultipleLocator(0.1)

            ax = plt.gca()
            ax.xaxis.set_major_locator(x_major_locator)
            ax.yaxis.set_major_locator(y_major_locator)
            plt.xlim(0.225, 0.775)
            plt.ylim(0, 0.65)
            plt.ylim(0, 1.05)

            plt.xlim(-0.05, 1.05)
            plt.xlim(-0.05, 1.05)

            plt.ylim(0, 1.05)

            Y, Y_pred, X = data['Y_sim'], data['Y_sim_hat'], data['Sim_list']
            Y_pred = [1 if y_pred > 0.5 else 0 for y_pred in X]

            delta = 0.05
            range_list = np.arange(0.25, 0.75 + delta, delta)
            delta = 0.1
            range_list = np.arange(delta, 1 + delta, delta)
            # delta = 0.1
            # range_list = np.arange(delta, 1 + delta, delta)
            precision, recall, overall_f1 = eval_F1(Y, Y_pred, 'sklearn_f1')
            print(overall_f1)



        Y = np.array(Y)
        Y_pred = np.array(Y_pred)
        X = np.array(X)

        ran_min = 0
        X_range_data, Y_range_data, P_range_data = [], [], []
        X_index = []
        F1_range = []
        for ran_max in range_list:
            index_range = (X>ran_min) & (X<=ran_max)
            # X_range = X[index_range]
            Y_range = Y[index_range]
            P_range = Y_pred[index_range]

            if (len(Y_range) < 20):
                f1 = 0
            else:
                if (tea_name == 'sim'):
                    f1 = f1_score(Y_range, P_range, average='weighted', labels=np.unique(P_range))
                    f1 = f1*100
                if (tea_name == 'cls'):
                    precision, recall, f1 = eval_F1(Y_range, P_range, 'conlleval')

            F1_range.append(f1/100)

            X_index.append(ran_min+delta/2)

            ran_min = ran_max

        filename = '../MassPlotPng/exps/analyse_'+tea_name


        print(F1_range)
        print(X_index)

        plt.figure(dpi=300, figsize=(8, 6))
        plt.bar(X_index, F1_range, color = 'tab:orange', width = delta/2+0.001)

        delta = 0.02
        x = np.arange(0, 1 + delta, delta)

        # x = range_list
        if(tea_name=='cls'):
            y = x ** 2
        else:
            y = [(2 * a - 1) ** 2 for a in x]
            # y = [( a - 0.5) ** 2 for a in x]
        plt.plot(x, y, linewidth=3, color='tab:green')

        legend_elements = [Line2D([0], [0], color='tab:green', lw=3, label='w'),
                           Patch(facecolor='tab:orange', edgecolor='tab:orange',
                                 label='F1')]

        if(tea_name=='cls'):
            loc = 2
        else:
            loc = 2
        plt.legend(handles=legend_elements, loc = loc, fontsize=22)        # 9
        x_major_locator = MultipleLocator(0.2)
        y_major_locator = MultipleLocator(0.2)
        ax = plt.gca()
        ax.xaxis.set_major_locator(x_major_locator)
        ax.yaxis.set_major_locator(y_major_locator)
        if(tea_name=='sim'):
            plt.xlim(0.2, 0.8)
        else:
            plt.xlim(-0.05, 1.05)
        # plt.xlim(0.35, 0.75)
        plt.ylim(0, 1.05)
        plt.tick_params(labelsize=22)
        plt.savefig(filename+'.png' ,bbox_inches='tight',pad_inches=0.0, format='png')

        plt.savefig(filename+'.eps' ,bbox_inches='tight',pad_inches=0.0, format='eps')
        plt.close()


    def plot_dim_embeds(self, lang, tea_name = 'sim', number=0, stu_num=1):

        file_cls_dict = {
            'es':'../MassPlotPng/2021-11-10 14_41_07/',
            'de': '../MassPlotPng/2021-11-10 14_41_07/',
            'nl': '../MassPlotPng/2021-11-10 14_41_07/',

            'de1': '../MassPlotPng/2021-11-12 20_08_22/',
            'nl1': '../MassPlotPng/2021-11-12 14_11_33/',
            'es1': '../MassPlotPng/2021-11-12 20_08_56/',

        }


        if (tea_name == 'cls'):
            filename = file_cls_dict[lang + str(number)] + tea_name + '_tea.json'
            f = open(filename, )
            data = json.load(f)
            embeds, labels = data['embed_list_use_cls'], data['tag_list_use']
        elif (tea_name == 'sim'):
            filename = file_cls_dict[lang + str(number)] + tea_name + '_tea.json'
            f = open(filename, )
            data = json.load(f)
            embeds, labels = data['embed_list_use_sim'], data['tag_list_use']
        else:
            filename = file_cls_dict[lang + str(number)] + tea_name + '_'+str(stu_num)+'.json'
            f = open(filename, )
            data = json.load(f)
            embeds, labels = data['embed_list_use_stu'], data['tag_list_use']

        if (isinstance(embeds, list)):
            embeds = np.array(embeds)

        tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
        lowDWeights = tsne.fit_transform(embeds)

        plt.figure(dpi=300, figsize=(8, 6))
        plt.cla()
        X, Y = lowDWeights[:, 0], lowDWeights[:, 1]

        c_list = ['tab:blue', 'tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:olive','tab:cyan','tab:gray']


        for x, y, s in zip(X, Y, labels):


            c =c_list[s]
            plt.scatter(x, y, color=c, s=50, alpha=0.8)

        ALL_TAGS = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-MISC', 'I-MISC']
        legend_list = []
        for tag_idx, tag in enumerate(ALL_TAGS):
            # c = cm.rainbow(int(255 / 9 * tag_idx))
            c = c_list[tag_idx]

            # legend_item = mpatches.Patch(color=c, label=tag)
            legend_item = Line2D([0], [0], marker='o', color='w', label=tag,
                   markerfacecolor=c, markersize=10)
            legend_list.append(legend_item)
        plt.legend(handles=legend_list, loc = 2, fontsize=16)
        # plt.style.use('seaborn-dark')
        # plt.xlim(X.min(), X.max())
        # plt.ylim(Y.min(), Y.max())
        # # plt.axis('off')
        # plt.xticks([])
        # plt.yticks([])
        plt.tick_params(labelsize=16)

        filename = '../MassPlotPng/exps/fea_'+tea_name+'.eps'
        plt.margins(0, 0)
        plt.savefig(filename,bbox_inches='tight',pad_inches=0.0)
        filename = '../MassPlotPng/exps/fea_'+tea_name+'.png'
        plt.savefig(filename,bbox_inches='tight',pad_inches=0.0)
        plt.close()
