# -*- coding: utf-8 -*-

import csv
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from extract_dataset import generate_samples

import numpy as np
from sklearn.linear_model import SGDClassifier

from roberta_embeddings import extract_dataset_features


def linear_scores(X, y):
    # train a simple linear probe
    svm = SGDClassifier()
    svm.fit(X, y)
    train_acc = svm.score(X, y)
    
    # compute its softmax scores on the data
    scores = svm.decision_function(X)
    
    return scores, train_acc


def insertion_scores(datafolder, contexts, insertions, suffix="_emb_pre.csv", rand=False):
    # load embeddings
    print(">> run.py: loading embeddings")
    embeddings = []
    for i, context in enumerate(contexts):
        filepath = datafolder + "features/c_" + str(i)
        emb = np.loadtxt(filepath + suffix, delimiter=",")
        embeddings.append(emb)
    embeddings = np.concatenate(embeddings, axis=0)
    assert(embeddings.shape[0] == len(contexts) * len(insertions))
    
    # compute binary ground-truth
    lab_leq = [i[2] == "leq" for i in insertions] * len(contexts)
    lab_geq = [i[2] == "geq" for i in insertions] * len(contexts)
    
    # activate this for a control experiment with randomised labels
    if rand:
        np.random.shuffle(lab_leq)
        np.random.shuffle(lab_geq)
    
    # train a linear probe and compute its softmax output
    print(">> run.py: training linear classifiers")
    scores_leq, t_acc_leq = linear_scores(embeddings, lab_leq)
    scores_geq, t_acc_geq = linear_scores(embeddings, lab_geq)
    
    return (scores_leq, scores_geq, t_acc_leq, t_acc_geq)


def check_safety(context, s_leq, s_geq, s_entail):
    # select the relevant insertion probe
    if context[2] == "up":
        s_insert = s_leq
    elif context[2] == "down":
        s_insert = s_geq
    else:
        raise ValueError("Unknown context monotonicity", context[2],
                         "for context", context[0])
    
    n_entries = len(s_entail)
    n_pairs = np.zeros(n_entries)
    n_safe = np.zeros(n_entries)
    
    # check all possible pairs
    for a in range(0, n_entries-1):
        for b in range(a+1, n_entries):
            
            order_insert = np.sign(s_insert[a] - s_insert[b])
            order_entail = np.sign(s_entail[a] - s_entail[b])
            
            if order_insert == order_entail:
                n_safe[a] = n_safe[a] + 1
                n_safe[b] = n_safe[b] + 1
            
            n_pairs[a] = n_pairs[a] + 1
            n_pairs[b] = n_pairs[b] + 1
    
    r_safe = n_safe / n_pairs
    
    return r_safe


def print_sorted(filepath, items, scores):
    sid = np.argsort(scores)
    sitems = [(*items[i], scores[i]) for i in sid]
    
    with open(filepath, 'w', newline='') as f:
        csv_out = csv.writer(f, delimiter=',')
        csv_out.writerows(sitems)
    
    return


