import numpy as np

def get_perplexity(loss, round=2, base=2):
    if loss is None:
        return 0.
    loss = min(loss, 30)
    return np.round(np.power(base, loss), round)

def rm_space_in_str(str):
    return "".join(str.split(" "))

def rm_prefix_in_key_of_dict(d, prefix):
    res = {}
    for k, v in d.items():
        len_p = len(prefix)
        if k[:len_p] == prefix:
            k = k[len_p:]
        res[k] = v
    return res

def load_text_from_given_slot(file, slot_id):
    f = open(file, "r")
    datas = []
    for line in f:
        slots = line.strip("\n").split("\t")
        if len(slots) <= slot_id:
            print("Error. len(slots) <= slot_id in load_text_from_given_slot.", \
                 "len(slots):", len(slots), "slot_id:", slot_id)
        text = slots[slot_id]
        words = text.split()
        datas.append(words)
    f.close()
    return datas

def load_emb_text(file):
    embs = []
    dict = {}
    cnt = 0
    f = open(file, "r")
    for line in f:
        line = line.strip()
        slots = line.split()
        if len(slots) == 2:
            continue
        word = slots[0]
        emb = list(map(float, slots[1:]))
        #print("word:", word, "emb:", emb, "line:", line)
        if word in dict:
            continue
        dict[word] = cnt
        cnt += 1
        embs.append(emb)
    print("load_emb_text done. file:", file, "total num:", cnt)
    return dict, embs

def print_tensor_by_flatten(d, name=""):
    df = d.flatten()
    vec = df.cpu().tolist()
    vec_str = " ".join(list(map(str, vec)))
    print("PrintTensorValue:" + "\t" + name + "\t" + vec_str)

def _ids2texts_for_print(ids, vocab):
    res = []
    ids = ids.int().cpu().tolist()
    for id in ids:
        if id == vocab.bos_token_id:
            continue
        if id == vocab.eos_token_id:
            break
        if id >= vocab.size:
            print("error. oov in tool._ids2texts_for_print in tool.print_texts_from_batch. id:", id)
            continue
        res.append(vocab.index2word[id])
    #res_str = " ".join(res)
    return res

def print_texts_from_batch(batch, vocab):
    src_ids = batch['input_batch'].transpose(0,1)
    tgt_ids = batch['target_batch'].transpose(0,1)
    src_texts = []
    tgt_texts = []
    
    for s in src_ids:
        s_text = _ids2texts_for_print(s, vocab)
        src_texts.append(" ".join(s_text))
    
    for t in tgt_ids:
        t_text = _ids2texts_for_print(t, vocab)
        tgt_texts.append(" ".join(t_text))

    srctgt_texts = [src + "\t" +  tgt for src, tgt in zip(src_texts, tgt_texts)]
    print("src_texts:", src_texts)
    print("tgt_texts:", tgt_texts)
    print("srctgt_texts:", srctgt_texts)
    print("len src_texts:", len(src_texts))
    print("len tgt_texts:", len(tgt_texts))
    print("len srctgt_texts:", len(srctgt_texts))
    return src_texts, tgt_texts, srctgt_texts

def change_special_tokenid_to_ori_trs_model(config, vocab): # while using ori trs model as pre-train
    # oritrs: 0pad 1bos 2eos 3unk
    # maml:0: 'UNK', 1: 'PAD', 2: 'EOS', 3: 'SOS'
    #vocab.token_to_id_map_py[vocab._eos_token] = 2
    #vocab.token_to_id_map_py[vocab._bos_token] = 1
    #vocab.token_to_id_map_py[vocab._unk_token] = 3
    #vocab.token_to_id_map_py[vocab._pad_token] = 0
    config.PAD_idx = 0
    config.SOS_idx = 1
    config.EOS_idx = 2
    config.UNK_idx = 3
 # Add new token? 
#    config.GUI_idx = 4
#    vocab.index2word[config.GUI_idx] = "GUI"

    vocab.index2word[config.PAD_idx] = "PAD"
    vocab.index2word[config.SOS_idx] = "SOS"
    vocab.index2word[config.EOS_idx] = "EOS"
    vocab.index2word[config.UNK_idx] = "UNK"
    vocab.bos_token_id = config.SOS_idx
    vocab.eos_token_id = config.EOS_idx
    
def add_begin_str_in_keys(str, dict):
    from collections import OrderedDict
    d = {}
    for k, v in dict.items():
        k = str + k
        d[k] = v
    return OrderedDict(d)

def rm_begin_str_in_keys(str, dict):
    from collections import OrderedDict
    d = {}
    for k, v in dict.items():
        lenstr = len(str)
        if str == k[:lenstr]:
            k = k[lenstr:]
        d[k] = v
    return OrderedDict(d)
    
def replace_id_in_batch_input(batch, srcid, tgtid):
    #(a-src)+(1-(a-src).bool().int())*(tgt-src)+src
    changed_slots = (1 - (batch-srcid).bool().int()) # tensor with binary value (0: not change 1: change)
    changed_value = changed_slots * (tgtid-srcid)
    res = batch + changed_value
    return res

def replace_end_pad_1_with_0_for_all(ids):
    ids2 = ids.transpose(0,1)
    ids2[0] = ids2[0] * -1
    ids3 = replace_id_in_batch_input(ids2, 1, 0)
    #print("ids:", ids)
    #print("ids2:", ids2)
    #print("ids3:", ids3)
    ids3[0] = ids3[0] * -1
    ids4 = ids3.transpose(0,1)
    #print("ids4:", ids4)
    return ids4

def print_keys(dict, mark):
    print(mark)
    keys = dict.keys()
    for i, k in enumerate(keys):
        print(i, k)
