import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from transformers import AutoModel, AutoTokenizer
from utils.config import device

# Define the model
# transformer_path = "/exp/data_luar/pretrained_weights"
# checkpoint_path = "/exp/$USER/sbert_mud.pt"
transformer_path = "/exp/$USER/pretrained_uar/pretrained_weights"
checkpoint_path = "/exp/$USER/pretrained_uar/sbert_mud.pt"

class SelfAttention(nn.Module):
    """Implements Dot-Product Self-Attention as used in "Attention is all You Need".
    """
    def __init__(self):
        super(SelfAttention, self).__init__()

    def forward(self, k, q, v):
        d_k = q.size(-1)
        scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
        p_attn = F.softmax(scores, dim=-1)

        return torch.matmul(p_attn, v)

class Transformer(nn.Module):
    """Defines the SBERT model.
    """
    def __init__(self):
        super(Transformer, self).__init__()

        self.create_transformer()
        self.attn_fn = SelfAttention()
        self.linear = nn.Linear(768, 512)
        
    def create_transformer(self):
        """Creates the transformer model.
        """
        model_path = os.path.join(transformer_path, "paraphrase-distilroberta-base-v1")
        self.sbert = AutoModel.from_pretrained(model_path)

    def mean_pooling(self, token_embeddings, attention_mask):
        """Mean Pooling as described in the SBERT paper.
        """
        input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=768).float()
        sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
        sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
        return sum_embeddings / sum_mask

    def get_author_embedding(self, text):
        """Computes the Author Embedding. 
        """
        input_ids, attention_mask = text[0], text[1]
        B, N, E, _ = input_ids.shape
        
        input_ids = rearrange(input_ids, 'b n e l -> (b n e) l')
        attention_mask = rearrange(attention_mask, 'b n e l -> (b n e) l')

        outputs = self.sbert(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True)
        
        embedded_episode = self.mean_pooling(outputs['last_hidden_state'], attention_mask)
        embedded_episode = rearrange(embedded_episode, '(b n e) l -> (b n) e l', b=B, n=N, e=E)

        embedded_features = reduce(self.attn_fn(embedded_episode, embedded_episode, embedded_episode), 
                                   'b e l -> b l', 'max')
        
        author_embedding = self.linear(embedded_features)

        return author_embedding
    
    def forward(self, data):
        """Calculates a fixed-length feature vector for a batch of episode samples.
        """
        output = self.get_author_embedding(data)

        return output

# Load model
model = Transformer()
state_dict = torch.load(checkpoint_path)
model.load_state_dict(state_dict, strict=True)
model.to(torch.device(device))
tokenizer = AutoTokenizer.from_pretrained(
    os.path.join(transformer_path, "paraphrase-distilroberta-base-v1")
    )

# Define get_*_embedding function
def get_uar_embedding_32(texts):
    data = tokenizer(texts, padding="max_length", truncation=True, max_length=32, return_tensors='pt')
    data = list(data.values())

    # (batch_size, num_samples_per_author, num_episodes, max_length)
    num_episodes = len(texts)
    max_length = data[0].shape[1]
    data[0] = data[0].reshape(1, 1, num_episodes, max_length)    # input_ids
    data[1] = data[1].reshape(1, 1, num_episodes, max_length)    # attn_mask
    data[0] = data[0].to(torch.device(device))
    data[1] = data[1].to(torch.device(device))

    embeddings = model(data)
    return torch.nn.functional.normalize(embeddings).detach().cpu().numpy()[0]

# Define get_*_embedding function
def get_uar_embedding_512(texts):
    data = tokenizer(texts, padding="max_length", truncation=True, max_length=512, return_tensors='pt')
    data = list(data.values())

    # (batch_size, num_samples_per_author, num_episodes, max_length)
    num_episodes = len(texts)
    max_length = data[0].shape[1]
    data[0] = data[0].reshape(1, 1, num_episodes, max_length)    # input_ids
    data[1] = data[1].reshape(1, 1, num_episodes, max_length)    # attn_mask
    data[0] = data[0].to(torch.device(device))
    data[1] = data[1].to(torch.device(device))

    embeddings = model(data)
    return torch.nn.functional.normalize(embeddings).detach().cpu().numpy()[0]

# Define get_*_embedding function
def get_uar_embedding_all(texts):
    data = tokenizer(texts, padding="max_length", truncation=True, max_length=2048, return_tensors='pt')
    data = list(data.values())

    # Split the full post into many instances of size 32
    max_length = data[0].shape[1]
    multiple_of_32 = 32 * round(max_length / 32)
    data = tokenizer(texts, padding="max_length", truncation=True, max_length=multiple_of_32, return_tensors='pt')
    data = list(data.values())
    data[0] = torch.concat(data[0].split(32, dim=1))
    data[1] = torch.concat(data[1].split(32, dim=1))

    # Ignore any episodes filled with padding (attention mask filled with zeros)
    attention_mask_non_empty = torch.sum(data[1], dim=1) != 0
    data[0] = data[0][attention_mask_non_empty]
    data[1] = data[1][attention_mask_non_empty]
    attention_mask_empty = torch.sum(data[1], dim=1) == 0

    # (batch_size, num_samples_per_author, num_episodes, max_length)
    num_episodes = data[0].shape[0]
    max_length = data[0].shape[1]
    data[0] = data[0].reshape(1, 1, num_episodes, max_length)    # input_ids
    data[1] = data[1].reshape(1, 1, num_episodes, max_length)    # attn_mask
    data[0] = data[0].to(torch.device(device))
    data[1] = data[1].to(torch.device(device))

    embeddings = model(data)
    return torch.nn.functional.normalize(embeddings).detach().cpu().numpy()[0]

# Define the variant we will use
get_uar_embedding = get_uar_embedding_512