import json
from collections import defaultdict
import re
import logging
import pprint
import copy

import uuid

logging.basicConfig(filename='data_prepare.log', level=logging.INFO, filemode='w')

from tqdm import tqdm


# Required for mapping slot names in dialogue_acts.json file
# to proper designations.


from utils import (ALL_SLOTS, normalize_text, normalize_time,
                   time_diff, time_in_range, NUM_MAP, find_pos, val_in_text, simple_delex, UtterOp)

no_span_slots = ["hotel-type", "hotel-internet", "hotel-parking"] # these 3 slots do not hace 'dontcare'

# we need to delexicalize the dialog data of each turn， and locate the position of each span
# we can use the dialogue_acts.json to  figure out all spans

from utils import REQUEST_SLOTS, SPEC_TOKENS

# dense_format is more readable, and we also fix some annotation errors in train/val files 
with open('multiwoz2.1/dense_format/train_dials.json') as f:
    train = json.load(f)
with open('multiwoz2.1/dense_format/val_dials.json') as f:
    val = json.load(f)
with open('multiwoz2.1/dense_format/test_dials.json') as f:
    test = json.load(f)

num_map = NUM_MAP

with open('all_address.json') as f:
    all_address = json.load(f)
inv_num_map = dict(zip(num_map.values(), num_map.keys()))

from ontology import Ontology
onto = Ontology()


def detect_conflicted_span(act_sp, state):
    pos_dic = {}
    output_act_sp = []
    for act, s, v, start, end in act_sp:
        if act.split('-')[0] in ['Hospital', 'general']:
            output_act_sp.append([act, s, v, start, end])
            continue
        if act.split('-')[0] in ['Police']:
            # output_act_sp.append([act, s, v, start, end])
            continue
        if (start, end) not in pos_dic:
            pos_dic[(start, end)] = []
        pos_dic[(start, end)].append((act, s, v))
    
    metioned_domains = set([s.split('-')[0] for s in state])

    slot_li = list(state.keys())
    if len(slot_li) >= 3: men_dom = slot_li[-1].split('-')[0]
    else: men_dom = None
    mem_act, mem_s, mem_v = None, None, None
    
    for k, asv in pos_dic.items():
        asv_set = set(asv)
        if len(asv_set) > 1:
            # print(asv_set)
            # print(state)
            best_idx, best_slot, best_act, best_s, best_v = -1, None, None, None, None
            # dom_set = list(set([ss.split('-')[0] for ss in slot_li]))
            for act, s, v in asv_set:
                this_slot = (act.split('-')[0] + '-' + s).lower()
                
                if this_slot in slot_li:
                    idx = slot_li.index(this_slot)
                else:
                    idx = -1
                if act.split('-')[0].lower() == men_dom:
                    mem_act, mem_s, mem_v = act, s, v
                if idx > best_idx:
                    best_slot, best_idx = this_slot, idx
                    best_act, best_s, best_v = act, s, v
                    
            if best_idx > 0:
                # use best_s
                best_dom = best_act.split('-')[0].lower()
                if men_dom and mem_act and best_dom != 'booking' and men_dom != best_dom:
                    # print(' ----  last domain %s confilicted with %s ---'%(men_dom, best_dom))
                    output_act_sp.append([mem_act, mem_s, mem_v] + list(k))
                else:
                    output_act_sp.append([best_act, best_s, best_v] + list(k))
            else:
                if men_dom and mem_act:
                    # print(' ----  using last domain %s to fill ---'%men_dom)
                    output_act_sp.append([mem_act, mem_s, mem_v] + list(k))
                    
        elif asv[0][0].split('-')[0].lower() in metioned_domains or 'Booking' in asv[0][0]:
            output_act_sp.append(list(asv[0])+list(k))
    
    return output_act_sp

    
def find_requestable_slots(good_da_spans, sys_sent_cased, sys_sent):
    # use requestable_slot to filter 
    for _, _, _, start, end in good_da_spans:
        sys_sent_cased = sent_placeholder(sys_sent_cased, start, end)
        sys_sent = sent_placeholder(sys_sent, start, end)

    # extract ref #
    m = re.search(r'\b[A-Z0-9]{8}(\b|\.)', sys_sent_cased)
    if m:
        start, end = m.span(0)
        good_da_spans.append((None, 'ref', sys_sent[start:end], start, end))
    
    # extract train id
    m = re.search(r'\bTR\d{4}\b', sys_sent_cased)
    if m:
        start, end = m.span(0)
        good_da_spans.append((None, 'trainid', sys_sent[start:end], start, end))
    
    # extract phone
    m = re.search(r'\d{5}\s?\d{6,7}(\b|\.)', sys_sent)
    if m:
        start, end = m.span(0)
        good_da_spans.append((None, 'phone', sys_sent[start:end], start, end))
    
    # extract postcode
    m = re.search(r'\b[cbCB]{2}[A-Za-z0-9]{4,5}(\b|\.)', sys_sent_cased)
    if m:
        start, end = m.span(0)
        good_da_spans.append((None, 'postcode', sys_sent[start:end], start, end))
    
    # extract price
    
    m = re.search(r'\b(\d+\.)?\d+\s?(gbp|pounds|pound)', sys_sent)
    if m:
        start, end = m.span(0)
        good_da_spans.append((None, 'price', sys_sent[start:end], start, end))


