# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.  
# SPDX-License-Identifier: CC-BY-NC-4.0

import numpy as np
import tensorflow as tf
import random

from da4er.gan import GANAugmentation
from da4er.gan.preprocessing.transformers_preprocessor import BARTGAN as BARTGANPreprocessor


class BARTGAN(GANAugmentation):
    """
        Generate synthetic data using GAN trained with BART encoder
    """
    def __init__(self, preprocessor: BARTGANPreprocessor, generator, dimension, bart_score_low=0.3, bart_score_high=0.7):
        super().__init__(preprocessor, generator)
        self.preprocessor = preprocessor
        self.dim = dimension
        self.low = bart_score_low
        self.high = bart_score_high
        self.accept_prob = bart_score_high
        self.normalized_dictionary = self._get_normalized_dictionary()

    def _get_normalized_dictionary(self):
        candidates = self.preprocessor.encoder_embed
        candidates = tf.linalg.l2_normalize(candidates, axis=-1)
        return candidates.numpy()

    def augment(self, txt: str) -> str:
        tokenized_input = self.preprocessor.tokenizer.tokenize(txt)
        embedding = self.preprocessor.preprocess(txt)
        if embedding is None:
            return txt
        embedding_size = np.shape(embedding)[0]

        if len(np.where(embedding[0] == 0)[0]) != 0:  # Find the dimensions until considered tokens
            target_len = np.where(embedding[0] == 0)[0][0]
        else:
            target_len = self.dim

        input_data = np.reshape(embedding, [1, embedding_size, self.dim])
        noise = tf.random.normal([1, embedding_size, self.dim])  # Random noise
        input_syn = input_data + noise  # Input for generator

        tmp_data = self.generator(input_syn, training=False)  # Target's synthetic embedding
        # Average between synthetic and original data for proper augmentation
        tmp_data = (tf.reshape(tmp_data, shape=(embedding_size, self.dim)) + embedding) / 2
        # (tokens, embeddings)
        tmp_data = tf.transpose(tmp_data)
        # truncating only to relevant tokens
        tmp_data = tmp_data[:target_len, :]
        # Normalize each vector
        tmp_data = tf.linalg.l2_normalize(tmp_data, axis=-1)
        # Compute dot product
        similarity = tf.linalg.matmul(tmp_data, self.normalized_dictionary, transpose_b=True)
        # Sorting indices across tokens and truncating to up to 50 candidates
        indices = tf.argsort(similarity, axis=-1, direction='DESCENDING')[:, :50]
        results = similarity.numpy()
        indices = indices.numpy()
        final_shape = np.shape(indices)

        output_loc = []
        for k in range(final_shape[0]):
            for jj in range(final_shape[1]):
                token_index = indices[k, jj]
                score = results[k, token_index]

                # Define threshold for comparison
                if self.low < score < self.high:
                    output_loc.append((k, token_index, score))
        output_loc = sorted(output_loc, key=lambda x: x[2], reverse=True)

        output_txt = [[] for _ in range(self.dim)]  # Generate the possible combination
        for kk in range(len(output_loc)):
            token = self.preprocessor.tokenizer.decode(output_loc[kk][1], skip_special_tokens=True)
            if tokenized_input[output_loc[kk][0]] == token and len(token.strip()) > 0:
                continue
            output_txt[output_loc[kk][0]].append((token, output_loc[kk][2]))

        synthetic_text = []
        for piece_id in range(target_len):
            old_fragment: str = tokenized_input[piece_id]
            candidates = output_txt[piece_id]

            # Old tokens and the replacing tokens should be coherent
            filtered_candidates = []
            for candidate_tuple in candidates:
                candidate, score = candidate_tuple
                if (old_fragment.startswith(" ") or old_fragment.startswith("Ġ")) \
                        != (candidate.startswith(" ") or candidate.startswith("Ġ")):
                    continue
                if candidate.isupper() != old_fragment[0].isupper():
                    continue
                filtered_candidates.append(candidate)
            candidates = filtered_candidates

            if len(candidates) == 0 or random.random() > self.accept_prob:
                synthetic_text.append(tokenized_input[piece_id])
                continue

            chosen_candidate = None
            for candidate in candidates:
                if random.random() >= 0.5:
                    chosen_candidate = candidate
                    break
            if chosen_candidate is None:
                chosen_candidate = candidates[0]
            synthetic_text.append(chosen_candidate)
        synthetic_text = self.preprocessor.tokenizer.convert_tokens_to_string(synthetic_text)
        return synthetic_text
