from argparse import Namespace
from html import unescape
import json
import logging
import random
import re
import os
# os.environ['CLASSPATH'] = 'corenlp/*'
from tqdm.auto import tqdm

import numpy as np
import torch

import faiss
from transformers import AutoConfig, AutoTokenizer
from dpr.indexer.faiss_indexers import DenseHNSWFlatIndexer  # , DenseFlatIndexer
from drqa.reader import Predictor
from mdr.retrieval.models.retriever import RobertaCtxEncoder
from utils.model_utils import load_state

from retriever import DenseRetriever, SparseRetriever
from utils.data_utils import load_corpus_

faiss.omp_set_num_threads(16)


# noinspection PyUnresolvedReferences
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)

    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def set_global_logging_level(level=logging.ERROR, prefixes=("",)):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        level: desired level. Optional. Default is logging.ERROR
        prefixes: list of one or more str prefixes to match (e.g. ["transformers", "torch"]). Optional.
            Default is `[""]` to match all active loggers.
            The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{"|".join(prefixes)})')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)


set_seed(0)

samples = []
with open("data/hotpot-dev.tsv") as f:
    for line in f:
        q_id, question, answer, sp_facts = line.strip().split('\t')
        sp_facts = json.loads(sp_facts)
        samples.append((q_id, (question, answer, sp_facts)))
print(len(samples))

sparse_retriever = SparseRetriever('enwiki-20171001-paragraph-4', ['10.60.0.59:9200'], timeout=30)

qg1 = Predictor(model='ckpts/golden-retriever/hop1.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt',
                num_workers=-1)
qg1.cuda()
qg1.model.network.to(torch.device('cuda:0'))
qg2 = Predictor(model='ckpts/golden-retriever/hop2.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt',
                num_workers=-1)
qg2.cuda()
qg2.model.network.to(torch.device('cuda:0'))

args = Namespace(**{
    "model_name": "roberta-base",
    "model_path": "ckpts/mdr/qp_encoder.1.pt",
    "index_prefix_path": "data/index/mdr/hotpot-paragraph-mdr-1.hnsw",
    "index_buffer_size": 50000,
    "max_q_len": 70,
    "max_q_sp_len": 350
})

bert_config = AutoConfig.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
query_encoder = RobertaCtxEncoder(bert_config, args)
query_encoder = load_state(query_encoder, args.model_path, exact=False)
device = torch.device('cuda:1')
query_encoder.to(device)
# if torch.cuda.device_count() > 1:
#     query_encoder = torch.nn.DataParallel(query_encoder)
query_encoder.eval()

vector_size = bert_config.hidden_size
dense_indexer = DenseHNSWFlatIndexer(vector_size, args.index_buffer_size)
dense_indexer.deserialize_from(args.index_prefix_path)

dense_retriever = DenseRetriever(dense_indexer, query_encoder, tokenizer)

corpus, title2id = load_corpus_('data/corpus/hotpot-paragraph-4.tsv', for_hotpot=True, require_hyperlinks=True)

questions = [sample[1][0] for sample in samples]  # (N,)
questions = [q[:-1] if q.endswith('?') else q for q in questions]
hits_list1 = dense_retriever.msearch(questions, 1000, args.max_q_len, 32)  # (N, 2, 1000)

cached_queries = dict()
bm25_searches = dict()
mdr_searches = dict()
for q_idx, question in enumerate(questions):
    mdr_searches[question] = hits_list1[q_idx][0]

# set_global_logging_level(logging.WARNING, ["elasticsearch"])
faiss.omp_set_num_threads(1)

step_data_file = open('data/hotpot-step-dev.jsonl', 'w')