def sent_placeholder(sent, start, end):
    sent = list(sent)
    sent[start:end] = ['*' for _ in range(end - start)]
    return ''.join(sent)


def get_good_da_span(act_sp, sent, dial_loc=None):
    last_end = 0
    good_da_spans = []  #
    last_attraction_name = ''
    for act, s, v, start, end in sorted(act_sp, key=lambda x: x[3]):
        # print('act, s, v, start, end:', act, s, v, start, end)
        v = v.lower()
        if act.split('-')[0] in ['Hotel', 'Taxi'] and s == 'type':
            continue
        if start < last_end:
            logging.info('\n%s conflict: %s-%s \n  %s' % (dial_loc, act, s, str(sorted(act_sp, key=lambda x: x[3]))))
            if act.split('-')[0] == 'Attraction' and s == 'type':
                if v in last_attraction_name:
                    good_da_spans.append((act, s, v, start, end))
            continue
        last_end = end
        if s in REQUEST_SLOTS:
            # we need to process requestable_slot first
            # compare whether  value and span is the same 
            # if not use 'value in sent' to correct
            if sent[start:end] == v:
                good_da_spans.append((act, s, v, start, end))
                sent = sent_placeholder(sent, start, end)
            else:
                m = re.search('%s' % v, sent)
                if m:
                    start, end = m.span(0)
                    good_da_spans.append((act, s, v, start, end))
                    sent = sent_placeholder(sent, start, end)
            continue
        
        if v == 'dontcare': 
            good_da_spans.append((act, s, v, start, end))
            sent = sent_placeholder(sent, start, end)
            continue
        
        if s == 'name':
            if sent[start:end] == v:
                if act.split('-')[0] == 'Booking':
                    norm_s, norm_v, _ = onto.find_name(v)
                    if norm_s is not None:
                        good_da_spans.append((act, norm_s, v, start, end))
                else:
                    good_da_spans.append((act, s, v, start, end))
                sent = sent_placeholder(sent, start, end)
                if act.split('-')[0] == 'Attraction' and s == 'name':
                    last_attraction_name = v
            else:
                m = re.search('%s' % v, sent)
                if m:
                    start, end = m.span(0)
                    assert sent[start:end] == v
                    if act.split('-')[0] == 'Booking':
                        norm_s, norm_v, _ = onto.find_name(v)
                        if norm_s is not None:
                            good_da_spans.append((act, norm_s, v, start, end))
                    else:
                        good_da_spans.append((act, s, v, start, end))
                    sent = sent_placeholder(sent, start, end)
            continue
        
        if v in num_map and sent[start:end] == num_map[v]:
            good_da_spans.append((act, s, v, start, end))
            sent = sent_placeholder(sent, start, end)
            continue
        
        if re.search(r'^same', sent[start:end]):
            good_da_spans.append((act, s, sent[start:end], start, end))
            sent = sent_placeholder(sent, start, end)
            continue
        
        if sent[start:end] == v:
            good_da_spans.append((act, s, v, start, end))
            sent = sent_placeholder(sent, start, end)
            continue
        
        else:
            if v in sent[start - 2:end + 2]:
                start_ = sent.find(v, start - 2)
                end_ = start_ + len(v)
                assert sent[start_:end_] == v
                good_da_spans.append((act, s, v, start_, end_))
                sent = sent_placeholder(sent, start, end)
                continue
            
            if v in num_map and num_map[v] in sent[start - 2:end + 2]:
                num = num_map[v]
                start_ = sent.find(v, start - 2)
                end_ = start_ + len(num)
                assert sent[start_:end_] == v
                good_da_spans.append((act, s, v, start_, end_))
                sent = sent_placeholder(sent, start, end)
                continue
            
            if v == 'centre':
                m = re.search(r'\bcenter\b', sent)
                if m:
                    start, end = m.span(0)
                    assert sent[start:end] == 'center'
                    good_da_spans.append((act, s, v, start, end))
                    sent = sent_placeholder(sent, start, end)
                    continue
            
            if v in ['north', 'west', 'east', 'south', 'cheap', 'expensive', 'moderate']:
                m = re.search(r'\b%s\b'%v, sent)
                if m:
                    start, end = m.span(0)
                    good_da_spans.append((act, s, v, start, end))
                    sent = sent_placeholder(sent, start, end)
                    continue
            # print(act, s, v, start, end)
            # print(sent[start:end],  v)
            # print(dial_loc)
            # print(sent)
            # print(act_sp)
            # print()
    return good_da_spans


