import torch
from utils import file_utils


# before editing
class Base:
    def __init__(self, args):
        self.device = args.device

    def build_prompt(self, prompt_template, q):
        prompt = prompt_template.replace('[[QUESTION]]', q)
        return prompt

    def generate(self, model, tokenizer, input_texts, max_new_tokens):
        inputs = tokenizer(input_texts, return_tensors="pt", truncation=True, padding=True).to(self.device)

        with torch.no_grad():
            gen_tokens = model.generate(inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=max_new_tokens)
        gen_texts = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

        return gen_texts

    def predict_answer(self, model, tokenizer, input_texts, targets, max_new_tokens, use_cot):
        gen_texts = self.generate(model, tokenizer, input_texts, max_new_tokens)

        answers_pred = []
        for i, gen in enumerate(gen_texts):
            ans = gen.replace(input_texts[i], '').strip()

            if not use_cot:
                ans = ans.split('\n')[0]

            # use_cot is True.
            else:
                ans = ans.split('\n')
                if len(ans) >= 2 and 'Answer: ' in ans[1]:
                    ans = ans[1][len('Answer: '):]
                else:
                    ans = ans[-1]
                    if 'Answer: ' in ans:
                        ans = ans[len('Answer: '):]

            # cut string by the length of targets.
            answers_pred.append(ans.strip()[:len(targets[i])])

        print(f'===> {input_texts[0]}')
        print(f'===> {answers_pred[0]}')

        return answers_pred

    # don't break even if one answer is correct for faster inference.
    def answer_multihop(self, inputs, targets, prompt_template, max_new_tokens, use_cot):
        input_texts = []
        input_targets = []

        # inputs is list(list).
        for i, qs in enumerate(inputs):
            for q in qs:
                input_texts.append(self.build_prompt(prompt_template, q))
                input_targets.append(targets[i])

        answers_pred = self.predict_answer(self.model, self.tokenizer, input_texts, input_targets, max_new_tokens, use_cot)

        sizes = [len(qs) for qs in inputs]
        answers_pred = file_utils.chunk_by_size(answers_pred, sizes)

        return answers_pred

    def batch_test(self, eva_data, eva_mode, eva_params):
        use_cot = (eva_mode == 'multihop-cot')
        prompt_template = file_utils.read_texts(eva_params.prompt)
        max_new_tokens = eva_params.max_new_tokens

        targets = eva_data['targets']

        test_input_format = eva_params.test_input_format
        if test_input_format == 'cloze':
            inputs = eva_data['prompts_full']
        elif test_input_format == 'question':
            inputs = eva_data['questions']

        input_texts = []

        if eva_mode == 'edit':
            for item in inputs:
                input_texts.append(self.build_prompt(prompt_template, item))
            answers_pred = self.predict_answer(self.model, self.tokenizer, input_texts, targets, max_new_tokens, use_cot)

        # one multihop sample has multiply questions.
        elif eva_mode == 'multihop-cot' or eva_mode == 'multihop':
            answers_pred = self.answer_multihop(inputs, targets, prompt_template, max_new_tokens, use_cot)

        return answers_pred
