from sklearn.metrics import precision_recall_curve
from pathlib import Path
import copy
import time
import sklearn
from sklearn.metrics import classification_report
#from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
import torch
from pytorch_models import LogisticRegression, pytorch_sigmoid, numpy_sigmoid
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import random
import tqdm
import numpy as np
import argparse
import ipdb

from sampling import sample_noise

from data_utils import make_data_split, make_data_split_3way, MainDataset, convert_str_to_array

torch.set_num_threads(1)
print(f"Using CPU threads:",torch.get_num_threads())


def OneHot(data, label):
    res = np.array(np.transpose([data==c for  c in label]), dtype=int)
    return res


def classification_with_thresholds(y_true, y_score, n_steps=20, fix_thresh=None):
    output = {}
    if fix_thresh is None:
        thresholds = np.linspace(0,1,n_steps + 1)
    else:
        thresholds = np.array(fix_thresh)
        #ipdb.set_trace()
    for thresh in thresholds:
        tp, fp, tn, fn = 0, 0, 0, 0
        y_pred = (y_score > thresh).tolist()
        assert len(y_true) == len(y_pred)
        for label, pred in zip(y_true, y_pred):
            if label == 1 and pred == True:
                tp += 1
            elif label == 1 and pred == False:
                fn += 1
            elif label == 0 and pred == True:
                fp += 1
            elif label == 0 and pred == False:
                tn += 1

            #ipdb.set_trace()
        f1 = tp / float(tp + 0.5*float(fp+fn)) # tp / (tp + 1/2 * (fp +fn))
        output[thresh] = {"f1": f1,
                          "tp": tp,
                          "fp": fp,
                          "tn": tn,
                          "fn": fn,
                         }

        #print(thresh, f"F1: {f1}, TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
    #ipdb.set_trace()
    return output

def evaluate(weights, data_loader, fix_thresh=None, perturb=0):
    correct = 0
    total = 0
    y_test = []
    y_pred = []
    y_score = []
    for z, y, text in data_loader:
        y_test.extend(y.tolist())

        z = z.numpy()
        # perturb
        if perturb > 0:
            z = z + sample_noise(768, 1/perturb, z.shape[0], 768)
        #ipdb.set_trace()

        z = np.concatenate([z, np.ones((z.shape[0],1)).astype(np.float32)], axis=1)
        outputs = numpy_sigmoid(np.matmul(z, np.transpose(weights)))
        scores = copy.deepcopy(outputs)
        #outputs[:,1] = outputs
        #ipdb.set_trace()
        #predicted = outputs >= 0.5
        #predicted = predicted.squeeze(1)
        total += y.size(0)
        # for gpu, bring the predicted and labels back to cpu fro python operations to work
        #correct += (predicted == y.numpy()).sum()
        #y_pred.extend(predicted.tolist())

        # v2
        #outputs = outputs[:, 0]
        outputs = np.argmax(outputs,axis=1) # multiclass
        y_pred.extend(outputs.tolist())
        y_score.append(scores)

        #y_score.extend(outputs.tolist())
        #ipdb.set_trace()

    #ipdb.set_trace()
    #test_accuracy = 100 * correct/total
    #fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_test, y_score, pos_label=1) # Order: y_true, y_score
    #auc = sklearn.metrics.auc(fpr, tpr)
    #print("Epoch: {}. Loss: {}. Accuracy: {}. AUC: {}.".format(epoch, loss.item(), test_accuracy, auc))
    #print('### Regression ###', '\n', classification_report(y_test, y_pred))
    #f1_results = classification_with_thresholds(y_test, y_score, n_steps=20, fix_thresh=fix_thresh)
    #best_f1 = sorted(f1_results.items(), key=lambda x: x[1]['f1'], reverse=True)[0]

    """
    #ipdb.set_trace()
    if fix_thresh is None:
        # ROC Curve
        #optimal_idx = np.argmax(tpr - fpr)
        #optimal_idx = np.argmin((1 - tpr) ** 2 + fpr ** 2)
        #optimal_threshold = thresholds[optimal_idx]

        #y_pred = (np.array(y_score) > optimal_threshold).tolist()
        #optimal_f1 = sklearn.metrics.f1_score(y_test, y_pred)

        # PR Curve
        ipdb.set_trace()
        precision, recall, thresholds = precision_recall_curve(y_test, y_score) 
        fscore = (2 * precision * recall) / (precision + recall)
        ix = np.argmax(fscore)

        optimal_threshold = thresholds[ix]
        y_pred = (np.array(y_score) > optimal_threshold).tolist()
        optimal_f1 = fscore[ix]
        #print(f"ROC: {optimal_threshold} F1: {optimal_f1} , Prec-rec: {cc[ix]}, F1: {fscore[ix]}")
        #ipdb.set_trace()

    else:
        #ipdb.set_trace()
        optimal_threshold = fix_thresh[0]
        y_pred = (np.array(y_score) > fix_thresh).tolist()

    best_f1 = classification_with_thresholds(y_test, y_score, n_steps=20, fix_thresh=[optimal_threshold])
    """
    y_score = np.concatenate(y_score,0)
    #ipdb.set_trace()
    #test_accuracy = (np.array(y_test) == y_pred).sum()/total
    test_accuracy = sklearn.metrics.f1_score(y_test, y_pred, average='macro')

    class_aucs = []
    for class_id in range(7):
        class_items = [(label, pred) for (label, pred) in zip (y_test, y_pred)] 
        class_labels = [item[0] for item in class_items]
        #class_preds = [item[1] for item in class_items]
        class_labels = (np.array(class_labels) == class_id)
        #class_preds = (np.array(class_preds) == class_id)
        class_preds = y_score[:,class_id]
        #ipdb.set_trace()
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(class_labels, class_preds, pos_label=1) # Order: y_true, y_score
        auc = sklearn.metrics.auc(fpr, tpr)
        class_aucs.append(auc)
        #ipdb.set_trace()

    """
    best_f1 = (optimal_threshold, {'f1': optimal_f1,
                                   'tp': ,
                                   'fp': ,
                                   'tn': ,
                                   'fn': ,
                                  })
    """
    #print(best_f1)

    #'y_test': y_test,
    #'y_pred': y_pred,
    #'y_score': y_score,
    #ipdb.set_trace()
    return {'test_accuracy': test_accuracy,
            'auc': np.mean(class_aucs),
            'best_f1': (0, test_accuracy),
            }