OBS_SIZE = 1
RET_SIZE = 500
# HOP_MAX_STEPS = (RET_SIZE + OBS_SIZE - 1) // OBS_SIZE
hotpot_filter = {"term": {"for_hotpot": True}}
examples = []
for q_idx, (q_id, qas) in enumerate(tqdm(samples)):
    question, answer, sp_facts = qas
    if len(sp_facts) < 2:
        print(f"less than 2 supporting facts: {q_id}")
    norm_sp_titles = set(unescape(t) for t in sp_facts.keys())
    sp_ids = list(title2id[t] for t in norm_sp_titles)

    hard_negs = set()
    state2action = dict()

    # ==================== initial ====================
    sp_ranks = {strategy: {sp_id: 2 * RET_SIZE for sp_id in sp_ids}
                for strategy in ["BM25", "BM25+Link", "MDR", "MDR+Link"]}

    # BM25
    if question not in cached_queries:
        cached_queries[question] = qg1.predict(question, question)[0][0]
    q1 = cached_queries[question]
    if q1 not in bm25_searches or len(bm25_searches[q1]) < RET_SIZE:
        bm25_searches[q1] = [
            hit['_id']
            for hit in sparse_retriever.search(q1, RET_SIZE, filter_dic=hotpot_filter, n_retrieval=RET_SIZE * 2)
        ]
    bm25_hits = bm25_searches[q1]
    for p_idx, p_id in enumerate(bm25_hits[:RET_SIZE]):
        # hyperlinks = {unescape(t): anchors for t, anchors in corpus[p_id]['hyperlinks'].items()}
        hyperlinks = corpus[p_id]['hyperlinks']

        if p_id in sp_ids:
            sp_ranks['BM25'][p_id] = min(p_idx, sp_ranks['BM25'][p_id])
            for hard_id in bm25_hits[max(0, p_idx - 5):p_idx + 6]:
                hard_negs.add(hard_id)
        elif len(norm_sp_titles & set(hyperlinks.keys())) > 0:
            hard_negs.add(p_id)

        for sp_title in norm_sp_titles & set(hyperlinks.keys()):
            sp_id = title2id[sp_title]
            sp_ranks['BM25+Link'][sp_id] = min(p_idx, sp_ranks['BM25+Link'][sp_id])

        if max(list(sp_ranks['BM25'].values()) + list(sp_ranks['BM25+Link'].values())) <= p_idx:
            break
    for hard_id in bm25_hits[:10]:
        hard_negs.add(hard_id)

    # MDR
    mdr_hits = mdr_searches[questions[q_idx]]
    for p_idx, p_id in enumerate(mdr_hits[:RET_SIZE]):
        # hyperlinks = {unescape(t): anchors for t, anchors in corpus[p_id]['hyperlinks'].items()}
        hyperlinks = corpus[p_id]['hyperlinks']

        if p_id in sp_ids:
            sp_ranks['MDR'][p_id] = min(p_idx, sp_ranks['MDR'][p_id])
            for hard_id in mdr_hits[max(0, p_idx - 5):p_idx + 6]:
                hard_negs.add(hard_id)
        elif len(norm_sp_titles & set(hyperlinks.keys())) > 0:
            hard_negs.add(p_id)

        for sp_title in norm_sp_titles & set(hyperlinks.keys()):
            sp_id = title2id[sp_title]
            sp_ranks['MDR+Link'][sp_id] = min(p_idx, sp_ranks['MDR+Link'][sp_id])

        if max(list(sp_ranks['MDR'].values()) + list(sp_ranks['MDR+Link'].values())) <= p_idx:
            break
    for hard_id in mdr_hits[:10]:
        hard_negs.add(hard_id)

    # get the greedy action
    if min(min(_sp_ranks.values()) for _sp_ranks in sp_ranks.values()) >= RET_SIZE:
        print(f"Unable recall SP1 in the first {RET_SIZE} retrieval results: {q_id}")
        state2action['initial'] = {"query": q1, "action": "ANSWER"}
    else:
        # calculate the number of step to get SPs
        easy_steps, hard_steps = {}, {}
        for strategy, _sp_ranks in sp_ranks.items():
            (easy_sp_id, easy_sp_rank), (hard_sp_id, hard_sp_rank) = sorted(_sp_ranks.items(), key=lambda x: x[1])
            easy_steps[strategy] = (easy_sp_rank + OBS_SIZE) // OBS_SIZE
            hard_steps[strategy] = (hard_sp_rank + OBS_SIZE) // OBS_SIZE
            if strategy.endswith('+Link'):
                easy_steps[strategy] += 1
                hard_steps[strategy] += 1
        # find the fastest strategy
        strategy = min(easy_steps.keys(), key=lambda k: (easy_steps[k], hard_steps[k]))
        state2action['initial'] = {"query": q1, "action": "BM25" if strategy.startswith('BM25') else "MDR"}

    # ==================== partial ====================
    for sp1_id, sp2_id in [sp_ids, sp_ids[::-1]]:
        sp1 = corpus[sp1_id]
        norm_sp1_title = unescape(sp1['title'])
        sp2 = corpus[sp2_id]
        norm_sp2_title = unescape(sp2['title'])
        sp2_ranks = {strategy: 2 * RET_SIZE for strategy in ["BM25", "BM25+Link", "MDR", "MDR+Link"]}

        # BM25
        obs = ' '.join([question, f"<t> {sp1['title']} </t> {sp1['text']}"])
        if obs not in cached_queries:
            cached_queries[obs] = qg2.predict(obs, question)[0][0]
        q2 = cached_queries[obs]
        if q2 not in bm25_searches or len(bm25_searches[q2]) < RET_SIZE:
            bm25_searches[q2] = [
                hit['_id']
                for hit in sparse_retriever.search(q2, RET_SIZE, filter_dic=hotpot_filter, n_retrieval=RET_SIZE * 2)
            ]
        bm25_hits = bm25_searches[q2]
        for p_idx, p_id in enumerate(bm25_hits[:RET_SIZE]):
            # hyperlinks = {unescape(t): anchors for t, anchors in corpus[p_id]['hyperlinks'].items()}
            hyperlinks = corpus[p_id]['hyperlinks']

            if p_id == sp2_id:
                sp2_ranks['BM25'] = min(p_idx, sp2_ranks['BM25'])
                for hard_id in bm25_hits[max(0, p_idx - 5):p_idx + 6]:
                    hard_negs.add(hard_id)

            if norm_sp2_title in hyperlinks.keys():
                hard_negs.add(p_id)
                sp2_ranks['BM25+Link'] = min(p_idx, sp2_ranks['BM25+Link'])

            if max(sp2_ranks['BM25'], sp2_ranks['BM25+Link']) <= p_idx:
                break
        for hard_id in bm25_hits[:10]:
            hard_negs.add(hard_id)

        # MDR
        expanded_query = (questions[q_idx], sp1['text'] if sp1['text'] else sp1['title'])
        if expanded_query not in mdr_searches or len(mdr_searches[expanded_query]) < RET_SIZE:
            mdr_searches[expanded_query] = dense_retriever.search(expanded_query, max(RET_SIZE, 1000),
                                                                  args.max_q_sp_len)[0]
        mdr_hits = mdr_searches[expanded_query]
        for p_idx, p_id in enumerate(mdr_hits[:RET_SIZE]):
            # hyperlinks = {unescape(t): anchors for t, anchors in corpus[p_id]['hyperlinks'].items()}
            hyperlinks = corpus[p_id]['hyperlinks']

            if p_id == sp2_id:
                sp2_ranks['MDR'] = min(p_idx, sp2_ranks['BM25'])
                for hard_id in mdr_hits[max(0, p_idx - 5):p_idx + 6]:
                    hard_negs.add(hard_id)

            if norm_sp2_title in hyperlinks.keys():
                hard_negs.add(p_id)
                sp2_ranks['MDR+Link'] = min(p_idx, sp2_ranks['BM25+Link'])

            if max(sp2_ranks['MDR'], sp2_ranks['MDR+Link']) <= p_idx:
                break
        for hard_id in mdr_hits[:10]:
            hard_negs.add(hard_id)

        # get the greedy action
        if norm_sp2_title in set(unescape(t) for t in sp1['hyperlinks'].keys()):  # hyperlink
            state2action[norm_sp1_title] = {"query": q2, "action": "LINK"}
        else:
            if min(sp2_ranks.values()) >= RET_SIZE:
                print(f"{q_id}: Unable recall SP2({repr(norm_sp2_title)}) "
                      f"in the first {RET_SIZE} retrieval(Q + {repr(norm_sp1_title)}) results")
                state2action[norm_sp1_title] = {"query": q2, "action": "ANSWER"}
            else:
                # calculate the number of step to get SP2
                sp2_steps = dict()
                for strategy, sp2_rank in sp2_ranks.items():
                    sp2_steps[strategy] = (sp2_rank + OBS_SIZE) // OBS_SIZE
                    if strategy.endswith('+Link'):
                        sp2_steps[strategy] += 1
                strategy = min(sp2_steps.keys(), key=lambda k: sp2_steps[k])
                state2action[norm_sp1_title] = {"query": q2, "action": "BM25" if strategy.startswith('BM25') else "MDR"}

    # ==================== full ====================
    # state2action['full'] = {"query": None, "action": "ANSWER"}

    hard_negs = hard_negs - set(sp_ids)
    example = {
        "_id": q_id,
        "question": question,
        "answer": answer,
        "sp_facts": {unescape(t): sents for t, sents in sp_facts.items()},
        "hard_negs": list(hard_negs),  # in-neighbors, top ranked passages
        "state2action": state2action
    }
    examples.append(example)
    step_data_file.write(json.dumps(example, ensure_ascii=False) + '\n')

step_data_file.close()
