# Dual Decoding for Rocket Inference Speed of Large Language Models

Dual Decoding is a new decoding algorithm that has much faster generation speed compared with Speculative Decoding and Lookahead Decoding, without any training cost.

## Install

### Install from pip

```shell
pip install dualdec
```

### Install from source code

- First, clone this repo as `dualdec`
- Then, run the following command

```shell
cd dualdec
pip install -e .
```

## Quick Start

### An Example

- Here is an example of using Dual Decoding. Note: current version of `dualdec` can only support Llama model inference. Please import `LlamaForCausalLM` from `dualdec.models` instead of `transformers`

- The following code shows an example of comparing the output of Dual Decoding with Greedy decoding. If `dualdec` is installed correctly, the following code should output `True`.

```python
import torch
from dualdec import dualdec
from transformers import AutoTokenizer
from dualdec.models import LlamaForCausalLM

window_size = 20
guess_set_size = 20
lookahead_level = 7
gamma = 12

small_model = LlamaForCausalLM.from_pretrained("yourpath", torch_dtype=torch.float16, device_map='cuda')
target_model = LlamaForCausalLM.from_pretrained("yourpath", torch_dtype=torch.float16, device_map='cuda')

tokenizer = AutoTokenizer.from_pretrained("yourpath")

prompt = "Please summarize the following paragraph. Officers searched properties in the Waterfront Park and Colonsay View areas of the city on Wednesday. Detectives said three firearms, ammunition and a five-figure sum of money were recovered. A 26-year-old man who was arrested and charged appeared at Edinburgh Sheriff Court on Thursday. Summary: "

input_ids = tokenizer(prompt, return_tensors='pt').to('cuda')['input_ids']

dualdec_output = dualdec(input_ids, small_model, target_model, max_len=64, gamma=gamma, window_size=window_size, guess_set_size=guess_set_size, lookahead_level=lookahead_level)

std_output = target_model.generate(input_ids, do_sample=False, min_length=64, max_length=64)

print(dualdec_output[:,:64].equal(std_output[:,:64]))
```

### Using Dual Decoding 

- First, prepare two huggingface transformers models (the target model and draft model). Load the model with `LlamaForCausalLM` from `dualdec.models`.
- Then import the generation function `dualdec` from `dualdec`.
- Call the function by the following parameters:
```python 
@torch.no_grad()
def dualdec(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, ngram_cache : CacheEngine = None,
                         max_len : int = 512 , gamma : int = 4, window_size = 20, guess_set_size = 20, lookahead_level = 7, eos_token_id = 2) -> torch.Tensor:
    """
    Performs dual decoding with an approximate model and a target model to generate a sequence of tokens.

    Args:
        prefix (torch.Tensor): The initial sequence of tokens to start the generation from.
        approx_model (torch.nn.Module): The approximate model used for initial token generation. The model should support huggingface transformers model methods.
        target_model (torch.nn.Module): The target model used for refining the generated tokens. The model should support huggingface transformers model methods.
        ngram_cache (CacheEngine, optional): A cache engine for storing and retrieving n-gram predictions. Defaults to None, in which case a new cache engine is created.
        max_len (int, optional): The maximum length of the generated sequence. Defaults to 512.
        gamma (int, optional): The lookahead parameter for generation. Defaults to 4.
        window_size (int, optional): The window size used for n-gram generation. Defaults to 20. Currently, must be equal to guess_set_size.
        guess_set_size (int, optional): The size of the guess set for n-gram retrieving. Defaults to 20. Currently, must be equal to window_size.
        lookahead_level (int, optional): The level of lookahead decoding. Defaults to 7.
        eos_token_id (int, optional): The token id representing the end-of-sequence token. Defaults to 2. Should be given by tokenizer.eos_token_id.

    Returns:
        torch.Tensor: The generated sequence of tokens, including the initial prefix and any additional tokens generated by the function.
    """
```

## Reproduction of Experimental Results

Please refer to [this](reproduction.md)
