from typing import Literal, Optional, Tuple, List
from math import ceil
import torch
from functools import partial
from utils import kurtosis, entropy, hoyer_sparsity
import einops
from torch import nn
from transformers import DynamicCache


def adjust_by_norm(past_key_values,
                   sort: Literal['key', 'value'] = 'value',
                   keep_ratio: float = 1,
                   prune_after: int = 2048,
                   skip_layers: list = [],
                   descending: bool = False,
                   sort_by: Literal['norm', 'random', 'kurtosis', 'entropy', 'hoyer'] = 'norm',
                   ):
    """
    Adjust the key value cache for the model.
    The function should take in the past key values and return the adjusted key values.
    Args:
        past_key_values: the past key values from the model. This is a list of tuples, where each tuple contains the key and value tensors. 
        sort: whether to sort by key or value norms. Default is 'key'.  
        keep_ratio: the ratio of tokens to keep for each sequence. Default is 1, which means keep all tokens. ( e.g. If keep_ratio is 0.5, then we keep half of the tokens in each sequence)
        prune_after: the number of tokens after which to prune. If seq_len is less than this value, the kv_cache will not be changed by this functioin. Default is 2048.
        sort_by: the metric to sort by. Default is 'norm'.
        skip_layers: the layers to skip, i.e. for which we do not prune the kvcache. Default is an empty list.
        descending: whether to sort in descending order. Default is False.

    Returns:
        past_key_values: the adjusted past key values.
    """

    # both key and value have shape (batch_size, num_heads, seq_len, head_dim)
    # print('Adjusting key-value cache')
    # need a list not a tuple
    past_key_values = list(past_key_values)

    # define sort function
    if sort_by == 'norm':
        sort_fn = partial(torch.norm, p=2, dim=-1)
    elif sort_by == 'kurtosis':
        sort_fn = partial(kurtosis, dim=-1)
    elif sort_by == 'random':
        sort_fn = lambda x: torch.rand(x.size(0), x.size(1), x.size(2), device=x.device)
    elif sort_by == 'entropy':
        sort_fn = partial(entropy, dim=-1)
    elif sort_by == 'hoyer':
        sort_fn = partial(hoyer_sparsity, dim=-1)
    else:
        raise ValueError('Invalid sort_by argument. Please choose from "norm", "kurtosis", "random"')

    # iterate over the past key values, should we filter out some layers here ?
    for layer, kv in enumerate(past_key_values):

        if kv[0].size(2) < prune_after:
            continue

        keys, values = kv
        token_dim = keys.shape[-1]

        # print(keys.shape)
        tokens_to_keep = ceil(keep_ratio * keys.size(2))

        # sort kv cache by key or value  
        if sort == 'key':
            token_scalars = sort_fn(keys)
        else:
            token_scalars = sort_fn(values)

        # sort by norm
        sorted_indices = token_scalars.squeeze(-1).argsort(descending=descending, dim=-1)

        # print('Sorted indices')
        # print(sorted_indices)

        sorted_indices_expanded = sorted_indices.unsqueeze(-1).expand(-1, -1, -1, token_dim)
        # print('Expanded indices')
        # print(sorted_indices_expanded)

        # apply sort to t
        sorted_keys = torch.gather(keys, dim=2, index=sorted_indices_expanded)
        sorted_values = torch.gather(values, dim=2, index=sorted_indices_expanded)
        

        # print('Sorted tensor')
        # print(sorted_tensor)

        # IMPORTANT
        # We might just skip the sorting if the layer is in skip_layers,
        # BUT we want to make sure that the cache is still sorted in case we exceed the context length
        # so for now we do this, but we can optimize this a lot
        if layer not in skip_layers:
            past_key_values[layer] = (sorted_keys[:, :, :tokens_to_keep, :], sorted_values[:, :, :tokens_to_keep, :])
            #past_key_values.key_cache[layer] = sorted_keys[:, :, :tokens_to_keep, :]
            #past_key_values.value_cache[layer] = sorted_values[:, :, :tokens_to_keep, :]

        # else:
        #    past_key_values[layer] = (sorted_keys, sorted_values)

        # print(sorted_keys[:, :, :tokens_to_keep, :].shape)

    return past_key_values


def slide_kv_cache(past_kv_values, max_context_len: int):
    """
    Slide the key value cache for the model, by keeping only the last max_context_len tokens.
    The function should take in the past key values and return the adjusted key values.
    Args:
        past_kv_values: the past key values from the model. This is a list of tuples, where each tuple contains the key and value tensors. 
        max_context_len: the maximum number of tokens to keep for each sequence.
    Returns:
        past_kv_values: the adjusted past key values.
    """

    # both key and value have shape (batch_size, num_heads, seq_len, head_dim)

    # need a list not a tuple
    past_kv_values = list(past_kv_values)

    # iterate over the past key values
    for layer, kv in enumerate(past_kv_values):

        keys, values = kv
        token_dim = keys.shape[-1]

        # print(keys.shape)
        tokens_to_keep = max_context_len

        if keys.size(2) <= max_context_len:
            # we don't need to do anything because the sequence length is less than the max_context_len
            continue

        past_kv_values[layer] = (keys[:, :, :tokens_to_keep, :], values[:, :, :tokens_to_keep, :])

        # make sure to keep the first token, as attention sink does
        # past_kv_values[layer][0][:, :, 0:2, :] = keys[:, :, 0:2, :]
        # past_kv_values[layer][1][:, :, 0:2, :] = values[:, :, 0:2, :]

        # past_kv_values.key_cache[layer] = keys[:, :, :tokens_to_keep, :]
        # past_kv_values.value_cache[layer] = values[:, :, :tokens_to_keep, :]

    return past_kv_values



