import sys
import os

sys.path.append("..")

import json
import numpy as np
from random import Random

import tensorflow as tf
from datasets import load_dataset
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import matthews_corrcoef, f1_score, accuracy_score

from transformers import BertTokenizer

from utils.classification_utils import convertors as classification_convertors
from utils.glue_utils import glue_examples_to_tfdataset, glue_tasks_metrics, glue_tasks_num_labels
from utils.rationale_utils import eraser_to_token_level_rationale_data, num_label_tasks
from utils.rationale_utils import convertors as rationale_convertors

from directories import *

def sals_to_tfdataset(sals):
    def gen():
        for i in range(len(sals)):
            yield sals[i]

    return tf.data.Dataset.from_generator(gen,
        tf.float32,
        tf.TensorShape([12, None]))

def get_best_epochs(dir):
    ### Returns epochs from worst to best
    with open(dir, "r", encoding='utf-8') as jsonfile:
        data = json.load(jsonfile)
    dev_set = list(data[0].keys())[0]
    dev_metric = list(data[0][dev_set].keys())[0]
    scores = []
    for i in range(len(data)):
        scores.append(data[i][dev_set][dev_metric])
    return np.argsort(scores)


def load_imdb(tokenizer, max_length, training_batch_size, eval_batch_size, seed=42, include_sals=False):
    train_data = load_dataset("imdb", split='train').shuffle(seed=seed)
    test_data = load_dataset("imdb", split='test')

    num_train_examples = len(train_data) - 4096
    num_eval_examples = 4096
    num_test_examples = len(test_data)

    train_dataset = classification_convertors["imdb"](train_data, tokenizer, max_length=max_length).cache()
    eval_dataset = train_dataset.take(4096)
    train_dataset = train_dataset.skip(4096)
    test_dataset = classification_convertors["imdb"](test_data, tokenizer, max_length=max_length).cache()

    train_dataset = train_dataset.batch(training_batch_size)
    eval_dataset = eval_dataset.batch(eval_batch_size)
    test_dataset = test_dataset.batch(eval_batch_size)

    train_steps = int(np.ceil(num_train_examples / training_batch_size))
    eval_steps = int(np.ceil(num_eval_examples / eval_batch_size))
    test_steps = int(np.ceil(num_test_examples / eval_batch_size))

    outputs = (
        {
            "train": train_dataset,
            "evals": {"validation": eval_dataset},
            "tests": {"test": test_dataset}
        },
        {
            "train_steps": train_steps,
            "train_examples": num_train_examples,
            "num_labels": train_data.features["label"].num_classes
        },
        {
            'accuracy': accuracy_score
        }
    )

    if include_sals:
        train_summed_sal_mat = np.zeros((num_train_examples, 12, max_length), dtype=np.float32)
        SALIENCIES_PATH = WORKER_SALIENCIES_PATH_to_READ("imdb", "bert" if tokenizer.name_or_path[:4] == "bert" else "albert")
        BEST_EPOCHS = get_best_epochs("./directory/bert/imdb/logs/ft_logs.json")[-3:] + 1

        for i in BEST_EPOCHS:
            loaded = np.load(SALIENCIES_PATH+f"{i}.npy")
            train_summed_sal_mat += loaded[:num_train_examples, :, :max_length]
        train_summed_sal_mat /= len(BEST_EPOCHS)

        sal_dataset = tf.data.Dataset.from_tensor_slices(train_summed_sal_mat[:, 0]).batch(training_batch_size)

        outputs = outputs + (sal_dataset,)

    return outputs


