from typing import Optional, Tuple, Any
from transformers import GPT2LMHeadModel, GPT2Model, GPT2Config
from transformers.file_utils import ModelOutput
from transformers.generation_logits_process import LogitsProcessor
from torch import nn
import torch
import torch.utils.checkpoint


class CustomMinLengthLogitsProcessor(LogitsProcessor):
    def __init__(self, min_length: int, eos_token_id: int):
        self.min_length = min_length
        self.eos_token_id = eos_token_id
        self.prompt_lengths = None

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.prompt_lengths is None:
            self.prompt_lengths = (input_ids == self.eos_token_id).sum(dim=1)
        cur_len = input_ids.shape[-1]
        for i in range(scores.shape[0]):
            if cur_len - self.prompt_lengths[i] < self.min_length:
                scores[i, self.eos_token_id] = -float("inf")
        return scores


class CausalLMOutputWithCrossAttentionsAndValues(ModelOutput):
    """
    A custom variant of `CausalLMOutputWithCrossAttentions` that also stores the value predicted by a value head
    """
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


class GPT2LMAndValueHeadModel(GPT2LMHeadModel):
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight",
                                       'value_head.head.weight', 'value_head.head.bias']

    def __init__(
        self,
        config: GPT2Config,
    ):
        super().__init__(config)
        # 'config != None' means from scratch
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.tokenizer = None
        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        logits_config: Optional[dict[str, Any]] = None
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)
        lm_logits = self.lm_head(hidden_states)

        # Standard loss computation kept for compatibility; the loss actually used is computed outside the model
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return CausalLMOutputWithCrossAttentionsAndValues(
            loss=loss,  # we will always compute loss outside this class
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

    def _get_logits_processor(self, *args, **kwargs):
        logits_processors = super()._get_logits_processor(*args, **kwargs)
        min_length, eos_token_id = kwargs.get('min_length'), kwargs.get('eos_token_id')
        logits_processors.append(CustomMinLengthLogitsProcessor(min_length, eos_token_id))
        return logits_processors
    