from abc import abstractmethod
import numpy as np
import torch.nn as nn
import torch
from scipy.special import softmax
from config import GPU_DEVICE

class Evaluator:
    def __init__(self, model, metric, num_classes=None):

        self.model = model

        self.device = torch.device(GPU_DEVICE) if torch.cuda.is_available() else torch.device("cpu")
        self.model.to(self.device)

        self.metric = metric
        self.num_classes = num_classes

        self.labels = np.empty(0, int)
        if num_classes is None:
            self.preds = np.empty(0, int)
        else:
            self.preds = np.empty((0, num_classes), float)

    @abstractmethod
    def evaluate_step(self, *args, **kwargs):
        pass

    def reset(self):
        self.labels = np.empty(0, int)
        self.preds = np.empty(0, int)

    @property
    def evaluation_results(self):
        return self.metric.compute_results(self.labels, self.preds)


class TransfomerEvaluator(Evaluator):
    def __init__(self, model, metric, num_classes=None):
        super(TransfomerEvaluator, self).__init__(model, metric, num_classes)

    def evaluate_step(self, batch):
        self.model.eval()

        batch = tuple(b.to(self.device) for b in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
            outputs = self.model(b_input_ids,
                                 attention_mask=b_input_mask)

        b_preds = outputs[0].detach().cpu().numpy()
        if b_preds.shape[1] > 1:
            if self.num_classes is None:
                b_preds = np.argmax(b_preds, axis=1)
            else:
                b_preds = softmax(b_preds, axis=-1)
        else:
            b_preds = b_preds.flatten()
        b_labels = b_labels.to('cpu').numpy()
        self.preds = np.append(self.preds, b_preds, axis=0)
        self.labels = np.append(self.labels, b_labels, axis=0)

        # return accuracy_score(b_labels, b_preds)


class ClassifierEvaluator(Evaluator):
    def __init__(self, model, metric, num_classes=None, transformation_net=None):
        super(ClassifierEvaluator, self).__init__(model, metric, num_classes)
        self.transformation_net = transformation_net
        if self.transformation_net:
            self.transformation_net.to(self.device)

        self.output_transformation = self._multiclass_output_to_preds

    def evaluate_step(self, b_embeddings, b_labels, concat_embeddings=None):
        # batch_accuracy = accuracy_score(b_labels, batch_preds)
        self.model.eval()

        with torch.no_grad():
            if self.transformation_net is not None:
                self.transformation_net.eval()
                b_embeddings = self.transformation_net(b_embeddings)

            if concat_embeddings is not None:
                b_embeddings = torch.hstack((b_embeddings, concat_embeddings))

            outputs = self.model(b_embeddings).cpu()

        if outputs.shape[1] > 1:
            b_preds = self.output_transformation(outputs)
        else:
            b_preds = outputs.flatten()

        b_labels = b_labels.to('cpu')

        self.preds = np.append(self.preds, b_preds, axis=0)
        self.labels = np.append(self.labels, b_labels, axis=0)

        # return accuracy_score(b_labels, b_preds)

    @staticmethod
    def _binary_output_to_preds(outputs, threshold=0.5):
        return (nn.Sigmoid()(outputs.flatten()) > threshold).int()

    @staticmethod
    def _multiclass_output_to_preds(outputs):
        return np.argmax(outputs, axis=1).flatten()
