import math
import torch
import numpy as np
from sentence_transformers import InputExample
from textattack.search_methods import GreedyWordSwapWIR
from textattack.shared.attacked_text import AttackedText


class PCTAugmentor(GreedyWordSwapWIR):
    def __init__(
        self,
        model_wrapper,
        projector=None,
        unimportant_percentage=0.1,
        important_percentage=0.1,
        important_reduce_at_lease=0,
        max_seq_length=64,
        reduce_threshold=100,
        rank_method="gradient",
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):

        self.model_wrapper = model_wrapper
        self.device = device
        self.projector = projector
        self.unimportant_percentage = unimportant_percentage
        self.important_percentage = important_percentage
        self.important_reduce_at_lease = important_reduce_at_lease
        self.max_seq_length = max_seq_length
        self.reduce_threshold = reduce_threshold
        self.rank_method = rank_method

    def get_score(self, initial_texts, reverse=False, getting_reduce_num=False):

        if getting_reduce_num:
            rank_method = "gradient"
        else:
            rank_method = self.rank_method

        if rank_method == "gradient":
            victim_model = self.model_wrapper
            grad_output = victim_model.get_batch_grad_score(initial_texts)
            attacked_texts = [AttackedText(text) for text in initial_texts]
            max_word_length = max([i.num_words for i in attacked_texts])
            init_saliency_scores = np.zeros((len(initial_texts), max_word_length))

            word2token_mappings = [
                text.align_with_model_tokens(victim_model) for text in attacked_texts
            ]
            for index, (grad, attacked_text) in enumerate(
                zip(grad_output, attacked_texts)
            ):
                for i, word in enumerate(attacked_text.words):
                    try:
                        matched_tokens = word2token_mappings[index][i]
                        if not matched_tokens:
                            init_saliency_scores[index, i] = 0.0
                        else:
                            init_saliency_scores[index, i] = np.mean(
                                grad[matched_tokens], axis=0
                            )
                    except:
                        print("error mapping")

            init_saliency_scores_index_rank = []

        else:
            raise ValueError("Error rank_method")

        if reverse:
            init_saliency_scores = -init_saliency_scores
            init_saliency_scores_index_rank = init_saliency_scores.argsort()

        return init_saliency_scores, init_saliency_scores_index_rank

    def generate_new_attacked_text_with_delete(self, initial_text, indices):
        UNK_initial_text = initial_text.replace_words_at_indices(
            indices, ["DELETE"] * len(indices)
        )
        replace_UNK_initial_text = UNK_initial_text.tokenizer_input.replace(
            "DELETE", ""
        )
        replace_UNK_initial_text = replace_UNK_initial_text.replace("  ", " ")

        return replace_UNK_initial_text

    def get_reduced_text(self, initial_text, target_rubbish_index):
        new_attacked_texts = [
            self.generate_new_attacked_text_with_delete(
                initial_text, target_rubbish_index[:index]
            )
            for index in range(1, len(target_rubbish_index) + 1)
        ]

        return new_attacked_texts

    def softmax(self, x, axis=0):
        """Compute softmax values for each sets of scores in x."""
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum(axis=axis)

    def get_reduce_example(self, initial_text, percentage, score, index_rank, reverse=False,):
        if not isinstance(initial_text, AttackedText):
            initial_text = AttackedText(initial_text)
        len_text = len(initial_text.words)
        rubbish_word_len = int(math.ceil(len_text * percentage))

        if rubbish_word_len >= self.reduce_threshold:
            rubbish_word_len = self.reduce_threshold

        if reverse:
            init_saliency_scores = -score
            init_saliency_scores_index_rank = init_saliency_scores.argsort()
        else:
            init_saliency_scores, init_saliency_scores_index_rank = score, index_rank

        target_rubbish_index = init_saliency_scores_index_rank[:rubbish_word_len]

        texts = self.get_reduced_text(
            initial_text, target_rubbish_index
        )
        if len(texts) > 0:
            if isinstance(texts[0], AttackedText):
                texts = [i.text for i in texts]

        return texts

    def get_pair_example(self, initial_text, reduce_text_list, positive=True):
        if not isinstance(initial_text, AttackedText):
            initial_text = AttackedText(initial_text)
        samples = []
        original_text = initial_text.text
        for reduce_time, reduce_text in enumerate(reduce_text_list):
            if positive:
                bias_label = 1
            else:
                bias_label = 0
            if len(reduce_text.split(" ")) >= 5:
                if positive:
                    samples.append(
                        InputExample(
                            texts=[original_text, reduce_text], label=bias_label
                        )
                    )
                else:
                    samples.append(
                        InputExample(
                            texts=[original_text, reduce_text], label=bias_label
                        )
                    )

        return samples

    def get_reduce_positive_pair_example(self, initial_text, score, index_rank):
        reduce_text_list = self.get_reduce_example(
            initial_text, self.unimportant_percentage, score, index_rank, reverse=False
        )
        positive_samples = self.get_pair_example(
            initial_text, reduce_text_list, positive=True
        )

        return positive_samples

    def get_reduce_negative_pair_example(self, initial_text, score, index_rank):
        reduce_text_list = self.get_reduce_example(
            initial_text, self.important_percentage, score, index_rank, reverse=True
        )
        negative_samples = self.get_pair_example(
            initial_text, reduce_text_list, positive=False
        )

        if len(negative_samples) > self.important_reduce_at_lease:
            negative_samples_filter = negative_samples[self.important_reduce_at_lease :]
        else:
            negative_samples_filter = []

        return negative_samples_filter

    def get_pathology_example(self, initial_text, score, index_rank):
        pair_examples = []
        positive_example = self.get_reduce_positive_pair_example(
            initial_text, score, index_rank
        )
        negative_samples = self.get_reduce_negative_pair_example(
            initial_text, score, index_rank
        )
        pair_examples.extend(positive_example)
        pair_examples.extend(negative_samples)

        return pair_examples

    def build_NtXent_input_batch(self, text_list):

        scores, index_ranks = self.get_score(text_list)
        flat_example = []
        labels = []
        for index, text in enumerate(text_list):
            flat_example.append(text)
            labels.append(index)

            score = scores[index][np.nonzero(scores[index])]
            index_rank = score.argsort()
            pathology_example = self.get_pathology_example(text, score, index_rank)
            pathology_example_reduce_text_positive = [
                i.texts[1] for i in pathology_example if i.label == 1
            ]
            pathology_example_reduce_text_negative = [
                i.texts[1] for i in pathology_example if i.label == 0
            ]

            pathology_example_reduce_text_positive_length = len(
                pathology_example_reduce_text_positive
            )
            pathology_example_reduce_text_negative_length = len(
                pathology_example_reduce_text_negative
            )
            flat_example.extend(pathology_example_reduce_text_positive)
            flat_example.extend(pathology_example_reduce_text_negative)

            labels.extend([index] * pathology_example_reduce_text_positive_length)
            labels.extend(
                [10000 + index] * pathology_example_reduce_text_negative_length
            )

        return flat_example, labels

    def batch_to_device(self, batch, target_device):
        """
        send a pytorch batch to a device (CPU/GPU)
        """
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(target_device)
        return batch

    def contrastive_batch_encode(self, batch_texts, labels):
        try:
            tokenized = self.model_wrapper.tokenizer.batch_encode_plus(
                batch_texts,
                add_special_tokens=True,
                max_length=self.max_seq_length,
                padding=True,
                return_tensors="pt",
            )
            self.batch_to_device(tokenized, self.device)
        except:
            tokenized = self.model_wrapper.tokenizer.batch_encode_plus(batch_texts)
            tokenized = torch.tensor(tokenized).to(self.device)

        labels = torch.LongTensor(labels).to(self.device)

        return tokenized, labels

    def batch_project(self, encoded_flat_example, pick_method="cls"):
        if not isinstance(
            self.model_wrapper.model, LSTMForClassification
        ) and not isinstance(self.model_wrapper.model, WordCNNForClassification):
            encoded_flat_example = self.batch_to_device(
                encoded_flat_example, self.device
            )
        else:
            pass

        if not isinstance(
            self.model_wrapper.model, LSTMForClassification
        ) and not isinstance(self.model_wrapper.model, WordCNNForClassification):
            last_layer_hidden = self.model_wrapper.model(
                **encoded_flat_example, output_hidden_states=True
            )[1][-1]
            if self.projector is not None:
                last_layer_hidden = self.projector(last_layer_hidden)
            elif pick_method == "cls":
                last_layer_hidden = last_layer_hidden[:, 0]
            else:
                last_layer_hidden = torch.mean(last_layer_hidden, 1)
        else:
            _, last_layer_hidden = self.model_wrapper.model(
                encoded_flat_example, output_hidden_states=True
            )

        return last_layer_hidden
