import json
import logging
import re
import time
from tqdm import trange
# from tqdm.auto import trange

from elasticsearch import Elasticsearch
import torch

from transformers import PreTrainedTokenizer
from mdr.retrieval.models.retriever import RobertaCtxEncoder
from dpr.indexer.faiss_indexers import DenseIndexer

from utils.model_utils import get_device
from utils.tensor_utils import to_device

logger = logging.getLogger(__name__)

core_title_pattern = re.compile(r'([^()]+[^\s()])(?:\s*\(.+\))?')


def filter_core_title(x):
    return core_title_pattern.match(x).group(1) if core_title_pattern.match(x) else x


class SparseRetriever(object):
    def __init__(self, index_name='enwiki-20171001-paragraph-5', hosts=('10.60.0.59:9200',),
                 max_retries=3, timeout=30, **kwargs):
        self.index_name = index_name
        self.es = Elasticsearch(hosts, max_retries=max_retries, timeout=timeout, retry_on_timeout=True, **kwargs)

    @staticmethod
    def pack_query(query, must_not=None, filter_dic=None, offset=0, size=50):
        dsl = {
            "query": {
                "bool": {
                    "must": {
                        "multi_match": {
                            "query": query,
                            "fields": ["title^1.25", "title_unescaped^1.25", "text",
                                       "title.bigram^1.25", "title_unescaped.bigram^1.25", "text.bigram"]
                        }
                    }
                }
            },
            "from": offset,
            "size": size
        }
        if must_not is not None:
            dsl['query']['bool']['must_not'] = must_not
        if filter_dic:
            dsl['query']['bool']['filter'] = filter_dic  # {"term": {"for_hotpot": True}}
        return dsl

    def search(self, query, n_rerank=10, must_not=None, filter_dic=None, n_retrieval=50, **kwargs):
        n_retrieval = max(n_rerank, n_retrieval)
        dsl = self.pack_query(query, must_not, filter_dic, size=n_retrieval)
        hits = [hit for hit in self.es.search(dsl, self.index_name, **kwargs)['hits']['hits']]
        if n_rerank > 0:
            hits = self.rerank_with_query(query, hits)[:n_rerank]

        return hits

    def msearch(self, queries, n_rerank=10, must_not=None, filter_dic=None, n_retrieval=50, **kwargs):
        n_retrieval = max(n_rerank, n_retrieval)
        body = ["{}\n" + json.dumps(self.pack_query(q, must_not, filter_dic, size=n_retrieval)) for q in queries]
        responses = self.es.msearch('\n'.join(body), self.index_name, **kwargs)['responses']
        hits_list = [[hit for hit in r['hits']['hits']] for r in responses]
        if n_rerank > 0:
            hits_list = [self.rerank_with_query(query, hits)[:n_rerank] for query, hits in zip(queries, hits_list)]

        return hits_list

    @staticmethod
    def rerank_with_query(query, hits):
        def score_boost(hit, q):
            title = hit['_source']['title_unescaped']
            core_title = filter_core_title(title)
            q1 = q[4:] if q.startswith('The ') or q.startswith('the ') else q

            score = hit['_score']
            if title in [q, q1]:
                score *= 1.5
            elif title.lower() in [q.lower(), q1.lower()]:
                score *= 1.2
            elif title.lower() in q:
                score *= 1.1
            elif core_title in [q, q1]:
                score *= 1.2
            elif core_title.lower() in [q.lower(), q1.lower()]:
                score *= 1.1
            elif core_title.lower() in q.lower():
                score *= 1.05
            hit['_score'] = score

            return hit

        return sorted([score_boost(hit, query) for hit in hits], key=lambda hit: -hit['_score'])


class DenseRetriever(object):
    """Does passage retrieving over the provided index and question encoder"""

    def __init__(self, dense_indexer, dense_encoder, tokenizer):
        """

        Args:
            dense_indexer (DenseIndexer):
            dense_encoder (RobertaCtxEncoder):
            tokenizer (PreTrainedTokenizer):
        """
        self.dense_indexer = dense_indexer
        self.dense_encoder = dense_encoder
        self.tokenizer = tokenizer

    def encode_queries(self, queries, max_length=None, batch_size=None):
        """

        Args:
            queries (list): (N,)
            max_length (int):
            batch_size (int):

        Returns:
            np.array: (N, H)
        """
        total = len(queries)
        if batch_size is None or batch_size <= 0:
            batch_size = total
        vectors = []
        self.dense_encoder.eval()
        with torch.no_grad():
            for batch_start in trange(0, total, batch_size, disable=total <= batch_size):
                inputs = self.tokenizer.batch_encode_plus(queries[batch_start:batch_start + batch_size],
                                                          padding=True, truncation=True, max_length=max_length,
                                                          return_tensors="pt")
                inputs = to_device({"input_ids": inputs["input_ids"], "input_mask": inputs["attention_mask"]},
                                   get_device(self.dense_encoder))
                embeddings = self.dense_encoder(inputs)['embed']
                vectors.append(embeddings.cpu())
        vectors = torch.cat(vectors, dim=0).contiguous()
        logger.debug(f'Encoded queries into {vectors.shape} vectors')
        assert vectors.shape[0] == total

        return vectors.numpy()

    def msearch_(self, vectors, size=100):
        """Retrieve the best matching passages given the query vectors batch

        Args:
            vectors (np.array): (N, H)
            size (int):

        Returns:
            list[tuple[list[object], list[float]]]: (N, 2, size) list of (p_id, score)
        """
        t0 = time.time()
        hits_list = self.dense_indexer.search_knn(vectors, size)
        logger.debug(f'dense search time: {time.time() - t0}s')
        return hits_list

    def msearch(self, queries, size=100, max_length=None, batch_size=None):
        """

        Args:
            queries (list): (N,)
            size (int):
            max_length (int):
            batch_size (int):

        Returns:
            list[tuple[list[object], list[float]]]: (N, 2, size)
        """
        vectors = self.encode_queries(queries, max_length, batch_size)
        hits_list = self.msearch_(vectors, size)
        return hits_list

    def search(self, query, size=100, max_length=None):
        """

        Args:
            query ():
            size (int):
            max_length (int):

        Returns:
            tuple[list[object], list[float]]: (2, size) (ids, scores)
        """
        return self.msearch([query], size, max_length)[0]


if __name__ == "__main__":
    sparse_retriever = SparseRetriever('enwiki-20171001-paragraph-3.1', ['10.60.0.59:9200'], timeout=30)
    print([x['title'] for x in sparse_retriever.search("In which city did Mark Zuckerberg go to college?")])
    print([[y['title'] for y in x]
           for x in sparse_retriever.msearch(["In which city did Mark Zuckerberg go to college?"])])
