# -*- coding: utf-8 -*-


import os
import json
import pandas as pd
import matplotlib
matplotlib.use('Agg')

def get_prob_dic(models_path, modelname, step):
    
    path = models_path+modelname+'/eval/'
    file = [file for file in os.listdir(path) if file.endswith('step'+str(step)+'.selectedIdx')][0]
    with open(path+file,'r') as f:
        data = json.load(f)
        flat = [item for sublist in data for item in sublist]  
    prob_dic = []
    print('######',path+file)
    print('max sentence index',max(flat))
    
    for i in range(max(flat)+1):
        prob_dic.append(0)
    for i in set(flat):
        prob_dic[i] = (flat.count(i)/len(flat))*100
        
    print('prob_dic, len:',len(prob_dic))
    print(prob_dic)
    return prob_dic

def pop_zeros(items):
    while items[-1] == 0:
        items.pop()
    
def plot_summ_distribution(models_path, dataset, step_models, PNG_FILENAME,  MAX_SENT_POS):
    print("=================================================")
    print("Plotting summary distribution...")
    prob_dics = {}


    
    for model in step_models:

        prob_dic = get_prob_dic(models_path, model[0], model[1])

        prob_dic = prob_dic[:MAX_SENT_POS]
        print('prob_dic[:MAX_SENT_POS], len:',len(prob_dic))
        print(prob_dic)
        
        prob_dics.update({model[2]: prob_dic})
        
    print(prob_dics)
    flat_prob_dics_values =  [item for sublist in prob_dics.values() for item in sublist]  
    max_prob = max(flat_prob_dics_values)
    
    print('max_prob',max_prob)
    

    png_file = models_path+dataset.upper()+PNG_FILENAME
    
    lens = [len(v) for v in prob_dics.values()]
    max_le = max(lens)
    index = [i for i  in range(max_le)]
    for k in list(prob_dics.keys()):
        if len(prob_dics[k])<max_le:
            prob_dics[k] = prob_dics[k] + [0]*(max_le - len(prob_dics[k]))
       
    df = pd.DataFrame(prob_dics, index=index)
    ax = df.plot.bar(rot=0,figsize=(15,6),title='summary distribution, dataset %s'%(dataset))
   
    ax.set_ylim(0,round(max_prob/5)*5)
    ax.set_xlabel("linear sentence index in source text")
    ax.set_ylabel("propotion of selected sentences")
    ax.get_figure().savefig(png_file)
    
    print("Plot summary distribution...DONE")    


######Arguments   
MAX_SENT_POS = 50
#DATASET = 'cnndm'
DATASET = 'pubmed'
#DATASET = 'arxiv'
PATH = ''
MODELS_PATH = PATH+'models/'
PNG_FILENAME = '.summ.dist.png'
#[model_folder_name,step,model_name ]
#STEP_MODELS=[['cnndm_oracle_bert_mp1024','0','cnndm_oracle_bert_mp1024'],
#             ['cnndm_bert__bs200ac2ws10000ts50000_mp1024ms407N_3gpu','28000','BERT-base_mp1024'],
#             ['cnndm_hs_bert_s_la_sum_bs200ac2ws10000ts50000_mp1024ms407N_3gpu','29000','HiStruct + BERT-base_mp1024']]
#STEP_MODELS=[['cnndm_oracle_roberta_mp1024','0','cnndm_oracle_roberta_mp1024'],
#             ['cnndm_robertaB__bs250ac2ws10000ts50000_mp1024mns407N_3gpu_FT','16000','RoBERTa-base_mp1024'],
#             ['cnndm_hs_robertaB_s_la_sum_bs250ac2ws10000ts50000_mp1024mns407N_3gpu','12000','HiStruct + RoBERTa-base_mp1024']]

STEP_MODELS=[['pubmed_oracle_roberta_mp15000','0','pubmed_oracle_roberta_mp15000'],
             ['pubmed_longformerB__bs500ac2ws10000ts70000_mp15000mns450N_law1024_F-finetune_T-globatt__3gpu','53000','Longformer-base_mp15000'],
             ['pubmed_hs_longformerB_s_la_sum_bs500ac2ws10000ts70000_mp15000mns450N_law1024_F-finetune_T-globatt_sn-longformerB-sumCLS8_3gpu','61000','HiStruct + Longformer-base_mp15000']]

#STEP_MODELS=[['arxiv_oracle_roberta_mp15000','0','arxiv_oracle_roberta_mp15000'],
#             ['arxiv_longformerB__bs500ac2ws10000ts70000_mp15000mns720N_law1024_F-finetune_T-globatt__3gpu','68000','Longformer-base_mp15000'],
#             ['arxiv_hs_longformerB_s_la_sum_bs500ac2ws10000ts70000_mp15000mns720N_law1024_F-finetune_T-globatt_sn-longformerB-sum_3gpu','70000','HiStruct + Longformer-base_mp15000']]

#STEP_MODELS=[['arxiv_oracle_roberta_mp28000','0','arxiv_oracle_roberta_mp28000'],
#             ['arxiv_longformerB__bs500ac2ws10000ts100000_mp28000mns1300N_law1024_F-finetune_T-globatt__3gpu','62000','Longformer-base_mp28000'],
#             ['arxiv_hs_longformerB_s_la_sum_bs500ac2ws10000ts100000_mp28000mns1300N_law1024_F-finetune_T-globatt_sn-longformerB-sum_3gpu','80000','HiStruct + Longformer-base_mp28000']]

if __name__ == '__main__':
    
    plot_summ_distribution(MODELS_PATH, DATASET, STEP_MODELS,PNG_FILENAME, MAX_SENT_POS)
    
     

    
    
    

    