import torch
from utils import model_utils

from editors.Base import Base
from editors.utils import nethook
from editors.ft_main import apply_ft_to_model



class FT(Base):
    def __init__(self, args):
        super().__init__(args)

        self.device = args.device

        self.train_input_format = args.train_input_format

        self.weights_copy = None

        self.apply_algo = apply_ft_to_model

        self.hparams = args.editor_hparams
        self.hparams.device = args.device

        self.model, self.tokenizer = model_utils.create_LM_tokenizer(self.hparams.model_name, args.device)

    def restore(self):
        # Restore original weights
        with torch.no_grad():
            for k, v in self.weights_copy.items():
                nethook.get_parameter(self.model, k)[...] = v.to(self.device)

    # def generate(self, model, tokenizer, input_texts, max_new_tokens):
    #     return generate_fast(model, tokenizer, input_texts)

    def edit(self, edit_data, fact_type):
        prompts_full = edit_data['prompts_full']
        targets = edit_data['targets']
        facts_new = edit_data['facts_new']

        if self.weights_copy is not None:
            self.restore()

        if fact_type == 'Struct' or fact_type == 'Unstruct-triplets':
            # # structured facts
            if self.train_input_format == 'cloze':
                input_prompts = prompts_full
                input_targets = targets

            elif self.train_input_format == 'QA':
                questions = edit_data['questions']
                input_prompts = []
                input_targets = []
                for q in questions:
                    input_prompts.append(f"Question: {q}")
                for t in targets:
                    input_targets.append(f"Answer: {t}")

        elif fact_type == 'Unstruct':
            # unstructured facts
            input_prompts = [" "] * len(facts_new)
            input_targets = facts_new

        requests = [{
            'prompt': p,
            'target_new': t
        }
        for p, t in zip(input_prompts, input_targets)
        ]

        self.model, self.weights_copy = self.apply_algo(
            self.model,
            self.tokenizer,
            requests,
            self.hparams,
            copy=False,
            return_orig_weights=True,
        )
