import torch
from torch import nn, Tensor
import torch.nn.functional as F
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import time
from transformers.trainer_callback import TrainerState
from transformers.optimization import Adafactor, AdamW, get_scheduler, get_linear_schedule_with_warmup

from transformers.file_utils import ModelOutput
from transformers.trainer_pt_utils import get_parameter_names
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)
import json

import data_loader
from data_loader import MAX_NUM_1, MAX_NUM_2, MAX_NUM_3
from dataclasses import dataclass
import math
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from mlm_inference import MLM_model_inferencer
import re
import numpy as np
import os

logger = logging.getLogger(__name__)


class GenerativeModelTrainer:
    def __init__(self, args, config, model, tokenizer,
                 mlm_model, mlm_tokenizer, banned_token_ids, generator_map2_mlm,
                 device, mlm_filter_indices, mlm_index_list, alpha=0.75):
        # super(Trainer, self).__init__()
        self.args = args
        self.config =  config
        # set_seed(self.args.seed)
        self.is_in_train = False
        # self.train_dataset = train_dataset
        # self.eval_dataset = eval_datset
        self.tokenizer = tokenizer
        self.model = model
        self.max_seq_len = args.max_seq_len
        self.max_input_seq_len = args.max_input_seq_len
        self.top_k = 100
        self.min_tokens_to_keep = 1

        self.bos_token_id = self.tokenizer.bos_token_id
        self.eos_token_id = self.tokenizer.eos_token_id
        self.unk_token_id = self.tokenizer.unk_token_id
        self.mask_token_id = self.tokenizer.mask_token_id
        self.pad_token_id = self.tokenizer.pad_token_id
        self.banned_token_ids = banned_token_ids
        self.generator_map2_mlm = generator_map2_mlm
        self.None_token_id = self.mask_token_id # self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(" None"))[0]

        # Maksed Language Model
        self.mlm_model = mlm_model
        self.mlm_tokenizer = mlm_tokenizer
        self.metric_evaluator = MLM_model_inferencer(self.mlm_model, self.mlm_tokenizer,
                                                     list(self.mlm_tokenizer.get_vocab().keys()), alpha)
        self.mlm_filter_indices = mlm_filter_indices
        self.mlm_index_list = mlm_index_list
        self.device = device

        self.common_g_ids = []
        self.common_mlm_ids = []
        for g_id in self.generator_map2_mlm:
            mlm_id = self.generator_map2_mlm[g_id]
            self.common_mlm_ids.append(mlm_id)
            self.common_g_ids.append(g_id)
        self.common_mlm_ids = torch.tensor(self.common_mlm_ids).long().to(self.device)
        self.common_g_ids = torch.tensor(self.common_g_ids).long().to(self.device)

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method (or :obj:`create_optimizer`
        and/or :obj:`create_scheduler`) in a subclass.
        """
        self.create_optimizer()
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)


    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        # if self.optimizer is None:
        decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
        # optimizer_cls = Adafactor if self.args.adafactor else AdamW
        # if self.args.adafactor:
        #     optimizer_cls = Adafactor
        #     optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
        # else:
        # optimizer_cls = AdamW
        optimizer_kwargs = {"betas": (self.args.adam_beta1, self.args.adam_beta2),
                            "eps": self.args.adam_epsilon,
                            "lr": self.args.learning_rate}
        self.optimizer = AdamW(optimizer_grouped_parameters, **optimizer_kwargs)


    def get_warmup_steps(self, num_training_steps: int):
        """
        Get number of steps used for a linear warmup.
        """
        warmup_steps = (
            self.args.warmup_steps if self.args.warmup_steps > 0 else math.ceil(num_training_steps * self.args.warmup_ratio)
        )
        return warmup_steps


    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
        passed as an argument.

        Args:
            num_training_steps (int): The number of training steps to do.
        """
        self.lr_scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                            self.get_warmup_steps(num_training_steps),
                                                            num_training_steps)


    def num_examples(self, dataloader) -> int:
        """
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.

        Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
        """
        total = 0
        for idx, dataset in enumerate(dataloader):
            total += len(dataset)
        return total


    def _get_input_tensors_batch_train(self, samples_list, wiki):
        tokens_tensors_list = []
        attention_tensors_list = []
        X_token_list = []
        X_attention_list = []
        mlm_label_ids = []
        mlm_X_list = []
        mlm_X_att_list = []
        predicate_id_list = []

        max_tokens = 0
        max_X_tokens = 0
        max_mlm_X = 0
        for sample in samples_list:
            sub_label = sample["sub_label"]
            obj_label = sample["obj_label"]
            masked_sentences = sample["masked_sentences"]
            relation_meta = sample["template"][0]
            predicate_id = sample["predicate_id"]
            predicate_id_list.append(predicate_id)
            input_id_tensor, input_att_tensor, X_id_tensor, X_att_tensor, \
            mlm_X_id_tensor, mlm_X_att_tensor, mlm_label_id = self.__get_BART_input_tensors(sub_label, obj_label, relation_meta, masked_sentences,
                                                                                                                       wiki=wiki)
            tokens_tensors_list.append(input_id_tensor)
            attention_tensors_list.append(input_att_tensor)
            X_token_list.append(X_id_tensor)
            X_attention_list.append(X_att_tensor)
            mlm_X_list.append(mlm_X_id_tensor)
            mlm_X_att_list.append(mlm_X_att_tensor)
            mlm_label_ids.append(mlm_label_id)
            if (input_id_tensor.shape[1] > max_tokens):
                max_tokens = input_id_tensor.shape[1]
            if (X_id_tensor.shape[1] > max_X_tokens):
                max_X_tokens = X_id_tensor.shape[1]
            if (mlm_X_id_tensor.shape[1] > max_mlm_X):
                max_mlm_X = mlm_X_id_tensor.shape[1]

        mlm_label_ids = torch.LongTensor(mlm_label_ids)
        # apply padding and concatenate tensors
        # use [PAD] for tokens and 0 for segments
        final_tokens_tensor = None
        final_attention_mask = None
        final_X_tensor = None
        final_X_attention = None
        final_mlm_X = None
        final_mlm_X_att = None
        for tokens_tensor, attention_tensor, X_tensor, X_attention, mlm_X, mlm_X_att in zip(tokens_tensors_list, attention_tensors_list, X_token_list, X_attention_list, mlm_X_list, mlm_X_att_list):
            dim_tensor = tokens_tensor.shape[1]
            pad_length = max_tokens - dim_tensor
            if pad_length > 0:
                pad_1 = torch.full([1, pad_length], self.pad_token_id, dtype=torch.long)
                pad_2 = torch.full([1, pad_length], 0, dtype=torch.long)
                tokens_tensor = torch.cat((tokens_tensor, pad_1), dim=1)
                attention_tensor = torch.cat((attention_tensor, pad_2), dim=1)
            X_dim = X_tensor.shape[1]
            X_pad_len = max_X_tokens- X_dim
            if X_pad_len > 0:
                pad_1 = torch.full([1, X_pad_len], self.pad_token_id, dtype=torch.long)
                pad_2 = torch.full([1, X_pad_len], 0, dtype=torch.long)
                X_tensor = torch.cat([X_tensor, pad_1], dim=1)
                X_attention = torch.cat([X_attention, pad_2], dim=1)
            mlm_X_dim = mlm_X.shape[1]
            mlm_X_pad_len = max_mlm_X- mlm_X_dim
            if mlm_X_pad_len > 0:
                pad_1 = torch.full([1, mlm_X_pad_len], self.mlm_tokenizer.pad_token_id, dtype=torch.long)
                pad_2 = torch.full([1, mlm_X_pad_len], 0, dtype=torch.long)
                mlm_X = torch.cat([mlm_X, pad_1], dim=1)
                mlm_X_att = torch.cat([mlm_X_att, pad_2], dim=1)

            if final_tokens_tensor is None:
                final_tokens_tensor = tokens_tensor
                final_attention_mask = attention_tensor
                final_X_tensor = X_tensor
                final_X_attention = X_attention
                final_mlm_X = mlm_X
                final_mlm_X_att = mlm_X_att
            else:
                final_tokens_tensor = torch.cat((final_tokens_tensor, tokens_tensor), dim=0)
                final_attention_mask = torch.cat((final_attention_mask, attention_tensor), dim=0)
                final_X_tensor = torch.cat([final_X_tensor, X_tensor], dim=0)
                final_X_attention = torch.cat([final_X_attention, X_attention], dim=0)
                final_mlm_X = torch.cat([final_mlm_X, mlm_X], dim=0)
                final_mlm_X_att = torch.cat([final_mlm_X_att, mlm_X_att], dim=0)
        return final_tokens_tensor, final_attention_mask, \
               final_X_tensor, final_X_attention, final_mlm_X, final_mlm_X_att, mlm_label_ids, predicate_id_list


    def __get_BART_input_tensors(self, subject, mlm_label, relation_meta, masked_sentences, wiki):
        tokenized_subject = self.tokenizer.tokenize(" "+subject)
        tokenized_input = [self.tokenizer.bos_token]+ tokenized_subject+ [self.tokenizer.eos_token]
        tokenized_input += self.tokenizer.tokenize(" "+ relation_meta)+ [self.tokenizer.eos_token]
        # tokenized_input = [self.tokenizer.bos_token]+ self.tokenizer.tokenize(" "+ relation_meta)+ [self.tokenizer.eos_token]
        if wiki:
            for sentence in masked_sentences:
                sentence.replace("[MASK]", self.tokenizer.mask_token)
                words = self.tokenizer.tokenize(" "+ sentence)
                if (len(tokenized_input)+len(words)) > self.max_input_seq_len:
                    break
                tokenized_input += words
            tokenized_input += [self.tokenizer.eos_token]

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_input)
        tokens_tensor = torch.tensor([indexed_tokens]).long()
        attention_mask = torch.ones(tokens_tensor.size()).long()

        indexed_subject = self.tokenizer.convert_tokens_to_ids(tokenized_subject)
        subject_tensor = torch.tensor([indexed_subject]).long()
        subject_attention = torch.ones(subject_tensor.size()).long()

        mlm_indexed_sub = self.mlm_tokenizer.convert_tokens_to_ids(self.mlm_tokenizer.tokenize(" "+subject))
        mlm_sub_tensor = torch.tensor([mlm_indexed_sub]).long()
        mlm_sub_att = torch.ones(mlm_sub_tensor.size()).long()

        label_id = self.mlm_tokenizer.convert_tokens_to_ids(self.mlm_tokenizer.tokenize(' '+mlm_label))
        assert (len(label_id) == 1)
        return tokens_tensor, attention_mask, subject_tensor, subject_attention, \
               mlm_sub_tensor, mlm_sub_att, label_id


    def train(self, train_dataloader, dev_dataloader_dict, wiki):
        args = self.args
        self.is_in_train = True

        total_train_batch_size = args.train_batch_size  #* args.gradient_accumulation_steps * args.world_size
        num_update_steps_per_epoch = len(train_dataloader) #// args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if args.max_steps > 0:
            max_steps = args.max_steps
            num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                args.max_steps % num_update_steps_per_epoch > 0
            )
            # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
            # the best we can do.
            # num_train_samples = args.max_steps * total_train_batch_size
        else:
            max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
            num_train_epochs = math.ceil(args.num_train_epochs)
            # num_train_samples = len(train_dataloader) * args.num_train_epochs

        delay_optimizer_creation = False
        self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        self.state = TrainerState()
        # if args.gradient_checkpointing:
        #     self.model.gradient_checkpointing_enable()

        # Train!
        num_examples = self.num_examples(train_dataloader)

        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  alpha = {args.alpha}")
        logger.info(f"  Num Epochs = {num_train_epochs}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
        logger.info(f"  Total optimization steps = {max_steps}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None

        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs

        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0)
        pred_loss = torch.tensor(0.0)
        info_loss = torch.tensor(0.0)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        self.model.zero_grad()
        best_result = {}
        for epoch in range(epochs_trained, num_train_epochs):
            epoch_iterator = train_dataloader
            for step, samples in enumerate(epoch_iterator):
                input_tokens_tensor, input_att_mask, X_token_tensor, X_att_mask, \
                mlm_X_tensor, mlm_X_att, mlm_label, predict_list = self._get_input_tensors_batch_train(samples, wiki)
                input_tokens_tensor = input_tokens_tensor.to(self.device)
                input_att_mask = input_att_mask.to(self.device)
                X_token_tensor = X_token_tensor.to(self.device)
                X_att_mask = X_att_mask.to(self.device)
                mlm_X_tensor = mlm_X_tensor.to(self.device)
                mlm_X_att = mlm_X_att.to(self.device)
                mlm_label = mlm_label.to(self.device)

                tr_loss_step, pred_loss_step, info_loss_step, decoder_ids = self.training_step(input_tokens_tensor, input_att_mask, X_token_tensor, X_att_mask,
                                                                                               mlm_X_tensor, mlm_X_att, mlm_label, predict_list)
                tr_loss += tr_loss_step
                pred_loss += pred_loss_step
                info_loss += info_loss_step
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.clip)
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()

                steps_trained_in_current_epoch += 1
                if steps_trained_in_current_epoch % self.args.logging_step == 0:
                    logging.info("%d, Loss: [%f, %f, %f]" % (steps_trained_in_current_epoch,
                                                             tr_loss/steps_trained_in_current_epoch,
                                                             pred_loss/steps_trained_in_current_epoch,
                                                             info_loss/steps_trained_in_current_epoch))
                    for idx in range(1):
                        seq = self.tokenizer.decode(list(decoder_ids[idx, :].numpy()))
                        logging.info("%s: %s" % (predict_list[idx], seq))

                if steps_trained_in_current_epoch % self.args.save_step == 0:
                    for key in dev_dataloader_dict:
                        precision, current_result, _ = self.evaluate(dev_dataloader_dict[key], wiki=wiki)
                        if not (key in best_result):
                            best_result[key] = 0.0
                        if precision > best_result[key]:
                            best_result[key] = precision
                            result_per = current_result
                            logger.info('!!! Best valid (epoch=%d) for %s: %.2f' % (epoch, key, precision * 100))
                            self.save_optiprompt(key)
                        else:
                            logger.info('(epoch=%d) for %s: %.2f' % (epoch, key, precision * 100))
        logger.info("Result Summarization: ")
        for key in best_result:
            logger.info("%s, %.2f" % (key, best_result[key] * 100))


    def training_step(self, input_ids, attention_mask, X_ids, X_att,
                      mlm_X_ids, mlm_X_att, label, predict_list) -> torch.Tensor:
        self.model.train()
        outputs, decoder_ids = self.compute_loss(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    X_ids=X_ids,
                                    X_att=X_att,
                                    mlm_X_ids=mlm_X_ids,
                                    mlm_X_att=mlm_X_att,
                                    label=label,
                                    decoder_start_token_id=self.eos_token_id,
                                    predict_list=predict_list,
                                    mode="training")
        decoder_ids = decoder_ids.cpu()
        loss, pred_loss, info_loss = outputs
        loss.backward()
        return loss.detach().cpu().item(), pred_loss.detach().cpu().item(), info_loss.detach().cpu().item(), decoder_ids


    def _prepare_decoder_input_ids_for_generation(self,
        input_ids: torch.LongTensor,
        decoder_start_token_id: int = None
    ):
        decoder_input_id = torch.zeros((input_ids.shape[0], 1, self.tokenizer.vocab_size), dtype=torch.float, device=input_ids.device)
        decoder_input_id[:, 0, decoder_start_token_id] = 1.0
        return decoder_input_id


    def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: torch.LongTensor, model_kwargs):
        # if "encoder_outputs" not in model_kwargs:
        encoder = self.model.get_encoder()
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
        }
        model_kwargs["encoder_outputs"] = encoder(input_ids, return_dict=True, **encoder_kwargs)
        return model_kwargs


    def prepare_inputs_for_generation(self,
                                      decoder_input_ids,
                                      past=None,
                                      attention_mask=None,
                                      head_mask=None,
                                      decoder_head_mask=None,
                                      cross_attn_head_mask=None,
                                      use_cache=None,
                                      encoder_outputs=None,
                                      **kwargs):
        return {"input_ids": None,  # encoder_outputs is defined. input_ids not needed
                "encoder_outputs": encoder_outputs,
                "past_key_values": past,
                "decoder_input_ids": decoder_input_ids,
                "attention_mask": attention_mask,
                "head_mask": head_mask,
                "decoder_head_mask": decoder_head_mask,
                "cross_attn_head_mask": cross_attn_head_mask,
                "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }


    def compute_loss(self,
                     input_ids, attention_mask, X_ids, X_att, mlm_X_ids, mlm_X_att,
                     predict_list=None,
                     decoder_start_token_id: Optional[int] = None,
                     use_cache: Optional[bool] = None,
                     output_attentions: Optional[bool] = None,
                     output_hidden_states: Optional[bool] = None,
                     label: Optional[torch.LongTensor] = None,
                     mode: str = "training",
                     **model_kwargs):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        model_kwargs["output_attentions"] = output_attentions
        model_kwargs["output_hidden_states"] = output_hidden_states
        model_kwargs["attention_mask"] = attention_mask

        # encoder_input_ids = input_ids if self.config.is_encoder_decoder else None

        # add encoder_outputs to model_kwargs
        model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)

        decoder_input_ids = self._prepare_decoder_input_ids_for_generation(
                    input_ids, decoder_start_token_id=decoder_start_token_id)

        # set model_kwargs
        model_kwargs["use_cache"] = use_cache
        # cur_len = decoder_input_ids.shape[-1]

        # Sample
        outputs = self.sample(decoder_input_ids=decoder_input_ids,
                              max_length = self.max_seq_len,
                              X_input_ids=X_ids,
                              X_att_mask=X_att,
                              Y_id=self.tokenizer.mask_token_id,
                              return_dict_in_generate=True,
                              eos_token_id=self.eos_token_id,
                              **model_kwargs)
        decoder_ids = outputs["sequences"]
        decoder_ids_scores = outputs["scores"]
        part1_len_list = outputs["part1_len"].view(-1) # (B, )
        part2_len_list = outputs["part2_len"].view(-1) # (B, )
        part3_len_list = outputs["part3_len"].view(-1) # (B, )
        x_len_list = outputs["X_len"].view(-1) # (B, )

        batch_size = decoder_ids.size(0)
        attention_mask = []
        masked_indices = []

        dummy_token_probs_tensor = []
        None_token_prob = torch.zeros([1, decoder_ids_scores.size(-1)]).long().to(self.device)
        None_token_prob[0, self.None_token_id] = 1.0
        Pad_token_prob = torch.zeros([1, decoder_ids_scores.size(-1)]).long().to(self.device)
        Pad_token_prob[0, self.pad_token_id] = 1.0
        mlm_X_len_tensor = torch.sum(mlm_X_att, dim=-1)

        max_mlm_seq_len = 0
        for i in range(batch_size):
            part1_len = part1_len_list[i].item()
            X_len = x_len_list[i].item()
            part2_len = part2_len_list[i].item()
            part3_len = part3_len_list[i].item()
            max_att_len = part1_len+ X_len+ part2_len+ 1 + part3_len
            remove_seq_idxs = set([part1_len- 1]+ list(range(part1_len, part1_len+ X_len))+ \
                              [part1_len+ X_len+ part2_len-1])
            chosen_seq_idxs = []
            for idx in range(max_att_len):
                if idx in remove_seq_idxs:
                    continue
                chosen_seq_idxs.append(idx)
            token_probs = F.softmax(decoder_ids_scores[i, chosen_seq_idxs, :], dim=-1)

            mlm_X_len = mlm_X_len_tensor[i].item()
            mlm_seq_len = token_probs.size(0)+ mlm_X_len
            if max_mlm_seq_len < mlm_seq_len:
                max_mlm_seq_len = mlm_seq_len
            dummy_None_tensor = None_token_prob.expand(mlm_X_len, -1).to(self.device)
            dummy_token_probs = torch.cat([token_probs[:part1_len-1, :], dummy_None_tensor, token_probs[part1_len-1:, :]], dim=0)
            dummy_token_probs_tensor.append(dummy_token_probs)
            attention = [1]* mlm_seq_len
            attention_mask.append(attention)
            mask_indice = part1_len-1+ mlm_X_len+ part2_len- 1
            masked_indices.append(mask_indice)

        for i in range(batch_size):
            dummy_token_probs = dummy_token_probs_tensor[i]
            current_seq_len = dummy_token_probs.size(0)
            if current_seq_len < max_mlm_seq_len:
                # PADDING
                pad_len = max_mlm_seq_len- current_seq_len
                pad_tensor = Pad_token_prob.expand(pad_len, -1).to(self.device)
                dummy_token_probs_tensor[i] = torch.cat([dummy_token_probs, pad_tensor], dim=0)
                attention_mask[i] += [0]* pad_len
            attention_mask[i] = torch.tensor([attention_mask[i]])
            dummy_token_probs_tensor[i] = dummy_token_probs_tensor[i].unsqueeze(0)
        attention_mask = torch.cat(attention_mask, dim=0).to(self.device)
        dummy_token_probs_tensor = torch.cat(dummy_token_probs_tensor, dim=0).to(self.device)

        # convert to MLM model token ids
        mlm_token_probs = []
        mlm_dummy_token_probs = torch.zeros([dummy_token_probs_tensor.size(0), dummy_token_probs_tensor.size(1), self.mlm_tokenizer.vocab_size]).to(self.device)
        # for g_token_id in self.generator_map2_mlm.keys():
        #     m_token_id = self.generator_map2_mlm[g_token_id]
        #     mlm_dummy_token_probs[:, :, m_token_id] = dummy_token_probs_tensor[:, :, g_token_id]
        mlm_dummy_token_probs[:, :, self.common_mlm_ids] = dummy_token_probs_tensor[:, :, self.common_g_ids]

        for i in range(batch_size):
            part1_len = part1_len_list[i].item()
            mlm_X_len = mlm_X_len_tensor[i].item()
            mlm_X_tensor = torch.zeros([mlm_X_len, self.mlm_tokenizer.vocab_size]).float().to(self.device)
            for idx, t_idx in enumerate(mlm_X_ids[i, :mlm_X_len]):
                mlm_X_tensor[idx, t_idx] = 1.0
            token_probs = torch.cat([mlm_dummy_token_probs[i, :part1_len-1, :], mlm_X_tensor, mlm_dummy_token_probs[i, (part1_len-1+ mlm_X_len):, :]], dim=0)
            mlm_token_probs.append(token_probs.unsqueeze(0))
        mlm_token_probs = torch.cat(mlm_token_probs, dim=0)

        results = self.metric_evaluator.mask_filling(token_probs=mlm_token_probs,
                                                     attention_mask=attention_mask,
                                                     label_ids=label,
                                                     masked_indices=masked_indices,
                                                     dummy_token_probs= mlm_dummy_token_probs,
                                                     dummy_att_mask=attention_mask,
                                                     dummy_masked_indices=masked_indices,
                                                     mode=mode,
                                                     filter_indices=self.mlm_filter_indices,
                                                     index_list=self.mlm_index_list,
                                                     predict_list=predict_list,
                                                     device=self.device)
        return results, decoder_ids


    def logits_process(self, input_ids, scores,
                       eos_token_id,
                       cur_part1_len, unfinished_part1_sequences,
                       unfinished_X_sequences,
                       cur_part2_len, unfinished_part2_sequences,
                       cur_part3_len, unfinished_part3_sequences,
                       unfinished_Y_sequences):
        cur_len = input_ids.size(1)
        batch_size = input_ids.size(0)
        num_tokens = scores.shape[1]
        scores[:, self.banned_token_ids] = -float("inf")
        # # MinLength Logits Process
        # if cur_len < self.min_length:
        #     scores[:, self.eos_token_id] = -float("inf")
        # Forced BOS Token Logits
        if cur_len == 1:
            scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
            scores[:, self.bos_token_id] = 0
            return scores
        tmp_list = []
        tmp_tensor1 = unfinished_part1_sequences* cur_part1_len
        tmp_tensor2 = (1- unfinished_X_sequences)* unfinished_part2_sequences* cur_part2_len
        tmp_tensor3 = (1-unfinished_Y_sequences)* unfinished_part3_sequences* cur_part3_len
        for idx in range(batch_size):
            if tmp_tensor1[idx] == MAX_NUM_1- 1:
                tmp_list.append(idx)
                continue
            elif tmp_tensor2[idx] == MAX_NUM_2- 1:
                tmp_list.append(idx)
                continue
            elif tmp_tensor3[idx] == MAX_NUM_3- 1:
                tmp_list.append(idx)
                continue
        scores[tmp_list, :]= -float("inf")
        scores[tmp_list, eos_token_id] = 0
        # # TopKLogitsWarper
        # top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1))  # Safety check
        # # Remove all tokens with a probability less than the last token of the top-k
        # indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        # scores = scores.masked_fill(indices_to_remove, -float("inf"))
        return scores


    def _update_model_kwargs_for_generation(self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values
        elif "mems" in outputs:
            model_kwargs["past"] = outputs.mems
        elif "past_buckets_states" in outputs:
            model_kwargs["past"] = outputs.past_buckets_states
        else:
            model_kwargs["past"] = None

        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

        # update attention mask
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )
        return model_kwargs


    def sample(self,
               decoder_input_ids,
               X_input_ids: torch.LongTensor,
               X_att_mask: torch.LongTensor,
               Y_id: int,
               eos_token_id: int,
               max_length: int,
               output_attentions: Optional[bool] = None,
               output_hidden_states: Optional[bool] = None,
               return_dict_in_generate: Optional[bool] = None,
               **model_kwargs,
               ):
        # Generates sequences for models with a language modeling head using multinomial sampling.
        scores = ()
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        # if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

        # keep track of which sequences are already finished
        copy_mechanism_X_cur_idx = torch.zeros(decoder_input_ids.size(0)).long().to(self.device)
        copy_mechanism_X_max_len = torch.sum(X_att_mask, dim=-1).to(self.device)
        unfinished_part1_sequences = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(1).to(self.device)
        unfinished_X_sequences = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(1).to(self.device)
        unfinished_part2_sequences = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(1).to(self.device)
        unfinished_Y_sequences = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(1).to(self.device)
        unfinished_part3_sequences = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(1).to(self.device)

        cur_part1_len = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(0).to(self.device)
        cur_part2_len = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(0).to(self.device)
        cur_part3_len = copy_mechanism_X_cur_idx.new(decoder_input_ids.shape[0]).fill_(0).to(self.device)

        cur_len = decoder_input_ids.size(1)
        seqs = []
        # auto-regressive generation
        while True:
            # prepare model inputs
            model_inputs = self.model.prepare_inputs_for_generation(decoder_input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self.model(**model_inputs,
                                 return_dict=True,
                                 output_attentions=output_attentions,
                                 output_hidden_states=output_hidden_states,)
            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = self.logits_process(decoder_input_ids, next_token_logits,
                                                    eos_token_id,
                                                    cur_part1_len, unfinished_part1_sequences,
                                                    unfinished_X_sequences,
                                                    cur_part2_len, unfinished_part2_sequences,
                                                    cur_part3_len, unfinished_part3_sequences,
                                                    unfinished_Y_sequences)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

            tmp_sample_X = (1- unfinished_part1_sequences)* unfinished_X_sequences
            tmp_sample_idx = torch.nonzero(tmp_sample_X == 1, as_tuple=False).view(-1)
            cur_X_ids = torch.gather(X_input_ids[tmp_sample_idx, :], 1, copy_mechanism_X_cur_idx[tmp_sample_idx].view(-1, 1)).view(-1)

            tmp_sample_Y = (1-unfinished_part2_sequences)* unfinished_Y_sequences

            # finished sentences should have their next token be a padding token
            tmp_sample_pad = (1-unfinished_part3_sequences)

            for idx, s_idx in enumerate(tmp_sample_idx):
                c_x_id = cur_X_ids[idx].item()
                next_token_scores[s_idx, :] = -float("inf")
                next_token_scores[s_idx, c_x_id] = 0

            tmp_sample_idx = torch.nonzero(tmp_sample_Y==1, as_tuple=False).view(-1)
            next_token_scores[tmp_sample_idx, :] = -float("inf")
            next_token_scores[tmp_sample_idx, Y_id] = 0

            tmp_sample_idx = torch.nonzero(tmp_sample_pad==1, as_tuple=False).view(-1)
            next_token_scores[tmp_sample_idx, :] = -float("inf")
            next_token_scores[tmp_sample_idx, self.pad_token_id] = 0

            # update generated ids, model inputs, and length for next step
            scores += (next_token_scores.unsqueeze(1),)
            next_probs = F.softmax(next_token_scores, dim=-1)
            decoder_input_ids: Tensor = torch.cat([decoder_input_ids, next_probs[:, None].detach()], dim=1)
            # decoder_input_ids: Tensor = torch.cat([decoder_input_ids, next_probs[:, None]], dim=1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            cur_len = cur_len + 1
            cur_part1_len += unfinished_part1_sequences
            copy_mechanism_X_cur_idx += (1 - unfinished_part1_sequences) * unfinished_X_sequences
            cur_part2_len += (1-unfinished_X_sequences)* unfinished_part2_sequences
            cur_part3_len += (1- unfinished_Y_sequences)* unfinished_part3_sequences

            # if eos_token was found in one sentence, set sentence to finished
            greedy_tokens = torch.argmax(next_probs, dim=-1)
            unfinished_part3_sequences = (1- ((1-unfinished_Y_sequences)* ((greedy_tokens == eos_token_id).long())))* unfinished_part3_sequences
            unfinished_Y_sequences = unfinished_part2_sequences
            unfinished_part2_sequences = (1- ((1- unfinished_X_sequences)*((greedy_tokens == eos_token_id).long())))* unfinished_part2_sequences
            unfinished_X_sequences = (copy_mechanism_X_cur_idx < copy_mechanism_X_max_len).long()
            unfinished_part1_sequences = unfinished_part1_sequences.mul((greedy_tokens != eos_token_id).long())

            seqs.append(greedy_tokens.unsqueeze(1))
            # stop when each sentence is finished, or if we exceed the maximum length
            if unfinished_part3_sequences.max() == 0 or cur_len == max_length:
                break
        scores = torch.cat(scores, dim=1).float()
        seqs = torch.cat(seqs, dim=1)
        return SampleEncoderDecoderOutput(
            sequences=seqs,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
            part1_len = cur_part1_len,
            part2_len = cur_part2_len,
            part3_len = cur_part3_len,
            X_len = copy_mechanism_X_max_len,
        )

    def save_optiprompt(self, key):
        if not os.path.exists(self.args.save_model_dir):
            os.makedirs(self.args.save_model_dir)
        if not os.path.exists(os.path.join(self.args.save_model_dir, key)):
            os.makedirs(os.path.join(self.args.save_model_dir, key))
        model_to_save = self.model.module if hasattr(self.model, "module") else self.model
        model_to_save.save_pretrained(os.path.join(self.args.save_model_dir, key))
        # Save training arguments together with the trained model
        torch.save(self.args, os.path.join(self.args.save_model_dir, key, "training_args.bin"))
        logging.info("Saving model checkpoint to %s" % os.path.join(self.args.save_model_dir, key))

    # @staticmethod
    # def load_optiprompt(model_dir):
    #     # Check whether model exists
    #     if not os.path.exists(model_dir):
    #         raise Exception("Model doesn't exists! Train first!")
    #     args = torch.load(os.path.join(model_dir, "training_args.bin"))
    #     generator_name = args.generative_model_dir
    #     generator_config = AutoConfig.from_pretrained(generator_name)
    #     generator_model = AutoModelForSeq2SeqLM.from_pretrained(generator_name, config=generator_config)
    #     model = generator_model.from_pretrained(model_dir)
    #     logging.info("***** Model Loaded *****")
    #     return model, args

    def evaluate(self, eval_dataloader, wiki, output_topk=None):
        cor_all = 0
        tot_all = 0
        result = {}
        list_of_predictions = {}
        eval_loss = 0.0
        generated_templates = []
        num_examples = self.num_examples(eval_dataloader)
        logger.info("***** Running evaluation *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Total evaluate batch size (w. parallel, distributed & accumulation) = {self.args.eval_batch_size}")

        self.model.eval()
        for step, samples in enumerate(eval_dataloader):
            input_tokens_tensor, input_att_mask, X_token_tensor, X_att_mask, \
            mlm_X_token_tensor, mlm_X_att_mask, mlm_label, predict_list = self._get_input_tensors_batch_train(samples, wiki)
            input_tokens_tensor = input_tokens_tensor.to(self.device)
            input_att_mask = input_att_mask.to(self.device)
            X_token_tensor = X_token_tensor.to(self.device)
            X_att_mask = X_att_mask.to(self.device)
            mlm_X_token_tensor = mlm_X_token_tensor.to(self.device)
            mlm_X_att_mask = mlm_X_att_mask.to(self.device)
            mlm_label = mlm_label.to(self.device)
            with torch.no_grad():
                outputs, decoder_ids = self.compute_loss(input_ids=input_tokens_tensor,
                                                         attention_mask=input_att_mask,
                                                         X_ids=X_token_tensor,
                                                         X_att=X_att_mask,
                                                         mlm_X_ids=mlm_X_token_tensor,
                                                         mlm_X_att=mlm_X_att_mask,
                                                         label=mlm_label,
                                                         predict_list=predict_list,
                                                         decoder_start_token_id=self.eos_token_id,
                                                         mode="eval")
                generated_templates.append(decoder_ids)
                loss, log_probs, cor_b, tot_b, pred_b, topk_preds, common_vocab_loss = outputs
                cor_all += cor_b
                tot_all += tot_b

                for pred, sample, topk, vocab_loss in zip(pred_b, samples, topk_preds, common_vocab_loss):
                    rel = sample['predicate_id']
                    if rel not in result:
                        result[rel] = (0, 0, 0, 0.0)
                        list_of_predictions[rel] = []
                    cor, tot, _, rel_tot_loss = result[rel]
                    tot += 1
                    cor += pred
                    rel_tot_loss += vocab_loss
                    result[rel] = (cor, tot, cor / tot if tot > 0 else 0.0, rel_tot_loss)
                    list_of_predictions[rel].append({
                        'uuid': sample['uuid'],
                        'relation': sample['predicate_id'],
                        'sub_label': sample['sub_label'],
                        'obj_label': sample['obj_label'],
                        # 'masked_sentences': sample['input_sentences'],
                        'topk': topk,
                    })

                eval_loss += loss.item() * tot_b

        if output_topk is not None:
            logger.info('Output top-k prediction to %s..' % output_topk)
            if not os.path.exists(output_topk):
                os.makedirs(output_topk)
            for rel in list_of_predictions:
                with open(os.path.join(output_topk, '%s_predictions.jsonl' % rel), 'w') as f:
                    f.write('\n'.join([json.dumps(x) for x in list_of_predictions[rel]]))
        micro, macro = data_loader.output_result(result, eval_loss)
        return micro, result, generated_templates


    def evaluate_dummy(self, eval_dataloader, wiki, output_topk=None):
        cor_all = 0
        tot_all = 0
        result = {}
        list_of_predictions = {}
        eval_loss = 0.0
        generated_templates = []
        num_examples = self.num_examples(eval_dataloader)
        logger.info("***** Running evaluation *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Total evaluate batch size (w. parallel, distributed & accumulation) = {self.args.eval_batch_size}")

        self.model.eval()
        for step, samples in enumerate(eval_dataloader):
            input_tokens_tensor, input_att_mask, X_token_tensor, X_att_mask, \
            mlm_X_token_tensor, mlm_X_att_mask, mlm_label, predict_list = self._get_input_tensors_batch_train(samples, wiki)
            input_tokens_tensor = input_tokens_tensor.to(self.device)
            input_att_mask = input_att_mask.to(self.device)
            X_token_tensor = X_token_tensor.to(self.device)
            X_att_mask = X_att_mask.to(self.device)
            mlm_X_token_tensor = mlm_X_token_tensor.to(self.device)
            mlm_X_att_mask = mlm_X_att_mask.to(self.device)
            mlm_label = mlm_label.to(self.device)
            with torch.no_grad():
                outputs, decoder_ids = self.compute_loss(input_ids=input_tokens_tensor,
                                                         attention_mask=input_att_mask,
                                                         X_ids=X_token_tensor,
                                                         X_att=X_att_mask,
                                                         mlm_X_ids=mlm_X_token_tensor,
                                                         mlm_X_att=mlm_X_att_mask,
                                                         label=mlm_label,
                                                         predict_list=predict_list,
                                                         decoder_start_token_id=self.eos_token_id,
                                                         mode="eval_dummy")
                generated_templates.append(decoder_ids)
                loss, log_probs, cor_b, tot_b, pred_b, topk_preds, common_vocab_loss = outputs
                cor_all += cor_b
                tot_all += tot_b

                for pred, sample, topk, vocab_loss in zip(pred_b, samples, topk_preds, common_vocab_loss):
                    rel = sample['predicate_id']
                    if rel not in result:
                        result[rel] = (0, 0, 0, 0.0)
                        list_of_predictions[rel] = []
                    cor, tot, _, rel_tot_loss = result[rel]
                    tot += 1
                    cor += pred
                    rel_tot_loss += vocab_loss
                    result[rel] = (cor, tot, cor / tot if tot > 0 else 0.0, rel_tot_loss)
                    list_of_predictions[rel].append({
                        'uuid': sample['uuid'],
                        'relation': sample['predicate_id'],
                        'sub_label': sample['sub_label'],
                        'obj_label': sample['obj_label'],
                        # 'masked_sentences': sample['input_sentences'],
                        'topk': topk,
                    })

                eval_loss += loss.item() * tot_b

        if output_topk is not None:
            logger.info('Output top-k prediction to %s..' % output_topk)
            if not os.path.exists(output_topk):
                os.makedirs(output_topk)
            for rel in list_of_predictions:
                with open(os.path.join(output_topk, '%s_predictions.jsonl' % rel), 'w') as f:
                    f.write('\n'.join([json.dumps(x) for x in list_of_predictions[rel]]))
        micro, macro = data_loader.output_result(result, eval_loss)
        return micro, result, generated_templates

@dataclass
class SampleEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
    the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)

    Args:
        sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
            shorter if all batches finished early due to the :obj:`eos_token_id`.
        scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor
            of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`).
        encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape
            :obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
        encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`.
        decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
            sequence_length)`.
        cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
        decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
    """

    sequences: torch.LongTensor = None
    scores: torch.FloatTensor = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

    part1_len: torch.Tensor = None
    part2_len: torch.Tensor = None
    part3_len: torch.Tensor = None
    X_len: torch.Tensor = None

