import numpy as np
import torch


class EarlyStopping(object):
    def __init__(self, min_delta, patience, max_epochs, mode='min'):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.max_epochs = max_epochs
        self.current_epoch = 0

        if patience == 0:
            self.is_better = lambda a, b: True

    def step(self, metrics):
        if self.current_epoch >= self.max_epochs:
            return False

        if self.best is None:
            self.best = metrics
            self.current_epoch += 1
            return True

        if np.isnan(metrics):
            return False

        if (self.mode == 'min' and self._min_mode(metrics)) or (self.mode == 'max' and self._max_mode(metrics)):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return False

        self.current_epoch += 1
        return True

    def _min_mode(self, metric):
        return metric < self.best - self.min_delta

    def _max_mode(self, metric):
        return metric > self.best + self.min_delta

    def epoch(self):
        return self.current_epoch


def sample_eval_tasks(meta_dataset, args):
    return [meta_dataset.sample_task(args.batch_size, args.n, args.k, num_queries=args.num_queries)
            for _ in range(args.num_samples)]


def get_entities(y):
    '''
        Returns a list of list of entity tuples (T, S, E) where T is the entity type id, S is the start index of the
        named entity, and E is the end index (exclusive), the returned list contains as many lists as there are
        utterances, and each list contains all entities (predicted or ground truth) in the utterance.
    '''
    all_entities = []
    for seq in y:
        entities = []
        if len(seq):
            previous = seq[0] if seq[0] != 1 else None
            start = 0
            j = 0
            for j in range(1, len(seq)):
                curr = seq[j]
                if previous and curr != previous:
                    entities.append((previous, start, j))
                if curr != previous:
                    start = j
                previous = curr if curr != 1 else None
            if previous:
                entities.append((previous, start, j+1))
        all_entities.append(entities)
    return all_entities


def micro_slot_f1(y_true, y_pred, labels):
    all_true_entities = get_entities(y_true)
    all_pred_entities = get_entities(y_pred)
    tp = 0
    num_ref, num_hyp = 0, 0
    for true_entities, pred_entities in zip(all_true_entities, all_pred_entities):
        true_entities_set = set(true_entities)
        tp += sum([1 if ent in true_entities_set else 0 for ent in pred_entities])
        num_ref += len(true_entities)
        num_hyp += len(pred_entities)
    if num_ref == 0:
        return 1
    if num_hyp == 0:
        return 0
    precision = tp / num_hyp
    recall = tp / num_ref
    if tp == 0:
        return 0
    return 2 * (precision * recall) / (precision + recall)


def process_labels(labels):
    labels_list = labels.cpu().numpy().tolist()
    # assuming padid is 0
    return [[x for x in elem if x != 0] for elem in labels_list]


@torch.no_grad()
def evaluate(dataset, ner_predictor, device):
    ner_predictor.eval()
    y_true, y_pred = [], []

    for batch in dataset.iter_queries(device=device):
        input, intents, slots = batch
        preds, _ = ner_predictor(input, device)
        y_true.extend(process_labels(slots))
        y_pred.extend(process_labels(preds))
        del _, preds

    micro_f1 = micro_slot_f1(y_true, y_pred, labels=dataset.tgt_slots)
    return micro_f1*100


def compute_ci(values):
    values = np.asarray(values)
    return np.mean(values), 1.96*np.std(values) / values.size**0.5


def aggregate_metrics(metrics):
    # _metrics shape is (num_samples, max_epochs, num_metrics)
    metrics = np.asarray(metrics)
    # _metrics shape is (num_metrics, max_epochs, num_samples)
    metrics = np.transpose(metrics, (2, 1, 0))
    mean_metrics = np.empty(metrics.shape[:-1], dtype=metrics.dtype)
    std_metrics = np.empty(metrics.shape[:-1], dtype=metrics.dtype)
    for m in range(metrics.shape[0]):
        for e in range(metrics.shape[1]):
            mean, std = compute_ci(metrics[m, e])
            mean_metrics[m, e] = mean
            std_metrics[m, e] = std
    return mean_metrics, std_metrics
