import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from utils import log
from utils import file_utils, model_utils
from editors.Base import Base


class Mello(Base):
    def __init__(self, args):
        super().__init__(args)
        self.model, self.tokenizer = model_utils.create_LM_tokenizer(args.model_name, self.device)

        self.contriever = AutoModel.from_pretrained(args.contriever_name).to(self.device)
        self.contriever_tokenizer = AutoTokenizer.from_pretrained(args.contriever_name)

    def edit(self, edit_data, fact_type):
        # #### Build a memory index which contains all the edits
        self.facts_new = edit_data['facts_new']
        # len(new_facts)} x 768
        self.new_facts_embs = self.get_sent_embeddings(self.facts_new, self.contriever, self.contriever_tokenizer)

        log.logger.info(f"===>facts_new size: {len(self.facts_new)}")

    #### Functions for retrieval models (Contriever)
    def mean_pooling(self, token_embeddings, mask):
        # ~ is to reverse bool values.
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

    def get_sent_embeddings(self, sents, contriever, tok, BSZ=32):
        all_embs = []
        for i in tqdm(range(0, len(sents), BSZ)):
            sent_batch = sents[i:i+BSZ]
            inputs = tok(sent_batch, padding=True, truncation=True, return_tensors='pt').to(self.device)
            with torch.no_grad():
                outputs = contriever(**inputs)
                embeddings = self.mean_pooling(outputs[0], inputs['attention_mask'])
            all_embs.append(embeddings.cpu())
        all_embs = torch.vstack(all_embs)
        return all_embs


    def retrieve_facts(self, query, fact_embs, contriever, tok, k=1):
        inputs = tok([query], padding=True, truncation=True, return_tensors='pt').to(self.device)
        with torch.no_grad():
            outputs = contriever(**inputs)
            query_emb = self.mean_pooling(outputs[0], inputs['attention_mask']).cpu()
        sim = (query_emb @ fact_embs.T)[0]
        knn = sim.topk(k, largest=True)

        fact_ids = knn.indices
        fact_sent = self.facts_new[fact_ids[0]]

        return fact_sent

    def batch_test(self, eva_data, eva_mode, eva_params):

        assert eva_mode == 'multihop-cot'

        prompt_template = file_utils.read_texts(eva_params.prompt)
        max_new_tokens = eva_params.max_new_tokens
        answers = eva_data['answers']
        questions = eva_data['questions']
        num_hops = eva_params.num_hops # default 4

        answers_pred = []

        # loop over samples.
        for i, sample_qs in tqdm(enumerate(questions)):
            # for each question
            ans_pred = []

            # loop over questions in a sample.
            for q in sample_qs:

                # found_ans = False
                ans = None
                prompt = prompt_template + "\n\nQustion: " + q + '\n'

                # print('\n' + '-' * 20)
                # print(f"New question: {q}")

                # loop over hops
                for _ in range(num_hops):

                    # prompt the model to generate a subquestion and a tentative answer
                    gen = self.generate(self.model, self.tokenizer, prompt, max_new_tokens)[0]
                    gen = gen.replace(prompt, '').strip()
                    gen = '\n'.join(gen.split('\n')[:4])

                    # print(f"gen:\n{gen}")

                    # find "Final answer: "
                    if 'Final answer: ' in gen:
                        # found_ans = True
                        for sent in gen.split('\n')[::-1]:
                            if sent.startswith('Final answer: '):
                                ans = sent[len("Final answer: "):]
                                break
                        # print("===>Break: find final answer")
                        break

                    # find "Retrieved fact: "
                    if "Retrieved fact: " in gen:
                        gen = gen[:gen.index("Retrieved fact: ")]
                        # print(f"truncated gen:\n{gen}")

                    # otherwise, extract the generated subquestion
                    if len(gen.strip().split('\n')) < 2:
                        # print("===>Break: gen less than 2 sentences")
                        break # failed case

                    subquestion = gen.strip().split('\n')[-2]
                    if not subquestion.startswith('Subquestion: '):
                        # print("===>Break: no subquestions.")
                        break # failed case
                    subquestion = subquestion[len("Subquestion: "):]

                    # retrieve an edited fact using the generated subquestion
                    fact_sent = self.retrieve_facts(subquestion, self.new_facts_embs, self.contriever, self.contriever_tokenizer)

                    # put the retrieved fact at the end of the prompt, the model self-checks if it contradicts
                    prompt = prompt + gen + 'Retrieved fact: ' + fact_sent + '.'

                    # print(f'Retrieved fact: {fact_sent}')

                log.logger.info(f"===>final_prompt {ans in answers[i]}: \n" + (prompt + gen).replace(prompt_template, '').strip() + '\n')

                ans_pred.append(ans)
                # Break if generate the correct answer. This is to save time.
                if ans in answers[i]:
                    break

            answers_pred.append(ans_pred)

        return answers_pred
