import os
import time
import torch
import ujson
import numpy as np

import itertools
import threading
import queue

from pruner.modeling.inference import ModelInference
from pruner.evaluation.loaders import load_pruner
from pruner.utils.utils import print_message



class CollectionEncoder():
    def __init__(self, args, process_idx, num_processes):
        self.args = args
        self.collection = args.collection
        self.process_idx = process_idx
        self.num_processes = num_processes

        # Determine subset sizes for output
        self.possible_subset_sizes = [100000]

        #?@ debugging
        # self.possible_subset_sizes = [1000] 

        self.print_main("#> Local args.bsize =", args.bsize)
        self.print_main("#> args.output =", args.output)
        self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")

        self._load_model()
        
        self.iterator = self._initialize_iterator()

    def _initialize_iterator(self):
        return open(self.collection)

    def _saver_thread(self):
        for args in iter(self.saver_queue.get, None):
            self._save_batch(*args)

    def _load_model(self):
        self.pruner, self.checkpoint = load_pruner(self.args, do_print=(self.process_idx == 0))
        self.pruner = self.pruner.cuda()
        self.pruner.eval()

        self.inference = ModelInference(self.pruner, amp=self.args.amp)

    def scoring(self):
        self.saver_queue = queue.Queue(maxsize=3)
        thread = threading.Thread(target=self._saver_thread)
        thread.start()

        t0 = time.time()
        local_docs_processed = 0

        for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
            if owner != self.process_idx:
                continue

            t1 = time.time()
            pids, passages = self._preprocess_batch(offset, lines)
            assert len(pids)==len(passages)
            # pids: List[int] = list of pids
            # passages: List[str] = list of passages

            #?@ debugging
            # print(f'pids[:2]=\n\t{pids[:2]}') 
            # print(f'passages[:2]=\n\t{passages[:2]}')
            
            # Batch processing
            scores_batch, tokens_batch = [], []
            for _passage_offset in range(0, len(pids), self.args.bsize):
                
                _scores, _tokens = self._score_batch(docs=passages[_passage_offset:_passage_offset+self.args.bsize])
                # _scores: List[ List[float]] = scores for each token in each doc
                # _tokens: List[ List[str  ] ] = tokens for each doc
                
                #?@ debugging
                # print(f'len(_scores)={len(_scores)}, len(_tokens)={len(_tokens)}')
                # print(f'_scores[0]=\n\t{_scores[0]}')
                # print(f'_tokens[0]=\n\t{_tokens[0]}')
                # print(f'CollectionEncoder: prune: exit');exit()
                
                scores_batch.extend(_scores)
                tokens_batch.extend(_tokens)

            t2 = time.time()
            self.saver_queue.put((batch_idx, pids, scores_batch, tokens_batch)) # custom ``_save_batch``

            t3 = time.time()
            local_docs_processed += len(lines)
            overall_throughput = compute_throughput(local_docs_processed, t0, t3)
            this_encoding_throughput = compute_throughput(len(lines), t1, t2)
            this_saving_throughput = compute_throughput(len(lines), t2, t3)

            self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
                          f'Passages/min: {overall_throughput} (overall), ',
                          f'{this_encoding_throughput} (this encoding), ',
                          f'{this_saving_throughput} (this saving)')
            
            #?@ debugging
            # print(f'CollectionEncoder: encode: exit');exit() 

        self.saver_queue.put(None)

        self.print("#> Joining saver thread.")
        thread.join()

    def _batch_passages(self, fi):
        """
        Must use the same seed across processes!
        """
        np.random.seed(0)

        offset = 0
        for owner in itertools.cycle(range(self.num_processes)):
            batch_size = np.random.choice(self.possible_subset_sizes)

            L = [line for _, line in zip(range(batch_size), fi)]

            if len(L) == 0:
                break  # EOF

            yield (offset, L, owner)
            offset += len(L)

            if len(L) < batch_size:
                break  # EOF

        self.print("[NOTE] Done with local share.")

        return

    def _preprocess_batch(self, offset, lines):
        endpos = offset + len(lines)

        pids = []
        passages = []

        for line_idx, line in zip(range(offset, endpos), lines):
            line_parts = line.strip().split('\t')

            pid, passage = line_parts
            
            pid = int(pid)
            assert pid == line_idx
            
            passage = passage.strip()
            assert len(passage) >= 1

            pids.append(pid)
            passages.append(passage)
        
        return pids, passages

    def _score_batch(self, docs):
        with torch.no_grad():

            scores, tokens = self.inference.scoreFromText(docs=docs, return_tokembs=False)
            # scores: List[ List[float] ] = scores for each token in each doc
            # tokens: List[ List[str  ] ] = tokens for each doc
            assert type(scores) is list
            assert len(scores) == len(docs)
            assert type(tokens) is list
            assert len(tokens) == len(docs)

        return scores, tokens

    def _save_batch(self, batch_idx, pids, scores, tokens):
        # pids: List[int] = list of pids
        # scores    : List[ List[float] ] = (sorted by scores) scores for each token in each doc
        # tokens    : List[ List[str  ] ] = (sorted by scores) tokens for each doc
        
        start_time = time.time()

        output_path = os.path.join(self.args.output, "{}.tsv".format(batch_idx))
        
        # Save the scores.
        with open(output_path, 'w') as ofile:
            for pid, score, token in zip(pids, scores, tokens):

                outline = f'{pid}\t{ujson.dumps(token)}\t{ujson.dumps(score)}\n'
                ofile.write(outline)

                #?@ debugging
                # print(f'outline={outline}')
                # print(f'CollectionEncoder: _save_batch: exit');exit()
        
        #?@ debugging
        # print(f'CollectionEncoder: _save_batch: exit');exit()
        
        throughput = compute_throughput(len(scores), start_time, time.time())
        self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
                        "Saving Throughput =", throughput, "passages per minute.\n")

    def print(self, *args):
        print_message("[" + str(self.process_idx) + "]", "\t\t", *args)

    def print_main(self, *args):
        if self.process_idx == 0:
            self.print(*args)


def compute_throughput(size, t0, t1):
    throughput = size / (t1 - t0) * 60

    if throughput > 1000 * 1000:
        throughput = throughput / (1000*1000)
        throughput = round(throughput, 1)
        return '{}M'.format(throughput)

    throughput = throughput / (1000)
    throughput = round(throughput, 1)
    return '{}k'.format(throughput)
