import sys
sys.path.append('/home/mila/x/xiyuan.zou/research/icl-mechanism-project')
from cache_based_inference.kv_cache import StartRecentKVCache
from cache_based_inference.pos_shift.modify_llama import enable_llama_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_neox import enable_gpt_neox_pos_shift_attention

def enable_streaming_llm(model, cache_type, start_size=None, recent_size=None):
    if "llama" in model.config.model_type:
        k_seq_dim = v_seq_dim = 2
        enable_llama_pos_shift_attention(model)
    elif "mpt" in model.config.model_type:
        v_seq_dim = 2
        k_seq_dim = 3
    elif "gpt_neox" in model.config.model_type:
        k_seq_dim = v_seq_dim = 2
        enable_gpt_neox_pos_shift_attention(model)
    elif "falcon" in model.config.model_type:
        v_seq_dim = 1
        k_seq_dim = 1
        enable_falcon_pos_shift_attention(model)
    else:
        raise ValueError(f"got {model.config.model_type}")
    
    kv_cache = StartRecentKVCache(
        start_size=start_size,
        recent_size=recent_size,
        k_seq_dim=k_seq_dim,
        v_seq_dim=v_seq_dim,
    )
    return kv_cache

