import os
import time
import torch
import queue
import threading
import numpy as np
from tqdm import tqdm

from collections import defaultdict, OrderedDict

from colbert.utils.runs import Run
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger

from colbert.utils.utils import print_message, flatten, zipstar
from colbert.indexing.loaders import get_parts


from colbert.training.lazy_batcher import load_expansion_pt



""" unordered.tsv
"""
# from colbert.ranking.faiss_index import FaissIndex
def per_query_retrieve(query, inference, faiss_index, faiss_depth=1024):
    """
    faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range)
    inference = ModelInference(args.colbert, amp=args.amp)
    
    qbatch_text = [queries[qid] for qid in qbatch]
    # List[str] = list of query sequences
    print_message(f"#> Embedding {len(qbatch_text)} queries in parallel...")
    Q = inference.queryFromText(qbatch_text, bsize=512)
    # float tensor ``(n_queries, 32, 182)``
    
    print_message("#> Starting batch retrieval...")
    all_pids = faiss_index.retrieve(args.faiss_depth, Q, verbose=True)
    # List [int] = list of pids
    """
    Q = inference.queryFromText([query])
    all_pids = faiss_index.retrieve(faiss_depth, Q, verbose=True)
    return all_pids
def retrieve(qbatch_text, inference, faiss_index, faiss_depth=1024):
    """
    faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range)
    inference = ModelInference(args.colbert, amp=args.amp)
    
    qbatch_text = [queries[qid] for qid in qbatch]
    # List[str] = list of query sequences
    print_message(f"#> Embedding {len(qbatch_text)} queries in parallel...")
    Q = inference.queryFromText(qbatch_text, bsize=512)
    # float tensor ``(n_queries, 32, 182)``
    
    print_message("#> Starting batch retrieval...")
    all_pids = faiss_index.retrieve(args.faiss_depth, Q, verbose=True)
    # List [int] = list of pids
    """
    Q = inference.queryFromText(qbatch_text)
    all_pids = faiss_index.retrieve(faiss_depth, Q, verbose=True)
    return Q, all_pids
            


def exact_nn_search(index, Q, Q_wts, all_pids, ranking_template):
    # all_query_embeddings: float tensor, size (n_queries, dim, query_maxlen)
    # all_query_weights: float tensor, size (n_queries, query_maxlen)

    print_message("#> Sorting by PID..")
    all_query_indexes, all_pids = zipstar(all_pids)
    sorting_pids = torch.tensor(all_pids).sort()
    all_query_indexes, all_pids = torch.tensor(all_query_indexes)[sorting_pids.indices], sorting_pids.values

    pids = all_pids
    query_indexes = all_query_indexes

    #!@ custom
    scores = index.batch_rank(all_query_embeddings=Q, query_indexes=query_indexes, pids=pids, sorted_pids=True, all_query_weights=Q_wts)

    for query_index, pid, score in zip(query_indexes.tolist(), pids.tolist(), scores):
        ranking_template[0][query_index].append(pid)
        ranking_template[1][query_index].append(score)

