# Copyright 2021 Reranker Author. All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
import pickle as pkl
import random
from typing import Union, List, Tuple, Dict

import datasets
import torch
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding
from transformers import PreTrainedTokenizer, BatchEncoding
import json
from .arguments import DataArguments, RerankerTrainingArguments


class TrainDataset(Dataset):

    def __init__(
            self,
            args: DataArguments,
            path_to_tsv: Union[List[str], str],
            tokenizer: PreTrainedTokenizer,
            train_args: RerankerTrainingArguments = None,
    ):
        self.nlp_dataset = datasets.load_dataset(
            'json',
            data_files=path_to_tsv,
        )['train']

        self.tok = tokenizer
        self.SEP = [self.tok.sep_token_id]
        self.args = args
        self.total_len = len(self.nlp_dataset)
        self.train_args = train_args
        query_file = args.corpus_path + 'queries.all.tsv'
        collection_file = args.corpus_path + 'collection.tsv'
        self.qid2txt = self.read_txt(query_file)
        self.pid2txt = self.read_txt(collection_file)

    def read_txt(self, query_file):
        qid2txt = {}
        for line in open(query_file, 'r', encoding='utf-8'):
            items = line.strip().split('\t')
            qid2txt[items[0]] = items[1]
        return qid2txt

    def __len__(self):
        return self.total_len

    def create_one_example(self, doc_encoding: str):
        item = self.tok.encode_plus(
            doc_encoding,
            truncation=True,
            max_length=self.args.max_len,
            padding=False,
        )
        return item

    def __getitem__(self, item) -> [List[BatchEncoding], List[int]]:
        group = self.nlp_dataset[item]
        qid = group['qry']
        qry = self.qid2txt[qid]
        pos_pid = random.choice(group['pos'])
        neg_group = group['neg']  # [nid for nid in group['neg'] if nid not in adds]
        if len(neg_group) < self.args.train_group_size:
            negs = random.choices(neg_group, k=self.args.train_group_size)
        else:
            negs = random.sample(neg_group, k=self.args.train_group_size)
        idx = random.randint(0, self.args.train_group_size - 1)
        negs[idx] = pos_pid
        group_batch = []
        for neg_id in negs:
            psg = qry + ', text: ' + self.pid2txt[neg_id]
            item = self.create_one_example(psg)
            item['label'] = idx
            group_batch.append(item)
        return group_batch

class PredictionDataset(Dataset):
    query_columns = ['qid', 'query']
    document_columns = ['pid', 'passage']

    def __init__(self, args: DataArguments, path_to_json: List[str], tokenizer: PreTrainedTokenizer, max_len=128):
        self.nlp_dataset = datasets.load_dataset(
            'json',
            data_files=path_to_json,
        )['train']
        self.args = args
        self.tok = tokenizer
        self.max_len = max_len
        query_file = args.corpus_path + 'queries.all.tsv'
        collection_file = args.corpus_path + 'collection.tsv'
        self.qid2txt = self.read_txt(query_file)
        self.pid2txt = self.read_txt(collection_file)

    def __len__(self):
        return len(self.nlp_dataset)

    def read_txt(self, query_file):
        qid2txt = {}
        for line in open(query_file, 'r', encoding='utf-8'):
            items = line.strip().split('\t')
            qid2txt[items[0]] = items[1]
        return qid2txt

    def create_one_example(self, doc_encoding: str):
        item = self.tok.encode_plus(
            doc_encoding,
            truncation=True,
            max_length=self.args.max_len,
            padding=False,
        )
        return item

    def __getitem__(self, item) -> [List[BatchEncoding], List[int]]:
        group = self.nlp_dataset[item]
        qid = group['qry']
        qry = self.qid2txt[qid]
        negs = group['neg'][:self.args.train_group_size]
        group_batch = []
        for neg_id in negs:
            psg = qry + ', text: ' + self.pid2txt[neg_id]
            group_batch.append(self.create_one_example(psg))
        return group_batch

