#from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from dataclasses import dataclass, field
import logging
import random
import json
import jsonlines
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizerFast
from tqdm import tqdm
from typing import Optional

from .tokenizer import get_tokenizer
from .metric import compute_f1_sets, compute_exact_sets

logger = logging.getLogger(__name__)

@dataclass
class DataPredArguments:
    pred_file: Optional[str] = field(default=None)
    max_length: Optional[int] = field(default=128)

    def __post_init__(self):
        if self.pred_file is None:
            raise ValueError("Need either a training/evalation file.")

class DPRPredDataset(Dataset):
    def __init__(self, input_file, tokenizer, config):
        self.config = config
        self.tokenizer = tokenizer
        self.max_length \
            = self.config.max_length
        self.questions, \
            self.questions_tok, \
            self.answers, \
            self.bm25_questions, \
            self.bm25_questions_tok, \
            self.bm25_answers, \
            self.bm25_sims \
                = self.get_data(input_file)
        assert len(self.questions_tok["input_ids"]) == len(self.answers)
    
    def get_data(self, fname):
        with jsonlines.open(fname, "r") as reader:
            qas = [ \
                r for r in tqdm( \
                    reader, desc="Reading {}".format(fname.split('/')[-1])) \
            ]
        
        questions = []
        for qa in tqdm(qas, desc="Building dataset"):
            questions.append(qa["question"])

        print("Tokenizing questions")
        questions_tok = \
            self.tokenizer.batch_encode_plus( \
                questions, \
                padding="max_length", \
                max_length=self.max_length, \
                truncation=True, \
                return_tensors='pt' \
            )
        return ( \
            questions,
            questions_tok, \
        )
    
    def __len__(self):
        return len(self.questions_tok["input_ids"])

    def __getitem__(self, idx):
        question = self.questions[idx]
        q_input_ids = self.questions_tok["input_ids"][idx]
        q_token_type_ids = self.questions_tok["token_type_ids"][idx]
        q_attention_mask = self.questions_tok["attention_mask"][idx]

        return {
            "question": question,
            "q_input_ids": q_input_ids,
            "q_token_type_ids": q_token_type_ids,
            "q_attention_mask": q_attention_mask,
        }

def pred_data_collator(samples):
    if len(samples) == 0:
        return {}
    bsize = len(samples)
    
    questions = [s["question"] for s in samples]
    input_ids = torch.stack([s["q_input_ids"] for s in samples])
    token_type_ids = torch.stack([s["q_token_type_ids"] for s in samples])
    attention_mask = torch.stack([s["q_attention_mask"] for s in samples])
    
    return {
        "n_samples": bsize,
        "questions": questions,
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
    }
