import numpy as np
import torch
from scipy.optimize import linear_sum_assignment

def sinkhorn_iter(suppliers, demanders, cost, epsilon, n_iterations):
    def M(C, u, v, epsilon):
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon
    
    u = torch.ones_like(suppliers)
    v = torch.ones_like(demanders)

    # Sinkhorn iterations
    for _ in range(n_iterations):
        v += epsilon * (torch.log(demanders + 1e-8) - torch.logsumexp(M(cost, u, v, epsilon).T, dim=-1))
        u += epsilon * (torch.log(suppliers + 1e-8) - torch.logsumexp(M(cost, u, v, epsilon), dim=-1))

    U, V = u, v
    # Transport plan pi = diag(a)*K*diag(b)
    solution = torch.exp(M(cost, U, V, epsilon))
    # Sinkhorn distance
    total_cost = torch.sum(solution * cost)
    return solution, total_cost

def sinkhorn_iteration(suppliers, demanders, cost, epsilon, n_iterations):
    '''
    # params
        * `suppliers`: 
            * M长的数组，M个supplier拥有的goods
        * `demanders`: 
            * N长的数组，N个demander需要的goods
        * `cost`: 
            * M×N的矩阵，第m个supplier向第n个demander运输的代价
    # return
        * `solution`: 
            * M×N的矩阵，使在所有supplier的货物都运输出去，所有demander都满足需求的情况下，总的运输代价最小的运输方案
        * `total_cost`: 
            * 方案对应的总代价
    # explanation
        * 运输方案可以是浮点数
    '''
    assert -1e-5 < np.sum(suppliers) - np.sum(demanders) < 1e-5,\
           "sum(suppliers) should equals sum(demanders), while sum(suppliers) == %f and sum(demanders) == %f." % (np.sum(suppliers), np.sum(demanders))

    # init
    K = np.exp(-cost**2 / epsilon)
    v = np.ones(demanders.shape[0])

    for _ in range(n_iterations):
        u = suppliers / (np.dot(K, v))
        v = demanders / (np.dot(K.T, u))

    solution = np.dot(np.dot(np.diag(u), K), np.diag(v))
    total_cost = np.sum(cost * solution)

    return solution, total_cost

def hungarian_assign(decode_dist, target, ignore_indices, random=False):
    '''
    :param decode_dist: (batch_size, max_kp_num, kp_len, vocab_size)
    :param target: (batch_size, max_kp_num, kp_len)
    :return:
    '''

    batch_size, max_kp_num, kp_len = target.size()
    reorder_rows = torch.arange(batch_size)[..., None]
    if random:
        reorder_cols = np.concatenate([np.random.permutation(max_kp_num).reshape(1, -1) for _ in range(batch_size)], axis=0)
    else:
        score_mask = target.new_zeros(target.size()).bool()
        for i in ignore_indices:
            score_mask |= (target == i)
        score_mask = score_mask.unsqueeze(1)  # (batch_size, 1, max_kp_num, kp_len)

        score = decode_dist.new_zeros(batch_size, max_kp_num, max_kp_num, kp_len)
        
        for b in range(batch_size):
            for l in range(kp_len):
                score[b, :, :, l] = decode_dist[b, :, l, target[b, :, l]]
        score = score.masked_fill(score_mask, 0)
        score = score.sum(-1)  # batch_size, max_kp_num, max_kp_num

        reorder_cols = []
        for b in range(batch_size):
            row_ind, col_ind = linear_sum_assignment(score[b].detach().cpu().numpy(), maximize=True)
            reorder_cols.append(col_ind.reshape(1, -1))
            # total_score += sum(score[b][row_ind, col_ind])
        reorder_cols = np.concatenate(reorder_cols, axis=0)
    return reorder_rows, reorder_cols