def batch_query_expansion(args, colbert_prf, queries, fb_pids, all_query_embeddings, all_query_weights=None):
    
    with torch.no_grad():
        
        assert args.fb_k > 0 and args.beta > 0.0

        print_message(f'#> query expansion from feedback documents; (dim {args.dim}, org_qlen {all_query_embeddings.size(2)}) -> (dim {args.dim}, org_qlen {all_query_embeddings.size(2)} + fb_k {args.fb_k}) (beta={args.beta})')
        _n_fb_pids = list([len(pids) for qid, pids in fb_pids.items()])
        print_message(f'#> The number of feedback documents for {len(_n_fb_pids)} queries: min {np.min(_n_fb_pids)}, max {np.max(_n_fb_pids)}, mean {np.mean(_n_fb_pids):.3f}')
        
        # rel_docs_pids = [fb_pids[qid] for qid in queries]
        rel_docs_pids = [fb_pids.get(qid, []) for qid in queries] # there can be empty feedback documents
        # rel_docs_pids: List[List[int]] = for each qid (outer list), list of pids of relevant documents (inner list)

        # Expand query
        print_message(f'#> Expand query')
        _offset = 0
        all_exp_embeddings = torch.zeros(len(all_query_embeddings), args.fb_k, args.dim, dtype=all_query_embeddings.dtype, device=all_query_embeddings.device)
        all_exp_weights = torch.zeros(len(all_query_embeddings), args.fb_k, dtype=all_query_embeddings.dtype, device=all_query_embeddings.device)
        all_exp_tokens = []

        # Extract expansion query term from relevance feedback documents, for each query
        for query_index, _pids in enumerate(tqdm(rel_docs_pids)):
            if len(_pids) > 0:
                _endpos = _offset + len(_pids)
                
                exp_embs, exp_weights, exp_tokens = colbert_prf.expand(
                    q_embs=all_query_embeddings[query_index].transpose(0,1), 
                    fb_pids=_pids,
                )
                # exp_embs: cpu, float-32 tensor (fb_k, dim) = expansion embeddings
                # exp_weights: cpu, float-32 tensor (fb_k) = weights for the expansion embeddings
                # exp_tokens: List[str] = list of expansion tokens (len=fb_k)

            else:
                _endpos = _offset + 1
                exp_embs = torch.zeros(args.fb_k, args.dim, dtype=all_query_embeddings.dtype, device=all_query_embeddings.device)
                exp_weights = torch.zeros(args.fb_k, dtype=all_query_embeddings.dtype, device=all_query_embeddings.device)
                exp_tokens = ["[EmptyFeedback]"]*args.fb_k

            all_exp_embeddings[query_index, :len(exp_embs)] = exp_embs
            all_exp_weights[query_index, :len(exp_weights)] = exp_weights
            all_exp_tokens.append(exp_tokens)

            _offset = _endpos
        
        # Expand query embeddings, along with corresponding weights
        all_exp_embeddings = all_exp_embeddings.permute(0, 2, 1).contiguous()
        # all_exp_embeddings: float32 tensor, size (n_queries, dim, fb_k), on cpu device 
        all_query_embeddings = torch.cat((all_query_embeddings, all_exp_embeddings), dim=-1)
        # all_query_embeddings: float32 tensor, size (n_queries, dim, query_maxlen + fb_k), on cpu device 
        if (all_query_weights is None):
            all_query_weights = torch.cat((torch.ones(
                all_query_embeddings.size(0), args.query_maxlen, dtype=all_exp_weights.dtype, device=all_exp_weights.device
            ), all_exp_weights), dim=1)
        else:
            all_query_weights = torch.cat((all_query_weights, all_exp_weights), dim=1)
        # all_query_weights: float16 tensor, size (n_queries, query_maxlen + fb_k), on cpu device 

    print('\n\n\n')
    print_message('#> Done!')
    print_message(f'all_query_embeddings (shape {all_query_embeddings.shape}, dtype {all_query_embeddings.dtype}, device {all_query_embeddings.device})')
    print_message(f'all_query_weights (shape {all_query_weights.shape}, dtype {all_query_weights.dtype}, device {all_query_weights.device})')
    print('\n\n\n')
    return all_query_embeddings, all_query_weights


