import os
import shutil
from PIL import Image
import numpy as np
import torch

from utils import get_tasks_list, get_ontology, prune_ontology, plot

def get():
    tasks = get_tasks_list("/home/CrossFit/dataloader/custom_tasks_splits/random.json", "train")

    ontology, reverse_ontology = get_ontology("/home/LayerDrop/playground/plot_utils/ontology.json")
    pruned_ontology, _, _, _ = prune_ontology(tasks, ontology)

    tasks = []
    for k, v in pruned_ontology.items():
        for k1, v1 in v.items():
            tasks += v1

    return tasks, reverse_ontology

TOPN=3

output_dir="/home/LayerDrop/models/mar29/test_v2_linear_layer_specific_min0.3_anneal0.0003"

losses = {750: 3.755605180884323, 1500: 3.081802110935779, 2250: 2.86347183561432, 3000: 2.7795245045264747, 3750: 2.7181414632259293, 4500: 2.684855261199583, 5250: 2.6628627758375525, 6000: 2.6497475506535557, 6750: 2.6342990110832525, 7500: 2.625563418204321, 8250: 2.621306574591228, 9000: 2.6202434010091875, 9750: 2.6160212053571428, 10500: 2.616335781543696, 11250: 2.611971137248366, 12000: 2.6073156291280393, 12750: 2.60859849566578, 13500: 2.6134448845826714, 14250: 2.616999078912909, 15000: 2.617102379903959, 15750: 2.620870887066539, 16500: 2.6159190101639207, 17250: 2.616690615879445, 18000: 2.623408489050226, 18750: 2.6279690902276216, 19500: 2.6260272816250274, 20250: 2.625906400818848, 21000: 2.6286516908313278, 21750: 2.6278865828160924, 22500: 2.6327229308836917, 23250: 2.6314497703310566, 24000: 2.6352943407338665}

sorted_losses = sorted(losses.items(), key=lambda x: x[1])
print(sorted_losses)

# topn_dir = os.path.join(output_dir, "topn")
# os.makedirs(topn_dir, exist_ok=True)

# shutil.copyfile(os.path.join(output_dir, "{}-steps".format(sorted_losses[0][0]), "route.png"), os.path.join(topn_dir, "top1-{}-steps.png".format(sorted_losses[0][0])))
# shutil.copyfile(os.path.join(output_dir, "{}-steps".format(sorted_losses[1][0]), "route.png"), os.path.join(topn_dir, "top2-{}-steps.png".format(sorted_losses[1][0])))
# shutil.copyfile(os.path.join(output_dir, "{}-steps".format(sorted_losses[2][0]), "route.png"), os.path.join(topn_dir, "top3-{}-steps.png".format(sorted_losses[2][0])))

routes1 = np.load(os.path.join(output_dir, "{}-steps".format(sorted_losses[0][0]), "route.npy"))
routes2 = np.load(os.path.join(output_dir, "{}-steps".format(sorted_losses[1][0]), "route.npy"))
routes3 = np.load(os.path.join(output_dir, "{}-steps".format(sorted_losses[2][0]), "route.npy"))

cat = np.stack([routes1, routes2, routes3], axis=0)
print(cat.shape)
std = np.std(cat, axis=0)
print(std.shape)

tasks, reverse_ontology = get()
path = os.path.join(output_dir, "std.png")

plot(torch.tensor(std), tasks, path, reverse_ontology)

# create a gif of the routing changes
# checkpoints = os.listdir(output_dir)
# checkpoints = list(filter(lambda x: x.endswith("-steps"), checkpoints))
# checkpoints = sorted(checkpoints, key=lambda x: int(x[:-6]))

# frames = [Image.open(os.path.join(output_dir, checkpoint, "route.png")) for checkpoint in checkpoints]
# frame_one = frames[0]
# frame_one.save(os.path.join(output_dir, "routes-dynamics.gif"), format="GIF", append_images=frames,
#             save_all=True, duration=200, loop=0)