def optimal_transport_assign(decode_dist, target, assign_steps, epsilon=1e-3, n_iterations=100, has_null=None, k_strategy="", top_candidates=10, temperature=1):
    '''
    # params
        * `decode_dist`: 
            * (batch_size, pred_kp_num, kp_len, vocab_size)
            * each word is a probability distribution with a sample size of `vocab_size`
        * `target`:
            * (batch_size, trg_kp_num, kp_len)
            * each word is an index of vocabulary, ranging from 0 to `vocab_size` - 1
        * `epsilon`:
            * param for Sinkhorn Iteration (the smaller `epsilon`, the more precise matching result is; * however too `epsilon` small would generate `nan`)
        * `n_iterations`:
            * param for Sinkhorn Iteration (number of iterations)
        * `has_null`: 
            * list, `len(has_null) == batch_size`
            * which group of data is appended with background (null token)
    # return
        * `rematch_rows`:
            * a simple list of `[0, 1, ..., batch_size - 1]`.
        * `rematch_cols`:
            * the i-th pr matches the `rematch_cols[i]`-th gt
    '''
    
    assert k_strategy in ["null_protection", "normal"], \
        "k_strategy can only be `null_protection` or `normal`, while k is %s!" % k_strategy

    batch_size, pred_kp_num, kp_len, _ = decode_dist.shape
    device = str(decode_dist.device)
    rematch_rows = torch.arange(batch_size)[..., None]
    rematch_cols = []
    matching_scale = []

    scores = decode_dist.new_zeros(batch_size, pred_kp_num)
    for b in range(batch_size):
        # print("target[%d]: \n" % b, target[b])
        # print("target[%d].shape: \n" % b, target[b].shape)

        trg = target[b]
        trg_kp_num = trg.shape[0]
        assert trg_kp_num > 0
        
        trg = trg[:, :assign_steps]
        # decode_dist[b, :, l] 即第b个batch输出的任意kp的第l个词（记为 pr_w ）的分布
        # 因此 decode_dist[b, :, l, target[:, l]] 就是上述分布中 pr_w 对应位置的target的概率值
        score = decode_dist.new_zeros(pred_kp_num, trg_kp_num, kp_len)
        for l in range(kp_len):
            score[:, :, l] = decode_dist[b, :, l, trg[:, l]]
        score = score.sum(-1)  # pred_kp_num, trg_kp_num
        #print("[%d]: score: \n" % b, score)

        # normalization
        score_t = score.T ** (1 / temperature)
        score_nor = score_t / torch.sum(score_t, dim=0)
    
        # dynamic k strategy
        if k_strategy == "normal":
            k = torch.ones((trg_kp_num, )).to(device)
            if has_null[b]:
                topk_score_nor, _ = torch.topk(score_nor, top_candidates, dim=1)
                sum_topk_score_nor = torch.sum(topk_score_nor, dim=1)
                k = torch.ceil(sum_topk_score_nor)
                left_k = int(pred_kp_num - torch.sum(k[:-1]))
                k[-1] = left_k
                if left_k < 0:
                    k[-1] = 0.0  # 真正的trg已经分配超出了，所以null token不应再分配有k值
                    k, sum_topk_score_nor = k.detach().cpu().numpy(), sum_topk_score_nor.detach().cpu().numpy()
                    k_sorted_arg = (k - sum_topk_score_nor)[:-1].argsort()[::-1]
                    k, k_sorted_arg = torch.from_numpy(k.copy()).to(device), torch.from_numpy(k_sorted_arg.copy()).to(device)
                    loop_idx, xxcnt = 0, 0
                    while left_k < 0:
                        xxcnt += 1
                        if xxcnt > 2000:  # debug
                            print("Error: endless loop in k assignment!")
                            print("[%d]: k_sorted_arg: " % b, k_sorted_arg)
                            print("[%d]: k: " % b, k)
                            print("[%d]: sum(k): " % b, torch.sum(k))
                            print("[%d]: sum_topk_score_nor: " % b, sum_topk_score_nor)
                            print("[%d]: sum(sum_topk_score_nor): " % b, np.sum(sum_topk_score_nor))
                            exit()
                        if k[k_sorted_arg[loop_idx]] >= 2:
                            k[k_sorted_arg[loop_idx]] -= 1
                            left_k += 1
                        loop_idx = (loop_idx + 1) % (k.shape[0] - 1)

            # debug, check `k`
            if k[:-1].__contains__(0.):
                print("expected `k` are assigned as zero! \n[%d]: k: " % b, k)

        elif k_strategy == "null_protection":
            topk_score_nor, _ = torch.topk(score_nor, top_candidates, dim=1)
            sum_topk_score_nor = torch.sum(topk_score_nor, dim=1)
            k = torch.ceil(sum_topk_score_nor)
            surplus_k = int(torch.sum(k) - pred_kp_num)
            k = k.detach().cpu().numpy()
            k_sorted_arg = k.argsort()[::-1]
            k, k_sorted_arg = torch.from_numpy(k.copy()).to(device), torch.from_numpy(k_sorted_arg.copy()).to(device)
            loop_idx, xxcnt = 0, 0
            while surplus_k > 0:
                xxcnt += 1
                if xxcnt > 2000:  # debug
                    print("Error: endless loop in k assignment!")
                    print("[%d]: k_sorted_arg: " % b, k_sorted_arg)
                    print("[%d]: k: " % b, k)
                    print("[%d]: sum(k): " % b, torch.sum(k))
                    print("[%d]: sum_topk_score_nor: " % b, sum_topk_score_nor)
                    print("[%d]: sum(sum_topk_score_nor): " % b, torch.sum(sum_topk_score_nor))
                    exit()
                if k[k_sorted_arg[loop_idx]] >= 2:
                    k[k_sorted_arg[loop_idx]] -= 1
                    surplus_k -= 1
                loop_idx += 1
                loop_idx %= k.shape[0]

            if k.__contains__(0.):
                print("expected `k` are assigned as zero! \n[%d]: k: " % b, k)

        else:
            raise NotImplementedError
        
        # print("[%d]: k: \n" % b, k)
        
        suppliers = torch.ones((trg_kp_num, )).to(device) * k
        demanders = torch.ones((pred_kp_num, )).to(device)
        cost = -score_nor
        assert suppliers[-1] >= 0, \
               "background < 0! \npred_kp_num: {} \nk: {} \n".format(pred_kp_num, k)
        solution, _ = sinkhorn_iter(suppliers, demanders, cost, epsilon, n_iterations)
        # print("[%d]: solution: \n" % b, solution)
        # print("[%d]: sum(k): " % b, torch.sum(k))
        # print("sum of rows: ", torch.sum(solution, dim=1))
        # print("sum of cols: ", torch.sum(solution, dim=0))

        rematch_cols.append(torch.argmax(solution, dim=0).reshape(1, -1))
        # matching, _ = torch.max(solution, dim=0)
        # matching_scale.append(matching.reshape(1, -1) * (pred_kp_num / torch.sum(matching)))

        # 根据rematch_cols最后一个元素取出score_nor中对应位置的值
        # temp = torch.gather(score.T, 0, rematch_cols[-1])
        temp = torch.gather(score_t, 0, rematch_cols[-1])
        scores[b] = torch.gather(score.T, 0, rematch_cols[-1])
        matching_scale.append(temp.reshape(1, -1) * (pred_kp_num / torch.sum(temp)))

    rematch_cols = torch.cat(rematch_cols, dim=0)
    matching_scale = torch.cat(matching_scale, dim=0)

    return scores, rematch_cols, matching_scale


