# LongHeads: Multi-Head Attention is Secretly a Long Context Processor
### Overview
LongHeads is a training-free framework for extending the context window of large language models (LLMs) to more than 32x times their original pre-training length. LongHeads works efficiently in linear time, fits seamlessly with many LLMs that use relative positional encoding and can be integrated with popular extrapolation methods such as [Positional Interpolation (PI)](https://arxiv.org/abs/2306.15595) and [NTK-Dynamic RoPE](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/).

![schemes](figures/longheads.png)

### 🚀Quick Start
#### Environment Setup
```bash
pip install -r requirements.txt
# We use flash-attn==2.3.6
pip install flash-attn --no-build-isolation (FlashAttention >= 2.3.6)
```
#### Load model with LongHeads
```bash
# load longheads model
from modeling_longheads import LlamaForCausalLM
longheads_config = {
    # chunk size setting for longheads
    'window_size':256,
    # the attention window length of longheads (atten_length should be smaller to model's pretrained length)
    'atten_length':4096,
    # during encoding phrase, we use this praram to begin streamingly encoding long context with chunk selection strategy
    'begin_selective_length':4096,
    # whether offload KV cache to cpu memory, if True longheads can generate to 128k+ context length
    'cpu_offload':False,
    # whether use batch_encoding for encoding phrase acceleration, if True more memory will be needed
    'batch_encoding':False,
    # the hyper param for batch encoding
    'encoding_batch_size':128,
}
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, **longheads_config)
```
#### Run Inference Example
```bash 
python example.py
```

### Passkey Retrieval
#### LongHeads 128k
We successfully extend LLaMA-2-7b to **128k** with LongHeads without additional training and achieve 100% accuracy with **128k** context on passkey retrieval task!
After offloading the KV cache to CPU, peak GPU memory usage is 26.51GB and 44.48 GB when inference with 64k and 128k context.

<div  align="center">    
    <img src="figures/passkey_128k.png" alt="passkey_128k" width="50%" height="50%">
</div>



#### Evaluation
```bash
bash /passkey_retrieval/passkey_retrieval_script.sh
```


