import os
import torch
import pandas as pd
import numpy as np
from abc import ABC, abstractmethod
from collections import namedtuple
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from torch.nn import CosineSimilarity
from tqdm import tqdm
from nlp_safety_prop.models.representations import Representer
from nlp_safety_prop.utils.paths import get_data_path
from nlp_safety_prop.models.embeddings.roberta import RobertaModelPreTrained
from torch.utils.data import TensorDataset, DataLoader


class ConsistencyRule(ABC):
    """
    A structure for testing rule-generated hypotheses
    """

    @abstractmethod
    def check_condition(self, data: pd.DataFrame):
        pass

    @abstractmethod
    def generate_hypotheses(self, data: pd.DataFrame):
        pass


    @abstractmethod
    def test(self, hypotheses, data_left, data_right, representer, hypothesized_labeler):
        pass


class Labeler(ABC):
    @abstractmethod
    def label(self, representations):
        pass


def check_consistency(triple_data: pd.DataFrame,  # 
                      representer: Representer, 
                      hypothesized_labeler, # function:  -> labels (hypothesized model behaviour),
                      rule: ConsistencyRule, 
                      ):


    #TODO: Speedup by removing unecessary rows before condition checking? 

    data_left = triple_data[["word_1", "word_2"]]
    data_right = triple_data[["word_2", "word_3"]].copy()
    data_right = data_right.rename(columns={"word_2": "word_1", "word_3": "word_2"})
    representations_left = representer.get_representations(data_left)
    representations_right = representer.get_representations(data_right)
    # batch predict hypothesized model labels

    data_left['model_label'] = hypothesized_labeler.label(representations_left)
    data_right['model_label'] = hypothesized_labeler.label(representations_right)

    triple_data["model_label_left"] = data_left["model_label"]
    triple_data["model_label_right"] = data_right["model_label"]

    # # generate hyptheses by checking rule condition
    hypotheses = rule.generate_hypotheses(triple_data)

    results = rule.test(hypotheses, data_left, data_right, representer, hypothesized_labeler) 
    return results


class WordRelationTransitivity(ConsistencyRule):
    def __init__(self) -> None:
        super().__init__()

    def generate_hypotheses(self, data: pd.DataFrame):
        """
        data: pd.DataFrame with columns titled
                'word_1', 'word_2', 'word_3', 'model_label_left', 'model_label_right'
        """

        print('Generating Hypothses from Conditions: \n')
        row_indices = data.index.to_flat_index
        rows = list(data.itertuples(index=False))
        hypotheses = []

        for triple_row in rows:
            row1 = {"word_1":triple_row.word_1, "word_2": triple_row.word_2, "model_label": triple_row.model_label_left}
            row2 = {"word_1":triple_row.word_2, "word_2": triple_row.word_3, "model_label": triple_row.model_label_right}
            row_pair = (row1, row2)

            hypothesis = self.check_condition(row_pair)
            if hypothesis:
                hypotheses.append(hypothesis)

        return hypotheses

    def check_condition(self, row_pair):

        row1, row2 = row_pair
        a, b = row1["word_1"], row1["word_2"]
        # a, b = row1.word_1, row1.word_2
        c, d = row2["word_1"], row2["word_2"]
        # c, d = row2.word_1, row2.word_2

        Hypothesis = namedtuple("Hypothesis", ["word_1", "word_2", "intermediate", "expected_label"])
        # check word overlap cases
        if (row1["model_label"]!="RANDOM") and (row2["model_label"]!="RANDOM"):
            if (b==c and a!=d and a!=b and c!=d):
                if row1["model_label"]==row2["model_label"]:
                    hypothesis = Hypothesis(word_1=a, 
                                            word_2=d, 
                                            intermediate=b,
                                            expected_label= row1["model_label"])
                else: 
                    hypothesis=None
            else:
                hypothesis=None
        else:
            hypothesis=None
        
        return hypothesis

    def test(self, hypotheses, data_left, data_right, representer, hypothesized_labeler):

        hypotheses = pd.DataFrame.from_records(hypotheses, columns=['word_1', 'word_2', 'intermediate', 'expected_label'])
         # check if hypothesis in original data, inherit labels
        hypotheses = pd.merge(hypotheses, data_left, how='left', on=['word_1', 'word_2'])
        #TODO Why so many duplicates?
        hypotheses = hypotheses.drop_duplicates()
        # unchecked_hypotheses = hypotheses.loc[hypotheses.model_label.isna()].copy()
        # unchecked_representations  = representer.get_representations(unchecked_hypotheses)
        # unchecked_hypotheses['model_label'] = hypothesized_labeler.label(unchecked_representations)

        # hypotheses.loc[hypotheses.model_label.isna(), 'model_label'] = unchecked_hypotheses.model_label

        results = {
            "hyp_acc": accuracy_score(hypotheses.expected_label, hypotheses.model_label),
            "num_hypotheses": len(hypotheses),
            "hypotheses_dataframe": hypotheses
        }

        return results


