import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
import Levenshtein
from emb2emb.hausdorff import hausdorff_similarity
from autoencoders.torch_utils import make_mask


def compute_l0drop_statistics(encoder, data):
    bsize = 16
    l0_losses = []
    diffs = 0.
    ratios = 0.
    in_lens_total = 0.
    out_lens_total = 0.
    for idx in range(0, len(data["Sx"]), bsize):
        batch = data['Sx'][idx: (idx + bsize)]
        (outs, out_lens), l0, in_lens = encoder.encode(
            batch, return_l0loss=True)
        in_lens = torch.tensor(in_lens).float()
        in_lens_total = in_lens_total + in_lens.sum()
        out_lens = out_lens.float()
        out_lens_total = out_lens_total + out_lens.sum()
        diffs = diffs + (in_lens - out_lens).sum()
        ratios = ratios + (out_lens / in_lens).sum()

    diffs = diffs / len(data["Sx"])
    ratios = ratios / len(data["Sx"])
    in_lens_total = in_lens_total / len(data["Sx"])
    out_lens_total = out_lens_total / len(data["Sx"])
    print("Diffs", diffs)
    print("Ratios", ratios)
    print("In Total:", in_lens_total)
    print("Out Total:", out_lens_total)
    sys.exit(0)


def plot_num_words(data):
    lens_Sx = np.array([len(s.split(" ")) for s in data["Sx"]])
    lens_Sy = np.array([len(s.split(" ")) for s in data["Sy"]])

    ratio = lens_Sy / lens_Sx

    #plt.hist(lens_Sx, 50, density=True, label="Sx")
    #plt.hist(lens_Sy, 50, density=True, label="Sy")
    plt.hist(ratio, 50, density=True, label="Sy/Sx", range=(0.0, 2.0))
    plt.legend()
    print("MEAN:", ratio.mean())
    print("MEDIAN:", np.median(ratio))
    plt.show()


def compute_neighborhood_preservation(encoder, valid, params, num_neighbors_in_discrete_space=10, num_neighbors_in_embedding_space=[5, 10, 20, 50], verbose=False, levenshtein_precomputed=None, return_levenshtein=False):
    # merge both styles
    encoder.eval()
    data = valid["Sx"] + valid["Sy"]
    num_data = len(data)

    if verbose:
        print(
            f"Computing the neighborhood preservation on a dataset of size {num_data}...")
        print(
            f"Encoder: {params.modeldir}")

    if levenshtein_precomputed is None:
        # compute levenshtein distance for all pairs
        matrix = np.zeros((num_data, num_data))
        if verbose:
            print("Computing Levenshtein distance...")
        for i, s1 in enumerate(data):
            if verbose:
                print(f"{i} / {num_data}", flush=True, end="\r")
            for j, s2 in enumerate(data):
                d = Levenshtein.distance(s1, s2) / float(max(len(s1), len(s2)))
                matrix[i, j] = d

        # compute maximum levenshtein pairs
        top_k_levensthein = matrix.argsort(axis=1)[:, ::-1]
        top_k_levensthein = top_k_levensthein[:,
                                              :num_neighbors_in_discrete_space]
    else:
        top_k_levensthein = levenshtein_precomputed

    # compute NN in embedding space
    bsize = params.batch_size
    matrix = np.zeros((num_data, num_data))

    if verbose:
        print("Computing distances in embedding space...")
    for j in range(num_data):
        if verbose:
            print(f"{j} / {num_data}", flush=True, end="\r")
        batch = [data[j]]
        s1 = encoder(batch)

        for i in range(int(num_data / float(bsize))):
            batch = data[i * bsize: min(num_data, ((i + 1) * bsize))]
            if not type(batch) is list:
                batch = [batch]

            s2_batch = encoder(batch)

            X = s1[0]
            X = X.repeat(s2_batch[0].size(0), 1, 1)
            mask_X = make_mask(X.size(0), X.size(1), s1[1]).to(X.device)

            Y = s2_batch[0]
            mask_Y = make_mask(Y.size(0), Y.size(1), s2_batch[1]).to(X.device)

            d = (-1.0) * hausdorff_similarity(
                X, Y, mask_X, mask_Y, naive=False, differentiable=True).cpu().detach().numpy()
            matrix[j, i * bsize: ((i + 1) * bsize)] = d

    # compute recall of (levenshtein-)nearest neighbors in embedding space
    recall_vals = []
    for k in num_neighbors_in_embedding_space:
        top_k_embeddingspace = matrix.argsort(axis=1)[:, ::-1]
        top_k_embeddingspace = top_k_embeddingspace[:,
                                                    :k]

        # compute recall
        all_embeddings = np.hstack([top_k_embeddingspace, top_k_levensthein])
        recall = 0.
        for row in range(all_embeddings.shape[0]):
            _, unique_counts = np.unique(
                all_embeddings[row], axis=None, return_counts=True)

            # if a value appears twice, there is overlap
            num_recalled_values = (unique_counts == 2).sum(
            ) / float(num_neighbors_in_discrete_space)
            recall += num_recalled_values
        recall = recall / num_data
        recall_vals.append(recall)

    if return_levenshtein:
        return recall_vals, top_k_levensthein
    else:
        return recall_vals