class DocTrainDataset(Dataset):
    def __init__(
            self,
            args: DataArguments,
            path_to_tsv: Union[List[str], str],
            tokenizer: PreTrainedTokenizer,
            train_args: RerankerTrainingArguments = None,
    ):
        self.nlp_dataset = datasets.load_dataset(
            'json',
            data_files=path_to_tsv,
        )['train']

        self.tok = tokenizer
        self.args = args
        self.total_len = len(self.nlp_dataset)
        self.train_args = train_args
        self.read_corpus()
        self.part_len = args.part_len
        self.part_num = args.part_num

    def read_corpus(self):
        # self.doc_map = pkl.load(open('doc_map.pkl', 'rb'))
        # self.qry_map = pkl.load(open('qry_map.pkl', 'rb'))
        collection_path = [self.args.corpus_path + '/msmarco-docs.tsv']
        self.collection = datasets.load_dataset(
            'csv',
            data_files=collection_path,
            column_names=['did', 'url', 'title', 'body'],
            delimiter='\t',
            ignore_verifications=True,
        )['train']
        qry_collection_path =[self.args.corpus_path +  '/msmarco-doctrain-queries.tsv']
        qry_collection = datasets.load_dataset(
            'csv',
            data_files=qry_collection_path,
            column_names=['qid', 'qry'],
            delimiter='\t',
            ignore_verifications=True,
        )['train']

        self.doc_map = {x['did']: idx for idx, x in enumerate(self.collection)}
        self.qry_map = {str(x['qid']): x['qry'] for idx, x in enumerate(qry_collection)}

        if not os.path.exists('./doc_map.pkl'):
            pkl.dump(self.doc_map, open('doc_map.pkl', 'wb'))
            pkl.dump(self.qry_map, open('qry_map.pkl', 'wb'))

    def create_one_example(self, qry_encoding: str, doc_encoding: str):
        item = self.tok.encode_plus(
            qry_encoding,
            doc_encoding,
            truncation='only_second',
            max_length=self.args.max_len,
            padding=False,
        )
        return item

    def __len__(self):
        return self.total_len

    def __getitem__(self, item) -> List[BatchEncoding]:
        group = self.nlp_dataset[item]
        examples = []
        group_batch = []
        qry = self.qry_map[group['qry']]
        pos_pid = random.choice(group['pos'])
        ngroup = group['neg']
        # if random.randint(0, 2) == 0:
        #     ngroup = group['random'][:50]
        if len(ngroup) < self.args.train_group_size:
            negs = random.choices(ngroup, k=self.args.train_group_size)
        else:
            negs = random.sample(ngroup, k=self.args.train_group_size)
        idx = random.randint(0, self.args.train_group_size - 1)
        negs[idx] = pos_pid
        for neg_entry in negs:
            doc_index = self.doc_map[neg_entry]
            obj = self.collection[doc_index]
            url, title, body = map(lambda v: v if v else '', [obj['url'], obj['title'], obj['body']])
            body_token = body.split()
            for i in range(self.part_num):
                body_select = body_token[i * self.part_len: (i + 1) * self.part_len]
                if len(body_select) or i == 0:
                    doc_detail = url + self.tok.sep_token + title + self.tok.sep_token + ' '.join(body_select)
                    part = 1
                else:
                    doc_detail = url + self.tok.sep_token + title + self.tok.sep_token + 'null'
                    part = 0
                item = self.create_one_example(qry, doc_detail)
                item['label'] = idx
                item['part'] = part
                group_batch.append(item)
        return group_batch


class DocPredictionDataset(Dataset):

    def __init__(self, args: DataArguments, path_to_json: List[str], tokenizer: PreTrainedTokenizer, max_len=128):
        self.nlp_dataset = datasets.load_dataset(
            'json',
            data_files=path_to_json,
        )['train']
        self.tok = tokenizer
        self.max_len = max_len
        self.args = args
        self.read_corpus()
        self.part_len = args.part_len

    def read_corpus(self):
        # self.doc_map = pkl.load(open('doc_map.pkl', 'rb'))
        # self.qry_map = pkl.load(open('qry_map.pkl', 'rb'))
        collection_path = [self.args.corpus_path + '/msmarco-docs.tsv']
        self.collection = datasets.load_dataset(
            'csv',
            data_files=collection_path,
            column_names=['did', 'url', 'title', 'body'],
            delimiter='\t',
            ignore_verifications=True,
        )['train']
        qry_collection_path =[self.args.corpus_path + '/msmarco-doctrain-queries.tsv']
        qry_collection = datasets.load_dataset(
            'csv',
            data_files=qry_collection_path,
            column_names=['qid', 'qry'],
            delimiter='\t',
            ignore_verifications=True,
        )['train']

        self.doc_map = {x['did']: idx for idx, x in enumerate(self.collection)}
        self.qry_map = {str(x['qid']): x['qry'] for idx, x in enumerate(qry_collection)}

        if not os.path.exists('./doc_map.pkl'):
            pkl.dump(self.doc_map, open('doc_map.pkl', 'wb'))
            pkl.dump(self.qry_map, open('qry_map.pkl', 'wb'))

    def __len__(self):
        return len(self.nlp_dataset)

    def create_one_example(self, qry_encoding: str, doc_encoding: str):
        item = self.tok.encode_plus(
            qry_encoding,
            doc_encoding,
            truncation='only_second',
            max_length=self.args.max_len,
            padding=False,
        )
        return item

    def __getitem__(self, item):
        group = self.nlp_dataset[item]
        qry = self.qry_map[group['qry']]
        negs = group['neg'][:self.args.eval_group_size]
        examples = []
        group_batch = []
        part_num = self.args.eval_part_num
        for neg_entry in negs:
            doc_index = self.doc_map[neg_entry]
            obj = self.collection[doc_index]
            url, title, body = map(lambda v: v if v else '', [obj['url'], obj['title'], obj['body']])
            body_token = body.split()
            for i in range(part_num):
                body_select = body_token[i * self.part_len: (i + 1) * self.part_len]
                if len(body_select) or i == 0:
                    doc_detail = url + self.tok.sep_token + title + self.tok.sep_token + ' '.join(body_select)
                    part = 1
                else:
                    doc_detail = url + self.tok.sep_token + title + self.tok.sep_token + 'null'
                    part = 0
                item = self.create_one_example(qry, doc_detail)
                item['part'] = part
                group_batch.append(item)
        while len(group_batch) < self.args.eval_group_size * part_num:
            item = self.tok.encode_plus(
                qry,
                '',
                truncation='only_second',
                max_length=self.max_len,
                padding=False,
            )
            item['part'] = 0
            group.append(item)
        return group_batch

@dataclass
class GroupCollator(DataCollatorWithPadding):
    """
    Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
    and pass batch separately to the actual collator.
    Abstract out data detail for the model.
    """

    def __call__(
            self, features
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        if isinstance(features[0], list):
            features = sum(features, [])
        return super().__call__(features)

