import torch
from transformers import AutoTokenizer, AutoModel
import os

from typing import Optional, Union, List, Dict, Tuple, Iterable, Callable, Any


class ContrieverScorer:
    """
    ContrieverScorer 用于计算检索器得分。
    """
    def __init__(self, retriever_ckpt_path, device=None, max_batch_size=400):
        """
        初始化函数，用于创建 Retriever 对象。
        
        Args:
            retriever_ckpt_path (str): 包含模型权重的路径。
            device (str, optional): 设备类型，可选值为 "cuda" 或 "cpu"。默认为 None。
            max_batch_size (int, optional): 最大批处理大小。默认为 400。
        
        Returns:
            None
        """
        query_encoder_path = os.path.join(retriever_ckpt_path, 'query_encoder')
        reference_encoder_path = os.path.join(retriever_ckpt_path, 'reference_encoder')
            
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
        self.query_encoder = AutoModel.from_pretrained(query_encoder_path)
        self.reference_encoder = AutoModel.from_pretrained(reference_encoder_path)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if not device else device
        self.query_encoder = self.query_encoder.to(self.device).eval()
        self.reference_encoder = self.reference_encoder.to(self.device).eval()
        assert max_batch_size > 0
        self.max_batch_size = max_batch_size

    def get_query_embeddings(self, sentences):
        """
        获取查询句子的词向量，使用BERT模型进行推理。
        
        Args:
            sentences (list[str]): 一个包含多个句子的列表，每个句子都是一个字符串。
        
        Returns:
            torch.FloatTensor: 返回一个形状为（batch_size, hidden_size）的torch.FloatTensor，其中hidden_size是BERT模型的隐藏层维度大小。
            该tensor代表了输入句子的平均词向量，其中被padding部分的词向量被设置为0。
        
        Raises:
            None
        
        """
        # Tokenization and Inference
        torch.cuda.empty_cache()
        with torch.no_grad():
            inputs = self.tokenizer(sentences, padding=True,
                                    truncation=True, return_tensors='pt')
            for key in inputs:
                inputs[key] = inputs[key].to(self.device)
            outputs = self.query_encoder(**inputs)
            # Mean Pool
            token_embeddings = outputs[0]
            mask = inputs["attention_mask"]
            token_embeddings = token_embeddings.masked_fill(
                ~mask[..., None].bool(), 0.)
            sentence_embeddings = token_embeddings.sum(
                dim=1) / mask.sum(dim=1)[..., None]
            return sentence_embeddings
    
    def get_embeddings(self, sentences):
        """
        获取句子的词向量，返回一个numpy数组，每个元素是一个句子的词向量。
        
        Args:
            sentences (list[str]): 一个列表，包含需要获取词向量的句子，每个句子是一个字符串。
        
        Returns:
            numpy.ndarray: 一个numpy数组，大小为（len(sentences), embedding_size），其中len(sentences)是句子的数量，embedding_size是词向量的维度。
            每个元素是一个句子的词向量。
        
        Raises:
            None
        
        """
        # Tokenization and Inference
        torch.cuda.empty_cache()
        with torch.no_grad():
            inputs = self.tokenizer(sentences, padding=True,
                                    truncation=True, return_tensors='pt')
            for key in inputs:
                inputs[key] = inputs[key].to(self.device)
            outputs = self.reference_encoder(**inputs)
            # Mean Pool
            token_embeddings = outputs[0]
            mask = inputs["attention_mask"]
            token_embeddings = token_embeddings.masked_fill(
                ~mask[..., None].bool(), 0.)
            sentence_embeddings = token_embeddings.sum(
                dim=1) / mask.sum(dim=1)[..., None]
            return sentence_embeddings

    def score_documents_on_query(self, query, documents):
        """
        计算文档的查询分数，返回一个列表，每个元素是一个float值，代表对应文档在查询上的得分。
        
        Args:
            query (str): 查询语句，类型为str。
            documents (List[str]): 一个包含多个文档的列表，每个元素都是一个str类型的字符串，代表一篇文档。
        
        Returns:
            List[float]: 一个列表，长度与输入的documents相同，每个元素是一个float值，代表对应文档在查询上的得分。
        
        Raises:
            None
        """
        query_embedding = self.get_query_embeddings([query])[0]
        document_embeddings = self.get_embeddings(documents)
        return document_embeddings.t()

    def select_topk(self, query, documents, k=1):
        """
        Returns:
            `ret`: `torch.return_types.topk`, use `ret.values` or `ret.indices` to get value or index tensor
        """
        scores = []
        for i in range((len(documents) + self.max_batch_size - 1) // self.max_batch_size):
            scores.append(self.score_documents_on_query(
                query, documents[self.max_batch_size * i:self.max_batch_size * (i + 1)]).to('cpu'))
        scores = torch.concat(scores)
        return scores.topk(min(k, len(scores)))


class ReferenceFilter:
    def __init__(self, retriever_ckpt_path, device=None, max_batch_size=400):
        """
        初始化ContrieverRetriever类，用于调用ContrieverScorer类进行文本检索。
        
        Args:
            retriever_ckpt_path (str): Contriever模型的checkpoint路径。
            device (Optional[str], optional): 可选参数，默认为None，表示使用CPU进行计算。如果想使用GPU，请设置为'cuda'。 Default to None.
            max_batch_size (int, optional): 可选参数，默认为400，表示一次最大处理的样本数量。如果样本数量过大，可能会导致内存不足，建议根据实际情况调整该值。 Default to 400.
        
        Returns:
            None
        
        Raises:
            ValueError: 当device参数非法时，会抛出ValueError异常。
        """
        self.scorer = ContrieverScorer(retriever_ckpt_path, device, max_batch_size)

    def produce_references(self, query, paragraphs, topk=5):
        """Individually calculate scores of each sentence, and return `topk`. paragraphs should be like a list of {title, url, text}."""
        # paragraphs = self._pre_filter(paragraphs)
        texts = [item['text'] for item in paragraphs]
        topk = self.scorer.select_topk(query, texts, topk)
        indices = list(topk.indices.detach().cpu().numpy())
        return [paragraphs[idx] for idx in indices]


