"""
1. Load ranking.tsv files
2. Ensemble (merging rankings)
-- averaging relevance scores
ranker #a ==> (d1, r1a), (d2, r2a)
ranker #b ==> (d1, r1b), (d3, r3b)
Averaging option 1: (d1, mean(r1a, r1b)), (d2, mean(r2a, 0)), (d3, mean(0, r3b))
Averaging option 2: (d1, mean(r1a, r1b)), (d2, mean(r2a, min(r*b))), (d3, mean(min(r*a), r3b))
3. Save ranking.tsv for top-1000 documents
"""

import argparse
import os
from collections import OrderedDict
import numpy as np

def load_ranking(path):
    print("#> Loading ranking from", path, "...")

    ranking = OrderedDict()
    with open(path, mode='r', encoding="utf-8") as f:
        
        for line_idx, line in enumerate(f):
            qid, pid, rank, score = line.strip().split('\t')

            qid, pid, rank = map(int, (qid, pid, rank))
            score = float(score)

            ranking[qid] = ranking.get(qid, [])
            # ranking[qid].append({pid: (rank, score)})
            ranking[qid].append((pid, rank, score))

    # return: ranking: Dict[ qid: List[ Tuple(pid, rank, score) ] ]
    return ranking

def load_ranking_files(files):
    # sanity check
    for path in files:
        assert os.path.exists(path)

    ranking_list = []
    
    for path in files:
        
        # read ranking.tsv
        ranking = load_ranking(path)
        # ranking: Dict[ qid: List[ Tuple(pid, rank, score) ] ]
        
        ranking_list.append(ranking)

    return ranking_list
        
def mean_score(ranking_list, depth=1000, fill_na='zero'):
    print(f'#> Ensemble {len(ranking_list)} rankings.')
    # ---
    # get ensemble ranking
    # ---

    # param: ranking_list: List[ Dict[ qid: List[ Tuple(pid, rank, score) ] ] ]

    # sanity check: all qids must be same.
    # _qid_set = set(ranking_list[0].keys())
    # for ranking in ranking_list[1:]:
    #     _qid_set2 = set(ranking.keys())
    #     assert len(_qid_set - _qid_set2) == 0 and len(_qid_set2 - _qid_set) == 0

    ensemble_ranking = OrderedDict()

    
    fn_extract_uniq_pids_from_topk = lambda per_query_topk: set(map(lambda x: x[0], per_query_topk))
    def uniq_pids(per_query_topk_list):
        # param: per_query_topk: List [ List[ Tuple(pid, rank, score) ] ]
        pids = set()
        for per_query_topk in per_query_topk_list:
            # per_query_topk: List[ Tuple(pid, rank, score) ]
            pids = pids.union(fn_extract_uniq_pids_from_topk(per_query_topk))
        return pids

    ordered_qids = ranking_list[0].keys()
    for qid in ordered_qids:
        # get unique pids in rankings
        pids = uniq_pids([ranking[qid] for ranking in ranking_list])
        # pids: set(int): the set of unique pids in rankings
        
        # ensemble_topk
        per_query_ensemble_topk_dict = {pid: [] for pid in pids}
        # per_query_ensemble_topk_dict: Dict[ pid:List[score] ]

        # gather score from each ranker
        for ranking in ranking_list:
            # ranking: Dict[ qid: List[ Tuple(pid, rank, score) ] ]
            per_query_topk = ranking[qid]
            # per_query_topk: List[ Tuple(pid, rank, score) ]
            for pid, rank, score in per_query_topk:
                per_query_ensemble_topk_dict[pid].append(score)

            # N/A pids
            na_pids = pids - fn_extract_uniq_pids_from_topk(per_query_topk)
            if fill_na=='zero':
                fill_value = 0
            elif fill_na=='min':
                bottom_rank = per_query_topk[-1]
                fill_value = bottom_rank[2] # min score
            else:
                raise NotImplementedError
            for pid in na_pids:
                per_query_ensemble_topk_dict[pid].append(fill_value)
        
        # sanity check: all pids have the same relevance scores
        for pid in pids:
            assert len(per_query_ensemble_topk_dict[pid]) == len(ranking_list)
        
        # mean score
        per_query_ensemble_topk = [(pid, np.mean(per_query_ensemble_topk_dict[pid])) for pid in pids]
        # per_query_ensemble_topk: List[Tuple(pid, score)]
        
        # sort by score
        per_query_ensemble_topk_sorted = sorted(per_query_ensemble_topk, key=lambda x: x[1], reverse=True)
        # per_query_ensemble_topk_sorted: List[Tuple(pid, score)]

        # truncate by predefined depth
        per_query_ensemble_topk_sorted_truncated = per_query_ensemble_topk_sorted[:depth]
        # per_query_ensemble_topk_sorted_truncated: List[Tuple(pid, score)]

        # align with rank
        final_triples = [(pid, rank+1, score) for rank, (pid, score) in enumerate(per_query_ensemble_topk_sorted_truncated)]
        
        ensemble_ranking[qid] = final_triples
    
    return ensemble_ranking
    # return: ensemble_ranking: Dict[ qid: List[ Tuple(pid, rank, score) ] ]

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--ranking_files', type=str, nargs="+", required=True)
    parser.add_argument('--fill_na', choices=['zero', 'min'], default='min', help="How to fill not available relevance scores, i.e., when a document in top-k ranking of ranker A is not included that of ranker B.")
    parser.add_argument('--depth', type=int, default=1000)
    parser.add_argument('--output', type=str, required=True)

    args = parser.parse_args()

    # Load ranking.tsv files
    ranking_list = load_ranking_files(args.ranking_files)
    # ranking_list: List[ Dict[ qid: List[ Tuple(pid, rank, score) ] ] ]

    # Ensemble
    ranking = mean_score(ranking_list, depth=args.depth, fill_na=args.fill_na)
    # ranking: Dict[ qid: List[ Tuple(pid, rank, score) ] ]

    # Save result
    print(f'make directory: {os.path.dirname(args.output)}')
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    with open(args.output, 'w') as outfile:
        for qid, triples in ranking.items():
            for pid, rank, score in triples:
                outfile.write(f'{qid}\t{pid}\t{rank}\t{score}\n')
    print(f'\n\n\toutfile: {args.output}\n\n')