import os    
import tqdm
import jsonlines
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_curve, auc

def Find_Optimal_Cutoff(TPR, FPR, threshold, label, predict):
    """
    :param threshold: array, shape = [n_thresholds]
    """
    
    '''
    acc = 0
    optimal_threshold = np.inf
    idx = 0
    for i, th in enumerate(threshold):
        pred_label = np.array([1 if i > th else 0 for i in predict])
        acc = np.sum(pred_label * label) / len(label)
        if acc >= acc: 
            optimal_threshold = th
            idx = i
    '''
    y = TPR - FPR
    Youden_index = np.argmax(y)  # Only the first occurrence is returned.
    optimal_threshold = threshold[Youden_index]
    point = [FPR[Youden_index], TPR[Youden_index]]
    
    #point = [FPR[idx], TPR[idx]]
    return optimal_threshold, point

def ROC(label, y_prob):
    """
    Receiver_Operating_Characteristic, ROC
    :param label: (n, )
    :param y_prob: (n, )
    :return: fpr, tpr, roc_auc, optimal_th, optimal_point
    """
    fpr, tpr, thresholds = roc_curve(label, y_prob)
    roc_auc = auc(fpr, tpr)
    optimal_th, optimal_point = Find_Optimal_Cutoff(TPR=tpr, FPR=fpr, threshold=thresholds, label=label, predict=y_prob)
    return fpr, tpr, thresholds, roc_auc, optimal_th, optimal_point

topk = 1 
strategy = 'original' # original, square, exp
#strategy = 'square'
#strategy = 'exp'

def expand(num):
    if strategy == 'original':
        return num
    elif strategy == 'square':
        return num**2
    else:
        return np.exp(num)

dir_path_input = './record-full/'
dir_path_output = './pecora-results-full/'

files = os.listdir(dir_path_input)

# Full dataset
# top1
# mean -0.5SD
#threshold = {'te': 1.3244373356358734, 'bn': 0.5818090328144909, 'ja': 0.7374812658326557, 'fi': 2.6087579173802933, 'ru': 1.3846934058095868}
# mean + 0SD
#threshold = {'te': 3.719334466382861, 'bn': 2.4173193188011646, 'ja': 3.4615943045867605, 'fi': 5.882648429386318, 'ru': 4.833288226388395}
# mean + 1SD
threshold = {'te': 8.509128727876837, 'bn': 6.088339890774511, 'ja': 8.90982038209497, 'fi': 12.430429453398368, 'ru': 11.73047786754601}

langs = ['te', 'bn', 'ja', 'fi', 'ru']

res_thres = dict()

for fname in files:
    if 'neg' in fname: continue
    #if not 'train' in fname: continue
    if not 'val' in fname: continue
    lang = fname.split('.')[0].split('-')[1]
    print('*****************************')
    print(lang)
    idx = 0
    
    pos_ins = {}
    neg_ins = {}
    for l in langs:
        pos_ins[l] = []
        neg_ins[l] = []

    accuracy = 0
    accuracy_tot = 0
    
    with open(dir_path_input + fname) as f:
        for item in jsonlines.Reader(f):
            '''
            # Remove yes/no questions
            if item['prediction'] in ["yes", "no"]:
                idx += 1
                continue
            '''
            #print("Current: {} - {}".format(lang, idx))
            #"<Q>: " + item['query'] + " <P>:" + passage
            
            # Get the highest CTI score for each instance
            save_path = dir_path_output + lang + '-' + str(idx) + '.json'
            with open(save_path) as r:
                res_pecora = json.load(r)
            

            score = 0
            for i in range(min(topk, len(res_pecora['cci_scores']))):
                score += expand(res_pecora['cci_scores'][i]['cti_score'])
            score /= min(topk, len(res_pecora['cci_scores']))
            #print(score)
            
            if item['ais']:
                pos_ins[item['query_language']].append(score)
            else:
                neg_ins[item['query_language']].append(score)

            threshold_CTI = threshold[item['query_language']]

            if score >= threshold_CTI:
                mark_attri = True
            else:
                mark_attri = False
                
            if mark_attri == item['ais']:
                accuracy += 1
            accuracy_tot += 1

            idx += 1

    print("Accuracy:    {}/{}={}".format(accuracy, accuracy_tot, accuracy/accuracy_tot))
    print()
    
    for l in langs:
        print(l)
        print("pos_mean:    {}".format(np.mean(pos_ins[l])))
        print("neg_mean:    {}".format(np.mean(neg_ins[l])))
        print("pos_median:  {}".format(np.median(pos_ins[l])))
        print("neg_median:  {}".format(np.median(neg_ins[l])))
    
        print("all_mean:    {}".format(np.mean(pos_ins[l]+neg_ins[l])))
        print("all_median:  {}".format(np.median(pos_ins[l]+neg_ins[l])))
        
        # Use average + sd as threshold
        res_thres[l] = np.mean(pos_ins[l]+neg_ins[l]) + 1 * np.std(pos_ins[l]+neg_ins[l])
        # Use best ROC AUC threshold as threshold
        '''
        label = np.array([1 for _ in range(len(pos_ins[l]))] + [0 for _ in range(len(neg_ins[l]))])
        predict = np.array(pos_ins[l]+neg_ins[l])
        _, _, _, _, optimal_th, _ = ROC(label, predict)
        res_thres[l] = optimal_th
        '''
        print("=============")

print(res_thres)