def delex(sent, act_sp):
    seg_list = []
    delex_list = []
    has_NoBook = False
    if act_sp:
        last_end = 0
        last_sa = None
        last_add_s = None
        mentioned_book_doms = set([s for a, s, _, _, _ in act_sp if 'name' in s and a.split('-')[0] == 'Booking'])
        
        for a, s, v, begin, end in act_sp:
            if a and ('NoBook' in a or 'NoOffer' in a): has_NoBook = True
            if begin >= last_end:
                intermediate_str = sent[last_end:begin]
                m = re.search(r"^(s|ly|ly-priced)\b", intermediate_str)
                if m and delex_list:
                    if last_add_s!=delex_list[-1][0]:
                        delex_list[-1][1] = delex_list[-1][1]+m.group(0)
                        last_add_s = delex_list[-1][0]
                    intermediate_str = re.sub(r"^(s|ly|ly-priced)\b", "", intermediate_str).strip()
                seg_list.append(intermediate_str)

                # requestable slots
                if s in REQUEST_SLOTS:
                    seg_list.append(REQUEST_SLOTS[s])
                elif a.split('-')[0] == 'Booking':
                    if s and 'name' in s:
                        seg_list.append('<%s>' % s)
                        delex_list.append(['<%s>' % s, v])
                    elif s and 'booktime' in s:
                        seg_list.append('<restaurant-booktime>')
                        delex_list.append(['<restaurant-booktime>', v])
                    elif s and 'bookstay' in s:
                        seg_list.append('<hotel-bookstay>')
                        delex_list.append(['<hotel-bookstay>', v])
                    elif len(mentioned_book_doms) == 1:
                        book_dom = list(mentioned_book_doms)[0].split('-')[0]
                        if s and 'bookpeople' in s:
                            seg_list.append('<%s-bookpeople>'%book_dom)
                            delex_list.append(['<%s-bookpeople>'%book_dom, v])
                        if s and 'bookday' in s:
                            seg_list.append('<%s-bookday>'%book_dom)
                            delex_list.append(['<%s-bookday>'%book_dom, v])
                    else:
                        seg_list.append(sent[begin:end])
                elif s == 'entrancefee' and re.search(r'(\d+\.)?\d+\s?(gbp|pounds|pound)', v):
                    seg_list.append('<price>')
                elif s in ['entrancefee', 'choice', 'openhours',  'internet', 'duration', 'parking',
                           # 'stars','bookstay',
                           # 'pricerange'
                           # 'bookpeople', 'bookday', 'stars', 'booktime',
                           # 'day'
                           ]:
                    seg_list.append(sent[begin:end])
                elif a.split('-')[0] in ['Hotel','Taxi'] and s == 'type':
                    seg_list.append(sent[begin:end])
                elif s == 'bookpeople' and re.search(r"^\s?night", sent[end:]):
                    seg_list.append(sent[begin:end])
                elif s == 'bookstay' and re.search(r"^\s?people", sent[end:]):
                    seg_list.append(sent[begin:end])
                else:
                    domain = a.split('-')[0].lower()
                    s = domain + '-' + s
                    # if 'matter' in v or 'any' in v or 'care' in v or 'prefer' in v:
                    #     seg_list.append(sent[begin:end])
                    if v == 'dontcare':
                        seg_list.append(sent[begin:end])
                    else:
                        seg_list.append('<%s>' % s)
                        delex_list.append(['<%s>' % s, v])

                last_end = end
                if a: last_sa = '%s-%s' % (a.split('-')[0].lower(), s)
                else: last_sa = None
            else:
                if a.split('-')[0] == 'Attraction' and s == 'type':
                    delex_list.append(['<attraction-type>', v])
                else:
                    assert 0

        intermediate_str = sent[last_end:]
        m = re.search(r"^(s|ly|ly-priced)\b", intermediate_str)
        if m and delex_list and last_sa not in ['hotel-type', 'taxi-type']:
            if last_add_s != delex_list[-1][0]:
                delex_list[-1][1] = delex_list[-1][1] + m.group(0)
                last_add_s = delex_list[-1][0]
            intermediate_str = re.sub(r"^(s|ly|ly-priced)\b", "", intermediate_str).strip()
        seg_list.append(intermediate_str)

        output_s = ' '.join(' '.join(seg_list).split()).lower()
        output_s = re.sub(r" (s|ly|ly-priced)\b", r"\1", output_s)
        # process address
        for addr in all_address:
            addr = addr.lower()
            if addr in output_s:
                output_s = output_s.replace(addr, '<address>')
                break
          
        return output_s, delex_list, has_NoBook
    else:
        # process address
        for addr in all_address:
            addr = addr.lower()
            if addr in sent:
                sent = sent.replace(addr, '<address>')
                break
                
        return sent, delex_list, has_NoBook


