#!/usr/bin/env python3

import sys
import json
import argparse
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, binned_statistic, sem
from scipy import ndimage as nd

from lm_polygraph.utils.manager import UEManager


def fill(data, invalid=None):
    """
    Replace the value of invalid 'data' cells (indicated by 'invalid') 
    by the value of the nearest valid data cell

    Input:
        data:    numpy array of any dimension
        invalid: a binary array of same shape as 'data'. 
                 data value are replaced where invalid is True
                 If None (default), use: invalid  = np.isnan(data)

    Output: 
        Return a filled array. 
    """    
    if invalid is None: invalid = np.isnan(data)

    ind = nd.distance_transform_edt(invalid, 
                                    return_distances=False, 
                                    return_indices=True)
    return data[tuple(ind)]


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--ue_methods", nargs="+", default=[])
    parser.add_argument("--generation_metrics", nargs='+', default=[])
    parser.add_argument("--man_paths", nargs="+", default=[])
    parser.add_argument("--output_dir", default="./workdir")

    args = parser.parse_args()

    return args


def main():
    args = get_args()
    
    out = Path(args.output_dir)
    out.mkdir(parents=True, exist_ok=True)

    man_names = args.man_paths
    ues = args.ue_methods
    gen_metrics = args.generation_metrics

    archive = {}
    mans = []
    for man_name in man_names:
        mans.append(UEManager.load(man_name))

    for metric_name in gen_metrics:
        # that's just the way keys are in man archive
        metric_name = ('sequence', metric_name)
        metrics = []
        for man, man_path in zip(mans, man_names):
            try:
                metrics.append(man.gen_metrics[metric_name])
            except KeyError:
                print(f"No metric {metric_name} in manager archive {man_path}")

        metric = np.concatenate(metrics)
        metric_nans = np.argwhere(np.isnan(metric)).flatten()
        metric = np.delete(metric, metric_nans)
        for i, ue_name in enumerate(ues):
            ue_name = ('sequence', ue_name)
            man_ues = []
            for man, man_path in zip(mans, man_names):
                try:
                    man_ue = np.array(man.estimations[ue_name])
                except KeyError:
                    print(f"No UE {ue_name} in manager archive {man_path}")
                man_ues.append(man_ue)

            ue = np.concatenate(man_ues)
            ue = np.delete(ue, metric_nans)

            ue_nans  = np.argwhere(np.isnan(ue)).flatten()

            filtered_ue = np.delete(ue, ue_nans)

            filtered_metric = np.delete(metric, ue_nans)
            shifted_metric = (filtered_metric - filtered_metric.min())
            normed_metric = shifted_metric / shifted_metric.max()

            
            num_bins = 100
            metric_bins = binned_statistic(filtered_ue,
                                           normed_metric,
                                           bins=num_bins, statistic='mean')

            std_bins = binned_statistic(filtered_ue,
                                        normed_metric,
                                        bins=num_bins, statistic='std')
            
            sem_bins = binned_statistic(filtered_ue,
                                        normed_metric,
                                        bins=num_bins, statistic=sem)

            counts = []
            for bin_number in range(num_bins):
                counts.append((metric_bins.binnumber == (bin_number + 1)).sum())

            # fill bins with low support with nans
            bin_metric = metric_bins.statistic
            bin_metric[np.array(counts) < 50] = np.nan
            bin_metric = fill(bin_metric)

            bin_std = std_bins.statistic
            bin_std[np.array(counts) < 50] = np.nan
            bin_std = fill(bin_std)

            bin_sem = sem_bins.statistic
            bin_sem[np.array(counts) < 50] = np.nan
            bin_sem = fill(bin_sem)

            unnormalized_conf = bin_metric * 100
            min_metric = bin_metric.min()
            max_metric = bin_metric.max()
            normalized_conf = (bin_metric - min_metric) / (max_metric - min_metric)
            normalized_conf = normalized_conf * 100

            fig, ax = plt.subplots(6, 1, figsize=(5, 25))
            fig.suptitle(metric_name[1])

            ax[0].plot(range(len(bin_metric)), bin_metric)
            ax[0].set_ylabel(metric_name[1])
            ax[0].set_xlabel('Bin')
            ax[0].grid()
            
            ax[1].plot(range(len(bin_std)), bin_std)
            ax[1].set_ylabel(f'{metric_name[1]} StD')
            ax[1].set_xlabel('Bin')
            ax[1].grid()

            ax[2].plot(range(len(bin_sem)), bin_sem)
            ax[2].set_ylabel(f'{metric_name[1]} Standard Error')
            ax[2].set_xlabel('Bin')
            ax[2].grid()

            ax[3].plot(range(len(unnormalized_conf)), unnormalized_conf)
            ax[3].set_ylabel('Unnormalized conf, %')
            ax[3].set_xlabel('Bin')
            ax[3].grid()

            ax[4].plot(range(len(normalized_conf)), normalized_conf)
            ax[4].set_ylabel('Normalized conf, %')
            ax[4].set_xlabel('Bin')
            ax[4].grid()
            
            ax[5].stairs(counts, metric_bins.bin_edges)
            ax[5].set_ylabel('Num points in bin')
            ax[5].set_xlabel('UE in bin')
            ax[5].grid()

            archive[ue_name[1]] = {
                'ues': list(metric_bins.bin_edges),
                'normed_conf': list(normalized_conf),
                'unnormed_conf': list(unnormalized_conf),
            }

            plt.tight_layout()
            plt.savefig(out / f'{ue_name[1]}_{metric_name[1]}.jpg')
            plt.clf()

    with open(out / 'calibrated_ues.json', 'w') as handle:
        handle.write(json.dumps(archive))


if __name__ == "__main__":
    main()
