import numpy as np
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.optim as optim
import time
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from collections import Counter
import random

class dataset(Dataset):
    def __init__(self,feature1,feature2,label,num_class):
        f1 = np.array(np.loadtxt(feature1, delimiter='\t',dtype=np.float32))#f1为prompt的输出概率分布
        f2 = np.array(np.loadtxt(feature2, delimiter='\t',dtype=np.float32))
        label_np=np.array(np.loadtxt(label, delimiter='\t',dtype=np.int64))
        layer_norm1 = torch.nn.LayerNorm(num_class, eps=1e-6) 
        layer_norm2 = torch.nn.LayerNorm(num_class, eps=1e-6) 
        f1 = layer_norm1(torch.from_numpy(f1)).detach().numpy()
        f2 = layer_norm2(torch.from_numpy(f2)).detach().numpy()
        
        self.x_data = torch.from_numpy(np.concatenate([f1,f2],axis=1))
        self.y_data = torch.from_numpy(label_np)
        self.len = f1.shape[0]
 
    def __len__(self):
        return self.len
    def __getitem__(self, item):
        return self.x_data[item],self.y_data[item]

class Model(torch.nn.Module):
    def __init__(self,num_i,num_h,num_o):
        super(Model,self).__init__()
        self.linear1=torch.nn.Linear(num_i,num_h)
        self.relu2=torch.nn.ReLU()
        self.linear3=torch.nn.Linear(num_h,num_o)
  
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu2(x)
        x = self.linear3(x)
        return x
 
 
def f1_score(output,label,rel_num,na_num):
    correct_by_relation = Counter()
    guess_by_relation = Counter()
    gold_by_relation = Counter()
    for i in range(len(output)):
        guess=output[i]
        gold=label[i]
        if guess==na_num:
            guess=0
        elif guess<na_num:
            guess+=1
        if gold==na_num:
            gold=0
        elif gold<na_num:
            gold+=1
        if gold==0 and guess!=0:
            guess_by_relation[guess]+=1
        elif gold!=0 and guess==0:
            gold_by_relation[gold]+=1
        elif gold!=0 and guess !=0:
            guess_by_relation[guess]+=1
            gold_by_relation[gold]+=1
            if gold==guess:
                correct_by_relation[gold]+=1

    f1_by_relation = Counter()
    recall_by_relation = Counter()
    prec_by_relation = Counter()
    for i in range(1,rel_num):
        recall=0
        if gold_by_relation[i]>0:
            recall=correct_by_relation[i]/gold_by_relation[i]
        precision=0
        if guess_by_relation[i]>0:
            precision=correct_by_relation[i]/guess_by_relation[i]
        if recall+precision>0:
            f1_by_relation[i]=2*recall*precision/(recall+precision)
        recall_by_relation[i]=recall
        prec_by_relation[i]=precision
    
    micro_f1=0
    if sum(guess_by_relation.values())!=0 and sum(correct_by_relation.values())!=0:
        recall=sum(correct_by_relation.values())/sum(gold_by_relation.values())
        prec=sum(correct_by_relation.values())/sum(guess_by_relation.values())
        micro_f1=2*recall*prec/(recall+prec)
    
    return micro_f1,f1_by_relation

def evaluate():
    model.eval()
    all_labels=[]
    scores=[]
    with torch.no_grad():
        for test_data in test_dataloader:
            inputs, labels = test_data
            outputs = model(inputs)
            
            scores.append(outputs.cpu().detach())
            labels=torch.tensor(labels).numpy().tolist()
            all_labels+=labels
        
    scores=torch.cat(scores,0)
    scores=scores.detach().cpu().numpy()
    pred=np.argmax(scores,axis=-1)
    all_labels=np.array(all_labels)
    mi_f1,_=f1_score(pred,all_labels,num_class,NA_num)
    return mi_f1

if __name__ == "__main__":
    num_class=40
    NA_num=0
    num_i=num_class*2 #输入层节点数
    num_h=512   #隐含层节点数
    num_o=num_class    #输出层节点数
    batch_size=100000
    epochs = 215
    seed=23
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    train_dataset = dataset("./data/retacred/distribution/prompt_train.txt",
                            "./data/retacred/distribution/MLP_train.txt",
                            "./data/retacred/distribution/train_label.txt",num_class)
    test_dataset=dataset("./data/retacred/distribution/prompt_test.txt",
                         "./data/retacred/distribution/MLP_test.txt",
                         "./data/retacred/distribution/test_label.txt",num_class)
    train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=False)
    test_dataloader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
    
    model=Model(num_i,num_h,num_o)
    
    cost = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in tqdm(range(epochs)):
        sum_loss=0
        train_correct=0
        for data in train_dataloader:
            inputs,labels=data 
        #     print(inputs.shape)
            outputs=model(inputs)
            optimizer.zero_grad()
            loss=cost(outputs,labels)
            loss.backward()
            optimizer.step()
            
    mi_f1=evaluate()
    print(mi_f1)