from typing import Tuple, Dict, List, Any

import torch

from allennlp.predictors import Predictor
from allennlp.data import Instance 
from allennlp.models import BasicClassifier
from allennlp.modules.token_embedders import PretrainedTransformerMismatchedEmbedder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.seq2vec_encoders import ClsPooler 
from allennlp.modules import FeedForward

def create_labeled_instances(predictor: Predictor, outputs: Dict[str, Any], training_instances: List[Instance], cuda: bool, task: str=None):
    """
    Given instances and the output of the model, create new instances
    with the model's predictions as labels. 
    """ 
    new_instances = []

    if task == "QA":
        for instance, span in zip(training_instances, outputs['best_span']):
            new_instance = predictor.predictions_to_labeled_instances(instance, span)[0]
            new_instances.append(new_instance)
    else: 
        probs = outputs["probs"].cpu().detach().numpy() if cuda else outputs["probs"].detach().numpy()
        for idx, instance in enumerate(training_instances):
            tmp = { "probs": probs[idx] }
            new_instances.append(predictor.predictions_to_labeled_instances(instance, tmp)[0])

    return new_instances

def compute_rank(grads: torch.FloatTensor, idx_set: set) -> List[int]:
    """
    Given a one-dimensional gradient tensor, compute the rank of gradients
    with indices specified in idx_set. 
    """
    temp = [(idx, torch.abs(grad)) for idx, grad in enumerate(grads)]
    temp.sort(key=lambda t: t[1], reverse=True)

    rank = [i for i, (idx, grad) in enumerate(temp) if idx in idx_set]

    return rank

def get_stop_ids(instance: Instance, stop_words: set, attack_target: str = None) -> List[int]:
    """
    Returns a list of the indices of all the stop words that occur 
    in the given instance. 
    """

    stop_ids = []
    if attack_target == "premise": 
        for j, token in enumerate(instance['tokens']):
            if token.text in stop_words:
                stop_ids.append(j)

            if token.text == '[SEP]':
                break 

    elif attack_target == "hypothesis":
        encountered_sep = False 
        for j, token in enumerate(instance['tokens']):
            if token.text in stop_words and encountered_sep:
                stop_ids.append(j)

            if token.text == '[SEP]':
                encountered_sep = True 
                
    else: 
        for j, token in enumerate(instance['tokens']):
            if token.text in stop_words:
                stop_ids.append(j)

    return stop_ids

def extract_premise(nli_input: [str]):
    """
    Given an NLI input to BERT, extract only the tokens 
    for the premise.  
    """
    tokens = []

    for token in nli_input:
        if token == "[SEP]":
            break 

        if token != "[CLS]":
            tokens.append(token)

    return tokens 

def extract_question(qa_input: [str]):
    """
    Given a QA input to BERT, extract only the tokens 
    for the question.  
    """
    return extract_premise(qa_input)

def extract_hypothesis(nli_input: [str]):
    """
    Given an NLI input to BERT, extract only the tokens 
    for the hypothesis.  
    """
    tokens = []

    encountered_sep = False 
    for token in nli_input:
        if token != "[SEP]" and encountered_sep:
            tokens.append(token)

        if token == "[SEP]":
            encountered_sep = True  

    return tokens

# #################
# TESTING FUNCTIONS
# #################

def test_extract_premise():
    input_1 = ["[CLS]", "I", "love", "food", "[SEP]", "He", "went", "home", ".", "[SEP]"]
    input_2 = ["[CLS]", "I", "[SEP]", "He", "[SEP]"]

    assert extract_premise(input_1) == ["I", "love", "food"]
    assert extract_premise(input_2) == ["I"]

def test_extract_hypothesis():
    input_1 = ["[CLS]", "I", "love", "food", "[SEP]", "He", "went", "home", ".", "[SEP]"]
    input_2 = ["[CLS]", "I", "[SEP]", "He", "[SEP]"]

    assert extract_hypothesis(input_1) == ["He", "went", "home", "."]
    assert extract_hypothesis(input_2) == ["He"]

if __name__ == "__main__":
    test_extract_hypothesis()
    test_extract_premise()

    print("All Tests Passed.")