if __name__ == '__main__':
    '''
     :param decoder_dist: (batch_size, max_kp_num, max_kp_len, vocab_size)
     :param trg:  (batch_size, max_kp_num, max_kp_len)
     :return:
    '''
    
    def print_res(assign_method, decoder_dist, target):
        print('--------------------------------')
        if assign_method == "hungarian":
            print('|       %s_assign       |' % assign_method)
            reorder_index = hungarian_assign(decoder_dist, target, [0])
        else:  # optimal_transport
            print('|   %s_assign   |' % assign_method)
            reorder_index = optimal_transport_assign(decoder_dist, target, assign_steps=2, k_strategy="direct")
        print('--------------------------------')
        print("reorder_index: \n", reorder_index)
        print(" ------------------------------ ")
        # target要加上一个kp，作为background
        new_target = []
        for b in range(1):
            target[b] = torch.concatenate([target[b], torch.zeros((1, 2))], dim=0)
            new_target.append(target[b][reorder_index[1][b]])
        new_target = torch.array(new_target)
        # print("target: \n", target)
        # print("new_target: \n", new_target)
        # print(" ------------------------------ ")
        # print("trg_mask: \n", trg_mask)
        # print("new new_trg_mask: \n", trg_mask[reorder_index])
        # print(" ------------------------------ ")
        
    torch.manual_seed(2343)

    decoder_dist = torch.rand((1, 20, 2, 4)).softmax(-1) # 1个batch里共2组数据，每组数据里有20个kp，每个kp有2个单词，每个单词是一个样本容量为vocab_size（此处为4）的概率分布
    target = [torch.randint(4, (3, 2))]

    print_res("optimal_transport", decoder_dist, target)
