
import torch
import numpy as np
from dataloader.dataloader import train_unshuffle_loader
from sklearn import cluster
import pandas as pd


def get_entity_mask(features,model):
    input_ids=features['input_ids']
    e1_start_id = model.tokenizer.convert_tokens_to_ids('<e1>')
    e1_end_id = model.tokenizer.convert_tokens_to_ids('</e1>')
    e2_start_id = model.tokenizer.convert_tokens_to_ids('<e2>')
    e2_end_id = model.tokenizer.convert_tokens_to_ids('</e2>')
    e1_start_idx = torch.nonzero((input_ids == e1_start_id),as_tuple=False)[:,1].unsqueeze(1)
    e1_end_idx = torch.nonzero((input_ids == e1_end_id),as_tuple=False)[:,1].unsqueeze(1)
    e2_start_idx = torch.nonzero((input_ids == e2_start_id),as_tuple=False)[:,1].unsqueeze(1)
    e2_end_idx = torch.nonzero((input_ids == e2_end_id),as_tuple=False)[:,1].unsqueeze(1)
    e1_idx=torch.cat([e1_start_idx,e1_end_idx],dim=-1)
    e2_idx=torch.cat([e2_start_idx,e2_end_idx],dim=-1)
    return e1_idx,e2_idx

def prepare_task_input(model, batch, args):
    text, text1, text2, text3,class_label = batch['text'], batch['text1'], batch['text2'],batch['text3'], batch['label'].cuda()
    txts = [text, text1, text2,text3]
    
    feat = []
    e1_idx,e2_idx=None,None
    for idx,text in enumerate(txts):
        features = model.tokenizer.batch_encode_plus(text, max_length=args.max_length, return_tensors='pt', padding='longest', truncation=True)
        e1_idx,e2_idx=get_entity_mask(features,model)
        features['e1_idx']=e1_idx
        features['e2_idx']=e2_idx
        for k in features.keys():
            features[k] = features[k].cuda()
        feat.append(features)
    return feat, class_label.detach()
    
from scipy.sparse import coo_matrix
from sklearn.preprocessing import normalize
def contingency_matrix(ref_labels, sys_labels):
    ref_classes, ref_class_inds = np.unique(ref_labels, return_inverse=True)
    sys_classes, sys_class_inds = np.unique(sys_labels, return_inverse=True)
    n_frames = ref_labels.size
    cmatrix = coo_matrix(
        (np.ones(n_frames), (ref_class_inds, sys_class_inds)),
        shape=(ref_classes.size, sys_classes.size),
        dtype=np.int)
    return cmatrix, ref_classes, sys_classes


def bcubed_sparse(ref_labels, sys_labels, cm=None):
    ref_labels = np.array(ref_labels)
    sys_labels = np.array(sys_labels)
    if cm is None:
        cm, _, _ = contingency_matrix(ref_labels, sys_labels)
    cm = cm.astype('float64')
    cm_norm = cm / cm.sum()
    cm_col_norm = normalize(cm, norm='l1', axis=0)

    cm_row_norm = normalize(cm, norm="l1", axis=1)
    recall = np.sum(cm_norm.multiply(cm_row_norm))
    precision = np.sum(cm_norm.multiply(cm_col_norm))

    f1 = 2*(precision*recall)/(precision + recall)
    return precision, recall, f1

best_avg_score=0
def evaluate_embedding(model,table, args, step):
    global best_avg_score
    model.eval()
    dataloader = train_unshuffle_loader(args)
    print('---- {} evaluation batches ----'.format(len(dataloader)))

    for i, batch in enumerate(dataloader):
        with torch.no_grad():
            _, label = batch['text'], batch['label'] 
            features, _ = prepare_task_input(model, batch, args, is_contrastive=False)
            embeddings = model.get_sentence_encoding(features)

            if i == 0:
                all_labels = label
                all_embeddings = embeddings.detach()
            else:
                all_labels = torch.cat((all_labels, label), dim=0)
                all_embeddings = torch.cat((all_embeddings, embeddings.detach()), dim=0)

    kmeans = cluster.KMeans(n_clusters=args.num_classes, random_state=args.seed)
    embeddings = all_embeddings.cpu().numpy()
    kmeans.fit(embeddings)
    pred_labels = torch.tensor(kmeans.labels_.astype(np.int))

    # NMI, ARI, V_measure
    from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score, completeness_score, v_measure_score
    nmi = normalized_mutual_info_score(all_labels, pred_labels)
    ari = adjusted_rand_score(all_labels, pred_labels)
    pred_labels=pred_labels.tolist()
    hom = homogeneity_score(all_labels, pred_labels)
    com = completeness_score(all_labels, pred_labels)
    v_f1 = v_measure_score(all_labels, pred_labels)
    b3_prec, b3_rec, b3_f1 = bcubed_sparse(all_labels, pred_labels)
    avg_score = round((b3_f1+v_f1+ari)/2, 4)
    table.add_row([b3_prec, b3_rec, b3_f1, hom, com, v_f1, ari,nmi,avg_score])
    print(table)
    if avg_score > best_avg_score:
        best_avg_score=avg_score
        torch.save(model, 'model.pt')
        print(table.get_string(sortby="MEAN", reversesort=True))
        print("best avg score:{}".format(best_avg_score))
    return None






