
from typing import List
import torch

def inject_knowledge(input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, 
                     masked_lm_labels: torch.tensor, next_sentence_label: torch.tensor,
                     knowledge_sentences: List[List[int]], knowledge_labels: List[List[int]],
                     cls_id: int, sep_id: int, pad_id: int):

    # print(input_ids.shape)
    # print(len(knowledge_sentences))

    batch = input_ids

    num_ksentences = len(knowledge_sentences)

    aux_bs_infos = []
    row_indexed_infos = dict()
    for idx, sentence in enumerate(batch):
        sent_length = torch.sum(sentence != pad_id).item()
        sep_pos = torch.where(sentence == sep_id)[0][0].item()
        aux_bs_infos.append(dict(
            row=idx, length=sep_pos - 1, pos='a'
        ))
        aux_bs_infos.append(dict(
            row=idx, length=sent_length - sep_pos - 2, pos='b'
        ))
        row_indexed_infos[idx] = (
                                    sentence[1:sep_pos].tolist(),
                                    sentence[sep_pos + 1:sent_length - 1].tolist(),
                                    masked_lm_labels[idx][1:sep_pos].tolist(),
                                    masked_lm_labels[idx][sep_pos + 1:sent_length - 1].tolist()
                                )

    # sort by length
    knowledge_sentences, knowledge_labels = (list(t) for t in zip(*sorted(zip(knowledge_sentences, knowledge_labels),
                                                                          key=lambda x: len(x[0]))))

    aux_bs_infos = sorted(aux_bs_infos, key=lambda x: x['length'])

    # find shortest pairing that doesn't use same row in the batch
    # greedy, we could use hungarian but i didn't know how to avoid choosing
    # the same batch row
    pairs = []
    used_rows = set()
    for i, ks in enumerate(knowledge_sentences):
        len_ks = len(ks)
        for bs_info in aux_bs_infos:
            if bs_info['row'] not in used_rows and len_ks <= bs_info['length']:
                pairs.append(((ks, knowledge_labels[i]), bs_info))
                break
        used_rows.add(bs_info['row'])
    assert len(pairs) == len(knowledge_sentences)

    # just knolwedge labels
    masked_knowledge_labels = torch.ones_like(masked_lm_labels) * -1

    # rebuild the batch row, bcs len_ks is always inferior to the sentence length
    # we're good
    for k_info, bs_info in pairs:
        ks, kl = k_info
        tokens_a, tokens_b, labels_a, labels_b = row_indexed_infos[bs_info['row']]

        if bs_info['pos'] == 'a':
            tokens_a, labels_a = ks, kl
        elif bs_info['pos'] == 'b':
            tokens_b, labels_b = ks, kl
        
        tokens = [cls_id] + tokens_a + [sep_id] + tokens_b + [sep_id]
        sentence = torch.zeros(batch.size(1)) + pad_id
        sentence[:len(tokens)] = torch.LongTensor(tokens)
        batch[bs_info['row']] = sentence

        attention_mask_row = torch.zeros(batch.size(1))
        attention_mask_row[:len(tokens)] = 1
        attention_mask[bs_info['row']] = attention_mask_row

        token_type_ids_row = torch.ones(batch.size(1))
        token_type_ids_row[:len(tokens_a) + 1] = 0
        token_type_ids[bs_info['row']] = token_type_ids_row

        labels = [-1] + labels_a + [-1] + labels_b + [-1]
        sentence_labels = torch.zeros(batch.size(1)) - 1
        sentence_labels[:len(labels)] = torch.LongTensor(labels)
        masked_lm_labels[bs_info['row']] = sentence_labels

        # update masked_knowledge_labels
        # only masked sentence labels
        if bs_info['pos'] == 'a':
            labels_b = [-1] * len(labels_b)
        elif bs_info['pos'] == 'b':
            labels_a = [-1] * len(labels_a)
        labels = [-1] + labels_a + [-1] + labels_b + [-1]
        sentence_labels = torch.zeros(batch.size(1)) - 1
        sentence_labels[:len(labels)] = torch.LongTensor(labels)
        masked_knowledge_labels[bs_info['row']] = sentence_labels

        next_sentence_label[bs_info['row']] = 1

    return input_ids, attention_mask, token_type_ids, masked_lm_labels, next_sentence_label, masked_knowledge_labels
