# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import json
import logging
import torch
from tqdm import tqdm

from utils import WORLDS, Stats


def get_topk_predictions(biranker, dls, candidate_encoding, logger, top_k=64):

    biranker.model.eval()
    logger.info("Getting top %d predictions." % top_k)
    stats = {}
    world_size = len(dls)
    logger.info("World size : %d" % world_size)
    for domain in dls.keys():
        stats[domain] = Stats(top_k)

    # scoring
    for domain in dls.keys(): 
        dl = dls[domain]
        iter_ = tqdm(dl, desc="Evaluate "+domain)
        for step, batch in enumerate(iter_):
            batch = tuple(t.to("cuda") for t in batch)
            context_input, label_ids = batch
            with torch.no_grad():
                scores = biranker.score_candidate(context_input, None, cand_encs=candidate_encoding[domain].to("cuda"))
            _, indicies = scores.topk(top_k)
            for i in range(context_input.size(0)):
                inds = indicies[i]
                pointer = -1
                for j in range(top_k):
                    if inds[j].item() == label_ids[i].item():
                        pointer = j
                        break

                stats[domain].add(pointer)
                if pointer == -1:
                    continue

    res = Stats(top_k)
    for domain in dls.keys():
        if stats[domain].cnt == 0:
            continue
        
        logger.info("In world " + domain)
        output = stats[domain].output()
        logger.info(output)
        
        # merge results
        res.extend(stats[domain])

    # all results
    logger.info(res.output())

