
import os 
from utils.lemmatize import nltk_lemmatize

import json





def get_word_lemma_opensource(word_list, pos):  
    return nltk_lemmatize(word_list, pos )    

### load the word-pos map
word_pos_fp="./vocab/word_pos.json"
with open(word_pos_fp,"r") as f:
    vocab = json.loads( f.read().strip() )

def post_process_on_candidates_v2(info):
        
    global vocab 

    PosMap={"VERB":"v", "NOUN":"n", "ADJ":"a", "ADV":"r"} 
   
    
    info['sent_lst'][1] = " " + info['sent_lst'][1].strip() + " "
    info['sent_lst'][0] = info['sent_lst'][0].strip()
    info['sent_lst'][2] = info['sent_lst'][2].strip()
    
    
    idx = info['idx']
    word_pos = info['pos']
    word_lemma = info['word']
    word_format = info['sent_lst'][1].strip().lower()
        ## the word format in given sentence 

    candidates = info['candidates']
    if not isinstance(candidates, list):
        candidates = candidates.strip().split(' ')
    candidate_probas = info['candidate_probas']
        # candidates_probas : proba from BERT2mask predict
    
    word_lemma_probas = dict()
    word_lemma_format = list()
    candidates_lemma = get_word_lemma_opensource(candidates, PosMap[word_pos] )

    for word,proba, word_l in zip(candidates, candidate_probas, candidates_lemma ):
        ## check every candidate
        word = word.lower().strip()
        if len(word) == 0:  continue
        
        word_form = None 
        if word in vocab and word_pos in vocab[word]:
            word_form = word 
        if word_l in vocab and word_pos  in vocab[word_l]:
            word_form = word_l
        
        ## if current word cannot play the role of given POS, skip it 
        if word_form is None: 
            continue 

        ## if the same to the original word in format or lemma form, skip it 
        if word_form in [word_lemma, word_format]: continue 
        
        ## if the word has been recorded, skip it 
        if word_form in word_lemma_probas: continue 

        word_lemma_probas[word_form] = proba 
        word_lemma_format.append([proba, word , word_form])


    ## sort to fine the best candidate
    word_lemma_format = sorted( word_lemma_format, reverse=True   ) 



    ## in case of no valid candidates exist, 
    ### we will not use PoS restriction in this case 
    if len(word_lemma_format) == 0:
        for word,proba, word_l in zip(candidates, candidate_probas, candidates_lemma ):
            word = word.lower().strip()
            if len(word) == 0:    continue
            if word in [word_lemma, word_format]: continue 
            if word_l in [word_lemma, word_format]: continue  
            word_lemma_probas[word] = proba 
            word_lemma_format.append([proba, word , word_l])
        word_lemma_format = sorted( word_lemma_format, reverse=True   ) 
   
    return word_lemma_format



if __name__ == "__main__":
      
    #the input data "info" contains the information below:
    info = {
        "idx": 307, 
        "word": "side", 
        "pos": "NOUN", 
        "sent_lst": ["on which", " side ", "shall i be , when all these transitory things are done away with , when the dead have risen from their graves , when the great congregation shall stand upon the land , and upon the sea , when every valley , and every mountain , and every river , and every sea , shall be crowded with multitudes standing in thick array ?"], "sent": "on which side shall i be , when all these transitory things are done away with , when the dead have risen from their graves , when the great congregation shall stand upon the land , and upon the sea , when every valley , and every mountain , and every river , and every sea , shall be crowded with multitudes standing in thick array ?", 
        "answer": {"faction": 2, "position": 1, "team": 1}, 
            ###=== input info

        "candidates": ["side", "sides", "Side", "place", "part", "side", "position", "seat", "corner", "spot", "wing", "portion", "stand", "shore", "end", "edge", "body", "line", "viewpoint", "Side", "point", "bed", "hand", "one", "right", "front", "view", "plane", "way", "location", "field", "border", "court", "flank", "bench", "team", "left", "pole", "site", "land", "shoulder", "base", "day", "axis", "sided", "cross", "turn", "hill", "elevation", "sideline"], 
        "candidate_probas": [0.9999999403953552, 6.139466535159954e-08, 2.3229125289958574e-08, 1.011034478359818e-09, 6.427997578661859e-10, 4.936482889128513e-10, 2.7711463679302994e-10, 2.4150620392404676e-10, 6.76913802788448e-11, 2.905519921880817e-11, 1.937968362175635e-11, 1.7043293859542175e-11, 1.5158778746138246e-11, 1.3433417572761286e-11, 1.3025916283104788e-11, 1.2044727194793214e-11, 1.0754034815430025e-11, 1.0541566751454123e-11, 8.908785167904831e-12, 8.088345097856386e-12, 7.980797446516252e-12, 6.66594297121792e-12, 6.2644507636822055e-12, 5.451371559023199e-12, 5.3857595466721975e-12, 5.1128320327498145e-12, 4.574105417348706e-12, 4.158833000200701e-12, 3.840624866746634e-12, 3.3106893962409067e-12, 3.2471130818889637e-12, 2.2624684244132443e-12, 1.7995134462406304e-12, 1.6174360028403667e-12, 1.523695015645532e-12, 1.5045597146629008e-12, 1.3853530952609439e-12, 1.3799835840017094e-12, 1.3416779623062758e-12, 1.220891877179442e-12, 1.1135095666706119e-12, 1.0435784181250796e-12, 1.0421213588254763e-12, 1.0287276672876766e-12, 1.0195059857096012e-12, 9.802540723580222e-13, 9.201803815445109e-13, 8.882423876283019e-13, 8.784894469857085e-13, 8.761158573795946e-13],
            ## model predictions with probabilities
    } 

    
    out = post_process_on_candidates_v2(info)
    print(out)

    ## examples of output 
    ## [[1.011034478359818e-09, 'place', 'place'], [6.427997578661859e-10, 'part', 'part'], ... ,]
