from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import cross_entropy, softmax
from peft import LoraConfig, get_peft_model
from typing import Union, List
import torch.nn as nn
import random
import torch


class StandardModel(nn.Module):
    def __init__(self, tokenizer: AutoTokenizer, backbone: AutoModelForSequenceClassification, num_classes: int, device: str):
        super().__init__()
        
        # store the properties as well as the backbone model and its corresponding tokenizer
        self.num_classes = num_classes
        self.tokenizer = tokenizer
        self.backbone = backbone
        self.device = device
        
        # store the parameters on the desired device
        self.backbone.to(self.device)
        
    def tokenize_data(self, data: list) -> tuple:
        
        # tokenize the input sentences
        encodings = self.tokenizer.batch_encode_plus(data, return_tensors='pt', padding='longest', max_length=100, truncation=True)
        
        # use the available devices
        attn_mask = encodings['attention_mask'].to(self.device)
        token_ids = encodings['input_ids'].to(self.device)
        return token_ids, attn_mask
    
    def compute_loss(self, data: list, labels: torch.LongTensor) -> torch.Tensor:
        
        # tokenize the data
        token_ids, attn_mask = self.tokenize_data(data)
        
        # compute the loss value
        loss = self.backbone(token_ids, attn_mask, labels=labels)['loss']
        return loss
        
    def forward(self, data: list) -> torch.Tensor:
        
        # tokenizer the data
        token_ids, attn_mask = self.tokenize_data(data)
        
        # compute the logits
        logits = self.backbone(token_ids, attn_mask)['logits']
        return logits


class MetricModel(nn.Module):
    def __init__(self, tokenizer: AutoTokenizer, backbone: AutoModelForSequenceClassification,
                 num_features: int, num_classes: int, device: str):
        super().__init__()

        # store the properties as well as the backbone model and its corresponding tokenizer
        self.num_features = num_features
        self.num_classes = num_classes
        self.tokenizer = tokenizer
        self.backbone = backbone
        self.device = device

        # create a classification model
        self.classifier = nn.Linear(self.num_features, self.num_classes)
        
        # store the parameters on the desired device
        self.backbone.to(self.device)
        self.classifier.to(self.device)

    # noinspection PyCallingNonCallable
    def feature_extractor(self, data: List[str]) -> torch.Tensor:
        """
        Extracts feature representations from a batch of input samples.
        
        Parameters
        ----------
        data: Batch of input samples, for which feature representations should be extracted.

        Returns
        -------
        features: Extracted feature representations of the input batch.
        """

        # tokenize the input sentences
        encodings = self.tokenizer.batch_encode_plus(data, return_tensors='pt', padding='longest', max_length=100, truncation=True)

        # use the available devices
        attn_mask = encodings['attention_mask'].to(self.device)
        data_ids = encodings['input_ids'].to(self.device)

        # compute the embeddings
        features = self.backbone(data_ids, attn_mask)['logits']
        return features

    def forward(self, data: List[str]) -> tuple:
        """
        Forward pass of the metric model, which consists of a feature extraction step, followed by a data classification step.
        
        Parameters
        ----------
        data: Batch of input samples, for which feature representations should be extracted.

        Returns
        -------
        logits: Model predictions for the provided input batch.
        features: Extracted feature representations of the input batch.
        """
        
        # extract the feature representation of the data
        features = self.feature_extractor(data)
        
        # compute output logits based on the computed feature representations
        logits = self.classifier(features)
        return logits, features


def load_model(model_name: str, num_classes: int, num_features: Union[None, int] = None, efficient: bool = False,
               standard: bool = False, device: Union[str] = 'cpu') -> Union[MetricModel, StandardModel]:
    """
    Initializes and returns a MetricModel object.
    
    Parameters
    ----------
    model_name: Name of the backbone model, which is then downloaded from the huggingface hub.
    num_classes: Number of classes, in the classification problem.
    num_features: Number of dimensions of the feature dimension.
    efficient: Whether to use efficient model training or not.
    standard: Whether to load a standard or a metric model.
    device: Device, on which the model is stored.

    Returns
    -------
    model: MetricModel, which was initialized utilizing the desired properties.
    """

    # initialize a tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # initialize the model
    backbone = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True, num_labels=num_features)

    # if we want to trainin our model in a parameter efficient way
    if efficient:
        
        # initialize the configurations for parameter efficient finetuning using LoRa
        peft_config = LoraConfig(
            task_type="SEQ_CLS",
            inference_mode=False,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1
        )
        
        # apply the LoRa approach to the backbone model
        backbone = get_peft_model(backbone, peft_config)
    
    # create a model
    if standard:
        print('Initializing a standard model')
        model = StandardModel(tokenizer, backbone, num_classes, device)
    else:
        print('Initializing a metric model')
        model = MetricModel(tokenizer, backbone, num_features, num_classes, device)
        
        # use the desired device during model training and validation
        model.classifier.to(device)
        model.backbone.to(device)
    return model
