import os, numpy as np
import ujson
from tqdm import tqdm

from functools import partial
from colbert.utils.utils import print_message
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples

from colbert.utils.runs import Run

#!@ custom
def load_pruned_index(path, topk, collection_size=None, return_scores=False):
    print_message(f'#> Load {path}')
    
    if (collection_size is None):
        collection_size = sum(1 for _ in open(path))
    
    pruned_index = []
    pruner_scores = []
    
    with open(path, 'r', encoding='utf-8') as ifile:
        for line_idx, line in enumerate(ifile):
    
            if line_idx % (1000*1000) == 0:
                print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

            pid, positions, scores, *others = line.strip().split('\t')
            assert int(pid) == line_idx
            
            positions = ujson.loads(positions)
            positions = positions[:topk]

            pruned_index.append(positions)

            if return_scores:
                pruner_scores.append(ujson.loads(scores)[:topk])
    print()
    
    print_message(f'#> The size of pruned_index={len(pruned_index)} (topk={topk})')
    if not return_scores:
        return pruned_index
    else:
        print_message(f'#> Return pruner scores along with token positions, sorted by scores')
        return pruned_index, pruner_scores

#!@ custom
def load_collection(path):
    print_message("#> Loading collection...")

    collection = []

    with open(path, 'r', encoding='utf-8') as f:
        for line_idx, line in enumerate(f):
            if line_idx % (1000*1000) == 0:
                print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

            pid, passage = line.strip().split('\t')
            assert int(pid) == line_idx
            collection.append(passage)

    print()

    return collection   

class LazyBatcher():
    def __init__(self, args, rank=0, nranks=1):
        self.bsize, self.accumsteps = args.bsize, args.accumsteps

        self.query_tokenizer = QueryTokenizer(args.query_maxlen)
        self.doc_tokenizer = DocTokenizer(args.doc_maxlen)
        self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
        self.position = 0

        self.triples = self._load_triples(args.triples, rank, nranks)
        self.queries = self._load_queries(args.queries)
        
        #!@ custom
        self.collection = load_collection(args.collection) 
        
        #!@ custom
        self.prune_tokens = args.prune_tokens
        if self.prune_tokens:
            self.pruned_index = load_pruned_index(path=args.pruner_filepath, topk=args.pruned_index_size, collection_size=len(self.collection), 
                                                return_scores=False)
            # self.pruned_index: List[List[int]] = remaining top-k token positions after pruning
            
            
    def _load_triples(self, path, rank, nranks):
        """
        NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
        In particular, each subset is perfectly represented in every batch! However, since we never
        repeat passes over the data, we never repeat any particular triple, and the split across
        nodes is random (since the underlying file is pre-shuffled), there's no concern here.
        """
        print_message("#> Loading triples...")

        triples = []

        with open(path) as f:
            for line_idx, line in enumerate(f):
                if line_idx % nranks == rank:
                    
                    # qid, pos, neg = ujson.loads(line)
                    # triples.append((qid, pos, neg))
                    
                    qid, pos, *negs = ujson.loads(line) # multiple negatives for each query
                    triples.append((qid, pos, negs))
                    

        return triples

    def _load_queries(self, path):
        print_message("#> Loading queries...")

        queries = {}

        with open(path) as f:
            for line in f:
                qid, query = line.strip().split('\t')
                qid = int(qid)
                queries[qid] = query

        return queries

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.triples)

    def __next__(self):
        offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
        self.position = endpos

        if offset + self.bsize > len(self.triples):
            raise StopIteration

        queries, positives, negatives = [], [], []

        for position in range(offset, endpos):
            
            qid, ppid, npids = self.triples[position]
            query, pos = self.queries[qid], self.collection[ppid]
            negs = [self.collection[npid] for npid in npids]

            if self.prune_tokens:
                pos_tok_used = list(self.pruned_index[ppid]) # List[int]
                negs_tok_used = [list(self.pruned_index[npid]) for npid in npids]
            else:
                pos_tok_used = [] # Empty list for compatibility
                negs_tok_used = [[] for _ in range(len(npids))] # Empty list for compatibility
            pos = (pos, pos_tok_used)
            negs = [(neg, neg_tok_used) for neg, neg_tok_used in zip(negs, negs_tok_used)]

            queries.append(query)
            positives.append(pos)
            negatives.append(negs)

        return self.collate(queries, positives, negatives)

    def collate(self, queries, positives, negatives):
        # positives, negatives: List[ Tuple(str, List[int]) ] = for each query (outer list), tuple of ([positive/negative] passage, list of token positions used for matching)
        assert len(queries) == len(positives) == len(negatives) == self.bsize

        return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)

    def skip_to_batch(self, batch_idx, intended_batch_size):
        Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
        self.position = intended_batch_size * batch_idx
