from collections import defaultdict
from html import unescape
import json
import logging
from tqdm import tqdm
# from tqdm.auto import tqdm

from elasticsearch import Elasticsearch, helpers

logger = logging.getLogger(__name__)


def get_valid_links(passage, strict=False, valid_titles=None):
    if not strict:
        return {tgt: mention_spans for tgt, mention_spans in passage['hyperlinks']
                if len(mention_spans) > 0 and tgt != unescape(passage['title']) and (valid_titles is None or
                                                                                     tgt in valid_titles)}

    hyperlinks = {}
    abs_start, abs_end = passage['sentence_spans'][0][0], passage['sentence_spans'][-1][1]
    for tgt, mention_spans in passage['hyperlinks'].items():
        if len(mention_spans) > 0 and tgt != unescape(passage['title']) and (valid_titles is None or
                                                                             tgt in valid_titles):
            valid_mention_spans = []
            for anchor_span in mention_spans:
                if abs_start <= anchor_span[0] < anchor_span[1] < abs_end:
                    valid_mention_spans.append(anchor_span)
            if len(valid_mention_spans) > 0:
                hyperlinks[tgt] = valid_mention_spans

    return hyperlinks


def load_corpus(corpus_path, for_hotpot=True, require_hyperlinks=False, index_name='enwiki-20171001-paragraph-3'):
    corpus = dict()
    title2id = dict()
    with open(corpus_path) as f:
        doc_id = None
        for line in f:
            segs = line.strip().split('\t')
            p_id, text, title = segs[:3]
            p_id, text, title = p_id.strip(), text.strip(), title.strip()
            if p_id == 'id':
                continue
            if p_id in corpus:
                logger.warning(f"Duplicate passage id: {p_id} ({title})")
            corpus[p_id] = {
                "title": title,
                "text": text,
                "sentence_spans": []
            }
            if for_hotpot:
                corpus[p_id]['sentence_spans'] = [tuple(span) for span in eval(segs[3])]
            unescaped_title = unescape(title)
            if ''.join(p_id.split('_')[:-1]) != doc_id and unescaped_title in title2id:
                logger.warning(f"Duplicate title: {unescaped_title}")
            doc_id = p_id.split('_')[0]
            if for_hotpot:
                title2id[unescaped_title] = p_id  # passage id
            else:
                title2id[unescaped_title] = doc_id  # document id

    if require_hyperlinks:
        es = Elasticsearch(['10.60.0.59:9200'], timeout=30)
        if for_hotpot:
            query = {"query": {"term": {"for_hotpot": True}}}
        else:
            query = {"query": {"match_all": {}}}
        para_num = es.count(index=index_name, body=query)['count']
        for hit in tqdm(helpers.scan(es, query=query, index=index_name), total=para_num):
            para = hit['_source']
            if para['para_id'] in corpus:
                corpus[para['para_id']]['hyperlinks'] = para['hyperlinks']
            else:
                assert para['para_id'][-3:] == '_-1'
                corpus[para['para_id']] = {
                    "title": para['title'],
                    "text": para['text'],
                    "sentence_spans": [],
                    "hyperlinks": para['hyperlinks']
                }

    logger.info(f"Loaded {len(corpus)} passages from {corpus_path}")

    return corpus, title2id


def load_corpus_(corpus_path, for_hotpot=True, require_hyperlinks=False):
    corpus = dict()
    if for_hotpot:
        title2id = dict()
    else:
        title2id = defaultdict(list)
    with open(corpus_path) as f:
        doc_id = None
        for line in f:
            segs = line.strip().split('\t')
            p_id, text, title, hyperlinks = segs[:4]
            p_id, text, title = p_id.strip(), text.strip(), title.strip()
            unescaped_title = unescape(title)
            if p_id == 'id':
                continue
            if p_id in corpus:
                logger.warning(f"Duplicate passage id: {p_id} ({title})")
            corpus[p_id] = {
                "title": title,
                "text": text,
                "sentence_spans": []
            }
            if require_hyperlinks:
                corpus[p_id]['hyperlinks'] = {
                    unescape(t): [tuple(a['span']) for a in anchors if a['span'][0] != a['span'][1]]
                    for t, anchors in json.loads(hyperlinks).items()
                }
            if for_hotpot:
                assert len(segs) > 4
                corpus[p_id]['sentence_spans'] = [tuple(span) for span in eval(segs[-1])]
            if '_'.join(p_id.split('_')[:-1]) != doc_id and unescaped_title in title2id:
                logger.warning(f"Duplicate title: {unescaped_title}")
            doc_id = p_id.split('_')[0]
            if for_hotpot:
                title2id[unescaped_title] = p_id
            else:
                title2id[unescaped_title].append(p_id)
    logger.info(f"Loaded {len(corpus)} passages from {corpus_path}")

    return corpus, title2id


def load_qas(file_path):
    qas_samples = []
    with open(file_path) as f:
        for line in f:
            q_id, question, answer, sp_facts = line.strip().split('\t')
            sp_facts = json.loads(sp_facts)
            qas_samples.append((q_id, (question, answer, sp_facts)))
    logger.info(f"Loaded {len(qas_samples)} samples from the {file_path}")
    return qas_samples


def load_samples(file_path, test=True):
    samples = []
    with open(file_path) as f:
        for line in f:
            segs = line.strip().split('\t')
            q_id, question = segs[:2]
            if test:
                samples.append((q_id, (question,)))
            else:
                samples.append((q_id, (question, segs[2], json.loads(segs[3]))))
    logger.info(f"Loaded {len(samples)} samples from the {file_path}")
    return samples
