from utils import file_utils


class CounterfactDatasetHandler:
    def __init__(self, dataset_dir, dataset_name, after_editing):
        self.dataset_dir = dataset_dir
        self.dataset_name = dataset_name
        self.after_editing = after_editing

    def get_dataset(self):
        dataset = file_utils.read_json(f'{self.dataset_dir}/{self.dataset_name}.json')

        for d in dataset:
            r = d['requested_rewrite']
            prompt_full = r['prompt'].format(r['subject'])
            r['prompt_full'] = prompt_full

        return dataset

    def get_eva_data(self, dataset_chunk, eva_mode, fact_type):
        return self.get_edit_eva(dataset_chunk, fact_type)

    def get_edit_eva(self, dataset_chunk, fact_type):
        questions = []
        answers = []
        targets = []
        facts_new = []
        prompts_full = []
        prompts = []
        subjects = []

        keyname_answer = 'answer_new' if self.after_editing else 'answer_true'

        if fact_type == 'Unstruct' or fact_type == 'Unstruct-triplets':
            keyname_fact = 'fact_new_uns'
        elif fact_type == 'Struct':
            keyname_fact = 'fact_new'

        keyname_answer_alias = f'{keyname_answer}_alias'

        for d in dataset_chunk:
            r = d['requested_rewrite']
            questions.append(r['question'])
            targets.append(r[keyname_answer])

            # counterfact doesn't have answer aliases.
            if keyname_answer_alias not in r:
                answers.append(r[keyname_answer])
            # WikiUpdate has answer aliases.
            else:
                answers.append([r[keyname_answer]] + r[keyname_answer_alias])

            facts_new.append(r[keyname_fact])
            prompts_full.append(r['prompt_full'])
            prompts.append(r['prompt'])
            subjects.append(r['subject'])

        edit_data = {
            'targets': targets,
            'answers': answers,
            'questions': questions,
            'facts_new': facts_new,
            'prompts_full': prompts_full,
            'prompts': prompts,
            'subjects': subjects
        }

        return edit_data

    def get_edit(self, dataset_chunk, fact_type):
        if fact_type == 'Struct' or fact_type == 'Unstruct':
            return self.get_edit_eva(dataset_chunk, fact_type)

        targets = []
        facts_new = []
        prompts_full = []
        prompts = []
        subjects = []

        for d in dataset_chunk:
            r = d['requested_rewrite']
            for triplet in r['unsfact_triplets_GPT']:
                prompts.append(triplet['prompt'])
                prompt_full = triplet['prompt'].format(triplet['subject'])
                prompts_full.append(prompt_full)
                targets.append(triplet['target'])
                subjects.append(triplet['subject'])
                facts_new.append(f"{prompt_full} {triplet['target']}.")

        edit_data = {
            'targets': targets,
            'facts_new': facts_new,
            'prompts_full': prompts_full,
            'prompts': prompts,
            'subjects': subjects,
        }

        return edit_data

    def merge_answers_pred(self, dataset_save, answers_pred, eva_mode):
        assert len(dataset_save) == len(answers_pred)

        for i, d in enumerate(dataset_save):
            r = d['requested_rewrite']
            r['answer_pred'] = answers_pred[i]

        return dataset_save
