import os
import ast
import json
import pickle
import numpy as np
from tqdm import tqdm
from collections import defaultdict, Counter
from pathlib import Path
from datasets import load_dataset, Dataset
import pandas as pd
import random
import collections
from scipy.stats import pearsonr
random.seed(9001)

import csv

ROOT_DIR = os.getenv('ROOT_DIR', 'default')

encoders = ['ance', 'contriever', 'dpr', 'gtr', 'simcse', 'tasb']

def percent_round(_value):
    return round(_value * 100, 1)

def append_to_result(result, file):
    if os.path.exists(file):
        append_write = 'a'
    else:
        append_write = 'w'

    if isinstance(result, dict):
        result = json.dumps(result)

    with open(file, append_write) as w:
        w.write(result + '\n')

def answer_passage_hit(answer_list, passage_list):
    """
    answer_list: list of golden answers
    passage_list: list of the cbontents retrieve
    """
    passage_con = ' '.join(passage_list)
    for answer in answer_list:
        if answer.lower() in passage_con.lower():
            return True
    return False

def answer_passage_hit_pos(answer_list, passage_list):
    hit_pos = []
    for answer in answer_list:
        for passage_idx, passage in enumerate(passage_list):
            if answer.lower() in passage.lower():
                hit_pos.append(passage_idx)

    return list(set(hit_pos))


def remove_duplicates_preserve_order(seq):
    seen = set()
    result = []
    for item in seq:
        if item not in seen:
            seen.add(item)
            result.append(item)
    return result

def evaluate(qid_answer_dict,
             qid_content_list_dict,
             topk=20):
    """
    qid_answer_list: {qid: [answer1, answer2]}
    qid_content_list_dict: {qid: [content_str1, content_str2]} (content can be passage or proposition)
    """
    
    # the query id mapped to the passages's content
    query_hits = {}
    for qid in tqdm(qid_answer_dict.keys()):
        passages_deduplicate = remove_duplicates_preserve_order(qid_content_list_dict[qid])
        query_hits[qid] = int(answer_passage_hit(qid_answer_dict[qid], passages_deduplicate[:topk]))

    # analysis the results
    nq_hits = []
    webq_hits = []
    eq_hits = []
    squad_hits = []
    tqa_hits = []
    
    nq_hit_qids = []
    webq_hit_qids = []
    eq_hit_qids = []
    squad_hit_qids = []
    tqa_hit_qids = []
    
    for qid, hit in query_hits.items():
        if qid.startswith('natural_questions'):
            nq_hits.append(hit)
            if hit:
                nq_hit_qids.append(qid)
        elif qid.startswith('web_questions'):
            webq_hits.append(hit)
            if hit:
                webq_hit_qids.append(qid)
        elif qid.startswith('entity_question'):
            eq_hits.append(hit)
            if hit:
                eq_hit_qids.append(qid)
        elif qid.startswith('squad'):
            squad_hits.append(hit)
            if hit:
                squad_hit_qids.append(qid)
        elif qid.startswith('trivia'):
            tqa_hits.append(hit)
            if hit:
                tqa_hit_qids.append(qid)
        else:
            raise Exception(f'The query ID {qid} is not valid.')
    
    results = {}
    results['nq'] = round(sum(nq_hits) / len(nq_hits) * 100, 1)
    results['squad'] = round(sum(squad_hits) / len(squad_hits) * 100, 1)
    results['tqa'] = round(sum(tqa_hits) / len(tqa_hits) * 100, 1)
    results['macro'] = round((results['nq'] + results['squad'] + results['tqa']) / 3, 1)
    results['micro'] = round((sum(nq_hits) + sum(squad_hits) + sum(tqa_hits)) / (len(nq_hits) + len(squad_hits) + len(tqa_hits)) * 100, 1)
    
    
    print("Dataset {0:20}; Length: {1:5}; Acc: {2:4}".format('natural_questions', len(nq_hits), results['nq']))
    print("Dataset {0:20}; Length: {1:5}; Acc: {2:4}".format('trivia_questions', len(tqa_hits), results['tqa']))
    print("Dataset {0:20}; Length: {1:5}; Acc: {2:4}".format('squad', len(squad_hits), results['squad']))
    
    print("Dataset {0:20}; Length: {1:5}; Acc: {2:4}".format('Macro Avg', '', results['macro']))
    print("Dataset {0:20}; Length: {1:5}; Acc: {2:4}".format('Macro Avg', '', results['micro']))

    
    results['nq_hit_qids'] = nq_hit_qids
    results['squad_hit_qids'] = squad_hit_qids
    results['tqa_hit_qids'] = tqa_hit_qids

    return results