def train(train_loader, eval_loader, weights, v_weights, gamma, lr_rate, epochs, perturb=0):
    #ipdb.set_trace()
    start = time.time()
    #print("Time start") 

    results_log = {}

    for epoch in range(int(epochs)):
        #for i, (x, y) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")):
        for i, (z, y, text) in enumerate(train_loader):
            z = z.numpy()

            # perturb
            if perturb > 0:
                z = z + sample_noise(768, 1/perturb, z.shape[0], 768)
            #ipdb.set_trace()

            z = np.concatenate([z, np.ones((z.shape[0],1)).astype(np.float32)], axis=1)

            # Set label array according to number of classes
            #label = np.array([0,1]).astype(np.int32)
            label = np.array([0, 1, 2, 3, 4, 5, 6])

            #ipdb.set_trace()
            if len(label) > 2:
                y_train_prep = OneHot(y.numpy(),label)
            else:
                y_train_prep = OneHot(y.numpy(),label[-1:])
           
            y_train_prep = np.where(y_train_prep == 0, -1, y_train_prep)
            # num_class (2) x num_rows x num_cols
            z_train_prep = np.zeros((y_train_prep.shape[1], z.shape[0], z.shape[1])) 
            num_class, num_rows, num_cols = z_train_prep.shape
            #ipdb.set_trace()
            for idx in range(num_class):
                for i in range(num_rows):
                    z_train_prep[idx][i] = z[i] * y_train_prep[i, idx]

            for c in range(len(label)):
                
                z = copy.deepcopy(z_train_prep[c])

                #ipdb.set_trace()

                w = copy.deepcopy(weights[c])
                v = copy.deepcopy(v_weights[c])
                #ipdb.set_trace()
                #w = weights[c]
                #v = v_weights[c]

                # Compute grad 
                grad = np.matmul(z, v)
                grad = numpy_sigmoid(-grad)
                grad = lr_rate * np.matmul(np.transpose(z), grad) / z.shape[0]

                #ipdb.set_trace()
                # Update
                weights[c] = v + grad
                v_weights[c] = (1 - gamma) * weights[c] + gamma * w
                #ipdb.set_trace()

        #ipdb.set_trace()
        result_dict = evaluate(weights, eval_loader, perturb=perturb)
        results_log[epoch] = result_dict

    #ipdb.set_trace()
    end = time.time()
    elapsed = end - start
    #ipdb.set_trace()
    return results_log, elapsed


