import os
import json
import pickle
import numpy as np
from tqdm import tqdm
import collections
from collections import defaultdict, Counter
from pathlib import Path
from datasets import load_dataset, Dataset
import pandas as pd
import random

def get_dataset(task, split, root_dir=None):
    """
        DenseX datasets: nq, tqa, webq, squad, entity_question
        the format will be {id, question}
        return: List of dict
    """

    return_dataset = []
    if task=='nq':
        dataset = load_dataset('kilt_tasks', name=dataset, split=split)
        for ins in dataset:
            return_dataset.append(
                {
                    'id': ins['id'],
                    'input': ins['input'],
                    'output': ins['output'],
                }
            )    
        
    elif task=='tqa':
        dataset = load_dataset('trivia_qa', name='rc', split=split)
        for ins in dataset:
            return_dataset.append(
                {
                    'id': ins['question_id'],
                    'input': ins['question'],
                    'output': ins['answer'],
                }
            )
        
    elif task=='webq':
        # no validation
        if split == 'validation':
            return return_dataset
        
        dataset = load_dataset('web_questions', split=split)
        for ins in dataset:
            return_dataset.append(
                {
                    'id': ins['url'],
                    'input': ins['question'],
                    'output': ins['answers'],
                }
            )
        
    elif task=='squad':
        # no test
        if split == 'test':
            return return_dataset

        dataset = load_dataset('squad', split=split)
        for ins in dataset:
            return_dataset.append(
                {
                    'id': ins['id'],
                    'input': ins['question'],
                    'output': ins['answers'],
                }
            )

    elif task=='entity_question':
        if root_dir is None:
            raise Exception('Please provide the correct root directory for entity_question.')
        if split == 'validation':
            eq_dir = os.path.join(root_dir, f'data/dataset/dev')
        else:
            eq_dir = os.path.join(root_dir, f'data/dataset/{split}')
        data_files = [f for f in os.listdir(eq_dir) if f.endswith('.json')]

        for data_file in data_files:
            with open(os.path.join(eq_dir, data_file), 'r') as f:
                ins_list = json.load(f)
            file_index = data_file.split('.')[0]
            
            for ins_idx, ins in enumerate(ins_list):
                return_dataset.append(
                    {
                        'id': file_index + '-' + str(ins_idx),
                        'input': ins['question'],
                        'output': ins['answers'],
                    }
                )
    else:
        raise Exception(f"Dataset {args.dataset} is beyond the consideration.")

    return return_dataset

def get_wiki_factoid():
    """
        Return the wiki factoid of papssages
        Returns:
            {wiki_passage_id (str): content (str)}
    """
    ROOT_DIR = os.getenv('ROOT_DIR', 'default')
    factoid_path = os.path.join(ROOT_DIR, 'retrieval/topics/passages/factoid-wiki.jsonl')
    with open(factoid_path, 'r') as f:
        # factoids = [json.loads(line) for line in f.readlines()]
        factoids = [json.loads(line) for line in tqdm(f, desc="Reading file")]
        
    factoid_id_content_dict = {line['id']: line['contents'] for line in factoids}
    
    return factoid_id_content_dict