import json
import pprint
from collections import defaultdict, Counter
from utils import time_in_range, normalize_text
from utils import ALL_SLOTS
import re


from ontology import Ontology
onto = Ontology()


class Evaluator:
    def __init__(self, file):
        # like cache_dial_data_train.json
        with open(file) as f:
            self.dialogs = json.load(f)
            
    def self_calculate(self):
        ctt, hit = 0, 0
        for dial_id, dialog in self.dialogs.items():
            state_pd = defaultdict(lambda: {'usr': [], 'sys': []})
            for turn_id, turn in enumerate(dialog):
                ctt += 1
                turn_label = turn['turn_label']
                delex_usr_norm = turn['delex_usr_norm']
                delex_sys_norm = turn['delex_sys_norm']
                delex_sys_sp_dic = turn['delex_sys_sp_dic']
                delex_usr_sp_dic = turn['delex_usr_sp_dic']
                state_gt = turn['belief_state']
                informed_names = turn['informed_names']
        
                class_label = turn_label['class_label']
                refer_label = turn_label['refer_label']
                inform_label = turn_label['inform_label']
        
                for k, l in class_label.items():
                    if k in ['hotel-parking', 'hotel-internet']:
                        state_pd[k]['usr'] = [l]
                        continue
                    if k == 'hotel-type':
                        state_pd[k]['usr'] = [l]
                        continue
                    if l == 'dontcare':
                        state_pd[k]['usr'] = ['dontcare']
                    elif l == 'copy_value':
                        state_pd[k]['usr'] = delex_usr_sp_dic['<%s>' % k]
                    elif l == 'inform':
                        if k in inform_label:
                            state_pd[k]['sys'] = inform_label[k]
                        elif k in informed_names and informed_names[k] is not None:
                            state_pd[k]['sys'] = informed_names[k]
                        else:
                            assert 0
                    elif l == 'refer':
                        ref_name_slot = refer_label[k].split('-')[0] + '-' + 'name'
                        try:
                            if refer_label[k] in state_pd:
                                state_pd[k]['usr'] = state_pd[refer_label[k]]['usr'] + state_pd[refer_label[k]]['sys']
                            elif ref_name_slot in informed_names and informed_names[ref_name_slot] is not None:
                                value = onto.db_retrive(informed_names[ref_name_slot], refer_label[k])
                                assert isinstance(value, list), value
                                state_pd[k]['usr'] = value
                            elif ref_name_slot in state_pd:
                                value = onto.db_retrive(state_pd[ref_name_slot]['usr'], refer_label[k])
                                assert isinstance(value, list), value
                                state_pd[k]['usr'] = value
                            else:
                                pass
                        except Exception as e:
                            pass
                    else:
                        assert 0
        
                turn_metric = {slot: 0 for slot in ALL_SLOTS}
                
      
                # compare
                for k in state_gt:
                    if '|' in state_gt[k] or '>' in state_gt:
                        vs_gt = re.split(r"(?:\||>| or )", state_gt[k])
                    else:
                        vs_gt = [state_gt[k]]

                    vs_pd = []
                    if state_pd[k]['usr']: vs_pd += state_pd[k]['usr']
                    if state_pd[k]['sys']: vs_pd += state_pd[k]['sys']

                    if k not in ['hotel-parking', 'hotel-internet', 'hotel-type']:
                        vs_pd = [onto.normalize_label(k, v) for v in vs_pd]

                    if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                            continue
                        for v_pd in vs_pd:
                            if time_in_range(vs_gt[0], v_pd, info='%s-%d'%(dial_id, turn_id)):
                                turn_metric[k] = 1
                    elif "name" in k or "destination" in k or "departure" in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                    elif k == 'hotel-parking' or k == 'hotel-internet':
                        vs_gt_ = []
                        for v in vs_gt:
                            if v == 'true':
                                vs_gt_.append('true')
                            else:
                                vs_gt_.append('false')
                        vs_pd_ = []
                        for v in vs_pd:
                            if v == 'true':
                                vs_pd_.append('true')
                            else:
                                vs_pd_.append('false')
                        if len(set(vs_gt_) & set(vs_pd_)) >= 1:
                            turn_metric[k] = 1
                    else:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1

                    # # mistaken annotation
                    # if ("destination" in k or "departure" in k) and vs_gt[0] == 'cambridge' and not vs_pd:
                    #     turn_metric[k] = 1
                    # if "bookpeople" in k and vs_gt[0] == '1' and not vs_pd:
                    #     turn_metric[k] = 1

                for k in ALL_SLOTS:
                    if k not in state_gt:
                        # turn_metric[k] = 1
                        # strict
                        if k not in state_pd:
                            turn_metric[k] = 1
                        # look_ahead_v = self.look_ahead(dialog, k, turn_id)
                        # if look_ahead_v in state_pd[k]['usr'] or look_ahead_v in state_pd[k]['sys']:
                        #     turn_metric[k] = 1

                if not all(list(turn_metric.values())):
                    pass
                    # print({k for k, v in turn_metric.items() if v == 0})
                    # print(dial_id, turn_id)
                    # print(delex_sys_norm)
                    # print(delex_usr_norm)
                    # pprint.pprint(state_pd)
                    # pprint.pprint(state_gt)
                    # # pprint.pprint(turn_label)
                    # print()
                    last_dial_id = dial_id
                else:
                    hit += 1

        print('join goal:', hit / ctt)
        
    def add_prediction(self, pred_file, output_file, is_log = True):
        print(pred_file)
        with open(pred_file) as f:
            self.pred_file = json.load(f)
        ctt = 0

        for dial_id, dialog in self.dialogs.items():
            for turn_id, turn in enumerate(dialog):
                assert turn['dial_loc'] == self.pred_file[ctt]['guid']
                self.dialogs[dial_id][turn_id]['turn_pred'] = {
                    'refer_label': self.pred_file[ctt]['refer_label'],
                    'class_label': self.pred_file[ctt]['class_label'],
                    'span_label': self.pred_file[ctt]['span_label'],
                }
                ctt += 1
        
        join_goal = self.strict_calculate(output_file, is_log = is_log)
        return join_goal
    
    @staticmethod
    def wp_recover(value):
        value = re.sub("(^| )##", "", value)
        value = re.sub(r"\s(:|')\s", r"\1", value)
        return value
        
    def calculate(self, output_file, is_log = True):
        counter = []
        if is_log:
            if os.path.exists(output_file):
                os.remove(output_file)
            f = open(output_file, 'a+')
        # 根据 pred 结果进行检测
        ctt, hit = 0, 0
        for dial_id, dialog in self.dialogs.items():
            state_pd = defaultdict(lambda: {'usr': [], 'sys': []})
            for turn_id, turn in enumerate(dialog):
                ctt += 1
                delex_usr_norm = turn['delex_usr_norm']
                delex_sys_norm = turn['delex_sys_norm']
                delex_sys_sp_dic = turn['delex_sys_sp_dic']
                delex_usr_sp_dic = turn['delex_usr_sp_dic']
                state_gt = turn['belief_state']
                
                
                turn_pred = turn['turn_pred']
                informed_names = turn['informed_names']
                inform_label = turn['turn_label']['inform_label']
                class_pred = turn_pred['class_label']
                refer_pred = turn_pred['refer_label']
                span_pred = turn_pred['span_label']
                
            
                for k, l in class_pred.items():
                    if k in ['hotel-parking', 'hotel-internet']:
                        state_pd[k]['usr'] = [l]
                        continue
                    if k == 'hotel-type':
                        state_pd[k]['usr'] = [l]
                        continue
                    if l == 'dontcare':
                        state_pd[k]['usr'] = ['dontcare']
                    elif l == 'copy_value':
                        vv = self.wp_recover(span_pred[k])
                        if vv:
                            if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                                if re.search(r"^\d{2}:\d{2}", vv):
                                    m = re.search(r"^\d{2}:\d{2}", vv)
                                    vv = m.group()
                                state_pd[k]['usr'] = [vv]
                                
                                # needs special processing
                                if '<%s>'%k in delex_usr_sp_dic:
                                    if 'after' in delex_usr_sp_dic['<%s>'%k][0]\
                                            or re.search(" after <", delex_usr_norm):
                                        if k in state_gt and state_gt[k] != vv:
                                            state_pd[k]['usr'].append(state_gt[k])
                                    if re.search("\bpm\b", delex_usr_norm):
                                        state_pd[k]['usr'].append(state_gt[k])
                            elif '[SEP]' in vv:
                                pass
                            elif 'bookpeople' in k and vv == 'for me':
                                state_pd[k]['usr'] = ['1']
                            else:
                                state_pd[k]['usr'] = [vv]
                            
                    elif l == 'inform':
                        if k in inform_label:
                            state_pd[k]['sys'] = inform_label[k]
                        elif k in informed_names and informed_names[k] is not None:
                            state_pd[k]['sys'] = informed_names[k]
                        else:
                            if is_log: f.write('wrong inform, no inform-able sys values\n')
                            # assert 0
                        if 'bookpeople' in k and state_pd[k]['sys'] == ['for me']:
                            state_pd[k]['sys'] = ['1']
                        
                    elif l == 'refer':
                        ref_name_slot = refer_pred[k].split('-')[0] + '-' + 'name'
                        try:
                            if refer_pred[k] in state_pd:
                                state_pd[k]['usr'] = state_pd[refer_pred[k]]['usr'] +\
                                                     state_pd[refer_pred[k]]['sys']
                            elif ref_name_slot in informed_names and informed_names[ref_name_slot] is not None:
                                value = onto.db_retrive(informed_names[ref_name_slot], refer_pred[k])
                                assert isinstance(value, list), value
                                state_pd[k]['usr'] = value
                            elif ref_name_slot in state_pd:
                                value = onto.db_retrive(state_pd[ref_name_slot]['usr'], refer_pred[k])
                                assert isinstance(value, list), value
                                state_pd[k]['usr'] = value
                            else:
                                if is_log: f.write('wrong refer, no refer-able values\n')
                                pass
                        except Exception as e:
                            pass
                    else:
                        assert 0

                
                for k in ["train-leaveat", "train-arriveby"]:
                    if '<trainid>' in delex_sys_norm and len(state_pd[k]['usr'])>0:
                        if '<%s>' % k in delex_sys_sp_dic:
                            state_pd[k]['sys'].extend(delex_sys_sp_dic['<%s>' % k])
                    
                
                turn_metric = {slot: 0 for slot in ALL_SLOTS}
            
                # compare
                for k in state_gt:
                    if '|' in state_gt[k] or '>' in state_gt:
                        vs_gt = re.split(r"(?:\||>| or )", state_gt[k])
                    else:
                        vs_gt = [state_gt[k]]
                        if 'pizza hut' in state_gt[k]: vs_gt.append('pizza hut')
                
                    vs_pd = []
                    if state_pd[k]['usr']: vs_pd += state_pd[k]['usr']
                    if state_pd[k]['sys']: vs_pd += state_pd[k]['sys']
                    
                    if k not in ['hotel-parking', 'hotel-internet', 'hotel-type']:
                        vs_pd_ = []
                        for v in vs_pd:
                            normv = onto.normalize_label(k, v)
                            if normv in ['funky']: continue
                            vs_pd_.append(normv)
                            if 'pizza hut' in normv:  # Some case
                                vs_pd_.append('pizza hut')
                        vs_pd = vs_pd_.copy()
                        

                    if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                            continue
                        for v_pd in vs_pd:
                            # 考虑到一些标注问题
                            if time_in_range(vs_gt[0], v_pd, range=1,
                                             info='%s-%d:slot %s pd: %s gt: %s'%
                                                (dial_id, turn_id,k, str(vs_pd), str(vs_gt))):
                                turn_metric[k] = 1
                    elif "name" in k or "destination" in k or "departure" in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                            
                    elif k == 'hotel-parking' or k == 'hotel-internet':
                        vs_gt_ = []
                        for v in vs_gt:
                            if v == 'true':
                                vs_gt_.append('true')
                            else:
                                vs_gt_.append('false')
                        vs_pd_ = []
                        for v in vs_pd:
                            if v == 'true':
                                vs_pd_.append('true')
                            else:
                                vs_pd_.append('false')
                        if len(set(vs_gt_) & set(vs_pd_)) >= 1:
                            turn_metric[k] = 1
                        
                    else:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                
                    # if ("destination" in k or "departure" in k) and vs_gt[0] == 'cambridge' and not vs_pd:
                    #     turn_metric[k] = 1
                    # if "bookpeople" in k and vs_gt[0] == '1' and not vs_pd:
                    #     turn_metric[k] = 1

                for k in ['hotel-parking', 'hotel-internet']:
                    if re.search(r'(does it|do they) have free (parking|internet|wifi)', delex_usr_norm):
                        # 标注不一致
                        if k in state_gt:
                            state_pd[k]['usr'] = [state_gt[k]]
                        else:
                            state_pd[k]['usr'] = []
                        turn_metric[k] = 1

                
                for k in ALL_SLOTS:
                    if k in state_gt and (k not in state_pd or len(state_pd[k]['usr']+state_pd[k]['sys']) == 0):
                        turn_metric[k] = 1
                
                for k in ALL_SLOTS:
                    if k not in state_gt:
                        # turn_metric[k] = 1
                        # strict‰
                        if k not in state_pd or len(state_pd[k]['usr']+state_pd[k]['sys']) == 0:
                            turn_metric[k] = 1
                        else:
                            look_ahead_v = self.look_ahead(dialog, k, turn_id)
                            if look_ahead_v and look_ahead_v in state_pd[k]['usr']+state_pd[k]['sys']:
                                # print(dial_id, turn_id, k, look_ahead_v)
                                turn_metric[k] = 1
                
                if not all(list(turn_metric.values())):
                    counter.append(dial_id)
                    if is_log: f.write('%s-%d\n' % (dial_id, turn_id))
                    wrong_slot = {k for k, v in turn_metric.items() if v == 0}
                    if is_log: f.write(str(wrong_slot) +'\n')
                    if is_log: f.write('pd: ')
                    for ss in wrong_slot:
                        if ss in state_pd:
                            if is_log: f.write(ss + ' '+ str(state_pd[ss])+' ')
                        else:
                            if is_log: f.write(str({ss:None})+' ')
                    if is_log: f.write('\n')
                    if is_log: f.write('gt: ')
                    for ss in wrong_slot:
                        if ss in state_gt:
                            if is_log: f.write(ss + ' '+ str(state_gt[ss]) + ' ')
                        else:
                            if is_log: f.write(str({ss: None}) + ' ')
                    if is_log: f.write('\n')
                    if is_log: f.write(delex_sys_norm +'\n')
                    if is_log: f.write(delex_usr_norm +'\n')
                    
                    # f.write('pd state: ' + str(dict(state_pd))+'\n')
                    # f.write('gt state: ' + str(state_gt)+'\n')
                    
                    # pprint.pprint(turn_label)
                    if is_log: f.write('\n\n\n')
                else:
                    hit += 1
    
        if is_log: f.write('join goal: %.4f ' % (hit / ctt))
        print(Counter(counter).most_common())
        return hit / ctt
        
    
    def accept(self, usr_sent):
        if re.search(r"^yes.+?\b(please|it is|perfect|should work)\b", usr_sent):
            return True
        if re.search(r"\b(great|perfect|sounds good|should work|it will do|will work)\b", usr_sent):
            return True
        if re.search(r"^(thank you|thanks).+?\balso\b", usr_sent):
            return True
        if re.search(r"\b(get|need|book).+?\b(tickets|ticket)", usr_sent):
            return True
        else:
            return False
    
    
    def strict_calculate(self, output_file, is_log = False):
        if is_log: f = open(output_file, 'a+')
        # 根据 pred 结果进行检测
        ctt, hit = 0, 0
        turn_hit, turn_ctt = 0, 0
        for dial_id, dialog in self.dialogs.items():
            state_pd = defaultdict(list)
            for turn_id, turn in enumerate(dialog):
                ctt += 1
                delex_usr_norm = turn['delex_usr_norm']
                delex_sys_norm = turn['delex_sys_norm']
                delex_sys_sp_dic = turn['delex_sys_sp_dic']
                delex_usr_sp_dic = turn['delex_usr_sp_dic']
                state_gt = turn['belief_state']
            
                turn_pred = turn['turn_pred']
                informed_names = turn['informed_names']
                inform_label = turn['turn_label']['inform_label']

                class_gt = turn['turn_label']['class_label']
                
                class_pred = turn_pred['class_label']
                refer_pred = turn_pred['refer_label']
                span_pred = turn_pred['span_label']
                
                for k, l in class_pred.items():
                    if k in ['hotel-parking', 'hotel-internet']:
                        state_pd[k] = [l]
                        continue
                    if k == 'hotel-type':
                        state_pd[k] = [l]
                        continue
                    if l == 'dontcare':
                        state_pd[k] = ['dontcare']
                    elif l == 'copy_value':
                        vv = self.wp_recover(span_pred[k])
                        if vv:
                            if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                                if re.search(r"^\d{2}:\d{2}", vv):
                                    m = re.search(r"^\d{2}:\d{2}", vv)
                                    vv = m.group()
                                state_pd[k] = [vv]

                                if '<%s>' % k in delex_usr_sp_dic:
                                    if 'after' in delex_usr_sp_dic['<%s>' % k][0] \
                                            or re.search(" after <", delex_usr_norm):
                                        if k in state_gt and state_gt[k] != vv:
                                            state_pd[k].append(state_gt[k])
                                    if re.search("\bpm\b", delex_usr_norm):
                                        state_pd[k].append(state_gt[k])
                            elif 'bookpeople' in k and vv == 'for me':
                                state_pd[k] = ['1']
                            elif '[SEP]' in vv:
                                pass
                            else:
                                state_pd[k] = [vv]
                
                    elif l == 'inform':
                        if 'area' in k and k in inform_label:
                            if not re.search(r"(east|west|south|north|downtown|centre|center|mid)",
                                             ' '.join(inform_label[k])):
                                continue
                                
                        if 'leaveat' in k or 'arriveby' in k: #mistaken annotation
                            if k in inform_label:
                                state_pd[k].extend(inform_label[k])
                        elif k in inform_label:
                            state_pd[k] = inform_label[k]
                        elif k in informed_names and informed_names[k] is not None:
                            state_pd[k] = informed_names[k]
                        else:
                            if is_log: f.write('wrong inform, no inform-able sys values\n')
                            # assert 0
                        if 'bookpeople' in k and state_pd[k] == ['for me']:
                            state_pd[k] = ['1']
                            
                    elif l == 'refer':
                        ref_name_slot = refer_pred[k].split('-')[0] + '-' + 'name'
                        try:
                            if refer_pred[k] in state_pd:
                                state_pd[k] = state_pd[refer_pred[k]]
                            elif ref_name_slot in informed_names and informed_names[ref_name_slot] is not None:
                                value = onto.db_retrive(informed_names[ref_name_slot], refer_pred[k])
                                assert isinstance(value, list), value
                                state_pd[k] = value
                            elif ref_name_slot in state_pd:
                                value = onto.db_retrive(state_pd[ref_name_slot], refer_pred[k])
                                assert isinstance(value, list), value
                                state_pd[k] = value
                            else:
                                if is_log: f.write('wrong refer, no refer-able values\n')
                                pass
                        except Exception as e:
                            pass
                    else:
                        assert 0
                    
                
                
                # if "<train-leaveat>" in delex_sys_sp_dic and "<train-arriveby>" in delex_sys_sp_dic:
                #     if state_pd["train-leaveat"] and not state_pd["train-arriveby"]:
                #         state_pd["train-leaveat"].extend(delex_sys_sp_dic["<train-leaveat>"])
                #         state_pd["train-arriveby"].extend(delex_sys_sp_dic["<train-arriveby>"])
                #     if state_pd["train-arriveby"] and not state_pd["train-leaveat"]:
                #         state_pd["train-leaveat"].extend(delex_sys_sp_dic["<train-leaveat>"])
                #         state_pd["train-arriveby"].extend(delex_sys_sp_dic["<train-arriveby>"])
                
                turn_metric = {slot: 0 for slot in ALL_SLOTS}
                
                # compare
                for k in state_gt:
                    if '|' in state_gt[k] or '>' in state_gt:
                        vs_gt = re.split(r"(?:\||>| or )", state_gt[k])
                    else:
                        vs_gt = [state_gt[k]]
                        if 'pizza hut' in state_gt[k]: vs_gt.append('pizza hut') # 特殊处理
                
                    vs_pd = state_pd[k].copy()
                
                    if k not in ['hotel-parking', 'hotel-internet', 'hotel-type']:
                        vs_pd_ = []
                        for v in vs_pd:
                            normv = onto.normalize_label(k, v)
                            if normv in ['funky']: continue
                            vs_pd_.append(normv)
                            if 'pizza hut' in normv:  # Some case
                                vs_pd_.append('pizza hut')
                        vs_pd = vs_pd_.copy()
                        
                
                    if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                            continue
                        for v_pd in vs_pd:
                            if time_in_range(vs_gt[0], v_pd, range=15,  # 有些标注不合理的地方
                                             info='%s-%d:slot %s pd: %s gt: %s' %
                                              (dial_id, turn_id, k, str(vs_pd), str(vs_gt))):
                                turn_metric[k] = 1
                    elif "name" in k or "destination" in k or "departure" in k:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                    elif k == 'hotel-parking' or k == 'hotel-internet':
                        vs_gt_ = []
                        for v in vs_gt:
                            if v == 'true':
                                vs_gt_.append('true')
                            else:
                                vs_gt_.append('false')
                        vs_pd_ = []
                        for v in vs_pd:
                            if v == 'true':
                                vs_pd_.append('true')
                            else:
                                vs_pd_.append('false')
                        if len(set(vs_gt_) & set(vs_pd_)) >= 1:
                            turn_metric[k] = 1
                    else:
                        if len(set(vs_gt) & set(vs_pd)) >= 1:
                            turn_metric[k] = 1
                
                for k in ['hotel-parking', 'hotel-internet']:
                    if re.search(r'(does it have|do they have|include|including|includes)\s+free\s+.*?(parking|internet|wifi)',
                                 delex_usr_norm):

                        if k in state_gt:
                            state_pd[k] = [state_gt[k]]
                        else:
                            state_pd[k] = []
                        turn_metric[k] = 1
                
                for k in ALL_SLOTS:
                    if k not in state_gt:
                        # turn_metric[k] = 1
                        # strict
                        if k not in state_pd or len(state_pd[k])==0:
                            turn_metric[k] = 1
                        else:
                            look_ahead_v = self.look_ahead(dialog, k, turn_id)
                            if look_ahead_v and look_ahead_v in state_pd[k]:
                                # print(dial_id, turn_id, k, look_ahead_v)
                                turn_metric[k] = 1
                
                turn_hit += sum(turn_metric.values())
                turn_ctt += len(turn_metric)
                
                if not all(list(turn_metric.values())):
                    if is_log: f.write('%s-%d\n' % (dial_id, turn_id))
                    wrong_slot = {k for k, v in turn_metric.items() if v == 0}
                    if is_log: f.write(str(wrong_slot) + '\n')
                    if is_log: f.write('pd: ')
                    for ss in wrong_slot:
                        if ss in state_pd:
                            if is_log: f.write(ss + ' ' + str(state_pd[ss]) + ' ')
                        else:
                            if is_log: f.write(str({ss: None}) + ' ')
                    if is_log: f.write('\n')
                    if is_log: f.write('gt: ')
                    for ss in wrong_slot:
                        if ss in state_gt:
                            if is_log: f.write(ss + ' ' + str(state_gt[ss]) + ' ')
                        else:
                            if is_log: f.write(str({ss: None}) + ' ')
                    if is_log: f.write('\n')
                    if is_log: f.write(delex_sys_norm + '\n')
                    if is_log: f.write(delex_usr_norm + '\n')
                    # f.write('pd state: ' + str(dict(state_pd))+'\n')
                    # f.write('gt state: ' + str(state_gt)+'\n')
                
                    # pprint.pprint(turn_label)
                    if is_log: f.write('\n\n\n')
                else:
                    hit += 1
                
                
    
        if is_log: f.write('join goal: %.4f ' % (hit / ctt))
        return round(hit / ctt, 4),  round(turn_hit/turn_ctt, 4)

    @staticmethod
    def look_ahead(dialog, slot, begin_tid):
        # 返回最近的 value
        for turn_id in range(begin_tid + 1, len(dialog)):
            # print(dialog[turn_id]['belief_state'])
            if slot in dialog[turn_id]['belief_state']:
                return dialog[turn_id]['belief_state'][slot]
        return None


if __name__ == '__main__':
    evaluator = Evaluator('cache_dial_data_test.json')
    # evaluator.self_calculate()
    
    import re, os
    output_dir = 'results_attend_mix_decay'
    files = sorted([file for file in os.listdir(output_dir) if re.match(r'pred_res\.test\.\d+?\.json', file)],
                   key=lambda x:int(x.split('.')[2]))
    
    print(files)
    # files = [file for file in files if '47211' in file]
    join_goals = []
    for file in files:
        join_goal = evaluator.add_prediction(os.path.join(output_dir, file),
                                 os.path.join(output_dir, file.replace('.json', '.log')),
                                             is_log=True)
        join_goals.append(join_goal)
        print('===', join_goal, '===')
    
    for join_goal, file in list(zip(join_goals, files)):
        print(join_goal, file)
