import torch
from utils import model_utils

from editors.Base import Base
from editors.utils import nethook
from editors.memit.memit_main import apply_memit_to_model


class MEMIT(Base):
    def __init__(self, args):
        super().__init__(args)

        self.device = args.device

        self.train_input_format = args.train_input_format

        self.weights_copy = None

        self.apply_algo = apply_memit_to_model

        self.hparams = args.editor_hparams
        self.hparams.device = args.device

        self.model, self.tokenizer = model_utils.create_LM_tokenizer(self.hparams.model_name, args.device)

    def restore(self):
        # Restore original weights
        with torch.no_grad():
            for k, v in self.weights_copy.items():
                nethook.get_parameter(self.model, k)[...] = v.to(self.device)

    # def generate(self, model, tokenizer, input_texts, max_new_tokens):
    #     return generate_fast(model, tokenizer, input_texts)

    def edit(self, edit_data, fact_type):
        prompts = edit_data['prompts']
        targets = edit_data['targets']
        subjects = edit_data['subjects']

        assert fact_type != 'Unstruct'

        if self.weights_copy is not None:
            self.restore()

        # structured facts
        input_prompts = prompts
        input_targets = targets
        input_subjects = subjects

        requests = [{
            'prompt': p,
            'target_new': t,
            'subject': s
        }
        for p, t, s in zip(input_prompts, input_targets, input_subjects)
        ]

        self.model, self.weights_copy = self.apply_algo(
            self.model,
            self.tokenizer,
            requests,
            self.hparams,
            copy=False,
            return_orig_weights=True,
        )