def load_agnews(tokenizer, max_length, training_batch_size, eval_batch_size, seed=42, include_sals=False):
    train_data = load_dataset("ag_news", split='train').shuffle(seed=seed)
    test_data = load_dataset("ag_news", split='test')

    num_train_examples = len(train_data) - 6000
    num_eval_examples = 6000
    num_test_examples = len(test_data)

    train_dataset = classification_convertors["ag_news"](train_data, tokenizer, max_length=max_length).cache()
    eval_dataset = train_dataset.take(6000)
    train_dataset = train_dataset.skip(6000)
    test_dataset = classification_convertors["ag_news"](test_data, tokenizer, max_length=max_length).cache()

    train_dataset = train_dataset.batch(training_batch_size)
    eval_dataset = eval_dataset.batch(eval_batch_size)
    test_dataset = test_dataset.batch(eval_batch_size)

    train_steps = int(np.ceil(num_train_examples / training_batch_size))
    eval_steps = int(np.ceil(num_eval_examples / eval_batch_size))
    test_steps = int(np.ceil(num_test_examples / eval_batch_size))

    outputs = (
        {
            "train": train_dataset,
            "evals": {"validation": eval_dataset},
            "tests": {"test": test_dataset}
        },
        {
            "train_steps": train_steps,
            "train_examples": num_train_examples,
            "num_labels": train_data.features["label"].num_classes
        },
        {
            'accuracy': accuracy_score
        }
    )

    if include_sals:
        train_summed_sal_mat = np.zeros((num_train_examples, 12, max_length), dtype=np.float32)
        SALIENCIES_PATH = WORKER_SALIENCIES_PATH_to_READ("ag_news", "bert" if tokenizer.name_or_path[:4] == "bert" else "albert")
        BEST_EPOCHS = get_best_epochs("./directory/bert/ag_news/logs/ft_logs.json")[-3:] + 1

        for i in BEST_EPOCHS:
            loaded = np.load(SALIENCIES_PATH+f"{i}.npy")
            train_summed_sal_mat += loaded[:num_train_examples, :, :max_length]
        train_summed_sal_mat /= len(BEST_EPOCHS)

        sal_dataset = tf.data.Dataset.from_tensor_slices(train_summed_sal_mat[:, 0]).batch(training_batch_size)
        # sal_dataset = sals_to_tfdataset(train_summed_sal_mat).batch(training_batch_size)

        outputs = outputs + (sal_dataset,)

    return outputs