from colbert.utils.utils import batch
from colbert.end_to_end_ranking.faiss_index import FaissIndex
from colbert.end_to_end_ranking.prf_expansion import ColbertPRF
from colbert.labeling.index_part import IndexPartRF
def ranking(args):

    inference = ModelInference(args.colbert, amp=args.amp)
    
    ann_faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range, inference=inference)
    index = IndexPartRF(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True)
    
    if args.prf:
        colbert_prf = ColbertPRF(args=args, faiss_index=ann_faiss_index, index=index, inference=inference)
    
    queries = args.queries

    start_time = time.time()

    # End-to-end Retrieval
    with torch.no_grad():

        qids_in_order = list(queries.keys())
        queries_in_order = list(queries.values())
    
        all_query_rankings = [defaultdict(list), defaultdict(list)]
        if args.prf:
            all_query_rankings_prf = [defaultdict(list), defaultdict(list)] 
        # pids = all_query_rankings[0][query_index]
        # scores = all_query_rankings[1][query_index]

        bsize = len(queries_in_order) # 43

        # Encode queries
        Q = inference.queryFromText(queries_in_order) # Q torch.Size([43, 32, 128])
        # Q: query vectors = ``(bsize, query_maxlen, dim)``

        # ANN Search
        all_pids = ann_faiss_index.retrieve(args.faiss_depth, Q, verbose=True)
        """
        len(all_pids)=43
        all_pids[0][:6]=[7602183, 6717451, 3833877, 5046296, 4292638, 1409060]
        all_pids[1][:6]=[425985, 425987, 425989, 425990, 425993, 1015819]
        """
        # all_pids: ANN result pids = List [ List[int] ]

        # Exact-NN Search
        Q_wts = torch.ones(Q.size(0), Q.size(1), dtype=Q.dtype, device=Q.device)
        all_pids = flatten([
            [
                (query_index, pid) 
                for pid in all_pids[query_index]
            ]
            for query_index, qid in enumerate(queries)
        ])
        exact_nn_search(
            index=index, 
            Q=Q.transpose(1,2).contiguous(), Q_wts=Q_wts, 
            all_pids=all_pids, 
            ranking_template=all_query_rankings,
        )
        # ``all_query_rankings`` is updated
            
        if args.prf:

            # Obtain PRF
            fb_ranking = OrderedDict()

            for query_index, qid in enumerate(queries):
                pids = all_query_rankings[0][query_index]
                scores = all_query_rankings[1][query_index]
                
                K = min(args.depth, len(scores))
                if K == 0: continue

                scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True)
                pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist()

                fb_ranking[qid] = pids[:args.fb_docs]
            
            # Query expansion
            Q, Q_wts = batch_query_expansion(args, colbert_prf, queries, fb_ranking, Q.transpose(1,2).contiguous(), Q_wts)
            Q = Q.transpose(1,2).contiguous()

            if args.reranking_topk:
                # QE: ReRanking
                exact_nn_search(
                    index=index, 
                    Q=Q.transpose(1,2).contiguous(), Q_wts=Q_wts, 
                    all_pids=all_pids, 
                    ranking_template=all_query_rankings_prf,
                )
            
            else:
                # QE: Ranking
                # ANN Search
                _all_pids = ann_faiss_index.retrieve(args.faiss_depth, Q, verbose=True)
                # all_pids: ANN result pids = List [ List[int] ]

                # Exact-NN Search
                _all_pids = flatten([
                    [
                        (query_index, pid) 
                        for pid in _all_pids[query_index]
                    ]
                    for query_index, qid in enumerate(queries)
                ])
                exact_nn_search(
                    index=index, 
                    Q=Q.transpose(1,2).contiguous(), Q_wts=Q_wts, 
                    all_pids=_all_pids, 
                    ranking_template=all_query_rankings_prf,
                )
            
            
    end_time = time.time()
    with open(os.path.join(Run.path, 'elapsed.txt'), 'w') as file:
        file.write(f'{end_time-start_time}\n')

    if args.prf:
        all_query_rankings =  all_query_rankings_prf

    # Save results
    ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores)
    with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
        with torch.no_grad():
            for query_index, qid in enumerate(queries):
                if query_index % 1000 == 0:
                    print_message("#> Logging query #{} (qid {}) now...".format(query_index, qid))

                pids = all_query_rankings[0][query_index]
                scores = all_query_rankings[1][query_index]

                K = min(args.depth, len(scores))

                if K == 0:
                    continue

                scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True)

                pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist()

                ranking = [(score, pid, None) for pid, score in zip(pids, scores)]
                assert len(ranking) <= args.depth, (len(ranking), args.depth)

                rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else [])

    print(ranking_logger.filename)
    print_message('#> Done.\n')