import torch
import torch.nn as nn
from transformers import T5ForConditionalGeneration, T5Config


class CustomMLP(nn.Module):
    def __init__(self, config, num_layers, dim_reduce=200):
        super(CustomMLP, self).__init__()
        self.layer_weights = nn.Parameter(torch.randn(num_layers, config.d_model))
        self.reduce_mlp = nn.Sequential(
            nn.Linear(config.d_model, dim_reduce),
            nn.ReLU()
        )

        self.layer_modifiers = nn.ParameterList([
            nn.Parameter(torch.randn(dim_reduce)) for _ in range(num_layers)
        ])
        self.relu = torch.ReLU()
        self.t2t = nn.Linear(dim_reduce,dim_reduce)
        self.expand_mlps = nn.Linear(dim_reduce, config.d_model)

    def forward(self, encoder_layer_outputs):
        weighted_sum = torch.matmul(self.layer_weights, torch.stack(encoder_layer_outputs).permute(1, 2, 0))
        reduced_output = self.reduce_mlp(weighted_sum.sum(dim=2))
        layer_specific_outputs = []
        for i, modifier in enumerate(self.layer_modifiers):
            modified = reduced_output + modifier
            expanded = self.expand_mlps(self.relu(self.t2t(modified)))
            layer_specific_outputs.append(expanded)
        return layer_specific_outputs


class CustomT5Model(T5ForConditionalGeneration):
    def __init__(self, config):
        super(CustomT5Model, self).__init__(config)
        self.custom_mlp = CustomMLP(config, config.num_decoder_layers)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=None,
            decoder_attention_mask=None,
            head_mask=None,
            decoder_head_mask=None,
            encoder_outputs=None,
            past_key_values=None,
            inputs_embeds=None,
            decoder_inputs_embeds=None,
            labels=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        encoder_hidden_states = encoder_outputs[0]
        if output_hidden_states:
            encoder_layer_outputs = encoder_outputs.hidden_states
        else:
            encoder_layer_outputs = [encoder_hidden_states]

        mlp_outputs = self.custom_mlp(encoder_layer_outputs)

        if decoder_input_ids is not None or decoder_inputs_embeds is not None:
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                head_mask=decoder_head_mask,
                encoder_attention_mask=attention_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            sequence_output = decoder_outputs[0]
            for i in range(sequence_output.size(1)):  # Iterate over sequence length
                if i < len(mlp_outputs):
                    sequence_output[:, i, :] = sequence_output[:, i, :] + sequence_output[:, i, :] * mlp_outputs[i]

            decoder_outputs = (sequence_output,) + decoder_outputs[1:]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        return decoder_outputs