import os, numpy as np, csv, random
import ujson
from tqdm import tqdm

from functools import partial
from pruner.utils.utils import print_message
from pruner.modeling.tokenization import DocTokenizer, tensorize

from pruner.utils.runs import Run


def load_data(path, rank, nranks):
    """
    NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
    In particular, each subset is perfectly represented in every batch! However, since we never
    repeat passes over the data, we never repeat any particular triple, and the split across
    nodes is random (since the underlying file is pre-shuffled), there's no concern here.
    """
    print_message("#> Loading data...")

    data = []

    """ 
    NOTE: Data Format for ``pseudo_query_extractor.dev.small.tsv`` and ``pseudo_query_extractor.train.tsv``
    (pid) \t (passage) \t (argmax positions of WordPiece tokens) \t (max cosine similarity scores for each WordPiece token)
    # pid: 0
    # passage: the presence of communication amid scientific minds was equally important to the success of the manhattan project as scientific intellect was. the only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.	
    # label: [0, 9, 12, 13, 14, 15, 16, 17, 18, 21, 22, 39, 55]	
    # soft_label: [0.8571652770042419, 0.5566365122795105, 0.5726322531700134, 0.4544663429260254, 0.4892122447490692, 0.4742664396762848, 0.3764260709285736, 0.46347975730895996, 0.4888140857219696, 0.6679250597953796, 0.626431405544281, 0.6127405166625977, 0.6533979177474976, 0.8302921056747437, 0.9268165826797485, 0.8959845900535583, 0.904545783996582, 0.9323557615280151, 0.9455055594444275, 0.6288437843322754, 0.41816484928131104, 0.5538116693496704, 0.6492456793785095, -1.0, 0.42318448424339294, 0.3725559413433075, 0.32031840085983276, 0.3326133191585541, 0.39637669920921326, 0.7689586281776428, 0.5710819363594055, 0.7196811437606812, 0.7894168496131897, 0.7972249388694763, 0.7457424402236938, 0.5736616849899292, 0.6780050992965698, 0.5783674716949463, 0.6286964416503906, 0.6224371194839478, 0.8099508881568909, 0.9170602560043335, 0.5875765681266785, 0.5694447159767151, -1.0, 0.4286840558052063, 0.4929162263870239, 0.4511883556842804, 0.4441525340080261, 0.4389607012271881, 0.4576737880706787, 0.3871139883995056, 0.4199727177619934, 0.5283070802688599, -1.0, 0.7963091135025024]
    ** The WordPiece tokens include special tokens (i.e., [CLS], [D], and [SEP]), in addition to the tokens from the original passage
    """

    with open(path, 'r', encoding='utf-8') as ifile:
        reader = csv.reader(ifile, delimiter='\t')
        for i_row, row in enumerate(reader):
            
            if i_row % nranks == rank:
                pid, passage, label, _ = row
                label = ujson.loads(label)
                num_bert_tokens = len(ujson.loads(_))
                assert num_bert_tokens <= 180

                _label = np.zeros((num_bert_tokens), dtype=np.float32)
                _label[label] = 1.0
                data.append((passage, _label))

    return data



class Batcher():
    def __init__(self, args, rank=0, nranks=1):
        self.bsize, self.accumsteps = args.bsize, args.accumsteps

        self.doc_tokenizer = DocTokenizer(args.doc_maxlen)
        self.tensorize = partial(tensorize, self.doc_tokenizer) #TODO
        self.position = 0

        self.data = load_data(args.data, rank, nranks)
        # self.data: List[Tuple(str, ndarray(float64)))] = The list of tuples of (passage, label)
    
    def __iter__(self): return self
    def __len__(self): return len(self.data)

    def __next__(self):
        offset, endpos = self.position, min(self.position + self.bsize, len(self.data))
        self.position = endpos

        #?@ debugging
        # print(f'offset={offset}, endpos={endpos} (bsize={self.bsize})')

        if offset + self.bsize > len(self.data):
            # raise StopIteration #!@ original
            
            #?@ debugging
            # print(f'[Shuffle] (before) self.data[0]={self.data[0]}; \n\toffset={offset}, endpos={endpos} (position={self.position}, bsize={self.bsize})')
            
            random.shuffle(self.data)
            offset, endpos = 0, self.bsize
            self.position = 0

            #?@ debugging
            # print(f'[Shuffle] (after ) self.data[0]={self.data[0]}; \n\toffset={offset}, endpos={endpos} (position={self.position}, bsize={self.bsize})')

        passages, labels = [], []

        for position in range(offset, endpos):
            
            passage, label = self.data[position] 
            passages.append(passage)
            labels.append(label)

        return self.collate(passages, labels)
        # return: List[ Tuple ( List[tensor], List[tensor], List[tensor] ) ]
        # = The list of batches,
        # where each batch consists of (token ids, token masks, token labels)


    def collate(self, passages, labels):
        assert len(passages) == len(labels) == self.bsize

        return self.tensorize(passages, labels, self.bsize // self.accumsteps)
