from collections import OrderedDict
import ujson, numpy as np

from collections import defaultdict
from pruner.utils.runs import Run


class Metrics:
    def __init__(self, recall_depths: set, total_passages=None):
        
        self.recall_sums = {ntokens: 0.0 for ntokens in recall_depths}
        self.total_passages = total_passages

        self.max_passage_idx = -1
        self.num_passages_added = 0

    def add(self, prediction, gold_tokens):
        # prediction: ndarray (float) = The pruning scores for each token in the passage
        # gold_tokens: set(int) = The set of positions of gold tokens
        
        self.num_passages_added += 1

        pred_tokens_sorted = np.argsort(prediction)[::-1]

        positives = [1 if token in gold_tokens else 0 for token in pred_tokens_sorted]
        accum_positives = np.cumsum(positives)

        if len(positives) == 0:
            return
        for ntokens in self.recall_sums:
            self.recall_sums[ntokens] += accum_positives[min(ntokens-1, len(accum_positives)-1)] / len(gold_tokens)

    def print_metrics(self, n_samples):
        for ntokens in sorted(self.recall_sums):
            print("Recall@" + str(ntokens), "=", self.recall_sums[ntokens] / n_samples)

    def output_final_metrics(self, path, n_samples, num_passages):
        assert n_samples == num_passages == self.total_passages

        self.print_metrics(n_samples)

        output = defaultdict(dict)

        for ntokens in sorted(self.recall_sums):
            score = self.recall_sums[ntokens] / n_samples
            output['recall'][ntokens] = score

        with open(path, 'w') as f:
            ujson.dump(output, f, indent=4)
            f.write('\n')


class RankingMetrics:
    def __init__(self, topks: set, total_num_data=None):
        
        self.mrr10_sums = OrderedDict()
        for topk in topks:
            self.mrr10_sums[topk] = 0.0
        self.total_num_data = total_num_data

        self.num_data_added = 0

    def add(self, pid, qd_sim, score, gold_pids, candidate_rels):
        # qd_sim: ``(query_maxlen, doc_maxlen)``
        # score: List[float]
        # gold_pids: set(int)
        # candidate_rels: List[(pid, score)]
        
        self.num_data_added += 1

        position_sorted = np.argsort(score)[::-1] # List[int]

        for topk in self.mrr10_sums:
            # print(f'\n:topk:{topk}')
            
            selected_positions = position_sorted[:topk] # List[int]
            # print(f'\n:selected_positions:={selected_positions}')
            # print(f'\n:qd_sim:{qd_sim.shape}=\n{qd_sim}')
            maxsim = qd_sim[:, selected_positions].max(1) # (``query_maxlen``)
            # print(f'\n:maxsim:={maxsim}')
            rel = maxsim.sum(0)
            # print(f'\n:rel:{rel}')
            
            ranking = candidate_rels + [(pid, rel)]
            # print(f'\n:ranking:{ranking}')
            ranking = sorted(ranking, key=lambda x: x[1], reverse=True)

            positives = [i for i, (pid, _) in enumerate(ranking) if pid in gold_pids]
            first_positive = positives[0]
            # print(f'\n:first_positive:{first_positive}')

            self.mrr10_sums[topk] += (1.0 / (first_positive+1.0)) if first_positive < 10 else 0.0
            # print(f'\n:self.mrr10_sums[topk]:{self.mrr10_sums[topk]}')

    def print_metrics(self, n_samples):
        for topk in sorted(self.mrr10_sums):
            print("MRR@10 - Top" + str(topk), "tokens=", self.mrr10_sums[topk] / n_samples)

    def output_final_metrics(self, path, n_samples, num_data):
        assert n_samples == num_data == self.total_num_data

        self.print_metrics(n_samples)

        output = defaultdict(dict)

        for topk in sorted(self.mrr10_sums):
            score = self.mrr10_sums[topk] / n_samples
            output['mrr@10-topk'][topk] = score

        with open(path, 'w') as f:
            ujson.dump(output, f, indent=4)
            f.write('\n')



class RerankingMetrics:
    def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
        self.results = {}
        self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
        self.recall_sums = {depth: 0.0 for depth in recall_depths}
        self.success_sums = {depth: 0.0 for depth in success_depths}
        self.total_queries = total_queries

        self.max_query_idx = -1
        self.num_queries_added = 0

    def add(self, query_idx, query_key, ranking, gold_positives):
        self.num_queries_added += 1

        assert query_key not in self.results
        assert len(self.results) <= query_idx
        assert len(set(gold_positives)) == len(gold_positives)
        assert len(set([pid for _, pid, _ in ranking])) == len(ranking)

        self.results[query_key] = ranking

        positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]

        if len(positives) == 0:
            return

        for depth in self.mrr_sums:
            first_positive = positives[0]
            self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0

        for depth in self.success_sums:
            first_positive = positives[0]
            self.success_sums[depth] += 1.0 if first_positive < depth else 0.0

        for depth in self.recall_sums:
            num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
            self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)

    def print_metrics(self, query_idx):
        for depth in sorted(self.mrr_sums):
            print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.success_sums):
            print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.recall_sums):
            print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))

    def log(self, query_idx):
        assert query_idx >= self.max_query_idx
        self.max_query_idx = query_idx

        Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
        Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/MRR." + str(depth), score, query_idx)

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Success." + str(depth), score, query_idx)

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Recall." + str(depth), score, query_idx)

    def output_final_metrics(self, path, query_idx, num_queries):
        assert query_idx + 1 == num_queries
        assert num_queries == self.total_queries

        if self.max_query_idx < query_idx:
            self.log(query_idx)

        self.print_metrics(query_idx)

        output = defaultdict(dict)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            output['mrr'][depth] = score

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            output['success'][depth] = score

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            output['recall'][depth] = score

        with open(path, 'w') as f:
            ujson.dump(output, f, indent=4)
            f.write('\n')