def propcess_turn(act, sent, state, dial_loc):
    if not act or not sent:
        return '', '', {}, [], [], {}, False
    act = json.loads(act)
    act_da = act['dialog_act']
    act_sp = act['span_info']
    sent_cased = ' '.join(sent.split())
    sent = sent_cased.lower()
    if 'type:' in sent and 'number:' in sent:
        sent = sent.replace('type:', 'type : ').replace('number:', 'number : ')
    # detect conflicts
    # print('origin:', act_sp)
    act_sp_new = detect_conflicted_span(act_sp, state)
    # print('in:', act_sp_new)
    # print('sent:', sent)
    # use span to extract slots, and use requestable slots to substitute 
    good_da_spans = get_good_da_span(act_sp_new, sent, dial_loc)
    find_requestable_slots(good_da_spans, sent_cased, sent)
    good_da_spans = sorted(good_da_spans, key=lambda x: x[3])
    # print('out:', good_da_spans)
    delex_sent, delex_span_list, hasNobook = delex(sent, good_da_spans)
    delex_sp_dic = defaultdict(list)
    for item in delex_span_list: delex_sp_dic[item[0]].append(item[1])
    return sent, delex_sent, act_da, good_da_spans, delex_span_list, delex_sp_dic, hasNobook
    
    