############################################################################################################
# The following code is used to prune the KV cache of the decoder layers in the LLaMA model

class LlamaDecoderLayerKVCacheDrop(nn.Module):
    """
    A module that applies key-value cache pruning to the original decoder layer.

    Args:
        original_decoder_layer (nn.Module): The original decoder layer.
        keep_ratio (float, optional): The ratio of key-value cache to keep after pruning. Defaults to 1.0.
        prune_after (int, optional): The number of key-value cache after which pruning should be applied. Defaults to 2048.
        sort_by (str, optional): The criterion to sort the key-value cache by. Defaults to 'norm'.

    Returns:
        Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: The output of the original decoder layer.

    """

    def __init__(self, 
                 original_decoder_layer, 
                 layer_idx, 
                 keep_ratio=1.0, 
                 prune_after=2048, 
                 sort_metric='norm',
                 descending: bool = False,
                 max_context_len: int = 4096
                 
                 ):
        super().__init__()
        self.sort_metric = sort_metric
        self.keep_ratio = keep_ratio
        self.prune_after = prune_after
        self.original_decoder_layer = original_decoder_layer
        self.layer_idx = layer_idx
        self.descending = descending
        self.max_context_len = max_context_len


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Forward pass of the LlamaDecoderLayerKVCacheDrop module.

        Args:
            hidden_states (torch.Tensor): The input hidden states.
            attention_mask (torch.Tensor, optional): The attention mask. Defaults to None.
            position_ids (torch.LongTensor, optional): The position IDs. Defaults to None.
            past_key_value (Tuple[torch.Tensor], optional): The past key-value states. Defaults to None.
            output_attentions (bool, optional): Whether to output attentions. Defaults to False.
            use_cache (bool, optional): Whether to use cache. Defaults to False.
            cache_position (torch.LongTensor, optional): The cache position. Defaults to None.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: The output of the original decoder layer.
        """
            
       
        if self.keep_ratio < 1.0 and past_key_value is not None and len(past_key_value) >= 1:
            # we sort the hidden states by norm
            # and keep only the top keep_ratio
            #print('Pruning key-value cache')
            # print(type(past_key_value))
            # print('cache size before pruning:', [past_key_value[i][0].shape for i in range(len(past_key_value))])
            # print('layer idx', self.layer_idx,  self.layer_idx-1)
            # print('layers in cache', list(range(len(past_key_value))))
            # print('kv cache len', len(past_key_value))
            skip_layers = list(range(len(past_key_value)))
            if self.layer_idx-1 in skip_layers:
                skip_layers.remove(self.layer_idx-1)
            # print('skip layers', skip_layers)
            
            past_key_value = adjust_by_norm(
                past_key_value, 
                sort='key', 
                keep_ratio=self.keep_ratio, 
                prune_after=self.prune_after, 
                sort_by=self.sort_metric,
                # we have to skip all layers but the current one
                skip_layers=skip_layers,
                descending=self.descending
                )
            # print('cache size after pruning', [past_key_value[i][0].shape for i in range(len(past_key_value))])
            # past_key_value = DynamicCache.from_legacy_cache(past_key_value)

        # remove tokens older than the context length
        past_key_value = slide_kv_cache(past_key_value, max_context_len=self.max_context_len)
            
        
        out = self.original_decoder_layer.forward(
                            hidden_states=hidden_states,
                            attention_mask=attention_mask,
                            position_ids=position_ids,
                            past_key_value=past_key_value,
                            output_attentions=output_attentions,
                            use_cache=use_cache,
                            cache_position=cache_position,
                            **kwargs)

          
        return out


def cast_to_kv_drop(
        model, 
        skip_layers: List[int],
        keep_ratio=1.0, 
        prune_after=2048,
        sort_descending: bool = False,
        sort_metric: Literal['norm', 'random', 'kurtosis', 'entropy', 'hoyer'] = 'norm',
        max_context_len: Optional[int] = None
        ):
    
    # we need to replace the forward method of the decoder layers

    # first let's define the new forward method
    # the new forward will wrap the original one and only sort the hidden states by norm
    # and keep only the top keep_ratio

    for idx, layer in enumerate(model.model.layers):
        model.model.layers[idx] = LlamaDecoderLayerKVCacheDrop(
            layer, 
            layer_idx=idx, 
            keep_ratio=1.0 if idx in skip_layers else keep_ratio, 
            prune_after=prune_after,
            sort_metric=sort_metric,
            descending=sort_descending,
            max_context_len=max_context_len
            )
    
    return model




""""def merge_layer(past_kv_values_layer, merge_ratio: float, strategy: Literal['every_n'] = 'every_n'):
    # past_kv_values_layer is a tuple, where each tuple contains the key and value tensors.
    # both key and value have shape (batch_size, num_heads, seq_len, head_dim)

    if strategy == 'every_n':
        # we basically do a mean pooling every n tokens
        keys, values = past_kv_values_layer
        batch_size, num_heads, seq_len, head_dim = keys.size()

        # every n token in the sequence, merge
        n = merge_ratio * seq_len
        keys = einops.reduce(keys, 'b h (l n) d -> b h l d', 'mean', n=n)
        values = einops.reduce(values, 'b h (l n) d -> b h l d', 'mean', n=n)
    else:
        raise ValueError('Invalid strategy argument. Please choose from "every_n"')

    return (keys, values)


def adjust_by_merge(past_key_values, merge_ratio: float, strategy: Literal['every_n'] = 'every_n'):
    # past_key_values is a list of tuples, where each tuple contains the key and value tensors.
    # both key and value have shape (batch_size, num_heads, seq_len, head_dim)

    # iterate over the past key values
    for layer, kv in enumerate(past_key_values):
        past_key_values[layer] = merge_layer(kv, merge_ratio, strategy)

    return past_key_values
"""


