import os
import time
import torch
import ujson
import numpy as np

import itertools
import threading
import queue

from pruner.modeling.inference import ModelInference
from pruner.evaluation.loaders import load_pruner
from pruner.utils.utils import print_message



class CollectionEncoder():
    def __init__(self, args, process_idx, num_processes):
        self.args = args
        self.collection = args.collection
        self.process_idx = process_idx
        self.num_processes = num_processes

        # Determine subset sizes for output
        self.possible_subset_sizes = [100000]

        #?@ debugging
        # self.possible_subset_sizes = [1000] 

        self.print_main("#> Local args.bsize =", args.bsize)
        self.print_main("#> args.output =", args.output)
        self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")

        self._load_model()
        
        self.iterator = self._initialize_iterator()

    def _initialize_iterator(self):
        return open(self.collection)

    def _saver_thread(self):
        for args in iter(self.saver_queue.get, None):
            self._save_batch(*args)

    def _load_model(self):
        self.pruner, self.checkpoint = load_pruner(self.args, do_print=(self.process_idx == 0))
        self.pruner = self.pruner.cuda()
        self.pruner.eval()

        self.inference = ModelInference(self.pruner, amp=self.args.amp)

    def prune(self):
        self.saver_queue = queue.Queue(maxsize=3)
        thread = threading.Thread(target=self._saver_thread)
        thread.start()

        t0 = time.time()
        local_docs_processed = 0

        for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
            if owner != self.process_idx:
                continue

            t1 = time.time()
            pids, passages = self._preprocess_batch(offset, lines)
            assert len(pids)==len(passages)
            # pids: List[int] = list of pids
            # passages: List[str] = list of passages

            #?@ debugging
            # print(f'pids[:2]=\n\t{pids[:2]}') 
            # print(f'passages[:2]=\n\t{passages[:2]}')
            
            # Batch processing
            positions_batch, scores_batch, tokens_batch = [], [], []
            for _passage_offset in range(0, len(pids), self.args.bsize):
                
                _scores, _tokens, _embs = self._score_batch(docs=passages[_passage_offset:_passage_offset+self.args.bsize])
                # _scores: List[ List[float]] = scores for each token in each doc
                # _tokens: List[ List[str  ] ] = tokens for each doc
                # _embs  : float tensor (bsize, doc_maxlen, dim) = token embeddings for each doc

                #?@ debugging
                # print(f'len(_scores)={len(_scores)}, len(_tokens)={len(_tokens)}, _embs.size()={_embs.size()}')
                # print(f'_scores[0]=\n\t{_scores[0]}')
                # print(f'_tokens[0]=\n\t{_tokens[0]}')
                # print(f'_embs[0]=\n\t{_embs[0]}')
                
                _positions, _scores, _tokens = self._sort_by_scores_with_mmr(scores=_scores, tokens=_tokens, embs=_embs)
                # _positions : List[ List[int]   ] = (sorted by scores) positions for each token in each doc
                # _scores    : List[ List[float] ] = (sorted by scores) scores for each token in each doc
                # _tokens    : List[ List[str  ] ] = (sorted by scores) tokens for each doc

                #?@ debugging
                # print(f'len(_positions)={len(_positions)}, len(_scores)={len(_scores)}, len(_tokens)={len(_tokens)}')
                # print(f'_positions[0]=\n\t{_positions[0]}')
                # print(f'_scores[0]=\n\t{_scores[0]}')
                # print(f'_tokens[0]=\n\t{_tokens[0]}')
                # print(f'CollectionEncoder: prune: exit');exit()
                
                positions_batch.extend(_positions)
                scores_batch.extend(_scores)
                tokens_batch.extend(_tokens)

            t2 = time.time()
            self.saver_queue.put((batch_idx, pids, positions_batch, scores_batch, tokens_batch)) # custom ``_save_batch``

            t3 = time.time()
            local_docs_processed += len(lines)
            overall_throughput = compute_throughput(local_docs_processed, t0, t3)
            this_encoding_throughput = compute_throughput(len(lines), t1, t2)
            this_saving_throughput = compute_throughput(len(lines), t2, t3)

            self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
                          f'Passages/min: {overall_throughput} (overall), ',
                          f'{this_encoding_throughput} (this encoding), ',
                          f'{this_saving_throughput} (this saving)')
            
            #?@ debugging
            # print(f'CollectionEncoder: encode: exit');exit() 

        self.saver_queue.put(None)

        self.print("#> Joining saver thread.")
        thread.join()

    def _batch_passages(self, fi):
        """
        Must use the same seed across processes!
        """
        np.random.seed(0)

        offset = 0
        for owner in itertools.cycle(range(self.num_processes)):
            batch_size = np.random.choice(self.possible_subset_sizes)

            L = [line for _, line in zip(range(batch_size), fi)]

            if len(L) == 0:
                break  # EOF

            yield (offset, L, owner)
            offset += len(L)

            if len(L) < batch_size:
                break  # EOF

        self.print("[NOTE] Done with local share.")

        return

    def _preprocess_batch(self, offset, lines):
        endpos = offset + len(lines)

        pids = []
        passages = []

        for line_idx, line in zip(range(offset, endpos), lines):
            line_parts = line.strip().split('\t')

            pid, passage = line_parts
            
            # pid = int(pid) # MSMARCO Passage Ranking
            # assert pid == line_idx # MSMARCO Passage Ranking
            
            passage = passage.strip()
            assert len(passage) >= 1

            pids.append(pid)
            passages.append(passage)
        
        return pids, passages

    def _score_batch(self, docs):
        with torch.no_grad():

            scores, tokens, embs = self.inference.scoreFromText(docs=docs, return_tokembs=True)
            # scores: List[ List[float] ] = scores for each token in each doc
            # tokens: List[ List[str  ] ] = tokens for each doc
            # embs  : float tensor (bsize, doc_maxlen, dim) = token embeddings for each doc
            assert type(scores) is list
            assert len(scores) == len(docs)
            assert type(tokens) is list
            assert len(tokens) == len(docs)
            assert embs.size(0) == len(docs)
            assert embs.size(2) == self.inference.pruner.dim, f'embs.size()={tuple(embs.size)}'
            
            __embs_norm = embs.norm(p=2, dim=2)
            assert (0.99 <= __embs_norm.min().item() <= 1.01) or (0.00 <= __embs_norm.min().item() <= 0.01), f'Min(emb norm)={__embs_norm.min().item()}'
            assert (0.99 <= __embs_norm.max().item() <= 1.01) or (0.00 <= __embs_norm.max().item() <= 0.01), f'Max(emb norm)={__embs_norm.max().item()}'
        
        return scores, tokens, embs

    def _sort_by_scores_with_mmr(self, scores, tokens, embs):
        # Adopting maximal marginal relevance using cosine similarity among embeddings, on the sorted queue of tokens by scores

        # param: scores : List[ List[float] ] = scores for each token in each doc
        # param: tokens : List[ List[str  ] ] = tokens for each doc
        # param: embs   : float tensor (bsize, doc_maxlen, dim) = token embeddings for each doc
        
        positions_sorted = []
        scores_sorted = []
        tokens_sorted = []

        for passage_idx, (score_1dlist, token_1dlist, emb) in enumerate(zip(scores, tokens, embs)):
        
            n_tokens = len(token_1dlist)
            emb = emb[:n_tokens, :] # (doc_maxlen, dim)

            # Sort by scores
            positions_sorted_1dlist = np.argsort(score_1dlist)[::-1].tolist()
            final_positions = positions_sorted_1dlist
            
            def filter_by_positions(inputs, positions):
                return [inputs[position] for position in positions]
            final_scores = filter_by_positions(score_1dlist, final_positions)
            final_tokens = filter_by_positions(token_1dlist, final_positions)

            # Filter punctuations
            def filter_by_masks(inputs, masks):
                return [input for input, mask in zip(inputs, masks) if mask]
            masks = (np.array(final_scores, dtype=np.float32) > 1e-4).tolist()
            final_positions = filter_by_masks(final_positions, masks)
            final_scores = filter_by_masks(final_scores, masks)
            final_tokens = filter_by_masks(final_tokens, masks)

            #?@ debugging
            # print('\n\n')
            # print(f'final_positions={final_positions}')
            # print(f'final_scores={final_scores}')
            # print(f'final_tokens={final_tokens}')
            # print(f'CollectionEncoder: _sort_by_scores_with_mmr: exit');exit()

            assert len(final_positions) == len(final_scores) == len(final_tokens)

            positions_sorted.append(final_positions)
            scores_sorted.append(final_scores)
            tokens_sorted.append(final_tokens)

        #return: positions_sorted : List[ List[int]   ] = (sorted by scores) positions for each token in each doc
        #return: scores_sorted    : List[ List[float] ] = (sorted by scores) scores for each token in each doc
        #return: tokens_sorted    : List[ List[str  ] ] = (sorted by scores) tokens for each doc
        return positions_sorted, scores_sorted, tokens_sorted


    def _save_batch(self, batch_idx, pids, positions, scores, tokens):
        # pids: List[int] = list of pids
        # positions : List[ List[int]   ] = (sorted by scores) positions for each token in each doc
        # scores    : List[ List[float] ] = (sorted by scores) scores for each token in each doc
        # tokens    : List[ List[str  ] ] = (sorted by scores) tokens for each doc
        
        start_time = time.time()

        output_path = os.path.join(self.args.output, "{}.tsv".format(batch_idx))
        
        # Save the scores.
        with open(output_path, 'w') as ofile:
            for pid, position, score, token in zip(pids, positions, scores, tokens):

                # Filter punctuations
                bool_flag = [True if t not in self.pruner.skiplist else False for t in token]
                position, score, token = list(zip(*[
                    (p, s, t) 
                    for p, s, t, b in zip(position, score, token, bool_flag)
                    if b
                ]))
                position, score, token = map(list, (position, score, token))

                outline = f'{pid}\t{ujson.dumps(position)}\t{ujson.dumps(score)}\t{ujson.dumps(token)}\n'
                ofile.write(outline)

                #?@ debugging
                # print(f'outline={outline}')
                # print(f'CollectionEncoder: _save_batch: exit');exit()
        
        #?@ debugging
        # print(f'CollectionEncoder: _save_batch: exit');exit()
        
        throughput = compute_throughput(len(scores), start_time, time.time())
        self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
                        "Saving Throughput =", throughput, "passages per minute.\n")

    def print(self, *args):
        print_message("[" + str(self.process_idx) + "]", "\t\t", *args)

    def print_main(self, *args):
        if self.process_idx == 0:
            self.print(*args)


def compute_throughput(size, t0, t1):
    throughput = size / (t1 - t0) * 60

    if throughput > 1000 * 1000:
        throughput = throughput / (1000*1000)
        throughput = round(throughput, 1)
        return '{}M'.format(throughput)

    throughput = throughput / (1000)
    throughput = round(throughput, 1)
    return '{}k'.format(throughput)
