# from colbert.modeling.tokenization.doc_tokenization import DocTokenizer
import torch
import numpy as np

def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    # positives, negatives: List[ Tuple(str, List[int]) ] = for each query (outer list), tuple of ([positive/negative] passage, list of token positions used for matching)

    assert len(queries) == len(positives) == len(negatives)
    assert bsize is None or len(queries) % bsize == 0

    N = len(queries)

    Q_ids, Q_mask = query_tokenizer.tensorize(queries)
    query_batches = _split_into_batches(Q_ids, Q_mask, bsize=bsize)
    
    D_ids, D_mask, D_pruning_mask = doc_tokenizer.tensorize(positives + [passage for negative_passages in negatives for passage in negative_passages])
    positive_ids = D_ids[:N]
    positive_mask = D_mask[:N]
    positive_pruning_mask = D_pruning_mask[:N]
    # positive_ids, positive_mask: (``N, doc_maxlen``)
    positive_batches = _split_into_batches_doc(positive_ids, positive_mask, positive_pruning_mask, bsize=bsize)

    negative_ids = D_ids[N:]
    negative_mask = D_mask[N:]
    negative_pruning_mask = D_pruning_mask[N:]
    n_negatives = len(negatives[0])
    negative_ids = negative_ids.view(N, n_negatives, -1)
    negative_mask = negative_mask.view(N, n_negatives, -1)
    negative_pruning_mask = negative_pruning_mask.view(N, n_negatives, -1)
    # negative_ids, negative_mask: (``N, n_negatives, doc_maxlen``)
    negative_batches = _split_into_batches_doc(negative_ids, negative_mask, negative_pruning_mask, bsize=bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask, p_pruning_mask), (n_ids, n_mask, n_pruning_mask) \
        in zip(query_batches, positive_batches, negative_batches):

        # Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        Q = (q_ids, q_mask)
        
        # D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        n_ids = n_ids.flatten(0,1)
        n_mask = n_mask.flatten(0,1)
        n_pruning_mask = n_pruning_mask.flatten(0,1)
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)), torch.cat((p_pruning_mask, n_pruning_mask)))

        batches.append((Q, D))

    return batches

#!@ original
# def _sort_by_length(ids, mask, bsize):
#     if ids.size(0) <= bsize:
#         return ids, mask, torch.arange(ids.size(0))

#     indices = mask.sum(-1).sort().indices
#     reverse_indices = indices.sort().indices

#     return ids[indices], mask[indices], reverse_indices

#!@ custom: Add ``pruning_mask`` parameter
def _sort_by_length(ids, mask, pruning_mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, pruning_mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], pruning_mask[indices], reverse_indices


#!@ original
# def _split_into_batches(ids, mask, bsize):
#     batches = []
#     for offset in range(0, ids.size(0), bsize):
#         batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

#     return batches

#!@ custom
def _split_into_batches(*inputs, bsize):
    batches = []
    for offset in range(0, len(inputs[0]), bsize):
        batches.append(tuple(x[offset:offset+bsize] for x in inputs))
    return batches

#!@ custom: ids, mask, bsize -> ids, mask, pruning_mask, bsize
def _split_into_batches_doc(ids, mask, pruning_mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        if (pruning_mask is not None):
            batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize], pruning_mask[offset:offset+bsize]))
        else:
            batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))

    return batches
