# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  
# SPDX-License-Identifier: CC-BY-NC-4.0
from typing import Optional
from da4er.gan.preprocessing import GANPreprocessor
import torch
import numpy as np
from transformers import BertTokenizer, BertModel, BartTokenizer
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)



class BARTGAN(GANPreprocessor):
    def __init__(self, dimension, hidden=768):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
        self.model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')
        self.encoder_embed = self.model.state_dict()['model.encoder.embed_tokens.weight'] # encoder embedding
        self.model.eval()
        self.dimension = dimension
        self.hidden = hidden

    def preprocess(self, txt: str) -> Optional[np.ndarray]:
        # Exclude first and last, meaning the start and end of sentence
        input_ids = torch.tensor([self.tokenizer.encode(txt, add_special_tokens=True)])[0][1:-1]

        if len(input_ids) > self.dimension: # Include data with tokens less than the considered tokens
            return None

        train_embed = np.zeros((self.hidden, self.dimension)) #

        for ii in range(len(input_ids)):
            train_embed[:,ii] = self.encoder_embed[input_ids[ii]]

        return train_embed