def model_setup(lr_rate, args):
    #input_dim=768
    #output_dim=1

    weights = np.random.rand(args.num_class, args.num_cols) / np.sqrt(args.num_cols)
    weights = weights.astype(np.float32)
    v_weights = np.random.rand(args.num_class, args.num_cols) / np.sqrt(args.num_cols)
    v_weights = v_weights.astype(np.float32)
    #ipdb.set_trace()

    #model = LogisticRegression(input_dim, output_dim)
    #torch.nn.init.uniform_(model.linear.weight, a=-1/np.sqrt(768), b=1/np.sqrt(768))
    #torch.nn.init.uniform_(model.linear.bias, a=-1/np.sqrt(768), b=1/np.sqrt(768))
    #ipdb.set_trace()
    #criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) # computes softmax and then the cross entropy
    #criterion = torch.nn.BCELoss()
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate, nesterov=True, momentum=0.9)
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)

    return weights, v_weights

def load_fixed_data(batch_size):
    path = "/home/XXXXX/workspace/privacy_project/data/3way"
    path = "/home/XXXXX/workspace/privacy_project/data/snips"

    #ipdb.set_trace()

    datasets = []

    for data_split in ['train', 'val', 'test']:
        X_list = []
        y_list = []
        text_list = []
        for vec in open(Path(path) / f"{data_split}.csv"):
            vec = convert_str_to_array(vec)
            X_list.append(vec)
        for item in open(Path(path) / f"{data_split}_xy.txt"):
            label, text = item.rstrip('\n').split('\t')
            label = int(label)
            y_list.append(label)
            text_list.append(text)

        dataset = MainDataset(X_list, y_list, text_list)
        loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0)
        datasets.append(loader)
    #ipdb.set_trace()
    return datasets

def run(args):
    curr_best_f1 = -1

    # Gridsearch hparams
    for seed in [args.seed]:
        args.seed = seed
        # Fix seeds
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)


        for batch_size in [128]:
            train_loader, val_loader, test_loader = load_fixed_data(batch_size)
            for eta in [0.1]:
                for lr_rate in [3.0]:
                    for epochs in [10]:
                            print(f"Seed {seed}, LR {lr_rate}, Epochs {epochs}, ETA {eta}, Batch size {batch_size}")
                            weights, v_weights = model_setup(lr_rate, args)

                            results_log, elapsed = train(train_loader, val_loader, weights, v_weights, eta, lr_rate, epochs, perturb=args.perturb)
        
                            #ipdb.set_trace()
                            best_epoch = sorted(results_log.items(), key=lambda x: x[1]['best_f1'][1], reverse=True)[0]
                            opt_f1 = best_epoch[1]['best_f1'][1]

                            if opt_f1 > curr_best_f1:
                                curr_best_f1 = opt_f1
                                opt_thresh = best_epoch[1]['best_f1'][0]
                                print(f"\nNew best F1: {curr_best_f1} - Seed {seed}, LR {lr_rate}, Gamma {eta}, Epochs {epochs}, Batch Size {batch_size}, N Train {len(train_loader.dataset)}")
                                print(f"Evaluation on val: {best_epoch}")
                                
                                # Inference on test
                                inference_result_dict = evaluate(weights, test_loader, fix_thresh=[opt_thresh], perturb=args.perturb)
                                print(f"Inference on test: {inference_result_dict}")
                                #ipdb.set_trace()

                            del weights
                            del v_weights
                            
                        
            del train_loader    
            del val_loader
            del test_loader

    ipdb.set_trace()

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--all_vectors', type=str, default='./sentence_vectors.csv', help='path to training data csv')
    parser.add_argument('--all_labels', type=str, default='./xy.txt', help='path to training data info')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_class', type=int, default=7)
    parser.add_argument('--num_cols', type=int, default=769)
    parser.add_argument('--test_ratio', type=float, default=0.15)
    parser.add_argument('--val_ratio', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=0.5)
    parser.add_argument('--perturb', type=float, default=0.0)

    parser.add_argument('--n_epochs', type=int, default=40, help='num training epochs')
    args = parser.parse_args()

    run(args)
