from model.ops import mask_lengths
import re
import torch
import numpy as np

from copy import deepcopy

def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits
    else:
        values, _ = torch.topk(logits, k=k)
        min_values = values[:, -1, None]
        return torch.where(
            logits < min_values,
            torch.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )


def top_p_logits(logits, p):
    """Nucleus sampling"""
    batch = logits.size(0)
    sorted_logits, _ = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    a = torch.arange(0,batch).to(logits.device)
    b = torch.max(torch.sum(cumulative_probs <= p, dim=-1) - 1, torch.Tensor([0]).long().to(logits.device))
    min_values = sorted_logits[a,b].to(logits.device)
    return torch.where(
        logits < min_values[:,None],
        torch.ones_like(logits) * -1e10,
        logits,
    )


def gathered_input(indexed):
    device = indexed.device
    # print(indexed.size())
    bs, l = indexed.size()
    lens = torch.LongTensor([l + 1] * bs).to(device)
    indexed = torch.cat([indexed, torch.LongTensor([0] * bs)[:, None].to(device)], 1)
    return bs, l, (indexed,lens)


def divided_input(indexed):
    device = indexed[0].device
    title, content, title_len, context_len = indexed
    bs, tl = title.size()
    content = content[:,-1:]
    _, cl = content.size()
    cls = torch.LongTensor([2] * bs).to(device)
    cind = torch.cat([content, torch.LongTensor([0] * bs)[:, None].to(device)], 1)
    return bs, cl, (title, cind, title_len, cls)


def get_mem(model,inp):
    istuple = True if isinstance(inp, tuple) else False
    with torch.no_grad():
        if istuple:
            title, context, title_len, context_len = inp
            context = context[:,:-1]
            context_len = torch.clamp_min(context_len - 1,0)
            _, mem = model.compute_hidden((title,context,title_len,context_len,None))
        else:
            bs, l = inp.size()
            lens = torch.LongTensor([l - 1] * bs).to(inp.device)
            _, mem = model.compute_hidden(inp[:,:-1],None,lens)
    return mem, inp


def sample(model, lengths, inp, top_w, temparature, experimental_loss, sampling_mode=0):
    top_whatever = top_k_logits if isinstance(top_w, int) else top_p_logits
    probs = None
    istuple = True if isinstance(inp, tuple) else False
    mem, inp = get_mem(model, inp)
    mem=[m.to(torch.float) for m in mem]
    res = inp
    # res = torch.LongTensor([]).to(inp.device)
    cnt = 0
    for _ in range(lengths):
        cnt+=1
        with torch.no_grad():
            if istuple:
                bs, l, inp = divided_input(inp)
            else:
                bs, l, inp = gathered_input(inp[:,-1:])
            if experimental_loss:
                logits, new_mem = model.sampling(inp + (mem, sampling_mode, top_w))
            else:
                logits, new_mem = model(inp + (None, mem))
#            new_mem=[m.to(torch.float) for m in new_mem]
            mem = [torch.cat([mem[i], new_mem[i][:,:-1]],1) for i in range(len(mem))]
            logits = top_whatever(logits, top_w)
            logits = logits.view(bs,l,-1)
            logits = logits[:,-1,:] / temparature
            saved_logits = logits
            sampled = torch.multinomial(torch.softmax(logits,-1),1)
            res = torch.cat([res,sampled],1)
            temp_probs = torch.softmax(saved_logits, 1)
            probs = torch.cat([probs,temp_probs[torch.arange(len(sampled)),sampled.squeeze(1)][:,None]],1) \
                if probs is not None else temp_probs[torch.arange(len(sampled)),sampled.squeeze(1)][:,None]
            if istuple:
                title, cind, tls, cls = inp
                cind = sampled
                inp = (title, cind,tls,cls)
            else:
                inp = sampled
            # if sampled == torch.LongTensor([[0]]).to('cuda'):
            #     cnt +=1
            #     if cnt ==2:
            #         break
    if istuple:
        return res.tolist(), probs.tolist()
    else:
        return res.tolist(), probs.tolist()

