import numpy as np
from utils import file_utils


class MQuAKEDatasetHandler:
    def __init__(self, dataset_dir, dataset_name, after_editing=True):
        self.dataset_dir = dataset_dir
        self.dataset_name = dataset_name
        self.after_editing = after_editing

    def get_dataset(self):
        dataset = file_utils.read_json(f'{self.dataset_dir}/{self.dataset_name}.json')

        for d in dataset:
            for r in d["requested_rewrite"]:
                prompt_full = r["prompt"].format(r["subject"])
                target_new = r["target_new"]["str"]
                target_true = r["target_true"]["str"]
                r['prompt_full'] = prompt_full
                r['fact_new'] = f"{prompt_full} {target_new}."
                r['answer_true'] = target_true
                r['answer_new'] = target_new

        ############ dataset = dataset[:5]

        return dataset

    def get_eva_data(self, dataset, eva_mode, fact_type):
        func_dict = {
            'edit': self.get_edit_eva,
            'singlehop': self.get_singlehop,
            'multihop': self.get_multihop,
            'multihop-cot': self.get_multihop
        }

        func = func_dict[eva_mode]
        if eva_mode == 'edit':
            return func(dataset, fact_type)
        else:
            return func(dataset)

    def get_edit_eva(self, dataset_chunk, fact_type):
        questions = []
        answers = []
        facts_new = []
        prompts = []
        prompts_full = []
        relations = []
        targets = []
        subjects = []

        key_name_answer = 'answer_new' if self.after_editing else 'answer_true'
        if fact_type == 'Unstruct' or fact_type == 'Unstruct-triplets':
            keyname_fact = 'fact_new_uns'
        elif fact_type == 'Struct':
            keyname_fact = 'fact_new'

        for d in dataset_chunk:
            for r in d["requested_rewrite"]:
                questions.append(r['question'])
                facts_new.append(r[keyname_fact])
                answers.append(r[key_name_answer])
                targets.append(r[key_name_answer])
                prompts_full.append(r['prompt_full'])
                relations.append(r['relation_id'])
                prompts.append(r['prompt'])
                subjects.append(r['subject'])

        # only keep unique editing samples.
        _, unique_idx = np.unique(prompts_full, return_index=True)

        edit_data = {
            'answers': answers,
            'targets': targets,
            'questions': questions,
            'facts_new': facts_new,
            'prompts': prompts,
            'prompts_full': prompts_full,
            'relations': relations,
            'subjects': subjects,
        }

        for key in edit_data:
            edit_data[key] = [edit_data[key][i] for i in unique_idx]

        return edit_data

    def get_edit(self, dataset_chunk, fact_type):
        if fact_type == 'Struct' or fact_type == 'Unstruct':
            return self.get_edit_eva(dataset_chunk, fact_type)

        # ONLY for editing
        # if fact_type == 'Unstruct-triplets'

        # questions = []
        # answers = []
        # relations = []
        facts_new = []
        prompts = []
        prompts_full = []
        targets = []
        subjects = []
        # triplets = []

        for d in dataset_chunk:
            for r in d["requested_rewrite"]:
                # triplets.append(r['unsfact_triplets_GPT'])
                for triplet in r['unsfact_triplets_GPT']:
                    prompts.append(triplet['prompt'])
                    prompt_full = triplet['prompt'].format(triplet['subject'])
                    prompts_full.append(prompt_full)
                    targets.append(triplet['target'])
                    subjects.append(triplet['subject'])
                    facts_new.append(f"{prompt_full} {triplet['target']}.")

        # print(f"===>average number of triplets: {np.mean([len(x) for x in triplets]):.3f}")

        # only keep unique editing samples.
        _, unique_idx = np.unique(prompts_full, return_index=True)

        edit_data = {
            # 'answers': answers,
            # 'questions': questions,
            # 'relations': relations,
            'targets': targets,
            'facts_new': facts_new,
            'prompts': prompts,
            'prompts_full': prompts_full,
            'subjects': subjects,
        }

        for key in edit_data:
            array = np.asarray(edit_data[key])
            edit_data[key] = array[unique_idx].tolist()

        # print(f"===>average word len of triplets: {np.mean([len(x.split()) for x in facts_new]):.3f}")

        return edit_data

    def get_facts_new_bysample(self, dataset_chunk, fact_type):
        facts_new_bysample = []
        if fact_type == 'Unstruct':
            keyname_fact = 'fact_new_uns'
        elif fact_type == 'Struct':
            keyname_fact = 'fact_new'

        for d in dataset_chunk:
            facts_new_bysample.append([])
            for r in d["requested_rewrite"]:
                facts_new_bysample[-1].append(r[keyname_fact])

        return facts_new_bysample

    def get_singlehop(self, dataset_chunk):
        key_name_hop = 'new_single_hops' if self.after_editing else 'single_hops'

        questions = []
        answers = []
        targets = []
        prompts_full = []

        for d in dataset_chunk:
            for hop in d[key_name_hop]:
                questions.append(hop['question'])
                targets.append(hop['answer'])
                answers.append([hop['answer']] + hop['answer_alias'])
                prompts_full.append(hop['cloze'])

        eva_data = {
            'questions': questions,
            'answers': answers,
            'targets': targets,
            'prompts_full': targets,
        }
        return eva_data

    def get_multihop(self, dataset_chunk):
        key_name_answer = 'new_answer' if self.after_editing else 'answer'

        # each item has three questions.
        questions = []
        answers = []
        targets = []

        # each multihop sample has multiply similar questions.
        for d in dataset_chunk:
            questions.append(d['questions'])
            answers.append([d[key_name_answer]] + d[f'{key_name_answer}_alias'])
            targets.append(d[key_name_answer])
            # targets.extend([d[key_name_answer]] * len(d['questions']))

        eva_data = {
            'questions': questions,
            'answers': answers,
            'targets': targets
        }

        return eva_data

    # merge predicted answers into the dataset for save.
    def merge_answers_pred(self, dataset_save, answers_pred, eva_mode):
        func_dict = {
            'edit': self.merge_edit,
            'singlehop': self.merge_singlehop,
            'multihop': self.merge_multihop,
            'multihop-cot': self.merge_multihop
        }

        dataset_save = func_dict[eva_mode](answers_pred, dataset_save)

        return dataset_save

    def merge_edit(self, answers_pred, dataset_save):
        # ans_id = -1
        edit_data = self.get_edit(dataset_save, fact_type='Struct')
        facts_new = edit_data['facts_new']

        assert len(facts_new) == len(answers_pred)

        for d in dataset_save:
            for r in d["requested_rewrite"]:
                # ans_id += 1
                # r['answer_pred'] = answers_pred[ans_id]
                r['answer_pred'] = answers_pred[facts_new.index(r['fact_new'])]

        # assert (ans_id + 1) == len(answers_pred)
        return dataset_save

    def merge_singlehop(self, answers_pred, dataset_save):
        ans_id = -1
        key_name_hop = 'new_single_hops' if self.after_editing else 'single_hops'
        # log.logger.info(f"===>Warning: use {key_name_hop=}")

        for d in dataset_save:
            for hop in d[key_name_hop]:
                ans_id += 1
                hop['answer_pred'] = answers_pred[ans_id]

        assert (ans_id + 1) == len(answers_pred)
        return dataset_save

    def merge_multihop(self, answers_pred, dataset_save):
        ans_id = -1
        for d in dataset_save:
            ans_id += 1
            d['answer_pred'] = answers_pred[ans_id]

        assert (ans_id + 1) == len(answers_pred)
        return dataset_save
