import torch
from sentence_transformers import SentenceTransformer, util

from editors.Base import Base
from utils import log
from utils import file_utils, model_utils


class IKE(Base):
    def __init__(self, args):
        super().__init__(args)
        self.sentence_model = SentenceTransformer(args.sentence_model_name).to(self.device)
        self.model, self.tokenizer = model_utils.create_LM_tokenizer(args.model_name, self.device)
        self.tokenizer.padding_side = 'left'

        if 'tok_eos2pad' in args and args.tok_eos2pad:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            log.logger.info('===>tok_eos2pad == True')

    def retrieve_facts(self, questions, sentences, embeddings, top_k):
        with torch.no_grad():
            query_embedding = self.sentence_model.encode(questions, convert_to_tensor=True, device=self.device, show_progress_bar=False)
            query_embedding = util.normalize_embeddings(query_embedding)

        hits = util.semantic_search(query_embedding, embeddings, score_function=util.dot_score, top_k=top_k)

        examples = [[sentences[item["corpus_id"]] for item in hit][::-1] for hit in hits]
        return examples

    def edit(self, edit_data, fact_type):
        facts_new = edit_data['facts_new']
        self.sentences = []
        for f in list(set(facts_new)):
            self.sentences.append(f"New Fact: {f}\n")

        with torch.no_grad():
            self.embeddings = self.sentence_model.encode(self.sentences, convert_to_tensor=True, device=self.device, show_progress_bar=False)
            self.embeddings = util.normalize_embeddings(self.embeddings)
        
        log.logger.info(f"===>facts_new size: {len(facts_new)}")

    def answer_multihop(self, questions, targets, prompt_template, max_new_tokens, use_cot, num_retrieved_facts):
        input_texts = []
        input_targets = []

        # questions is list(list).
        flatten_questions = [q for qs in questions for q in qs]
        # related_facts is list(list).
        # len(related_facts) is len(flatten_questions).
        related_facts = self.retrieve_facts(flatten_questions, self.sentences, self.embeddings, top_k=num_retrieved_facts)

        for i, qs in enumerate(questions):
            for j, q in enumerate(qs):
                text = prompt_template + f"{''.join(related_facts[i + j]).strip()}\nQuestion: {q}\nAnswer:"
                input_texts.append(text)
                input_targets.append(targets[i])

        answers_pred = self.predict_answer(self.model, self.tokenizer, input_texts, input_targets, max_new_tokens, use_cot)

        sizes = [len(qs) for qs in questions]
        answers_pred = file_utils.chunk_by_size(answers_pred, sizes)

        return answers_pred

    def batch_retrieve(self, eva_data, eva_mode, eva_params):
        questions = eva_data['questions']

        if eva_mode == 'edit':
            related_facts = self.retrieve_facts(questions, self.sentences, self.embeddings, top_k=eva_params.num_retrieved_facts)

        elif eva_mode == 'multihop-cot' or eva_mode == 'multihop':
            # questions is list(list).
            flatten_questions = [q for qs in questions for q in qs]
            # related_facts is list(list).
            # len(related_facts) is len(flatten_questions).
            related_facts = self.retrieve_facts(flatten_questions, self.sentences, self.embeddings, top_k=eva_params.num_retrieved_facts)

        for i, fs in enumerate(related_facts):
            for j, f in enumerate(fs):
                related_facts[i][j] = f[len('New Fact:'):].strip()

        return related_facts

    def batch_test(self, eva_data, eva_mode, eva_params):
        use_cot = (eva_mode == 'multihop-cot')
        prompt_template = file_utils.read_texts(eva_params.prompt)
        max_new_tokens = eva_params.max_new_tokens
        ## answers = eva_data['answers']
        targets = eva_data['targets']
        questions = eva_data['questions']

        if eva_mode == 'edit':
            input_texts = []

            related_facts = self.retrieve_facts(questions, self.sentences, self.embeddings, top_k=eva_params.num_retrieved_facts)

            for i, q in enumerate(questions):
                text = prompt_template + f"{''.join(related_facts[i]).strip()}\nQuestion: {q}\nAnswer:"
                input_texts.append(text)

            answers_pred = self.predict_answer(self.model, self.tokenizer, input_texts, targets, max_new_tokens, use_cot)

        elif eva_mode == 'multihop-cot' or eva_mode == 'multihop':
            answers_pred = self.answer_multihop(questions, targets, prompt_template, max_new_tokens, use_cot, eva_params.num_retrieved_facts)

        return answers_pred
