import argparse
import json
from lib2to3.pgen2.tokenize import tokenize
import torch
import numpy as np
import random
from collections import Counter
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import wandb
from tqdm import tqdm, trange
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler
from transformers import BertConfig, BertTokenizer, BertModel, \
                         RobertaConfig, RobertaTokenizer, RobertaModel
from sklearn.metrics import confusion_matrix

MODEL_DICT={
    'bert': {
        'config': BertConfig,
        'tokenizer': BertTokenizer,
        'model':BertModel,
    },
    'roberta': {
        'config': RobertaConfig,
        'tokenizer': RobertaTokenizer,
        'model':RobertaModel,
    }
}

def Get_Args():
    parser=argparse.ArgumentParser("Model Config",add_help=False)

    parser.add_argument("--data_dir",default="./data/tacred",type=str,help="The input data directory")
    parser.add_argument("--seed",default=13,type=int,help="Random Seed For Model")
    parser.add_argument("--model_type",default="roberta",type=str,choices=MODEL_DICT.keys(),help="The type of PLM to use")
    parser.add_argument("--model_name_or_path",default="roberta-large",type=str,help="Path to PLM")
    parser.add_argument("--cache_dir",default="./PLM",type=str,help="Where to store the PLM")
    parser.add_argument("--template_dir",default="Bidirectionaltemp.txt",type=str,help="Template txt")
    parser.add_argument("--input_format",default="typed_entity_marker_punct",type=str,help="Input format of sentences.")
    parser.add_argument("--print_distribution",default=0,type=int,help="whether to print out distribution,0 is no, 1 is to print prompt distribution, 2 is to print MLP's.")
    parser.add_argument("--print_distribution_dir",default="./data/tacred/distribution",type=str,help="the output dir of distribution of prompt or MLP.")
    
    parser.add_argument("--prompt_direction",default=2,type=int,help="Direction of prompt, 0 is positive direction, 1 is minus direction, 2 is both.")
    parser.add_argument("--prompt_minus_type",default=1,type=int,help="Type of minus prompt, 0 is like 'e1 was [mask] of e2', 1 is like 'e1 was e2's [mask]'.")
    parser.add_argument("--prompt_special_token",default=0,type=int,help="Add special token in prompt, 0 is not, 1 adds special token.")
    parser.add_argument("--prompt_connect_token",default=1,type=int,help="The connect token connecting pos and minus prompt, 0 is none, 1 is conversely,2 is and")

    parser.add_argument("--na_threshold",default=-0.5,type=float,help="threshold of NA.")
    parser.add_argument("--loss_type",default=1,type=int,help="type of loss, 0 is and of bidrection, 1 is max of bidrection, 2 is or of bidrection")
    parser.add_argument("--positive_direction_mask_num",default=5,type=int,help="the number of mask in the positive direction template.")
    parser.add_argument("--minus_direction_mask_num",default=4,type=int,help="the number of mask in the minus direction template.")
    
    parser.add_argument("--label_num",default=42,type=int,help="The number of labels.")
    parser.add_argument("--cls_ratio",default=0.6,type=float,help="The ratio of the cls task loss.")
    parser.add_argument("--drop_out_rate",default=0.1,type=float,help="Drop out rate.")
    parser.add_argument("--max_seq_length",default=256,type=int,help="The upper limitation of length of the sentence after tokenization. The sentence longer than this will truncated.")
    parser.add_argument("--output_dir",default="./result/tacred",type=str,help="The output directory of the model")
    parser.add_argument("--batch_size",default=16,type=int,help="Batch size")
    parser.add_argument("--num_epochs",default=4,type=int,help="Number of training epochs.")
    parser.add_argument("--gradient_accumulation_steps",default=4,type=int,help="Help to solve the memory problem")
    parser.add_argument("--max_grad_norm",default=1.0,type=float,help="Max gradient norm.")
    parser.add_argument("--weight_decay",default=1e-2,type=float,help="Weight decay")
    parser.add_argument("--learning_rate",default=3e-5,type=float,help="Adam learning rate")
    parser.add_argument("--adam_epsilon",default=1e-8,type=float,help="Adam epsilon")
    parser.add_argument("--warmup_steps",default=0,type=int,help="Linear warmup")
    parser.add_argument("--learning_rate_for_new_token",default=1e-4,type=float,help="Adam learning rate for new tokens")
    parser.add_argument("--new_tokens",default=5,type=int,help="The length of dict")

    args=parser.parse_args()
    args.n_GPU=torch.cuda.device_count()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_GPU!=0:
        torch.cuda.manual_seed_all(args.seed)
    return args

def Get_Tokenizer(args):
    model_config=MODEL_DICT[args.model_type]
    tokenizer=model_config['tokenizer'].from_pretrained(args.model_name_or_path,cache_dir=args.cache_dir)
    #tokenizer.add_tokens([])
    return tokenizer

