import torch
import os
import json
from torch_geometric.data import Data, Batch
from sklearn.metrics import f1_score, balanced_accuracy_score, confusion_matrix, precision_score, recall_score
import numpy as np
from collections import defaultdict

def transform_graph_geometric(embeddings, edge_index, edge_type):

    list_geometric_data = [Data(x=emb, edge_index=torch.tensor(edge_index[idx], dtype=torch.long),
                                y=torch.tensor(edge_type[idx], dtype=torch.long)) for idx, emb in enumerate(embeddings)]

    bdl = Batch.from_data_list(list_geometric_data)
    bdl = bdl.to("cuda:" + str(torch.cuda.current_device()))

    return bdl


def save_metrics(metrics, output_file):
    if os.path.exists(output_file):
        append_write = "a"  # append if already exists
    else:
        append_write = "w"  # make a new file if not

    with open(output_file, append_write, encoding="utf-8") as fd:
            fd.write(json.dumps(metrics, ensure_ascii=False) + "\n")


def maybe_save_checkpoint(metrics, save_dir, global_step, model, tokenizer):

    best_bacc = 0
    folder_checkpoint = ""

    output_file = os.path.join(save_dir, "best_checkpoint.json")
    if os.path.exists(output_file):
        with open(output_file, encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
            best_bacc = data[0]["avg_bacc"]
            folder_checkpoint = data[0]["folder_checkpoint"]

    if metrics["avg_bacc"] > best_bacc:

        save_dir_name = "step_{}".format(global_step)
        save_sub_dir = os.path.join(save_dir, save_dir_name)
        os.mkdir(save_sub_dir)
        torch.save(model, save_sub_dir + "/model.pt")
        #model.save_pretrained(save_sub_dir)
        tokenizer.save_pretrained(save_sub_dir)

        if folder_checkpoint:
            os.system("rm -rf " + folder_checkpoint)

        os.system("rm -rf " + output_file)
        metrics["folder_checkpoint"] = save_sub_dir
        save_metrics(metrics, output_file)


def calculate_metrics(global_step, pred_labels, validation_data, best_metric):

    data_source = np.array(validation_data["train"]["domain"])
    labels_val = np.array(validation_data["train"]["label"])
    pred_labels = np.array(pred_labels)
    assert len(labels_val) == len(pred_labels) == len(data_source)

    labels_source = defaultdict(list)
    preds_source = defaultdict(list)
    for ds, lv, pv in zip(data_source, labels_val, pred_labels):
        labels_source[ds].append(lv)
        preds_source[ds].append(pv)

    metrics = {"step": global_step,
               "accuracy": np.round_((pred_labels == labels_val).astype(np.float32).mean().item(), 4),
               "bacc": np.round_(balanced_accuracy_score(y_true=labels_val, y_pred=pred_labels), 4)}

    matrix = confusion_matrix(labels_val, pred_labels)
    matrix = matrix.diagonal() / matrix.sum(axis=1)
    metrics["recall_incorrect"] = np.round_(matrix[0], 4)
    metrics["recall_correct"] = np.round_(matrix[1], 4)
    metrics["f1"] = np.round_(f1_score(y_true=labels_val, y_pred=pred_labels, average="micro"), 4)
    metrics["f1_macro"] = np.round_(f1_score(y_true=labels_val, y_pred=pred_labels, average="macro"), 4)
    metrics["precision"] = np.round_(precision_score(y_true=labels_val, y_pred=pred_labels, average="micro"), 4)
    metrics["recall"] = np.round_(recall_score(y_true=labels_val, y_pred=pred_labels, average="micro"), 4)
    metrics["size"] = len(labels_val)

    selected_metrics = []
    for source in labels_source.keys():
        if source not in metrics:
            metrics[source] = {}
            metrics[source]["bacc"] = np.round_(
                balanced_accuracy_score(y_true=labels_source[source], y_pred=preds_source[source]), 4)
            metrics[source]["f1"] = np.round_(
                f1_score(y_true=labels_source[source], y_pred=preds_source[source], average="micro"), 4)
            matrix = confusion_matrix(labels_source[source], preds_source[source])
            matrix = matrix.diagonal() / matrix.sum(axis=1)
            metrics[source]["recall_incorrect"] = np.round_(matrix[0], 4)
            metrics[source]["recall_correct"] = np.round_(matrix[1], 4)
            metrics[source]["size"] = len(labels_source[source])
            selected_metrics.append(metrics[source]["bacc"])

    selected_metrics = np.mean(selected_metrics)
    metrics["avg_bacc"] = np.round_(selected_metrics, 4)

    if best_metric is not None:
        if selected_metrics > best_metric[0]:
            best_metric = [np.round_(selected_metrics, 4),
                           np.round_(metrics["f1"], 4)]

        metrics["best_bacc"] = best_metric

    return best_metric, metrics
