# -*- coding: utf-8 -*-

import csv
import random
import numpy as np

import original_text
import mod_text
from nlp_safety_prop.models.embeddings.roberta import RobertaModelPreTrained
from nlp_safety_prop.utils.paths import get_data_path


def generate_datasets(datafolder: str, roberta: RobertaModelPreTrained):
    # read the input dataset
    reviews = original_text.extract_from_file(datafolder + "original_rt_snippets.txt")
    
    # compute the scores for the original dataset
    scores = roberta_scores.compute_all_scores(tokenizer, model, reviews)
    np.savetxt(datafolder + "scores_orig.csv", scores, delimiter=",")
    
    # read the modifiers
    mods = mod_text.read_mod_from_file(datafolder + "mod_list.csv")
    
    for mod in mods[4:]:
        mod_id = mod[0]
        prefix = mod[1]
        suffix = mod[2]
        
        # generate modified inputs
        mod_reviews = mod_text.concatenate_mod_to_all(reviews, prefix, suffix)
        
        # compute and save the scores for the original dataset
        scores = roberta.compute_all_scores(mod_reviews)
        np.savetxt(datafolder + "scores_mod_" + mod_id + ".csv", scores, delimiter=",")
    
    # return a list with all mod ids
    mod_ids = [mod[0] for mod in mods]
    return mod_ids


def check_property(scores):
    
    n_entries = len(scores[0])
    n_safe = np.zeros(n_entries)
    
    for a in range(n_entries - 1):
        for b in range(a + 1, n_entries):
            
            # compute the pair ordering in every context
            diffs = np.zeros(len(scores))
            for i, val in enumerate(scores):
                diffs[i] = val[a] - val[b]
            
            # the property is satisfied if the order is the same in all contexts
            order = np.sign(diffs)
            if np.all(order == order[0]):
                n_safe[a] = n_safe[a] + 1
                n_safe[b] = n_safe[b] + 1
    
    n_pairs = n_entries * (n_entries - 1)
    safe_ratio = np.sum(n_safe) / n_pairs
    
    return (n_safe, safe_ratio)

def safety_comparison(datafolder, mod_ids):
    # load original scores
    orig_scores = np.loadtxt(datafolder + "scores_orig.csv", delimiter=",")
    
    # load modified scores
    all_mod_scores = []
    for mod_id in mod_ids:
        filepath = datafolder + "scores_mod_" + mod_id + ".csv"
        all_mod_scores.append(np.loadtxt(filepath, delimiter=","))
    
    # compare the original and modified scores
    for i, mod_scores in enumerate(all_mod_scores):
        n_safe, safe_ratio = check_all_properties([orig_scores, mod_scores])
        
        print("> run.py: mod id", mod_ids[i])
        print("> run.py: overall safety ratio of", safe_ratio)
        
        np.savetxt(datafolder + "safety_mod_" + mod_ids[i] + ".csv", n_safe, delimiter=",")
    
    return

def unstable_inputs(datafolder, mod_ids):
    
    filepath = datafolder + "safety_mod_" + mod_ids[0] + ".csv"
    n_safe = np.loadtxt(filepath, delimiter=",")
    for mod_id in mod_ids[1:]:
        filepath = datafolder + "safety_mod_" + mod_id + ".csv"
        n_safe = n_safe + np.loadtxt(filepath, delimiter=",")
    
    n_entries = len(n_safe)
    n_safe = n_safe / (len(mod_ids) * (n_entries - 1))
    
    ascending_ids = np.argsort(n_safe)
    
    reviews = original_text.extract_from_file(datafolder + "original_rt_snippets.txt")
    orig_scores = np.loadtxt(datafolder + "scores_orig.csv", delimiter=",")
    
    sorted_inputs = [[i, orig_scores[i], n_safe[i], reviews[i]] for i in ascending_ids]
    
    outputfile = datafolder + "sorted_inputs.csv"
    with open(outputfile, 'w', newline='') as f:
        csv_out = csv.writer(f, delimiter='\t')
        csv_out.writerow(['review id', 'positive sentiment log-odds', 'percentage of safe pairs', 'review text'])
        csv_out.writerows(sorted_inputs)
    
    corr = [[orig_scores[i], n_safe[i]] for i in ascending_ids]
    outputfile = datafolder + "correlation.csv"
    with open(outputfile, 'w', newline='') as f:
        csv_out = csv.writer(f, delimiter=',')
        csv_out.writerows(corr)
    
    corr = [corr[i] for i in random.sample(range(n_entries), 1024)]
    outputfile = datafolder + "corr_sampled.csv"
    with open(outputfile, 'w', newline='') as f:
        csv_out = csv.writer(f, delimiter=',')
        csv_out.writerows(corr)
    
    return

def main():
    datafolder = get_data_path("raw/sentiment_invariance/")
    
    print("> run.py: loading ML model")
    checkpoint = "siebert/sentiment-roberta-large-english"
    roberta = RobertaModelPreTrained(checkpoint)
    
    print("> run.py: generating scores")
    mod_ids = generate_datasets(datafolder, roberta)
    
    print("> run.py: checking safety property")
    safety_comparison(datafolder, mod_ids)
    
    print("> run.py: sorting unstable inputs")
    unstable_inputs(datafolder, mod_ids)

if __name__ == "__main__":
    main()