def main():
    datafolder = "../../data/raw/entailment_composition/"
    
    print("> run.py: loading raw data")
    with open(datafolder + "insertions_clean.csv", 'r', newline='') as f:
        csv_in = csv.reader(f, delimiter=',')
        insertions = [row for row in csv_in]
    with open(datafolder + "contexts_clean.csv", 'r', newline='') as f:
        csv_in = csv.reader(f, delimiter=',')
        contexts = [row for row in csv_in]
    
    print("> run.py: loading ML model")
    checkpoint = "roberta-large-mnli"
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, output_hidden_states=True)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model.eval()
    
    print("> run.py: computing embeddings")
    for i, context in enumerate(contexts):
        print(context[0])
        
        samples = generate_samples(context, insertions)
        
        (emb_pre, emb_last, logits) = extract_dataset_features(model, tokenizer, samples)
        
        filepath = datafolder + "features/c_" + str(i)
        np.savetxt(filepath + "_emb_pre.csv", emb_pre, delimiter=",")
        np.savetxt(filepath + "_emb_last.csv", emb_last, delimiter=",")
        np.savetxt(filepath + "_logits.csv", logits, delimiter=",")
    
    print("> run.py: training probe (second to last layer)")
    s_leq, s_geq, a_leq, a_geq = insertion_scores(datafolder,
                                                  contexts,
                                                  insertions,
                                                  suffix="_emb_pre.csv",
                                                  rand=False)
    n_ins = len(insertions)
    for i in range(len(contexts)):
        j = i * n_ins
        filepath = datafolder + "features/c_" + str(i)
        np.savetxt(filepath + "_s_leq_pre.csv", s_leq[j:n_ins+j], delimiter=",")
        np.savetxt(filepath + "_s_geq_pre.csv", s_geq[j:n_ins+j], delimiter=",")
    print("train_acc_leq (pre):", a_leq)
    print("train_acc_geq (pre):", a_geq)
    
    print("> run.py: training probe (very last layer)")
    s_leq, s_geq, a_leq, a_geq = insertion_scores(datafolder,
                                                  contexts,
                                                  insertions,
                                                  suffix="_emb_last.csv",
                                                  rand=False)
    n_ins = len(insertions)
    for i in range(len(contexts)):
        j = i * n_ins
        filepath = datafolder + "features/c_" + str(i)
        np.savetxt(filepath + "_s_leq_last.csv", s_leq[j:n_ins+j], delimiter=",")
        np.savetxt(filepath + "_s_geq_last.csv", s_geq[j:n_ins+j], delimiter=",")
    print("train_acc_leq (last):", a_leq)
    print("train_acc_geq (last):", a_geq)
    
    print("> run.py: loading scores")
    s_leq_pre = []
    s_geq_pre = []
    s_leq_last = []
    s_geq_last = []
    s_entail = []
    for i, context in enumerate(contexts):
        filepath = datafolder + "/features/c_" + str(i)
        
        s_leq_pre.append(np.loadtxt(filepath + "_s_leq_pre.csv", delimiter=","))
        s_geq_pre.append(np.loadtxt(filepath + "_s_geq_pre.csv", delimiter=","))
        s_leq_last.append(np.loadtxt(filepath + "_s_leq_last.csv", delimiter=","))
        s_geq_last.append(np.loadtxt(filepath + "_s_geq_last.csv", delimiter=","))
        
        # ignore neutral entailment label
        logits = np.loadtxt(filepath + "_logits.csv", delimiter=",")
        s_entail.append(logits[:, 2] - logits[:, 0])
    
    print("> run.py: checking safety")
    for i, context in enumerate(contexts):
        print(">>", str(i+1), "-", context[0])
        
        r_safe_pre = check_safety(context, s_leq_pre[i], s_geq_pre[i], s_entail[i])
        r_safe_last = check_safety(context, s_leq_last[i], s_geq_last[i], s_entail[i])
        
        filepath = datafolder + "features/c_" + str(i)
        np.savetxt(filepath + "_safety_pre.csv", r_safe_pre, delimiter=",")
        np.savetxt(filepath + "_safety_last.csv", r_safe_last, delimiter=",")
    
    print("> run.py: computing metrics")
    r_safe_pre = []
    r_safe_last = []
    for i, context in enumerate(contexts):
        filepath = datafolder + "features/c_" + str(i)
        r_safe_pre.append(np.loadtxt(filepath + "_safety_pre.csv", delimiter=","))
        r_safe_last.append(np.loadtxt(filepath + "_safety_last.csv", delimiter=","))
    
    # safety of each insertion across contexts
    insertion_safety_pre = np.mean(np.stack(r_safe_pre),axis=0)
    insertion_safety_last = np.mean(np.stack(r_safe_last),axis=0)
    np.savetxt(datafolder + "insertion_safety_pre.csv", insertion_safety_pre, delimiter=",")
    np.savetxt(datafolder + "insertion_safety_last.csv", insertion_safety_last, delimiter=",")
    print_sorted(datafolder + "insertion_sorted_pre.csv", insertions, insertion_safety_pre)
    print_sorted(datafolder + "insertion_sorted_last.csv", insertions, insertion_safety_last)
    
    # safety of each context across insertions
    context_safety_pre = np.array([np.mean(r) for r in r_safe_pre])
    context_safety_last = np.array([np.mean(r) for r in r_safe_last])
    np.savetxt(datafolder + "context_safety_pre.csv", context_safety_pre, delimiter=",")
    np.savetxt(datafolder + "context_safety_last.csv", context_safety_last, delimiter=",")
    print_sorted(datafolder + "context_sorted_pre.csv", contexts, context_safety_pre)
    print_sorted(datafolder + "context_sorted_last.csv", contexts, context_safety_last)
    
    # overall safety across the whole dataset
    overall_safety_pre = np.mean(np.concatenate(r_safe_pre))
    overall_safety_last = np.mean(np.concatenate(r_safe_last))
    print("overall_safety_pre:", overall_safety_pre)
    print("overall_safety_last:", overall_safety_last)
    
    return


if __name__ == "__main__":
    main()