def evaluate_qrels(qid_pids_dict, qrels_dict, topk=5):
    """
    Evaluate the retrieval results based on the qrels
    qid_pids_dict: 
    """
    recalls = []
    for _qid, _pids in qid_pids_dict.items():
        _pids = remove_duplicates_preserve_order(_pids)
        gold_pid = qrels_dict[_qid]
        recalls.append(int(gold_pid in _pids[:topk]))
        
    return round(np.mean(recalls) * 100, 2)


def load_qrels(qrels_file):
    reader = csv.reader(open(qrels_file, encoding="utf-8"),
                        delimiter='\t', quoting=csv.QUOTE_MINIMAL)
    
    next(reader)
    
    qrels = {}
    
    for _id, _row in enumerate(reader):
        query_id, corpus_id, score = _row[0], _row[1], int(_row[2])
        
        if query_id not in qrels:
            qrels[query_id] = {corpus_id: score}
        else:
            qrels[query_id][corpus_id] = score
            
    return qrels

def get_records(file_path, searcher=None):
    """
        Get the records from the file
        searcher is a pyserini searcher
        There will be two types of inputs:
        1. .aug: qid, pid, content
        2. .txt: the pyserini retrieval result

        Return: List of Dict
        [{
            'qid': str,
            'pid': str,
            'content': str
        }]
    """

    if file_path.endswith('.aug') or file_path.endswith('tsv') or 'rrf' in file_path:
        #read pd dataframe
        records = pd.read_csv(file_path, sep='\t').to_dict('records')
        if searcher == None:
            return records

        if 'content' not in records[0]:
            for record in records:
                doc = searcher.doc(record['pid'])
                content = json.loads(doc.raw())['contents'].replace('\n', '\\n').replace('\t', ' ')
                record['content'] = content

    elif file_path.endswith('.txt'):
        records = []
        list_records = pd.read_csv(file_path, sep=' ', header=None).to_records()
        for record in tqdm(list_records):
            if searcher is not None:
                doc = searcher.doc(record[3])
                content = json.loads(doc.raw())['contents'].replace('\n', '\\n').replace('\t', ' ')
            else:
                content = None
            record = {
                'qid': record[1],
                'pid': record[3],
                'score': record[5],
                'content': content,
            }
            records.append(record)
    
    return records

def preprocess_beir_dataset(retrieval_dir):
    # process the corpus
    corpus_file = os.path.join(retrieval_dir, 'corpus.jsonl')
    reformat_lines = []
    with open(corpus_file, 'r') as f:
        corpus = [json.loads(line) for line in f.readlines()]
        for line in corpus:
            reformat_lines.append({
                'id': line['_id'],
                'contents': line['title'] + '\n' + line['text'].replace('\n', '\\n')
            })

    write_lines = '\n'.join([json.dumps(line) for line in reformat_lines])
    write_file = os.path.join(retrieval_dir, 'corpus.reformat.jsonl')
    append_to_result(write_lines, write_file)
    # then, we can pass it to the chunk.py for corpus preprocessing
    
    # process the query
    query_file = os.path.join(retrieval_dir, 'queries.jsonl')
    reformat_lines = []
    with open(query_file, 'r') as f:
        lines = [json.loads(line) for line in f.readlines()]
        for line in lines:
            reformat_lines.append({
                'id': line['_id'],
                'title': line['text'],
            })

    write_lines = '\n'.join([json.dumps(line) for line in reformat_lines])
    write_file = os.path.join(retrieval_dir, "queries.reformat.jsonl")
    append_to_result(write_lines, write_file)


def process_decomposed_query(query_decomposition_file):
    retrieval_dir = os.path.dirname(query_decomposition_file)
    whole_file = os.path.join(retrieval_dir, 'queries.whole.jsonl')
    multi_file = os.path.join(retrieval_dir, 'queries.multi.jsonl')

    whole_lines = []
    multi_lines = []
    
    with open(query_decomposition_file, 'r') as f:
        lines = [json.loads(line) for line in f.readlines()]
        for line in lines:
            if len(line['prop_generation']) <= 1:
                continue
            
            _id = line['id']
            _title = line['input']
            
            whole_lines.append({
                'id': _id,
                'title': _title,
            })
            
            for _idx, prop in enumerate(line['prop_generation']):
                multi_lines.append({
                    'id': _id + '#' + str(_idx),
                    'title': prop,
                })
                
    write_lines = '\n'.join([json.dumps(line) for line in whole_lines])
    append_to_result(write_lines, whole_file)

    write_lines = '\n'.join([json.dumps(line) for line in multi_lines])
    append_to_result(write_lines, multi_file)