def Get_Template(args,tokenizer):
    template={}
    with open(args.data_dir+"/"+args.template_dir,"r") as f:
        for relation_template in f.readlines():
            relation_template=relation_template.strip().split("\t")
            #print(relation_template)
            '''
            if len(relation_template)!=11:
                print(relation_template)
                print("wrong")
            '''
            template_dict={}
            template_dict['relation']=relation_template[1].strip()
            
            if args.prompt_direction==0:
                if args.prompt_special_token==0:
                    template_dict['template']=[['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['the',tokenizer.mask_token]]
                    template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],)]
                else:
                    template_dict['template']=[['*',tokenizer.mask_token,'*'],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['^',tokenizer.mask_token,'^']]
                    template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],)]

            elif args.prompt_direction==1:
                if args.prompt_special_token==0:
                    if args.prompt_minus_type==0:
                        template_dict['template']=[['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token],['the',tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[7],),(relation_template[9],relation_template[10]),(relation_template[8],)]
                    else:
                        template_dict['template']=[['the',tokenizer.mask_token],['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[7],),(relation_template[8],),(relation_template[9],relation_template[10])]
                else:
                    if args.prompt_minus_type==0:
                        template_dict['template']=[['^',tokenizer.mask_token,'^'],[tokenizer.mask_token,tokenizer.mask_token],['*',tokenizer.mask_token,'*']]
                        template_dict['labels']=[(relation_template[7],),(relation_template[9],relation_template[10]),(relation_template[8],)]
                    else:
                        template_dict['template']=[['^',tokenizer.mask_token,'^'],['*',tokenizer.mask_token,'*'],[tokenizer.mask_token,tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[7],),(relation_template[8],),(relation_template[9],relation_template[10])]
            
            elif args.prompt_direction==2:
                if args.prompt_special_token==0:
                    if args.prompt_minus_type==0:
                        template_dict['template']=[['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['the',tokenizer.mask_token],['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token],['the',tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],),(relation_template[7],),(relation_template[9],relation_template[10]),(relation_template[8],)]
                    else:
                        template_dict['template']=[['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['the',tokenizer.mask_token],['the',tokenizer.mask_token],['the',tokenizer.mask_token],[tokenizer.mask_token,tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],),(relation_template[7],),(relation_template[8],),(relation_template[9],relation_template[10])]
                else:
                    if args.prompt_minus_type==0:
                        template_dict['template']=[['*',tokenizer.mask_token,'*'],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['^',tokenizer.mask_token,'^'],['^',tokenizer.mask_token,'^'],[tokenizer.mask_token,tokenizer.mask_token],['*',tokenizer.mask_token,'*']]
                        template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],),(relation_template[7],),(relation_template[9],relation_template[10]),(relation_template[8],)]
                    else:
                        template_dict['template']=[['*',tokenizer.mask_token,'*'],[tokenizer.mask_token,tokenizer.mask_token,tokenizer.mask_token],['^',tokenizer.mask_token,'^'],['^',tokenizer.mask_token,'^'],['*',tokenizer.mask_token,'*'],[tokenizer.mask_token,tokenizer.mask_token]]
                        template_dict['labels']=[(relation_template[2],),(relation_template[3],relation_template[4],relation_template[5]),(relation_template[6],),(relation_template[7],),(relation_template[8],),(relation_template[9],relation_template[10])]
                
            
            template[template_dict['relation']]=template_dict
    return template

class DictDataset(Dataset):
    def __init__(self,**tensors):
        self.tensors=tensors
    def __getitem__(self,index):
        return {key:tensor[index] for key,tensor in self.tensors.items()}
    def __len__(self):
        return next(iter(self.tensors.values())).size(0)
    def cuda(self):
        for key in self.tensors:
            self.tensors[key]=self.tensors[key].cuda()
    def cpu(self):
        for key in self.tensors:
            self.tensors[key]=self.tensors[key].cpu()


class Promptdataset(DictDataset):
    def __init__(self,path=None,name=None,rel2idjson=None,template=None,tokenizer=None,features=None,args=None):
        self.args=args
        with open(rel2idjson,"r") as f:
            self.rel2id=json.loads(f.read())
        if not 'NA' in self.rel2id:
            self.NA_num=self.rel2id['no_relation']
        else:
            self.NA_num=self.rel2id['NA']
        self.num_class=len(self.rel2id)
        self.template=template
        self.get_labels(tokenizer)
        if features is None:
            with open(path+"/"+name,"r",encoding='utf-8') as f:
                features=[]
                for line in f.readlines():
                    line=line.rstrip()
                    if len(line)>0:
                        features.append(eval(line))
            features=self.list2tensor(features,tokenizer)
        super().__init__(**features)

    def get_labels(self,tokenizer):
        total={}
        self.template_ids={}

        for relation in self.template:
            last=0
            self.template_ids[relation]={}
            self.template_ids[relation]['label_ids']=[]
            self.template_ids[relation]['mask_ids']=[]

            for index,template in enumerate(self.template[relation]['template']):
                unmask_template=template.copy()
                label_index=[]
                for i in range(len(unmask_template)):
                    if unmask_template[i]==tokenizer.mask_token:
                        unmask_template[i]=self.template[relation]['labels'][index][len(label_index)]
                        label_index.append(i)
                mask_template_encode=tokenizer.encode(" ".join(template),add_special_tokens=False)
                unmask_template_encode=tokenizer.encode(" ".join(unmask_template),add_special_tokens=False)
                self.template_ids[relation]['label_ids']+=[unmask_template_encode[i] for i in label_index]
                self.template_ids[relation]['mask_ids'].append(mask_template_encode)

                for i in label_index:
                    if not last in total:
                        total[last]={}
                    total[last][unmask_template_encode[i]]=1#存所有模板5个label位置对应的标签id集合
                    last+=1
        
        #print(total)
        self.label_set=[(list)((sorted)(set(total[i])))for i in range(len(total))]#存5个列表，第i个列表代表第i个label位置对应的标签id排完序后的集合
        #print(self.label_set)

        for relation in self.template_ids:
            for i in range(len(self.template_ids[relation]['label_ids'])):
                #每个self.template_ids[relation]['label_ids']存5个数，第i个数代表第i个label位置对应的标签id在排完序后的集合中的序号
                self.template_ids[relation]['label_ids'][i]=self.label_set[i].index(self.template_ids[relation]['label_ids'][i])
            #print(self.template_ids[relation]['label_ids'])
        
        self.prompt_id_2_label=torch.zeros(len(self.template_ids),len(self.label_set)).long()#42*5
        for relation in self.template_ids:
            for i in range(len(self.prompt_id_2_label[self.rel2id[relation]])):
                self.prompt_id_2_label[self.rel2id[relation]][i]=self.template_ids[relation]['label_ids'][i]
        self.prompt_id_2_label=self.prompt_id_2_label.long().cuda()#对于TACRED，42*5矩阵，(i,j)每个值代表分类数字是i的关系的模板的第j个mask位置的label对应该列的label的idx排完序后的下标
        #print(self.prompt_id_2_label)
        self.prompt_label_idx=[torch.Tensor(i).long() for i in self.label_set]#里面有5个list，每个对应该列的label的idx从小到大排完序后的集合
        #print(self.prompt_label_idx)
        #print(len(self.prompt_label_idx))
        '''
        for index,i in enumerate(self.prompt_label_idx):
            print(index)
            print(i)
        '''

    def list2tensor(self,data,tokenizer):
        res={}
        res['input_ids']=[]
        res['token_type_ids']=[]
        res['input_flags']=[]
        res['attention_mask']=[]
        res['labels']=[]
        res['mlm_labels']=[]
        res['sample_id']=[]
        for sample_id,sample in enumerate(data):
            input_ids,token_type_ids,input_flags=self.tokenize(sample,tokenizer)
            attention_mask=[1]*len(input_ids)
            padding_length=self.args.max_seq_length-len(input_ids)
            if padding_length>0:
                input_ids=input_ids+[tokenizer.pad_token_id]*padding_length
                token_type_ids=token_type_ids+[0]*padding_length
                input_flags=input_flags+[0]*padding_length
                attention_mask=attention_mask+[0]*padding_length
            label=self.rel2id[sample['relation']]
            res['input_ids'].append(np.array(input_ids))#token的id
            res['token_type_ids'].append(np.array(token_type_ids))#token属于第几个句子
            res['input_flags'].append(np.array(input_flags))#全1后补0
            res['attention_mask'].append(np.array(attention_mask))#全1后补0
            res['labels'].append(np.array(label))#关系id
            res['sample_id'].append(np.array(sample_id))
            #print(np.where(res['input_ids'][-1]==tokenizer.mask_token_id))
            mask_pos=np.where(res['input_ids'][-1]==tokenizer.mask_token_id)[0]
            mlm_labels=np.ones(self.args.max_seq_length)*(-1)
            mlm_labels[mask_pos]=1
            res['mlm_labels'].append(mlm_labels)#对应的masktoken位置为1，其余为-1
        for key in res:
            res[key]=np.array(res[key])
            res[key]=torch.Tensor(res[key]).long()
        #print(res['input_ids'].size(0))
        return res

    def convert_token(self,token):
        if token.lower()=="-lrb-":
            return "("
        elif token.lower()=="-rrb-":
            return ")"
        elif token.lower()=="-lsb-":
            return "["
        elif token.lower()=="-rsb-":
            return "]"
        elif token.lower()=="-lcb-":
            return "{"
        elif token.lower()=="-rcb-":
            return "}"
        return token

    def tokenize(self,sample,tokenizer):#对1个sample进行encode
        #改变input_format
        sample_tmp=[]
        if self.args.input_format=='typed_entity_marker_punct':
            for i in range(0,len(sample['token'])):
                if i==sample['h']['pos'][0]:
                    sample_tmp.extend(['@']+['*']+[sample['h']['type'].lower()]+['*']+[sample['token'][i]])
                elif i==sample['h']['pos'][1]:
                    sample_tmp.extend([sample['token'][i]]+['@'])
                elif i==sample['t']['pos'][0]:
                    sample_tmp.extend(['#']+['^']+[sample['t']['type'].lower()]+['^']+[sample['token'][i]])
                elif i==sample['t']['pos'][1]:
                    sample_tmp.extend([sample['token'][i]]+['#'])
                else:
                    sample_tmp.extend([sample['token'][i]])
        #print(tokenizer.encode("@ * # ^",add_special_tokens=False))
        #print(sample_tmp)
        sample_tmp=[self.convert_token(token) for token in sample_tmp]
        sentence=tokenizer.encode(" ".join(sample_tmp),add_special_tokens=False)
        #print(sentence)
        e1=tokenizer.encode(" ".join(['was', sample['h']['name']]),add_special_tokens=False)[1:]
        e2=tokenizer.encode(" ".join(['was', sample['t']['name']]),add_special_tokens=False)[1:]
        e1_edj=tokenizer.encode(" ".join([sample['h']['name'],'\'s']),add_special_tokens=False)
        was_code=tokenizer.encode("was",add_special_tokens=False)
        dot_code=tokenizer.encode(",",add_special_tokens=False)
        of_code=tokenizer.encode("of",add_special_tokens=False)
        at_code=tokenizer.encode("@",add_special_tokens=False)
        sharp_code=tokenizer.encode("#",add_special_tokens=False)

        conversely_code=tokenizer.encode("conversely",add_special_tokens=False)
        and_code=tokenizer.encode("and",add_special_tokens=False)

        relation=sample['relation']

        prompt=[]
        prompt_pos=[]
        prompt_minus=[]

        if args.prompt_direction==0 or args.prompt_direction==2:
            if args.prompt_special_token==0:
                prompt_pos=self.template_ids[relation]['mask_ids'][0]+e1+self.template_ids[relation]['mask_ids'][1]+self.template_ids[relation]['mask_ids'][2]+e2
            else:
                prompt_pos=at_code+self.template_ids[relation]['mask_ids'][0]+e1+at_code+self.template_ids[relation]['mask_ids'][1]+sharp_code+self.template_ids[relation]['mask_ids'][2]+e2+sharp_code
        
        if args.prompt_direction==1:
            if args.prompt_special_token==0:
                if args.prompt_minus_type==0:
                    prompt_minus=self.template_ids[relation]['mask_ids'][0]+e2+was_code+self.template_ids[relation]['mask_ids'][1]+of_code+self.template_ids[relation]['mask_ids'][2]+e1
                else:
                    prompt_minus=self.template_ids[relation]['mask_ids'][0]+e2+was_code+self.template_ids[relation]['mask_ids'][1]+e1_edj+self.template_ids[relation]['mask_ids'][2]
            else:
                if args.prompt_minus_type==0:
                    prompt_minus=sharp_code+self.template_ids[relation]['mask_ids'][0]+e2+sharp_code+was_code+self.template_ids[relation]['mask_ids'][1]+of_code+at_code+self.template_ids[relation]['mask_ids'][2]+e1+at_code
                else:
                    prompt_minus=sharp_code+self.template_ids[relation]['mask_ids'][0]+e2+sharp_code+was_code+at_code+self.template_ids[relation]['mask_ids'][1]+e1_edj+at_code+self.template_ids[relation]['mask_ids'][2]
        
        if args.prompt_direction==2:
            if args.prompt_special_token==0:
                if args.prompt_minus_type==0:
                    prompt_minus=self.template_ids[relation]['mask_ids'][3]+e2+was_code+self.template_ids[relation]['mask_ids'][4]+of_code+self.template_ids[relation]['mask_ids'][5]+e1
                else:
                    prompt_minus=self.template_ids[relation]['mask_ids'][3]+e2+was_code+self.template_ids[relation]['mask_ids'][4]+e1_edj+self.template_ids[relation]['mask_ids'][5]
            else:
                if args.prompt_minus_type==0:
                    prompt_minus=sharp_code+self.template_ids[relation]['mask_ids'][3]+e2+sharp_code+was_code+self.template_ids[relation]['mask_ids'][4]+of_code+at_code+self.template_ids[relation]['mask_ids'][5]+e1+at_code
                else:
                    prompt_minus=sharp_code+self.template_ids[relation]['mask_ids'][3]+e2+sharp_code+was_code+at_code+self.template_ids[relation]['mask_ids'][4]+e1_edj+at_code+self.template_ids[relation]['mask_ids'][5]

        if args.prompt_direction==0:
            prompt=prompt_pos
        elif args.prompt_direction==1:
            prompt=prompt_minus
        elif args.prompt_direction==2:
            if args.prompt_connect_token==0:
                prompt=prompt_pos+dot_code+prompt_minus
            elif args.prompt_connect_token==1:
                prompt=prompt_pos+dot_code+conversely_code+prompt_minus
            elif args.prompt_connect_token==2:
                prompt=prompt_pos+dot_code+and_code+prompt_minus
    
        flags=[]
        for i in prompt:
            flags.append(0)
        tokens=sentence+prompt
        flags=[0 for i in range(len(sentence))]+flags
        
        max_len=self.args.max_seq_length-tokenizer.num_special_tokens_to_add(False)
        tokens=tokens if len(tokens)<=max_len else tokens[len(tokens)-max_len:]
        flags=flags if len(flags)<=max_len else flags[len(flags)-max_len:]

        input_ids=tokenizer.build_inputs_with_special_tokens(tokens)#拼接上预训练模型对应的特殊token ，如[CLS]、[SEP]
        token_type_ids=tokenizer.create_token_type_ids_from_sequences(tokens)#返回[0,0,0,...,0,1,1,1,...,1],区分出模板和原句子
        input_flags = tokenizer.build_inputs_with_special_tokens(flags)
        input_flags[0]=0
        input_flags[-1]=0#头尾设置因为加上[cls],[sep],设置成0

        return input_ids, token_type_ids, input_flags

    def save(self,path=None,name=None):
        path=path+"/"+name+"/"
        np.save(path+"input_ids",self.tensors['input_ids'].numpy())
        np.save(path+"token_type_ids", self.tensors['token_type_ids'].numpy())
        np.save(path+"attention_mask", self.tensors['attention_mask'].numpy())
        np.save(path+"labels", self.tensors['labels'].numpy())
        np.save(path+"mlm_labels", self.tensors['mlm_labels'].numpy())
        np.save(path+"input_flags", self.tensors['input_flags'].numpy())
        np.save(path+"sample_id", self.tensors['sample_id'].numpy())

    @classmethod
    def load(cls, path = None, name = None, rel2idjson = None, template = None, tokenizer = None):
        path = path + "/" + name  + "/"
        features = {}
        features['input_ids'] = torch.Tensor(np.load(path+"input_ids.npy")).long()
        features['token_type_ids'] = torch.Tensor(np.load(path+"token_type_ids.npy")).long()
        features['attention_mask'] = torch.Tensor(np.load(path+"attention_mask.npy")).long()
        features['labels'] = torch.Tensor(np.load(path+"labels.npy")).long()
        features['input_flags'] = torch.Tensor(np.load(path+"input_flags.npy")).long()
        features['mlm_labels'] = torch.Tensor(np.load(path+"mlm_labels.npy")).long()
        features['sample_id'] = torch.Tensor(np.load(path+"sample_id.npy")).long()
        res = cls(rel2idjson = rel2idjson, features = features, template=template, tokenizer = tokenizer,args=args)
        return res

class Model(torch.nn.Module):
    def __init__(self,args,tokenizer=None,prompt_label_idx=None):
        super().__init__()
        model_config=MODEL_DICT[args.model_type]
        self.prompt_label_idx=prompt_label_idx
        self.PLM=model_config['model'].from_pretrained(
            args.model_name_or_path,
            return_dict=False,
            cache_dir=args.cache_dir if args.cache_dir else None
        )
        self.mlp=torch.nn.Sequential(
            torch.nn.Linear(self.PLM.config.hidden_size,self.PLM.config.hidden_size),
            torch.nn.ReLU(),
            #torch.nn.Dropout(args.drop_out_rate),
            torch.nn.Linear(self.PLM.config.hidden_size,self.PLM.config.hidden_size)
        )
        self.extra_token_embeddings=nn.Embedding(args.new_tokens,self.PLM.config.hidden_size)

        self.relation_cls=torch.nn.Sequential(
            torch.nn.Linear(self.PLM.config.hidden_size,self.PLM.config.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Dropout(args.drop_out_rate),
            torch.nn.Linear(self.PLM.config.hidden_size,args.label_num)
        )
    
    def forward(self,input_ids,attention_mask,token_type_ids,input_flags,mlm_labels,labels,sample_id):
        raw_embeddings=self.PLM.embeddings.word_embeddings(input_ids)
        new_token_embeddings=self.mlp(self.extra_token_embeddings.weight)
        new_embeddings=new_token_embeddings[input_flags]
        inputs_embeddings=torch.where(input_flags.unsqueeze(-1)>0,new_embeddings,raw_embeddings)
        #last_hidden_state：torch.FloatTensor类型,最后一个隐藏层的序列的输出
        #大小是(batch_size, sequence_length, hidden_size) sequence_length是我们截取的句子的长度
        hidden_states,_=self.PLM(
            inputs_embeds=inputs_embeddings,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        cls_embedding=hidden_states[:,0,:].view(input_ids.size(0),-1)
        cls_logits=self.relation_cls(cls_embedding)
        #把hidden_states压成batchsize*5(5个mask)*hidden_size的矩阵，并用mlm_labels>=0截取mask位置的嵌入
        hidden_states=hidden_states[mlm_labels>=0].view(input_ids.size(0),len(self.prompt_label_idx),-1)
        logits=[#5*batch_size*每个mask位置对应label集合大小
            #矩阵相乘后相当于求余弦相似度
            #(batch_size*hidden_size)*(hidden_size*每个mask位置对应label集合大小)
            torch.mm(
                hidden_states[:,index,:],#第index个mask位置的预测嵌入
                self.PLM.embeddings.word_embeddings.weight[i].transpose(1,0)
            )
            for index,i in enumerate(self.prompt_label_idx)
        ]
        logits.append(cls_logits)
        logits.append(sample_id)
        return logits

def Get_Optimizer(model,train_dataloader):
    t_total=len(train_dataloader)
    cur_model=model.module if hasattr(model,'module') else model
    #cur_model=model
    no_decay=['bias','LayerNorm.weight']
    optimizer_grouped_parameters=[
        {'params':[p for n,p in cur_model.named_parameters() if not any(nd in n for nd in no_decay)],'weight_decay':args.weight_decay},
        {'params':[p for n,p in cur_model.named_parameters() if any(nd in n for nd in no_decay)],'weight_decay':0.0}
    ]
    optimizer=AdamW(
        optimizer_grouped_parameters,
        lr=args.learning_rate,
        eps=args.adam_epsilon
    )
    scheduler=get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total
    )

    embedding_parameters=[
        {'params':[p for p in cur_model.mlp.parameters()]},
        {'params':[p for p in cur_model.extra_token_embeddings.parameters()]}
    ]
    embedding_optimizer=AdamW(
        embedding_parameters,
        lr=args.learning_rate_for_new_token,
        eps=args.adam_epsilon
    )
    embedding_scheduler=get_linear_schedule_with_warmup(
        embedding_optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total
    )
    return optimizer,scheduler,embedding_optimizer,embedding_scheduler

def f1_score(output,label,rel_num,na_num):
    correct_by_relation = Counter()
    guess_by_relation = Counter()
    gold_by_relation = Counter()
    for i in range(len(output)):
        guess=output[i]
        gold=label[i]
        if guess==na_num:
            guess=0
        elif guess<na_num:
            guess+=1
        if gold==na_num:
            gold=0
        elif gold<na_num:
            gold+=1
        if gold==0 and guess!=0:
            guess_by_relation[guess]+=1
        elif gold!=0 and guess==0:
            gold_by_relation[gold]+=1
        elif gold!=0 and guess !=0:
            guess_by_relation[guess]+=1
            gold_by_relation[gold]+=1
            if gold==guess:
                correct_by_relation[gold]+=1

    f1_by_relation = Counter()
    recall_by_relation = Counter()
    prec_by_relation = Counter()
    for i in range(1,rel_num):
        recall=0
        if gold_by_relation[i]>0:
            recall=correct_by_relation[i]/gold_by_relation[i]
        precision=0
        if guess_by_relation[i]>0:
            precision=correct_by_relation[i]/guess_by_relation[i]
        if recall+precision>0:
            f1_by_relation[i]=2*recall*precision/(recall+precision)
        recall_by_relation[i]=recall
        prec_by_relation[i]=precision
    
    micro_f1=0
    if sum(guess_by_relation.values())!=0 and sum(correct_by_relation.values())!=0:
        recall=sum(correct_by_relation.values())/sum(gold_by_relation.values())
        prec=sum(correct_by_relation.values())/sum(guess_by_relation.values())
        micro_f1=2*recall*prec/(recall+prec)
    
    return micro_f1,f1_by_relation

def Print_Confusion(output,label,rel_num):
    labels=["org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", \
        "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "NA", "per:city_of_birth", \
            "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", \
                "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", \
                    "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", \
                        "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"]
    maxtrix = confusion_matrix(label,output)
    plt.matshow(maxtrix)

    cmap = mpl.cm.viridis
    bounds = [0, 1, 5, 10, 50,100, np.max(maxtrix)]
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N, extend='neither')
    plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap))

    plt.xlabel('predict')
    plt.ylabel('correct')
    plt.xticks(np.arange(maxtrix.shape[1]))
    plt.yticks(np.arange(maxtrix.shape[1]),labels)
    plt.savefig("./figs/TACRED/TACRED-minus.png") 
    plt.show()
        
def Print_Distribution(scores,all_labels,rel_num,dataset_type):
    if dataset_type==0:
        np.savetxt(args.print_distribution_dir+'/test_label.txt', all_labels,fmt='%d',delimiter='\t')
        if args.print_distribution==1:
            np.savetxt(args.print_distribution_dir+'/prompt_test.txt', scores,fmt='%f',delimiter='\t')
        elif args.print_distribution==2:
            np.savetxt(args.print_distribution_dir+'/MLP_test.txt', scores,fmt='%f',delimiter='\t')
    else:
        np.savetxt(args.print_distribution_dir+'/train_label.txt', all_labels,fmt='%d',delimiter='\t')
        if args.print_distribution==1:
            np.savetxt(args.print_distribution_dir+'/prompt_train.txt', scores,fmt='%f',delimiter='\t')
        elif args.print_distribution==2:
            np.savetxt(args.print_distribution_dir+'/MLP_train.txt', scores,fmt='%f',delimiter='\t')


def evaluate(model,dataset,dataloader,confusion_flag,distribution_flag,dataset_type):#dataset_type=0:test,dataset_type=1:train
    model.eval()#测试时不启用 BatchNormalization 和 Dropout
    scores=[]
    all_labels=[]
    softmax=nn.Softmax(dim=1)

    with torch.no_grad():
        for batch in tqdm(dataloader):
            logits=model(**batch)
            sample_id=logits[-1]
            cls_logits=logits[-2]
            logits=logits[:-2]
            labels=train_dataset.prompt_id_2_label[batch['labels']]

            if args.loss_type==0:#每个mask位置概率相乘
                logits_temp=[]
                for _,i in enumerate(logits):
                    logits_temp.append(torch.log(softmax(i)))
                res=[]
                for i in train_dataset.prompt_id_2_label:
                    _res=0.0
                    for j in range(len(i)):
                        _res+=logits_temp[j][:,i[j]]
                    res.append(_res)
                logits=torch.stack(res,0).transpose(1,0)

            elif args.loss_type==1:#两个模板概率取最大
                res=[]
                for i in train_dataset.prompt_id_2_label:
                    _res_pos=0.0
                    _res_minus=0.0
                    for j in range(len(i)):
                        if j<args.positive_direction_mask_num:
                            _res_pos+=logits[j][:,i[j]]
                        else:
                            _res_minus+=logits[j][:,i[j]]
                    _res_pos/=args.positive_direction_mask_num
                    _res_minus/=args.minus_direction_mask_num
                    res.append(torch.max(_res_pos,_res_minus))
                logits=torch.stack(res,0).transpose(1,0)
                logits=logits*(1-args.cls_ratio)+args.cls_ratio*cls_logits
                
            elif args.loss_type==2:#取或
                res_pos=[]
                res_minus=[]
                for i in train_dataset.prompt_id_2_label:
                    _res_pos=0.0
                    _res_minus=0.0
                    for j in range(len(i)):
                        if j<args.positive_direction_mask_num:
                            _res_pos+=logits[j][:,i[j]]
                        else:
                            _res_minus+=logits[j][:,i[j]]
                    _res_pos/=args.positive_direction_mask_num
                    _res_minus/=args.minus_direction_mask_num
                    res_pos.append(_res_pos)
                    res_minus.append(_res_minus)
                res_pos=torch.stack(res_pos,0).transpose(1,0)
                res_minus=torch.stack(res_minus,0).transpose(1,0)
                p_pos=softmax(res_pos)
                p_minus=softmax(res_minus)
                p=p_pos+p_minus#-torch.mul(p_pos,p_minus)
                logits=torch.log(p)

            labels=batch['labels'].detach().cpu().tolist()
            all_labels+=labels
            scores.append(logits.cpu().detach())
        scores=torch.cat(scores,0)
        scores=scores.detach().cpu().numpy()
        all_labels=np.array(all_labels)

        pred=np.argmax(scores,axis=-1)
        if distribution_flag and args.print_distribution!=0:
            Print_Distribution(scores,all_labels,dataset.num_class,dataset_type)
        #if confusion_flag==True:
        #    Print_Confusion(pred,all_labels,dataset.num_class)
        mi_f1,ma_f1=f1_score(pred,all_labels,dataset.num_class,dataset.NA_num)
        return mi_f1,ma_f1

if __name__ == '__main__':
    args=Get_Args()
    tokenizer=Get_Tokenizer(args)
    template=Get_Template(args,tokenizer)
    #print(template)

    dataset=Promptdataset(path=args.data_dir,name="train.txt",rel2idjson=args.data_dir+"/"+"rel2id.json",template=template,tokenizer=tokenizer,args=args)
    dataset.save(path=args.output_dir,name="train")
    dataset=Promptdataset(path=args.data_dir,name="val.txt",rel2idjson=args.data_dir+"/"+"rel2id.json",template=template,tokenizer=tokenizer,args=args)
    dataset.save(path=args.output_dir,name="val")
    dataset=Promptdataset(path=args.data_dir,name="test.txt",rel2idjson=args.data_dir+"/"+"rel2id.json",template=template,tokenizer=tokenizer,args=args)
    dataset.save(path=args.output_dir,name="test")
    train_dataset=Promptdataset.load(path=args.output_dir,name="train",template=template,tokenizer=tokenizer,rel2idjson=args.data_dir+"/"+"rel2id.json")
    val_dataset=Promptdataset.load(path=args.output_dir,name="val",template=template,tokenizer=tokenizer,rel2idjson=args.data_dir+"/"+"rel2id.json")
    test_dataset=Promptdataset.load(path=args.output_dir,name="test",template=template,tokenizer=tokenizer,rel2idjson=args.data_dir+"/"+"rel2id.json")
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU
    args.n_gpu = torch.cuda.device_count()

    batch_size=args.batch_size*max(1,args.n_gpu)
    train_dataset.cuda()
    #train_dataset.to(device)
    train_sampler=RandomSampler(train_dataset)
    train_dataloader=DataLoader(train_dataset,sampler=train_sampler,batch_size=batch_size)
    val_dataset.cuda()
    #val_dataset.to(device)
    val_sampler=RandomSampler(val_dataset)
    val_dataloader=DataLoader(val_dataset,sampler=val_sampler,batch_size=batch_size)
    test_dataset.cuda()
    #test_dataset.to(device)
    test_sampler=RandomSampler(test_dataset)
    test_dataloader=DataLoader(test_dataset,sampler=test_sampler,batch_size=batch_size)

    model=Model(args,tokenizer,train_dataset.prompt_label_idx)
    if torch.cuda.device_count()>1:
        model=torch.nn.DataParallel(model)
    model.cuda()
    #model = model.to(device)
    
    optimizer,scheduler,optimizer_new_token,scheduler_new_token = Get_Optimizer(model, train_dataloader)
    criterion=nn.CrossEntropyLoss()
    softmax=nn.Softmax(dim=1)
    loss_func=nn.NLLLoss()

    max_res=0.0
    mi_f1_record=[]
    ma_f1_record=[]
    max_epoch=None
    last_epoch=None
    
    #wandb.init(project="my-project",name="test_project")

    for epoch in range(0,args.num_epochs):
        model.train()
        model.zero_grad()
        tr_loss=0.0
        global_step=0

        for step,batch in enumerate(tqdm(train_dataloader)):
            #logits是5（mask的数量）个tensor，每个tensor是batch_size*label集合大小
            #相当于是对集合中每一种label的概率
            logits=model(**batch)
            sample_id=logits[-1]
            cls_logits=logits[-2]
            logits=logits[:-2]
            #print(logits)
            #print(cls_logits.shape)
            #wandb.log({'logits':logits})
            #labels：batch_size*5,真实的label
            labels=train_dataset.prompt_id_2_label[batch['labels']]
            #print(labels)
            #wandb.log({'labels':labels})
            loss=0.0

            if args.loss_type==0:#概率相乘
                for index,i in enumerate(logits):
                    loss+=criterion(i,labels[:,index])
                loss/=len(logits)
                '''
                res=[]
                for i in train_dataset.prompt_id_2_label:
                    _res=0.0
                    for j in range(len(i)):
                        _res+=logits[j][:,i[j]]
                    res.append(_res)
                final_logits=torch.stack(res,0).transpose(1,0)
                loss+=criterion(final_logits,batch['labels'])
                '''
                
            elif args.loss_type==1:#取最大值

                sample_num=logits[0].shape[0]
                for j in range(0,sample_num):
                    forward_loss=0.0
                    reverse_loss=0.0
                    for index in range(0,len(logits)):
                        if index<args.positive_direction_mask_num:
                            forward_loss+=criterion(logits[index][j,:].unsqueeze(0),labels[j,index].unsqueeze(0))
                        else:
                            reverse_loss+=criterion(logits[index][j,:].unsqueeze(0),labels[j,index].unsqueeze(0))
                    forward_loss/=args.positive_direction_mask_num
                    reverse_loss/=args.minus_direction_mask_num
                    loss+=min(forward_loss,reverse_loss)

                res=[]
                for i in train_dataset.prompt_id_2_label:
                    _res_pos=0.0
                    _res_minus=0.0
                    for j in range(len(i)):
                        if j<args.positive_direction_mask_num:
                            _res_pos+=logits[j][:,i[j]]
                        else:
                            _res_minus+=logits[j][:,i[j]]
                    _res_pos/=args.positive_direction_mask_num
                    _res_minus/=args.minus_direction_mask_num
                    res.append(torch.max(_res_pos,_res_minus))
                final_logits=torch.stack(res,0).transpose(1,0)
                loss+=criterion(final_logits,batch['labels'])

                cls_loss=criterion(cls_logits,batch['labels'])
                loss=loss*(1-args.cls_ratio)+args.cls_ratio*cls_loss

            elif args.loss_type==2:#取或
                res_pos=[]
                res_minus=[]
                for i in train_dataset.prompt_id_2_label:
                    _res_pos=0.0
                    _res_minus=0.0
                    for j in range(len(i)):
                        if j<args.positive_direction_mask_num:
                            _res_pos+=logits[j][:,i[j]]
                        else:
                            _res_minus+=logits[j][:,i[j]]
                    _res_pos/=args.positive_direction_mask_num
                    _res_minus/=args.minus_direction_mask_num
                    res_pos.append(_res_pos)
                    res_minus.append(_res_minus)
                res_pos=torch.stack(res_pos,0).transpose(1,0)
                res_minus=torch.stack(res_minus,0).transpose(1,0)
                p_pos=softmax(res_pos)
                p_minus=softmax(res_minus)
                p=p_pos+p_minus#-torch.mul(p_pos,p_minus)
                p=torch.log(p)
                #print(torch.max(p))
                #print(torch.min(p))
                loss+=loss_func(p,batch['labels'])


            if args.gradient_accumulation_steps>1:
                loss/=args.gradient_accumulation_steps
            loss.backward()
            tr_loss+=loss.item()

            if (step+1)%args.gradient_accumulation_steps==0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),args.max_grad_norm)#解决梯度爆炸
                optimizer.step()
                scheduler.step()
                optimizer_new_token.step()
                scheduler_new_token.step()
                model.zero_grad()
                global_step+=1
            if step%5000==0 and step>=10:
                print(epoch,tr_loss/global_step,max_res)
        
        mi_f1,_=evaluate(model,test_dataset,test_dataloader,epoch>=3,epoch>=3,0)
        print(mi_f1)
    if args.print_distribution!=0:
        mi_f1,_=evaluate(model,train_dataset,train_dataloader,0,True,1)