import json
import logging
import random
import time
from multiprocessing import cpu_count
from multiprocessing.pool import Pool
from pyserini.index import IndexReader

logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
    logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)


class BM25:
    def __init__(self, index_name='wikipedia-dpr', num_processes=None, log_every=None):
        self.index_reader = IndexReader.from_prebuilt_index(index_name)
        self.num_processes = num_processes if num_processes is not None else cpu_count()
        self.log_every = log_every

    def _calculate_helper(self, questions, ctxs, shard_idx=0):
        results = []
        start_time = time.time()
        for i, question in enumerate(questions):
            if shard_idx == 0 and self.log_every is not None and i > 0 and i % self.log_every == 0:
                logger.info(f"Finished {i} questions in shard 0, took {(time.time()-start_time)/60:0.1f} minutes")
            results.append([self.index_reader.compute_query_document_score(ctx, question) for ctx in ctxs])
        return results

    def calculate_scores(self, questions, ctxs):
        ctxs = [(str(ctx["passage_id"]) if isinstance(ctx, dict) else str(ctx)) for ctx in ctxs]
        if self.num_processes == 1:
            return self._calculate_helper(questions, ctxs)

        num_questions_in_shard = len(questions) // self.num_processes + 1
        params = [(questions[i: i + num_questions_in_shard], ctxs, shard_idx)
                  for shard_idx, i in enumerate(range(0, len(questions), num_questions_in_shard))]
        with Pool(self.num_processes) as p:
            results_ = p.starmap(self._calculate_helper, params)
        return sum(results_, [])


class ApproximateBM25:
    def __init__(self, stats_file, stats_print_interval=1000, index_name='wikipedia-dpr', num_hard_negatives=1):
        self.stats_file = stats_file
        self.stats_print_interval = stats_print_interval
        self.index_reader = IndexReader.from_prebuilt_index(index_name)
        assert num_hard_negatives == 1, "num_hard_negatives != 1 isn't supported"
        self.num_pos_highest, self.num_neg_highest, self.num_random_highest = 0, 0, 0
        self.num_pos_lowest, self.num_neg_lowest, self.num_random_lowest = 0, 0, 0
        self.steps = 0

    def calculate_scores(self, questions, ctxs):
        ctxs = [(str(ctx["passage_id"]) if isinstance(ctx, dict) else str(ctx)) for ctx in ctxs]
        assert len(ctxs) == (2 * len(questions))
        results = []
        self.steps += 1
        for i, question in enumerate(questions):
            random_psg_idx = self._get_random_ctx(len(ctxs), {2 * i, 2 * i + 1})
            pos_psg_score = self.index_reader.compute_query_document_score(ctxs[2 * i], question)
            neg_psg_score = self.index_reader.compute_query_document_score(ctxs[2 * i + 1], question)
            random_psg_score = self.index_reader.compute_query_document_score(ctxs[random_psg_idx], question)
            min_score = min(pos_psg_score, neg_psg_score, random_psg_score)
            all_scores = [min_score] * len(ctxs)
            all_scores[2 * i] = pos_psg_score
            all_scores[2 * i + 1] = neg_psg_score
            results.append(all_scores)
            self._update_stats(pos_psg_score, neg_psg_score, random_psg_score)
        if self.steps % self.stats_print_interval == 0:
            self._print_stats()
        return results

    def _get_random_ctx(self, num_ctxs, ctxs_to_exclude):
        ctx_idx = random.randrange(num_ctxs)
        while ctx_idx in ctxs_to_exclude:
            ctx_idx = random.randrange(num_ctxs)
        return ctx_idx

    def _update_stats(self, pos_psg_score, neg_psg_score, random_psg_score):
        max_score = max(pos_psg_score, neg_psg_score, random_psg_score)
        if max_score == pos_psg_score:
            self.num_pos_highest += 1
        if max_score == neg_psg_score:
            self.num_neg_highest += 1
        if max_score == random_psg_score:
            self.num_random_highest += 1

        min_score = min(pos_psg_score, neg_psg_score, random_psg_score)
        if min_score == pos_psg_score:
            self.num_pos_lowest += 1
        if min_score == neg_psg_score:
            self.num_neg_lowest += 1
        if min_score == random_psg_score:
            self.num_random_lowest += 1

    def _print_stats(self):
        with open(self.stats_file, "a") as f:
            stats = {
                "steps": self.steps,
                "pos_high": self.num_pos_highest,
                "neg_high": self.num_neg_highest,
                "random_high": self.num_random_highest,
                "total_high":  (self.num_pos_highest + self.num_neg_highest + self.num_random_highest),
                "pos_low": self.num_pos_lowest,
                "neg_low": self.num_neg_lowest,
                "random_low": self.num_random_lowest,
                "total_low": (self.num_pos_lowest + self.num_neg_lowest + self.num_random_lowest),
            }
            f.write(json.dumps(stats) + "\n")