def get_turn_label_each_slot(slot, state, last_state, usr_da_sp, sys_da_sp, usr_sent, sys_sent,
                             informed_names, dial_loc=None):
    if slot in last_state and last_state[slot] == state[slot]: status = 'inherit'
    elif slot in last_state and last_state[slot] != state[slot]: status = 'change'
    elif slot not in last_state: status = 'new'
    else: status = None
    
    if re.search(r"(name|destination|departure)", slot):
        usr_sp_vals, sys_sp_vals = [], []
        for item in usr_da_sp:
            if item[0] == '<%s>' % slot:
                _, norm_name, _ = onto.find_name(item[1])
                if norm_name: usr_sp_vals.append(norm_name)
                else: usr_sp_vals.append(item[1])
        for item in sys_da_sp:
            if item[0] == '<%s>' % slot:
                _, norm_name, _ = onto.find_name(item[1])
                if norm_name: sys_sp_vals.append(norm_name)
                else: sys_sp_vals.append(item[1])
    else:
        usr_sp_vals = [onto.normalize_label(slot, item[1]) for item in usr_da_sp if item[0] == '<%s>' % slot]
        sys_sp_vals = [onto.normalize_label(slot, item[1]) for item in sys_da_sp if item[0] == '<%s>' % slot]
    
    s_v = state[slot]
    
    if re.search(r"(name|destination|departure)", slot):
        # name
        if s_v == 'dontcare': return 'dontcare', status
        
        _, norm_s_v, variants = onto.find_name(s_v)
        if norm_s_v is None: variants = [s_v]
        
        usr_sent_v, usr_sent_pos = val_in_text(variants, usr_sent)
        sys_sent_v, sys_sent_pos = val_in_text(variants, sys_sent)

        if s_v in usr_sp_vals:
            return 'usr_span', status
            # if usr_sp_vals[0] != normalize_label(slot, s_v):
        elif usr_sent_pos: 
            return 'usr_sent', status, usr_sent_v, usr_sent_pos
        elif s_v in sys_sp_vals:
            return 'sys_span', status
        elif sys_sent_pos:
            return 'sys_sent', status, sys_sent_v, sys_sent_pos
        else:
            refs = []
            if re.search(r"(destination|departure)", slot):
                for ref_s in onto.get_ref_slots(slot):
                    if ref_s in state and state[ref_s] == s_v:
                        refs.append(ref_s)
                if refs and status in ['new', 'change']: return 'refer', status, refs
                for ref_s in onto.get_ref_slots(slot):
                    if ref_s in informed_names and informed_names[ref_s] and s_v in informed_names[ref_s]:
                        refs.append(ref_s)
                if refs and status in ['new', 'change']: return 'refer', status, refs
            else:
                if status == 'new' and informed_names[slot] and s_v in informed_names[slot]:
                    return 'sys_span_multi', status
            return 'none', status
    
    elif re.search(r"bookpeople", slot):
        # people
        if s_v == 'dontcare': return 'dontcare', status
        norm_v, variants = copy.deepcopy(onto.find_value(s_v))
        if norm_v in variants: variants.remove(norm_v)
        
        usr_sent_v, usr_sent_pos = val_in_text(variants, usr_sent)
        if 'for %s people'%s_v in usr_sent:
            usr_num_pos = find_pos('for %s people'%s_v, usr_sent)
        elif 'for %s person'%s_v in usr_sent:
            usr_num_pos = find_pos('for %s person' % s_v, usr_sent)
        else:
            usr_num_pos = find_pos('for %s' % s_v, usr_sent)
            
        usr_num_pos2 = find_pos('%s of us' % s_v, usr_sent) # more rules: "they'll be 6 of us"
        sys_sent_v, sys_sent_pos = val_in_text(variants, sys_sent)
        sys_num_pos = find_pos('for %s' % s_v, sys_sent)
        
        if s_v in usr_sp_vals:
            return 'usr_span', status
            # if usr_sp_vals[0] != normalize_label(slot, s_v):
        elif usr_num_pos>=0:
            return 'usr_sent', status, s_v, (usr_num_pos+4, usr_num_pos+4+len(s_v))
        elif usr_num_pos2>=0:
            return 'usr_sent', status, s_v, (usr_num_pos2, usr_num_pos2+len(s_v))
        elif usr_sent_pos:  
            if usr_sent_v in ['for me', 'me', 'one'] and status == 'inherit': return 'none', status
            else: return 'usr_sent', status, usr_sent_v, usr_sent_pos
        elif s_v in sys_sp_vals: 
            return 'sys_span', status
        elif sys_num_pos>=0:
            return 'sys_sent', status, s_v, (sys_num_pos+4, sys_num_pos+4+len(s_v))
        elif sys_sent_pos:
            return 'sys_sent', status, sys_sent_v, sys_sent_pos
        else:  
            # referred in dialog state
            refs = []
            for ref_s in onto.get_ref_slots(slot):
                if ref_s in state and state[ref_s] == s_v:
                    refs.append(ref_s)
            if refs and status in ['new', 'change']:
                return 'refer', status, refs
            return 'none', status
    
    elif re.search(r"star", slot):
        # print(usr_sp_vals)
        if s_v == 'dontcare':
            return 'dontcare', status
        if s_v in usr_sp_vals:
            return 'usr_span', status
            # if usr_sp_vals[0] != normalize_label(slot, s_v):
        elif re.search(r"\bstar (rating |rate )?of (only )?(\d)\b", usr_sent):
            m = re.search(r"\bstar (rating |rate )?of (only )?(\d)\b", usr_sent)
            return 'usr_sent', status, m.groups()[2], (m.start(3), m.start(3) + 1)
        elif re.search(r"\b(\d) (rating|stars|star)\b", usr_sent):
            m = re.search(r"\b(\d) (rating|stars|star)\b", usr_sent)
            return 'usr_sent', status, m.groups()[0], (m.start(1), m.start(1) + 1)
        elif s_v in sys_sp_vals:  # sys span 提到了
            return 'sys_span', status
        elif re.search(r"\bstar (rating |rate )?of (only )?\d\b", sys_sent):
            m = re.search(r"\bstar (rating |rate )?of (only )?(\d)\b", sys_sent)
            return 'sys_sent', status, m.groups()[2], (m.start(3), m.start(3) + 1)
        elif re.search(r"\b\d (rating|stars|star)\b", sys_sent):
            m = re.search(r"\b(\d) (rating|stars|star)\b", sys_sent)
            return 'sys_sent', status, m.groups()[0], (m.start(1), m.start(1) + 1)
        else:  # refer or  not mentioned
            # referred in dialog state
            # refs = []
            # for ref_s in onto.get_ref_slots(slot):
            #     if ref_s in state and onto.normalize_label(ref_s, state[ref_s]) == s_v:
            #         refs.append(ref_s)
            # if refs and status in ['new', 'change']:
            #     return 'refer', status, refs
            return 'none', status

    elif re.search(r"area|pricerange", slot):
        if s_v == 'dontcare':
            return 'dontcare', status
        
        norm_v, variants = copy.deepcopy(onto.find_value(s_v))
        usr_sent_v, usr_sent_pos = val_in_text(variants, usr_sent)
        sys_sent_v, sys_sent_pos = val_in_text(variants, sys_sent)
        if s_v in usr_sp_vals:
            return 'usr_span', status
        elif usr_sent_pos:
            return 'usr_sent', status, usr_sent_v, usr_sent_pos
        elif s_v in sys_sp_vals: 
            return 'sys_span', status
        elif sys_sent_pos:
            return 'sys_sent', status, sys_sent_v, sys_sent_pos
        else: 
            refs = []
            for ref_s in onto.get_ref_slots(slot):
                if ref_s in state and state[ref_s] == s_v: refs.append(ref_s)
            if refs and status in ['new', 'change']:
                return 'refer', status, refs
            # maybe already ref a previous entity, if ' same ' in usr_sent:
            all_name_slots = ['restaurant-name', 'hotel-name', 'attraction-name']
            name_slot = slot.split('-')[0] + '-name'
            if informed_names[name_slot]:
                db_v = onto.db_retrive(informed_names[name_slot], slot)
                if s_v in db_v and status in ['new', 'change']:
                    logging.info('%s, sys DataBase Ref '%dial_loc)
                    return 'refer_db', status, [slot]
                
            if name_slot in state:
                db_v = onto.db_retrive(state[name_slot], slot)
                if db_v == s_v and status in ['new', 'change']:
                    logging.info('%s, usr DataBase Ref ' % dial_loc)
                    return 'refer_db', status, [slot]
                
            all_name_slots.remove(name_slot)
            for n_s in all_name_slots:
                new_slot = n_s.split('-')[0] +'-'+ slot.split('-')[1]
                if informed_names[n_s] and status in ['new', 'change']:
                    db_v = onto.db_retrive(informed_names[n_s], slot)
                    if s_v in db_v:
                        logging.info('%s, sys DataBase Ref ' % dial_loc)
                        return 'refer_db', status, [new_slot]
                if n_s in state and status in ['new', 'change']:
                    db_v = onto.db_retrive(state[n_s], slot)
                    if db_v == s_v:
                        logging.info('%s, usr DataBase Ref ' % dial_loc)
                        # return 'refer', status, []
                        return 'refer_db', status, [new_slot]
            return 'none', status

    elif re.search(r"bookstay", slot):
        if s_v == 'dontcare':
            return 'dontcare', status
        if s_v in usr_sp_vals:
            return 'usr_span', status
        elif re.search(r"\b\d night[s]?\b", usr_sent):
            m = re.search(r"\b(\d) night[s]?\b", usr_sent)
            return 'usr_sent', status, m.groups()[0], (m.start(1), m.start(1)+1)
        elif s_v in sys_sp_vals:
            return 'sys_span', status
        # elif re.search(r"\b\d+ night[s]?\b", sys_sent):
        #     return 'sys_sent', status, re.search(r"\bstar (rating |rate )?of (only )?\d\b", sys_sent).span()
        # elif re.search(r"\b\d (rating|stars|star)\b", sys_sent):
        #     return 'sys_sent', status, re.search(r"\b\d (rating|stars|star)\b", sys_sent).span()
        else:
            return 'none', status
        
    elif re.search(r"bookday|train-day|food", slot):
        if s_v == 'dontcare':
            return 'dontcare', status
        
        norm_v, variants = copy.deepcopy(onto.find_value(s_v))
        usr_sent_v, usr_sent_pos = val_in_text(variants, usr_sent)
        sys_sent_v, sys_sent_pos = val_in_text(variants, sys_sent)
        
        if s_v in usr_sp_vals:
            return 'usr_span', status
        elif usr_sent_pos:
            return 'usr_sent', status, usr_sent_v, usr_sent_pos
        elif s_v in sys_sp_vals:
            return 'sys_span', status
        elif sys_sent_pos:
            return 'sys_sent', status, sys_sent_v, sys_sent_pos
        else:  # refer 或者 没提到
            refs = []
            for ref_s in onto.get_ref_slots(slot):
                if ref_s in state and state[ref_s] == s_v: refs.append(ref_s)
            if refs and status in ['new', 'change']:
                return 'refer', status, refs
            return 'none', status
    
    elif re.search(r"(leaveat|arriveby|booktime)", slot):
        # 时间类
        usr_time_pos = find_pos(s_v, usr_sent)
        sys_time_pos = find_pos(s_v, sys_sent)
        if s_v == 'dontcare':
            return 'dontcare', status
        elif usr_sp_vals: # and time_in_range(s_v, normalize_label(slot, usr_sp_vals[0])):
            return 'usr_span', status
            # if usr_sp_vals[0] != normalize_label(slot, s_v):
        elif usr_time_pos>=0:
            return 'usr_sent', status, s_v, (usr_time_pos, usr_time_pos+len(s_v))
        elif any(s_v == sys_v for sys_v in sys_sp_vals):
            return 'sys_span', status
        elif sys_time_pos>=0:
            return 'sys_sent', status, s_v, (sys_time_pos, sys_time_pos+len(s_v))
        else:
            refs = []
            if re.search(r'(leaveat|arriveby)', slot):
                for ref_s in onto.get_ref_slots(slot):
                    if ref_s in state and time_in_range(state[ref_s], s_v):
                        refs.append(ref_s)
            elif 'booktime' in slot:
                for ref_s in onto.get_ref_slots(slot):
                    if ref_s in state and state[ref_s] == s_v:
                        refs.append(ref_s)
            if refs and status in ['new', 'change']:
                return 'refer', status, refs
            return 'none', status
    
    elif slot == 'attraction-type':
        if s_v == 'dontcare': return 'dontcare', status
        norm_v, variants = copy.deepcopy(onto.find_value(s_v))
        usr_sent_v, usr_sent_pos = val_in_text(variants, usr_sent)
        sys_sent_v, sys_sent_pos = val_in_text(variants, sys_sent)
        
        if s_v in usr_sp_vals:
            return 'usr_span', status
        elif usr_sent_pos:
            return 'usr_sent', status, usr_sent_v, usr_sent_pos
        # 包含
        elif usr_sp_vals and 'attraction-name' in state and usr_sp_vals[0] in state['attraction-name'] and status in ['new', 'change']:
            return 'usr_span', status
        elif s_v in sys_sp_vals:
            return 'sys_span', status
        elif sys_sent_pos:
            return 'sys_sent', status, sys_sent_v, sys_sent_pos
        else:
            name_slot = 'attraction-name'
            if name_slot in state:
                db_v = onto.db_retrive(state[name_slot], slot)
                if db_v == s_v and status in ['new', 'change']:
                    logging.info('%s, usr DataBase Ref ' % dial_loc)
                    return 'refer_db', status, [slot]
            if informed_names[name_slot]:
                db_v = onto.db_retrive(informed_names[name_slot], slot)
                if s_v in db_v and status in ['new', 'change']:
                    logging.info('%s, sys DataBase Ref ' % dial_loc)
                    return 'refer_db', status, [slot]
            
            return 'none', status
    else:
        assert 0, 'uncovered:%s'% slot


