import torch
from tqdm import tqdm
import torch

from dualdec.kv_cache_model import KVCacheModelSimple

@torch.no_grad()
def speculative_greedy_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, 
                         max_len : int , gamma : int = 4, eos_token_id: int = 2, profiling=False) -> torch.Tensor:
    seq_len = prefix.shape[1]
    T = seq_len + max_len
    
    
    device = target_model.device
    
    approx_model_cache = KVCacheModelSimple(approx_model)
    target_model_cache = KVCacheModelSimple(target_model)
    
    resample_count = 0
    target_sample_count = 0
    accepted_count = 0

    end_pos = None

    if profiling:
        spec_time = 0
        acc_len = 0
        cor_len = 0
        cor_len_b1 = 0

        sub_match = []

        consq_0 = 0

        cont_num = [0 for i in range(20)]
    
    while prefix.shape[1] < T:
        prefix_len = prefix.shape[1]

        x = approx_model_cache.generate(prefix, gamma)
        _ = target_model_cache.generate(x, 1)
        
        n = prefix_len + gamma - 1

        _d = x[:,prefix_len:]
        _t = target_model_cache._prob_history[:,prefix_len-1:-1,:].argmax(dim=-1)
        cor_len += (_d == _t).sum()
        cor_len_b1 += ((_d == _t) | (torch.cat((_t[:, :-1] == _d[:, 1:], torch.tensor([[False]]).cuda()), dim=-1)) | (torch.cat((torch.tensor([[False]]).cuda(), _t[:, 1:] == _d[:, :-1]), dim=-1))).sum()


        beg_pos = None
        for i in range(_d.shape[-1]):
            if _d[:, i] != _t[:, i]:
                beg_pos = i
                break
        
        # if beg_pos is not None:
            
        #     cur_cnt = 0
        #     for i in range(beg_pos, _d.shape[-1]):
        #         if _d[:, i] == _t[:, i]:
        #             cur_cnt += 1
                    
        #     for k in range(1, 21):
        #         for i in range(beg_pos, _d.shape[-1] - k + 1):
        #             if _d[:, i:i+k].equal(_t[:, i:i+k]):
        #                 cont_num[k - 1] += 1

        for i in range(gamma):
            j = x[:, prefix_len + i]
            
            t_tok = target_model_cache._prob_history[:, prefix_len + i - 1, :].argmax(dim=-1, keepdim=True)
            if t_tok != j:
                # reject
                n = prefix_len + i - 1
                break
            if j == eos_token_id:
                end_pos = prefix_len + i + 1

            accepted_count += 1
        
        prefix = x[:, :n + 1]
        if profiling:
            spec_time += 1
            acc_len += prefix.shape[-1] - prefix_len
        
        approx_model_cache.rollback(n+1)
        
        if n < prefix_len + gamma - 1:
            # reject someone, sample from the pos n
            t = target_model_cache._prob_history[:, n, :].argmax(dim=-1, keepdim=True)
            if t == eos_token_id:
                end_pos = n + 2
            resample_count += 1
            target_model_cache.rollback(n+1)
        else:
            t = target_model_cache._prob_history[:, -1, :].argmax(dim=-1, keepdim=True)
            if t == eos_token_id:
                end_pos = n + 2
            target_sample_count += 1
            # target_model_cache.rollback(n+2)
        
        
        prefix = torch.cat((prefix, t), dim=1)

        if end_pos is not None:
            prefix = prefix[:, :end_pos]
            break
    if profiling:
        return prefix[:, :T], acc_len, spec_time, cor_len, cor_len_b1, sub_match, consq_0, cont_num
    else:
        return prefix[:, :T]