class TrainedRobertaClassifier(Labeler):
    def __init__(self, model_path: str) -> None:
        super().__init__()
        self.model_path = get_data_path(model_path)
        self.device = torch.device('cuda')
        self.roberta = RobertaModelPreTrained(self.model_path, model_args={"num_labels": 4})


        print('Loaded Classification Model From: ', model_path)
        self.xlmr_tokenizer = self.roberta.tokenizer
        self.xlmr_model = self.roberta.model
        self.desc = self.xlmr_model.to(self.device)
        self.labels = ["ANT", "HYP", "RANDOM", "SYN"]
        self.labels_dict = dict(enumerate(self.labels))
        self.le = preprocessing.LabelEncoder()
        self.le.fit(self.labels)

    def label(self, representations: DataLoader):
        # put model in evaluation mode
        self.xlmr_model.eval()

        # variable for loss, predictions and labels
        predictions = []

        # Evaluate data for one epoch
        for batch in tqdm(representations):
            # Unpack training batch and copy the tensors to the gpu
            b_input_ids = batch[0].to(self.device)
            b_input_mask = batch[1].to(self.device)

            # no backprop needed
            with torch.no_grad():
                # forward pass
                output = self.xlmr_model(b_input_ids,
                                         token_type_ids=None,
                                         attention_mask=b_input_mask)
                logits = output.logits

            # on cpu
            logits = logits.detach().cpu().numpy()

            # save preds/true labels
            predictions.append(logits)

        # results of the whole validation set
        flat_predictions = np.concatenate(predictions, axis=0)

        # logit to label
        predicted_labels = np.argmax(flat_predictions, axis=1).flatten()

        return [self.labels_dict[key] for key in predicted_labels]


class CosineChecker(Labeler):
    def __init__(self) -> None:
        super().__init__()

    def label(self, representations: torch.Tensor):
        # get cosine
        cos = CosineSimilarity(dim=1)
        x1, x2 = representations[:, 0, :], representations[:, 1, :]

        return pd.Series((cos(x1, x2) > 0).numpy()).apply(lambda x: "SYN" if x else "RANDOM")


class ModelInputRepresenter(Representer):
    def __init__(self) -> None:
        super().__init__()
        self.roberta = RobertaModelPreTrained("xlm-roberta-base", model_args={"num_labels": 4})
        self.xlmr_tokenizer = self.roberta.tokenizer
        self.max_len = 16

    def get_representations(self, data: pd.DataFrame) -> torch.Tensor:
        # for each datasample:
        input_ids_ = []
        attn_masks_ = []

        print('Encoding data:')
        for index, row in tqdm(data.iterrows()):
            word1 = row['word_1']
            word2 = row['word_2']

            # create required input, i.e. ids and attention masks
            encoded_dict = self.xlmr_tokenizer.encode_plus(word1, word2,
                                                           max_length=self.max_len,
                                                           padding='max_length',
                                                           truncation=True,
                                                           return_tensors='pt')

            # add encoded sample to lists
            input_ids_.append(encoded_dict['input_ids'])
            attn_masks_.append(encoded_dict['attention_mask'])
         
        input_ids_ = torch.cat(input_ids_, dim=0)
        attn_masks_ = torch.cat(attn_masks_, dim=0)

        return DataLoader(TensorDataset(input_ids_, attn_masks_), batch_size=128)