# with open('review/test_dials_aug.json') as f:
#     aug = json.load(f)
# with open('review/test_dials_aug2.json') as f:
#     aug2 = json.load(f)
# for k in aug2:
#     aug[k] = copy.deepcopy(aug2[k])


# for k in val: aug[k] = copy.deepcopy(val[k])
for dataset, dialogs in zip(['train', 'val', 'test'], [train, val, test]):
    all_dial_data = {}
    for dial_id, dial in tqdm(dialogs.items(), desc='dataset %s'%dataset):
        
        # if not flag and dial_id != 'MUL1633.json':
        #     continue
        # elif dial_id == 'MUL1633.json':
        #     flag = True
        # if dial_id != 'MUL1202.json':
        #     continue
        
        last_state = {}
        informed_names = {'restaurant-name': None, 'hotel-name': None, 'attraction-name': None}
        this_dialog = []
        
        for turn_id, turn in enumerate(dial):
            dial_loc = '%s-%s-%s'%(dataset, dial_id, turn_id)
            state = turn['state']
            sys_act = turn['acts']['sys']
            usr_act = turn['acts']['usr']

            sys_sent, delex_sys_sent, sys_act_da, sys_act_sp, delex_sys_sp, delex_sys_sp_dic, hasNobook = \
                propcess_turn(sys_act, turn['sys_sent'], state, dial_loc)
            delex_sys_norm = normalize_text(delex_sys_sent)
            
            for k in informed_names:
                if '<%s>'%k in delex_sys_sp_dic:
                    this_names = []
                    for name in delex_sys_sp_dic['<%s>'%k]:
                        _, norm_k, _ = onto.find_name(name)
                        if norm_k: this_names.append(norm_k)
                    informed_names[k] = copy.deepcopy(this_names)

            usr_sent, delex_usr_sent, usr_act_da, usr_act_sp, delex_usr_sp, delex_usr_sp_dic, _ = \
                propcess_turn(usr_act, turn['usr_sent'], state, dial_loc)
            delex_usr_norm = normalize_text(delex_usr_sent)
            
            for s in ['restaurant', 'hotel', 'attraction']:
                if '%s-name'%s in state and '%s-bookname'%s in state:
                    state['%s-name' % s] = state['%s-bookname' % s]
                    del state['%s-bookname'%s]
                elif '%s-bookname'%s in state and '%s-name'%s not in state:
                    state['%s-name' % s] = state['%s-bookname' % s]
                    del state['%s-bookname' % s]
                    
            # normalization, choose the last one
            norm_state = {}
            for slot in state:
                if '|' in state[slot] or '>' in state[slot] or ' or ' in state[slot]:
                    s_v = re.split(r"(?:\||>| or )", state[slot])[-1]
                    s_v = onto.normalize_label(slot, s_v)
                else:
                    s_v = onto.normalize_label(slot, state[slot])
                norm_state[slot] = s_v
            state = copy.deepcopy(norm_state)

            # print(dial_loc)
            # print(state)
            
            turn_label = {
                'inform_label':{}, 'inform_slot_axu':{}, 'refer_label':{},
                'dial_state_aux':{s:1 for s in last_state}, 'class_label':{}}

            new_sys_sp, new_usr_sp = [], []   # additional act
            inform_slot_axu = {}
            
            for slot in list(state.keys())[::-1]:
                if 'hospital' in slot or 'police' in slot or 'entrancefee' in slot or 'bus' in slot:
                    continue
                if slot in ['hotel-type', 'hotel-internet', 'hotel-parking']:
                    if slot in last_state and state[slot] != last_state[slot]:
                        turn_label['class_label'][slot] = state[slot]
                    elif slot not in last_state:
                        turn_label['class_label'][slot] = state[slot]
                    # else:
                    #     if 'Hotel-Inform' in usr_act_da:
                    #         for item in usr_act_da['Hotel-Inform']:
                    #             if item[0] == slot.split('-')[1]:
                    #                 turn_label['class_label'][slot] = state[slot]
                    #                 break
                    continue
                
                outputs = get_turn_label_each_slot(
                    slot, state, last_state, delex_usr_sp, delex_sys_sp, delex_usr_norm, delex_sys_norm,
                    informed_names=informed_names,
                    dial_loc=dial_loc)
                
                # # considering  'inform act'
                # if 'star' in slot:
                #     print('outputs:', outputs)
                
                if outputs[0] == 'dontcare' and outputs[1] != 'inherit':
                    turn_label['class_label'][slot] = 'dontcare'
                elif outputs[0] == 'usr_span':
                    turn_label['class_label'][slot] = 'copy_value'
                elif outputs[0] == 'usr_sent':
                    turn_label['class_label'][slot] = 'copy_value'
                    delex_usr_sp_dic['<%s>'%slot].append(outputs[2])
                    new_usr_sp.append(('<%s>'%slot, outputs[2], outputs[3]))
                elif outputs[0] in ['sys_span', 'sys_span_multi']:
                    turn_label['class_label'][slot] = 'inform'
                    if outputs[1]: inform_slot_axu[slot] = 1
                elif outputs[0] == 'sys_sent':
                    turn_label['class_label'][slot] = 'inform'
                    delex_sys_sp_dic['<%s>' % slot].append(outputs[2])
                    new_sys_sp.append(('<%s>'%slot, outputs[2], outputs[3]))
                    if outputs[1]: inform_slot_axu[slot] = 1
                elif outputs[0] in ['refer', 'refer_db']:
                    turn_label['class_label'][slot] = 'refer'
                    turn_label['refer_label'][slot] = outputs[2][0]

            turn_label['inform_label'] = {}
            for s, v in delex_sys_sp_dic.items():
                if "leaveat" in s or "arriveby" in s or 'restaurant-booktime' in s:
                    turn_label['inform_label'][s.strip('<').strip('>')] = [onto.normalize_label(s, v[0])]
                else: turn_label['inform_label'][s.strip('<').strip('>')] = v

            # for s in delex_sys_sp_dic:
            #     if 'name' not in s: turn_label['inform_slot_axu'][s.strip('<').strip('>')] = 1
            # turn_label['inform_slot_axu'] = copy.deepcopy(inform_slot_axu)
            
            turn_label['inform_slot_axu'] = {s:1 for s in delex_sys_sp_dic}


            if new_usr_sp: delex_usr_norm = simple_delex(delex_usr_norm,  new_usr_sp)
            if new_sys_sp: delex_sys_norm = simple_delex(delex_sys_norm,  new_sys_sp)

            
            # class_label_ = {}
            # for k, v in turn_label['class_label'].items():
            #     if v == 'copy_value' and k not in delex_usr_norm:
            #         continue
            #     else:
            #         class_label_[k] = v
            # turn_label['class_label'] = copy.deepcopy(class_label_)
            
            # print(dial_id, turn_id)
            # print(delex_sys_norm)
            # print(delex_usr_norm)
            # print(delex_sys_sp_dic)
            # print(delex_usr_sp_dic)
            # pprint.pprint(turn_label)
            # print('informed_names:', informed_names)
            # print('state:', state)
            # print(input())
            
            this_dialog.append(
                {
                    'dial_loc': dial_loc,
                    'delex_sys_norm': delex_sys_norm,
                    'delex_usr_norm': delex_usr_norm,
                    'turn_label': copy.deepcopy(turn_label),
                    'delex_sys_sp_dic': copy.deepcopy(delex_sys_sp_dic),
                    'delex_usr_sp_dic': copy.deepcopy(delex_usr_sp_dic),
                    'belief_state': copy.deepcopy(state),
                    'informed_names': copy.deepcopy(informed_names)
                }
            )
            
            last_state = copy.deepcopy(state)

        # dial_id = str(uuid.uuid4())[:6]
        all_dial_data[dial_id] = copy.deepcopy(this_dialog)
    
    with open('cache_dial_data_%s.json'%dataset, 'w') as f:
        json.dump(all_dial_data, f)








