# -*- coding: utf-8 -*-

import torch
import numpy as np


def find_insertion_tokens(tokens, dec_words, text, start, end):
    # assign the correct token id to each char position
    char_2_tid = [[i] * len(s) for i, s in enumerate(dec_words)]
    char_2_tid = [tid for sublist in char_2_tid for tid in sublist]
    
    # correct the start and end position to account for extra tokens
    dec_text = "".join(dec_words)
    offset = dec_text.find(text)
    start = offset + start
    end = offset + end
    
    # collect all tokens that span the insertion text
    tids = list(set(char_2_tid[start:end]))
    tids.sort()
    
    return tids


def feature_pooling(features, tids):
    if len(tids) == 1:
        return features[tids[0], :]
    else:
        return np.mean(features[tids, :], axis=0)


def extract_features(model, tokenizer, sample):
    # unpack the sample
    text_1 = sample[0]
    text_2 = sample[1]
    start = sample[2]
    end_1 = sample[3]
    end_2 = sample[4]
    
    # encode the input premise and hypothesis
    tokens = tokenizer.encode(text_1, text_2, return_tensors='pt')
    output = model(tokens)
    emb_pre = output.hidden_states[-2][0].cpu().detach().numpy()
    emb_last = output.hidden_states[-1][0].cpu().detach().numpy()
    
    # identify the insertion tokens
    dec_words = [tokenizer.decode(tok) for tok in tokens[0]]
    tids_1 = find_insertion_tokens(tokens, dec_words, text_1, start, end_1)
    tids_2 = find_insertion_tokens(tokens, dec_words, text_2, start, end_2)
    
    # extract the insertion features
    feat_pre_1 = feature_pooling(emb_pre, tids_1)
    feat_pre_2 = feature_pooling(emb_pre, tids_2)
    feat_last_1 = feature_pooling(emb_last, tids_1)
    feat_last_2 = feature_pooling(emb_last, tids_2)
    feat_pre = np.concatenate([feat_pre_1, feat_pre_2])
    feat_last = np.concatenate([feat_last_1, feat_last_2])
    
    # extract the softmax outputs
    logits = output.logits[0].cpu().detach().numpy()
    
    return (feat_pre, feat_last, logits)


def extract_dataset_features(model, tokenizer, samples, verbose=True):
    n_entries = len(samples)
    embed_pre = np.zeros([n_entries, 2048])
    embed_last = np.zeros([n_entries, 2048])
    softmax = np.zeros([n_entries, 3])
    
    with torch.no_grad():
        for i, sample in enumerate(samples):
            
            (feat_pre, feat_last, logits) = extract_features(model,
                                                             tokenizer,
                                                             sample)
            embed_pre[i,:] = feat_pre
            embed_last[i,:] = feat_last
            softmax[i,:] = logits
            
            if verbose and i % 1000 == 999:
                print("> roberta_embeddings.py:", str(i+1),
                      "out of", str(n_entries), sample[0])
    
    return (embed_pre, embed_last, softmax)
