import json
import re
from fuzzywuzzy.fuzz import WRatio
from utils import REFER_CLUSTER, normalize_time, NUM_MAP
import logging
import random
import os

logger = logging.getLogger(__name__)


class Ontology:
    # give name
    def __init__(self):
        
        with open('ontology.json') as f:
            self.onto = json.load(f) # name 和 food
        self.onto_inv = {}
        for d, dic in self.onto.items():
            for norm_v, other_vs in dic.items():
                self.onto_inv[norm_v] = (d, norm_v)
                for v in other_vs:
                    self.onto_inv[v] = (d, norm_v)

        with open('label_map.json') as f:
            self.label_map = json.load(f) # 其他同义词
        
        self.label_map_inv = {}
        for norm_v, other_vs in self.label_map.items():
            self.label_map_inv[norm_v] = (norm_v, other_vs)
            for v in other_vs: self.label_map_inv[v] = (norm_v, other_vs)
        
        
        self.db = {}
        with open('multiwoz2.1/origin/attraction_db.json') as f:
            attraction_db = json.load(f)
        self.db['attraction'] = {item['name']:item for item in attraction_db}
        with open('multiwoz2.1/origin/hotel_db.json') as f:
            hotel_db = json.load(f)
        self.db['hotel'] = {item['name']: item for item in hotel_db}
        with open('multiwoz2.1/origin/restaurant_db.json') as f:
            restaurant_db = json.load(f)
        self.db['restaurant'] = {item['name']: item for item in restaurant_db}
        
        
        with open('names_from_web.json') as f:
            self.names_web = json.load(f)

        self.prices = []
        for p in ['cheap', 'moderate', 'expensive']:
            self.prices.append(p)
            self.prices.extend(self.label_map[p])

        self.areas = []
        for p in ['centre', 'north', 'east', 'west', 'south']:
            self.areas.append(p)
            self.areas.extend(self.label_map[p])
        
        with open('all_slot_values.json') as f:
            self.all_slot_values = json.load(f)
            
            food = self.all_slot_values['restaurant-food']
            self.food = [n for n in food if 'or' not in n]
    
    @staticmethod
    def get_ref_slots(s):
        for clu in REFER_CLUSTER:
            if s in clu:
                output_clu = clu.copy()
                output_clu.remove(s)
                if 'train-' in s:
                    return [i for i in output_clu if 'train' not in i and 'taxi' not in i]
                elif 'taxi-' in s:
                    return [i for i in output_clu if 'train' not in i and 'taxi' not in i]
                else:
                    return output_clu
        return []
    
    def find_name(self, name):
        name = name.strip().lower()
        # 给定某 name, 找出 domain, norm_v, candidate_v
        if name in self.onto_inv:
            domain, norm_v = self.onto_inv[name]
            return domain, norm_v, self.onto[domain][norm_v]
        else:
            # 针对 火车站名
            norm_v, cands = self.find_value(name)
            if norm_v:
                return None, norm_v, cands
            
            best_score, best_v =0,  None
            for possible_v in self.onto_inv:
                score = self.score(possible_v, name)
                if score>best_score:
                    best_score, best_v = score, possible_v

            if best_v is not None and best_score >= 0.9:
                domain, norm_v = self.onto_inv[best_v]
                logger.info('similar: %s | %s |%f' % (name, best_v, best_score))
                return domain, norm_v, self.onto[domain][norm_v]
            
            return None, None, None

    def find_value(self, value):
        value = value.strip()
        if value in self.label_map_inv:
            norm_v, cands = self.label_map_inv[value]
            return norm_v, cands
        else:
            return None, [value, value+'s', value.strip('s')]
            
    def compare_values(self, v1, v2):
        if v1 in self.label_map_inv and v2 in self.label_map_inv:
            if self.label_map_inv[v1][0] == self.label_map_inv[v2][0]:
                return True
        if self.score(v1, v2) > 0.9:
            logger.info('fuzzy match %s %s'%(v1, v2))
            return True
        return False

    def score(self, v1, v2):
        score = WRatio(v1, v2) / 100
        v1_ = self.core(v1)
        v2_ = self.core(v2)
        # print(v1_, v2_)
        score_ = WRatio(v1_, v2_) / 100
        return max(score, score_)
    
    def core(self, v):
        v = re.sub(r'^the', '', v)
        v = re.sub(r'(restaurant bar|restaurant|hotel|guesthouse|guest house|house|college|bar|church|food)$', '', v)
        v = re.sub(r'\'', '', v)
        return v.strip()

    def normalize_label(self, slot, value_label):
        value_label = value_label.strip().strip('.').strip(',').strip('?')
        
        # Normalization of time slots
        if "leaveat" in slot or "arriveby" in slot or slot == 'restaurant-booktime':
            value_label = re.sub(r'^(after|before|by|at|afer)\b', '', value_label).strip().strip('.').strip(',')
            value_label = re.sub(r'^(\d{2})[./]?(\d{2})$', r'\1:\2', value_label)
            value_label = re.sub(r'^(\d)\.(\d{2})$', r'0\1:\2', value_label)
            value_label = re.sub(r'^(\d)$', r'0\1:00', value_label)
            value_label = re.sub(r'^(\d{2})$', r'\1:00', value_label)
            if 'noon' in value_label or 'lunch time' in value_label: value_label = '12:00'
            return normalize_time(value_label)
    
        # Normalization of name slots
        if "name" in slot or "destination" in slot or "departure" in slot:
            value_label = re.sub("guesthouse", "guest house", value_label)
            _, norm_name, _ = self.find_name(value_label)
            if norm_name: return norm_name
            else: return value_label
    
        # Map to boolean slots
        # internet/parking   no/dontcare/none 其实属于一类 后处理进行
        if slot == 'hotel-parking' or slot == 'hotel-internet':
            if value_label == 'yes' or value_label == 'free':
                return "true"
            if value_label == "no":
                return "false"
            if value_label == "dontcare":
                return "dontcare"
    
        if slot == 'hotel-type':
            if "hotel" in value_label:
                return "true"
            if value_label in ["guest house", "guesthouse"]:
                return "false"
            if value_label == "dontcare":
                return "dontcare"
        
        if 'area' in slot:
            if re.search(r"(center|centre)", value_label):
                return "centre"
            elif "south" in value_label:
                return "south"
            elif "north" in value_label:
                return "north"
            elif "west" in value_label:
                return "west"
            elif "east" in value_label:
                return "east"
        
        #  in label maps
        if value_label in self.label_map_inv:
            norm_value, _ = self.label_map_inv[value_label]
            return norm_value
    
        return value_label
    
    def db_retrive(self, name, search_slot):
        if isinstance(name, str):
            return self._db_retrive(name, search_slot)
        if isinstance(name, list):
            returns = []
            for n in name:
                returns.append(self._db_retrive(n, search_slot))
            return returns
        else:
            return []
            
    def _db_retrive(self, name, search_slot):
        dom, norm_v, _ = self.find_name(name)
        if norm_v in ['nil', 'barbakan', 'wise buddha']: return None
        if norm_v and dom:
            dom = dom.split('-')[0]
            info = self.db[dom][norm_v]
            key = search_slot.split('-')[1]  # pricerange / area
            return info[key]
        else:
            return None
    
    def db_search(self, domain, slot_value):
        items, names = [], []
        for item in self.db[domain].values():
            if all([item[s] == v for s, v in slot_value.items()]):
                items.append(item)
                names.append(item['name'])
        return items, names
            
    def recommend_value(self, slot, value):
        if 'name' in slot:
            _, norm_v, cands = self.find_name(value)
            if norm_v is None:
                if slot == 'restaurant-name': return random.choice(list(self.db['restaurant'].keys()))
                if slot == 'hotel-name': return random.choice(list(self.db['hotel'].keys()))
                if slot == 'attraction-name': return random.choice(list(self.db['attraction'].keys()))
                return value
            if len(cands) < 2:
                if 'hotel-name' in slot:
                    hotel_type = self.db_retrive(norm_v, 'hotel-type')
                    _, cands = self.db_search('hotel', {'type': hotel_type})
                    if hotel_type == 'guesthouse': cands += self.names_web['guesthouse_names']
                    if hotel_type == 'hotel': cands += self.names_web['hotel_names']
                    return random.choice(cands)
                
                if 'restaurant-name' in slot:
                    restaurant_food = self.db_retrive(norm_v, 'restaurant-food')
                    _, cands = self.db_search('restaurant', {'food': restaurant_food})
                    cands += self.names_web['restaurant_names']
                    return random.choice(cands)
                
                if 'attraction-name' in slot:
                    attraction_type = self.db_retrive(norm_v, 'attraction-type')
                    _, cands = self.db_search('attraction', {'type': attraction_type})
                    return random.choice(cands)
                return value
            else:
                return random.choice(cands)

        elif "leaveat" in slot or "arriveby" in slot or slot == 'restaurant-booktime':
            hours = random.randint(0,23)
            mins = random.randint(0,59)
            return "%02d:%02d"%(hours, mins)
        
        elif 'bookpeople' in slot or 'bookstay' in slot or 'star' in slot:
            NUM_MAP_ = NUM_MAP.copy()
            NUM_MAP_['1'] = 'one'
            NUMS = list(NUM_MAP_.keys()) + list(NUM_MAP_.values())
            return random.choice(NUMS)
            
        else:
            norm_v, cands = self.find_value(value)
            if norm_v is None or len(cands) < 2:
                return value
            else:
                return random.choice(cands)

    def recommend_value_different(self, slot):
        if 'restaurant-name' in slot:
            return random.choice(list(self.db['restaurant'].keys()))
        elif 'hotel-name' in slot:
            return random.choice(list(self.db['hotel'].keys()))
        elif 'attraction-name' in slot:
            return random.choice(list(self.db['attraction'].keys()))
    
        elif "leaveat" in slot or "arriveby" in slot or 'restaurant-booktime' in slot:
            hours = random.randint(0, 23)
            mins = random.randint(0, 59)
            return "%02d:%02d" % (hours, mins)
    
        elif 'bookpeople' in slot or 'bookstay' in slot or 'star' in slot:
            NUM_MAP_ = NUM_MAP.copy()
            del NUM_MAP_['10']
            del NUM_MAP_['11']
            del NUM_MAP_['12']
            NUM_MAP_['1'] = 'one'
            NUMS = list(NUM_MAP_.keys()) + list(NUM_MAP_.values())
            return random.choice(NUMS)
        elif 'departure' in slot or 'destination' in slot:
            stations = ['birmingham new street', 'bishops stortford', 'broxbourne','cambridge','ely',
                        'kings lynn', 'leicester', 'london kings cross', 'london liverpool',
                        'norwich', 'peterborough', 'stansted airport', 'stevenage']
            return random.choice(stations)
        elif 'pricerange' in slot:
            return random.choice(self.prices)
        elif 'area' in slot:
            return random.choice(self.areas)
        elif 'attraction-type' in slot:
            types = ['architecture', 'cinema', 'college', 'concert hall', 'entertainment', 'museum',
                     'multiple sports', 'night club', 'park', 'swimming pool', 'theatre']
            return random.choice(types)
        elif 'food' in slot:
            return random.choice(self.food)
        elif 'bookday' in slot or 'train-day' in slot:
            days = ['monday', 'tuesday', 'thursday', 'wednesday', 'friday', 'saturday', 'sunday']
            return random.choice(days)
        else:
            return None
            

