import json
import pprint
from collections import defaultdict
from utils import time_in_range, normalize_text
from utils import ALL_SLOTS
import re

with open('cache_dial_data_train.json') as f:
    dialogs = json.load(f)

last_dial_id = None

def look_ahead(dialog, slot, begin_tid):
    # return to the lasted value
    for turn_id in range(begin_tid+1, len(dialog)):
        if slot in dialog[turn_id]:
            return dialog[turn_id][slot]
    return None

from ontology import Ontology
onto = Ontology()

# TODO consider whether time slot needs accumulate

l_dial_id = None
ctt, hit = 0, 0
for dial_id, dialog in dialogs.items():
    labels = []
    state_pd = defaultdict(lambda :{'usr':[], 'sys':[]})
    informed_names = {
        'restaurant-name': None,
        'hotel-name': None,
        'attraction-name': None,
    }
    for turn_id, turn in enumerate(dialog):
        ctt += 1
        turn_label = turn['turn_label']
        delex_usr_norm = turn['delex_usr_norm']
        delex_sys_norm = turn['delex_sys_norm']
        delex_sys_sp_dic = turn['delex_sys_sp_dic']
        delex_usr_sp_dic = turn['delex_usr_sp_dic']
        state_gt = turn['belief_state']

        class_label = turn_label['class_label']
        refer_label = turn_label['refer_label']
        inform_label = turn_label['inform_label']

        for k in informed_names:
            if '<%s>' % k in delex_sys_sp_dic:
                this_names = []
                for name in delex_sys_sp_dic['<%s>' % k]:
                    _, norm_k, _ = onto.find_name(name)
                    this_names.append(norm_k)
                informed_names[k] = this_names

        for k, l in class_label.items():
            if k in ['hotel-parking', 'hotel-internet']:
                state_pd[k]['usr'] = [l]
                continue
            if k == 'hotel-type':
                state_pd[k]['usr'] = [l]
                continue
            if l == 'dontcare':
                state_pd[k]['usr'] = ['dontcare']
            elif l == 'copy_value':
                state_pd[k]['usr'] = delex_usr_sp_dic['<%s>'%k]
            elif l == 'inform':
                if k in inform_label:
                    state_pd[k]['sys'] = inform_label[k]
                elif k in informed_names and informed_names[k] is not None:
                    state_pd[k]['sys'] = informed_names[k]
                else:
                    assert 0
            elif l == 'refer':
                ref_name_slot = refer_label[k].split('-')[0]+'-'+'name'
                try:
                    if refer_label[k] in state_pd:
                        state_pd[k]['usr'] = state_pd[refer_label[k]]['usr'] + state_pd[refer_label[k]]['sys']
                    elif ref_name_slot in informed_names and informed_names[ref_name_slot] is not None:
                        value = onto.db_retrive(informed_names[ref_name_slot], refer_label[k])
                        assert isinstance(value, list), value
                        state_pd[k]['usr'] = value
                    elif ref_name_slot in state_pd:
                        value = onto.db_retrive(state_pd[ref_name_slot]['usr'], refer_label[k])
                        assert isinstance(value, list), value
                        state_pd[k]['usr'] = value
                    else:
                        pass
                except Exception as e:
                    pass
                    # print(e)
                    # print(state_pd)
                    # pprint.pprint(turn)
                    # print(informed_names)
                    # print()
                    # print(input())
                    # pprint.pprint(turn)
                    # assert 0
                    
            else:
                pass
                # print(l)
                # assert 0

        turn_metric = {slot:0 for slot in ALL_SLOTS}
        
        # compare
        flag = False
        for k in state_gt:
            if '|' in state_gt[k] or '>' in state_gt:
                vs_gt = re.split(r"(?:\||>| or )", state_gt[k])
            else:
                vs_gt = [state_gt[k]]
                
            vs_pd = []
            if state_pd[k]['usr']: vs_pd += state_pd[k]['usr']
            if state_pd[k]['sys']: vs_pd += state_pd[k]['sys']
            
            if k not in ['hotel-parking', 'hotel-internet', 'hotel-type']:
                vs_pd = [onto.normalize_label(k, v) for v in vs_pd]
            
            if "leaveat" in k or "arriveby" in k or 'booktime' in k:
                if len(set(vs_gt) & set(vs_pd)) >= 1:
                    turn_metric[k] = 1
                    continue
                for v_pd in vs_pd:
                    if time_in_range(vs_gt[0], v_pd): turn_metric[k] = 1
            elif "name" in k or "destination" in k or "departure" in k:
                if len(set(vs_gt) & set(vs_pd)) >= 1:
                    turn_metric[k] = 1
            elif k == 'hotel-parking' or k == 'hotel-internet':
                vs_gt_ = []
                for v in vs_gt:
                    if v == 'yes':  vs_gt_.append('yes')
                    else: vs_gt_.append('no')
                vs_pd_ = []
                for v in vs_pd:
                    if v == 'yes': vs_pd_.append('yes')
                    else: vs_pd_.append('no')
                if len(set(vs_gt_) & set(vs_pd_)) >= 1:
                    turn_metric[k] = 1
            else:
                if len(set(vs_gt) & set(vs_pd)) >= 1:
                    turn_metric[k] = 1
                    
            # mistake annotation
            if ("destination" in k or "departure" in k) and vs_gt[0] == 'cambridge' and not vs_pd:
                turn_metric[k] = 1
            if "bookpeople" in k and vs_gt[0] == '1' and not vs_pd:
                turn_metric[k] = 1
        
        for k in ALL_SLOTS:
            if k not in state_gt:
                turn_metric[k] = 1
                # # strict
                # if k not in state_pd:
                #     turn_metric[k] = 1
                # look_ahead_v = look_ahead(dialog, k, turn_id)
                # if look_ahead_v in state_pd[k]['usr'] or look_ahead_v in state_pd[k]['sys']:
                #     turn_metric[k] = 1
                
        
        if not all(list(turn_metric.values())):
            # if l_dial_id == dial_id:
            #     continue
            print({k for k, v in turn_metric.items() if v == 0})
            print(dial_id, turn_id)
            print(delex_sys_norm)
            print(delex_usr_norm)
            pprint.pprint(state_pd)
            pprint.pprint(state_gt)
            # pprint.pprint(turn_label)
            print()
            l_dial_id = dial_id
        else:
            hit += 1
            
print('join goal:', hit / ctt)
        
        
        
                
                
        
                
        
        
    
    
    
    
    
    
    
    
    
        
        