def load_eraser_task(task, tokenizer, max_length, training_batch_size, eval_batch_size, seed=42, include_sals=False):
    data_root = os.path.join('../ERASER/eraserbenchmark/data', task)
    data = eraser_to_token_level_rationale_data(data_root)
    train_data = data['train']
    eval_data = data['validation']
    test_data = data['test']

    Random(seed).shuffle(train_data)

    train_dataset = rationale_convertors[task](train_data, tokenizer, max_length=max_length).cache()
    eval_dataset = rationale_convertors[task](eval_data, tokenizer, max_length=max_length).cache()
    test_dataset = rationale_convertors[task](test_data, tokenizer, max_length=max_length).cache()

    train_dataset = train_dataset.batch(training_batch_size, drop_remainder=True)
    eval_dataset = eval_dataset.batch(eval_batch_size, drop_remainder=True)
    test_dataset = test_dataset.batch(eval_batch_size, drop_remainder=True)

    train_steps = int(3178 // training_batch_size) if task == "movie_rationales" else len(list(train_dataset.as_numpy_iterator()))
    print(train_steps)
    eval_steps = int(388 // eval_batch_size) if task == "movie_rationales" else len(list(eval_dataset.as_numpy_iterator()))
    print(eval_steps)
    test_steps = int(423 // eval_batch_size) if task == "movie_rationales" else len(list(test_dataset.as_numpy_iterator()))
    print(test_steps)
    num_train_examples = train_steps * training_batch_size

    outputs = (
        {
            "train": train_dataset,
            "evals": {"validation": eval_dataset},
            "tests": {"test": test_dataset}
        },
        {
            "train_steps": train_steps,
            "train_examples": num_train_examples,
            "val_steps": eval_steps,
            "val_examples": eval_steps * eval_batch_size,
            "test_steps": test_steps,
            "test_examples": test_steps * eval_batch_size,
            "num_labels": 2
        },
        {
            'accuracy': accuracy_score,
            "f1": f1_score
        }
    )

    if include_sals:
        train_summed_sal_mat = np.zeros((num_train_examples, 12, max_length), dtype=np.float32)
        SALIENCIES_PATH = WORKER_SALIENCIES_PATH_to_READ(task, "bert" if tokenizer.name_or_path[:4] == "bert" else "albert")
        BEST_EPOCHS = get_best_epochs(f"./directory/bert/{task}/logs/ft_logs.json")[-3:] + 1

        for i in BEST_EPOCHS:
            loaded = np.load(SALIENCIES_PATH+f"{i}.npy")
            train_summed_sal_mat += loaded[:num_train_examples, :, :max_length]
        train_summed_sal_mat /= len(BEST_EPOCHS)

        sal_dataset = tf.data.Dataset.from_tensor_slices(train_summed_sal_mat[:, 0]).batch(training_batch_size)
        # sal_dataset = sals_to_tfdataset(train_summed_sal_mat).batch(training_batch_size)

        outputs = outputs + (sal_dataset,)

    return outputs

def load_glue_task(task, tokenizer, max_length, training_batch_size, eval_batch_size, seed=42, include_sals=False):
    train_data = load_dataset("glue", task , split='train').shuffle(seed=seed)
    if task != "mnli":
        validation_data = load_dataset("glue", task, split='validation')
        test_data = load_dataset("glue", task, split='test')
    else:
        validation_m_data = load_dataset("glue", task, split='validation_matched')
        validation_mm_data = load_dataset("glue", task, split='validation_mismatched')
        test_m_data = load_dataset("glue", task, split='test_matched')
        test_mm_data = load_dataset("glue", task, split='test_mismatched')
    
    num_train_examples = len(train_data)

    train_dataset = glue_examples_to_tfdataset(train_data, tokenizer, task, max_length=max_length).cache()

    if task != "mnli":
        validation_dataset = glue_examples_to_tfdataset(validation_data, tokenizer, task, max_length=max_length).cache()
        test_dataset = glue_examples_to_tfdataset(test_data, tokenizer, task, max_length=max_length).cache()
    else:
        validation_m_dataset = glue_examples_to_tfdataset(validation_m_data, tokenizer, task, max_length=max_length).cache()
        validation_mm_dataset = glue_examples_to_tfdataset(validation_mm_data, tokenizer, task, max_length=max_length).cache()
        test_m_dataset = glue_examples_to_tfdataset(test_m_data, tokenizer, task, max_length=max_length).cache()
        test_mm_dataset = glue_examples_to_tfdataset(test_mm_data, tokenizer, task, max_length=max_length).cache()

    train_dataset = train_dataset.batch(training_batch_size)

    if task != "mnli":
        validation_dataset = validation_dataset.batch(eval_batch_size)
        test_dataset = test_dataset.batch(eval_batch_size)
    else:
        validation_m_dataset = validation_m_dataset.batch(eval_batch_size)
        validation_mm_dataset = validation_mm_dataset.batch(eval_batch_size)
        test_m_dataset = test_m_dataset.batch(eval_batch_size)
        test_mm_dataset = test_mm_dataset.batch(eval_batch_size)

    train_steps = int(np.ceil(num_train_examples / training_batch_size))
    
    outputs = (
        {
            "train": train_dataset,
            "evals": {"validation": validation_dataset} if task != "mnli" else {"validation-m": validation_m_dataset, "validation-mm": validation_mm_dataset},
            "tests": {"test": test_dataset} if task != "mnli" else {"test-m": test_m_dataset, "test-mm": test_mm_dataset}
        },
        {
            "train_steps": train_steps,
            "train_examples": num_train_examples,
            "num_labels": glue_tasks_num_labels[task]
        },
        glue_tasks_metrics[task]
    )

    if include_sals:
        train_summed_sal_mat = np.zeros((num_train_examples, 12, max_length), dtype=np.float32)
        SALIENCIES_PATH = WORKER_SALIENCIES_PATH_to_READ(task, "bert" if tokenizer.name_or_path[:4] == "bert" else "albert")
        BEST_EPOCHS = get_best_epochs(f"./directory/bert/{task}/logs/ft_logs.json")[-3:] + 1

        for i in BEST_EPOCHS:
            loaded = np.load(SALIENCIES_PATH+f"{i}.npy")
            train_summed_sal_mat += loaded[:num_train_examples, :, :max_length]
        train_summed_sal_mat /= len(BEST_EPOCHS)

        sal_dataset = tf.data.Dataset.from_tensor_slices(train_summed_sal_mat[:, 0]).batch(training_batch_size)

        outputs = outputs + (sal_dataset,)

    return outputs

def load_movie(**kwargs):
    return load_eraser_task("movie_rationales", **kwargs)

load_task = {
    "imdb" : load_imdb,
    "ag_news": load_agnews,
    "movie_rationales": load_movie
}