if __name__ == '__main__':
    onto = Ontology()
    print(onto.find_name('cambridge museum of technology'))
    

#     with open('all_slot_values.json') as f:
#         all_slot_values = json.load(f)
#     print(onto.compare_values('cinemas', 'concerthalls'))
    # print(onto.get_variant('1'))
    
    # with open('multiwoz2.1/dense_format/train_dials.json') as f:
    #     train = json.load(f)
    # with open('multiwoz2.1/dense_format/val_dials.json') as f:
    #     val = json.load(f)
    # with open('multiwoz2.1/dense_format/test_dials.json') as f:
    #     test = json.load(f)
    #
    #
    # ctt = 0
    # num_map = {
    #     '1':'one',
    #     '2':'two',
    #     '0':'zero',
    #     '3':'three',
    #     '4':'four',
    #     '5':'five',
    #     '6':'six',
    #     '7':'seven',
    #     '8':'eight',
    #     '9':'nine'
    # }
    #
    #
    # slots_set = set()
    # inv_map = dict(zip(num_map.values(), num_map.keys()))
    # for dataset, dialogs in zip(['train', 'val', 'test'], [train, val, test]):
    #     for dial_id, dial in dialogs.items():
    #         for turn_id, turn in enumerate(dial):
    #             sys_act = turn['acts']['sys']
    #             sys_sent = turn['sys_sent']
    #             if sys_act:
    #                 sys_act = json.loads(sys_act)
    #                 sys_act_da = sys_act['dialog_act']
    #                 sys_act_sp = sys_act['span_info']
    #                 sys_sent = ' '.join(sys_sent.split()).lower()
    #                 if 'type:' in sys_sent and 'number:' in sys_sent:
    #                     sys_sent = sys_sent.replace('type:', 'type : ').replace('number:', 'number : ')
    #
    #
    #                 # sys_sent = re.sub(r'\b(\d\d);(\d\d)\b', r'\1:\2', sys_sent)
    #                 # sys_sent = re.sub(r'\b(\d\d):!(\d)\b', r'\1:1\2', sys_sent)
    #
    #                 # for act, params in sys_act_da.items():
    #                 #     for s, v in params:
    #                 #         if s == 'food' and act.split('-')[1] in ['Inform', 'NoOffer', 'Recommend', 'Select', 'Book', 'NoBook']\
    #                 #                 and act.split('-')[0] not in ['Hospital', 'Police', 'general']:
    #                 #             domain, norm_v, _ = onto.find(v)
    #                 #             if norm_v is None:
    #                 #                 print('~',dataset, dial_id, turn_id, act, s, v)
    #
    #                 for act, s, v, start, end in sys_act_sp:
    #                     if act.split('-')[0] not in ['Hospital', 'Police', 'general']:
    #                         v = v.lower()
    #                         if v == 'dontcare': continue
    #                         slots_set.add(s)
    #                         # if dataset == 'train' and dial_id == 'SNG02067.json' and turn_id == 1:
    #                         #     print(s, v, '|', sys_sent[start:end])
    #
    #                         if sys_sent[start:end] == v:
    #                             continue
    #                         if ' ' +v+ ' ' in sys_sent:
    #                             continue
    #                         if ' ' +v in sys_sent:
    #                             continue
    #
    #                         if v in num_map:
    #                             continue
    #                         if 'same' in sys_sent[start:end]:
    #                             continue
    #
    #                         if v in sys_sent:
    #                             continue
    #
    #                         if re.search(r'\b(\d+\.)?\d+\s?(GBP|gbp|pounds|pound)\b', sys_sent):
    #                             continue
    #
    #                                 # and  and 'same' not in \
    #                                 # and sys_sent[start:end].lower()!="":
    #                         print('~', dataset, dial_id, turn_id, act, s, v,  '|', sys_sent[start:end])
    #
    #
    #
    #             usr_act = json.loads(turn['acts']['usr'])
    #             usr_act_da = usr_act['dialog_act']
    #             usr_act_sp = usr_act['span_info']
    #             usr_sent = turn['usr_sent']
    #             usr_sent = ' '.join(usr_sent.split()).lower()
    #
    #             # for act, params in usr_act_da.items():
    #             #     for s, v in params:
    #             #         if v == 'dontcare':continue
    #             #         if s == 'food' and act.split('-')[1] in ['Inform', 'NoOffer', 'Recommend', 'Select', 'Book',
    #             #                                                  'NoBook'] \
    #             #                 and act.split('-')[0] not in ['Hospital', 'Police', 'general']:
    #             #             domain, norm_v, _ = onto.find(v)
    #             #             if norm_v is None:
    #             #                 print('~',dataset, dial_id, turn_id, act, s, v)
    #
    #             for act, s, v, start, end in usr_act_sp:
    #                 if act.split('-')[0] not in ['Hospital', 'Police', 'general']:
    #                     slots_set.add(s)
    #                     v = v.lower()
    #                     if v == 'dontcare': continue
    #                     if usr_sent[start:end] == v:
    #                         continue
    #                     if ' ' + v + ' ' in usr_sent:
    #                         continue
    #                     if ' ' + v in usr_sent:
    #                         continue
    #
    #                     if v in num_map:
    #                         continue
    #                     if 'same' in usr_sent[start:end]:
    #                         continue
    #
    #                     if v in usr_sent:
    #                         continue
    #
    #                     if v == 'centre' and ' center' in usr_sent:
    #                         continue
    #
    #                     print('~', dataset, dial_id, turn_id, act, s, v, '|', usr_sent[start:end])
    #                     ctt += 1
    #
    #
    #
    #             # 利用 span 和 in 可以刷出 position
    #             # 不是数字、same
    #             # fee  450.6 GBP
    #
    #
    # print(ctt)
    # print(slots_set)
    
    