import logging
import six
import numpy as np
import json
import re
import os
import pprint
import random
import itertools

# import nlpaug.augmenter.char as nac
# import nlpaug.augmenter.word as naw

logger = logging.getLogger(__name__)

REQUEST_SLOTS = {
    'ref': '<reference>',
    'address': '<address>',
    'trainid': '<trainid>',
    'phone': '<phone>',
    'postcode': '<postcode>',
    'price': '<price>'
}
ALL_SLOTS = [
    "taxi-leaveat",
    "taxi-destination",
    "taxi-departure",
    "taxi-arriveby",
    "restaurant-bookpeople",
    "restaurant-bookday",
    "restaurant-booktime",
    "restaurant-food",
    "restaurant-pricerange",
    "restaurant-name",
    "restaurant-area",
    "hotel-bookpeople",
    "hotel-bookday",
    "hotel-bookstay",
    "hotel-name",
    "hotel-area",
    "hotel-parking",
    "hotel-pricerange",
    "hotel-stars",
    "hotel-internet",
    "hotel-type",
    "attraction-type",
    "attraction-name",
    "attraction-area",
    "train-bookpeople",
    "train-leaveat",
    "train-destination",
    "train-day",
    "train-arriveby",
    "train-departure"
]

SPEC_TOKENS = ['<hospital-name>', '<police-name>',  '<hospital-department>', '<bookpeople>', '<bookday>'] + \
    ['<%s>'%s for s in ALL_SLOTS]

CLASS_TYPES = [
            "none",
            "dontcare",
            "copy_value",
            "true",
            "false",
            "refer",
            "inform"
        ]

REFER_CLUSTER = [
    [
        "taxi-leaveat",
        "taxi-arriveby",
        "restaurant-booktime",
        "train-arriveby",
        "train-leaveat",
    ],
    [
        "taxi-destination",
        "taxi-departure",
        "restaurant-name",
        "hotel-name",
        "attraction-name",
        "train-destination",
        "train-departure"
    ],
    [
        "restaurant-bookpeople",
        "hotel-bookpeople",
        "train-bookpeople",
    ],
    [
        "restaurant-bookday",
        "hotel-bookday",
        "train-day",
    ],
    [
        "restaurant-area",
        "attraction-area",
        "hotel-area",
    ],
    [
        "restaurant-pricerange",
        "hotel-pricerange",
    ],
]
NUM_MAP = {
    '1': 'one',
    '2': 'two',
    '0': 'zero',
    '3': 'three',
    '4': 'four',
    '5': 'five',
    '6': 'six',
    '7': 'seven',
    '8': 'eight',
    '9': 'nine',
    '10': 'ten',
    '11': 'eleven',
    '12': 'twelve'
}
num_map = NUM_MAP
inv_num_map = dict(zip(num_map.values(), num_map.keys()))


def normalize_text(text):
    text = normalize_time(text)
    text = re.sub(r'\bliek\b', 'like', text)
    text = re.sub(r'\beatery\b', 'restaurant', text)
    text = re.sub(r'\bbetweentwo\b', 'between two', text)
    text = re.sub(r'\bnightsstarting\b', 'nights starting', text)
    text = re.sub(r'\b(wensday|wednedady|wednes|wendesday)\b', 'wednesday', text)
    text = re.sub(r'\b(teusday|tuestday)\b', 'tuesday', text)
    text = re.sub(r'\b(saurday|satruday)\b', 'saturday', text)
    text = re.sub(r'\bsat\b', 'saturday', text)
    text = re.sub(r'\b(thurtsday|thurs)\b', 'thursday', text)
    text = re.sub(r'\bwast part\b', 'west part', text)
    text = re.sub(r'\bweat\b', 'west', text)
    text = re.sub(r'\bbirminggam\b', 'birmingham', text)
    text = re.sub(r'\bsteveage\b', 'stevenage', text)
    text = re.sub(r'\bstrretm\b', 'street', text)
    text = re.sub(r'\b(petersborough|pererborough)\b', 'peterborough', text)
    text = re.sub(r'\bstrevenage\b', 'stevenage', text)
    text = re.sub(r'\bbishops storthford\b', 'bishops storthford', text)
    text = re.sub(r'\bbringham new street\b', 'birmingham new street', text)
    text = re.sub(r'\b(cambride|cambidge)\b', 'cambridge', text)
    text = re.sub(r'\bmdoerate\b', 'moderate', text)
    text = re.sub(r'^i s it\b', 'is it', text)
    text = re.sub(r'\bat long as\b', 'as long as', text)
    text = re.sub(r'\bweish\b', 'welsh', text)
    text = re.sub(r'\bafer\b', 'after', text)
    text = re.sub(r'\babut\b', 'about', text)
    text = re.sub(r'\batke\b', 'take', text)
    text = re.sub(r'\bbook a tax\b', 'book a taxi', text)
    text = re.sub(r'\btaht\b', 'that', text)
    text = re.sub(r'\bpricerange(?!>)', 'price range', text)
    text = re.sub(r'\bpostcode(?!>)', 'post code', text)
    text = re.sub(r'\bbook fro\b', 'book for', text)
    text = re.sub(r'\bswimmingpool', 'swimming pool', text)
    text = re.sub(r"\barchaelogy\b", "archaeology", text)  # Systematic typo
    text = re.sub(r"\bguesthouse", "guest house", text)  # Normalization
    text = re.sub(r'\bnightclub', 'night club', text)
    text = re.sub(r'\bconcerthall', 'concert hall', text)
    text = re.sub(r'\bref\b', 'reference', text)
    text = re.sub(r'\bconf #', 'confirmation number', text)
    text = re.sub(r'\b#\b', 'number', text)
    text = re.sub(r'\binterned\b', 'internet', text)
    text = re.sub(r'\bplcae\b', 'place', text)
    text = re.sub(r'\bdoesn;t\b', 'doesn\'t', text)
    text = re.sub(r'\bdon;t\b', 'don\'t', text)
    text = re.sub(r'\bisn;t\b', 'isn\'t', text)
    text = re.sub(r'\bit;s\b', 'it\'s', text)
    text = re.sub(r'\btaht\b', 'that', text)
    text = re.sub(r'\bmmind\b', 'mind', text)
    text = re.sub(r'\bmutliple\b', 'multiple', text)
    text = re.sub(r'\bthankyou\b', 'thank you', text)
    text = re.sub(r'\btaht\b', 'that', text)
    text = re.sub(r'\bot\b', 'to', text)
    text = re.sub(r'\blcoated\b', 'located', text)
    text = re.sub(r'\bcenre\b', 'centre', text)
    text = re.sub(r'\bneat the hotel\b', 'near the hotel', text)
    text = re.sub(r'\baddressis\b', 'address is', text)
    text = re.sub(r"\bnorth bandb\b", "north b and b", text)
    text = re.sub(r'\bhosptial\b', 'hospital', text)
    text = re.sub(r'\bwi-fi\b', 'wifi', text)
    text = re.sub(r"(<address>\s?,\s?){1,5}\s?<address>", "<address>", text)
    text = re.sub(r"(<restaurant-food>\s?,\s?){3,8}", "<restaurant-food>, <restaurant-food>, <restaurant-food>, ", text)
    text = re.sub(r'\bnum\b', 'number', text)
    text = re.sub("n't", " not", text)
    text = re.sub(r'\bi ca\b', 'i can', text)
    text = re.sub(r'\bpacking\b', 'parking', text)
    text = re.sub(r'(?<!hotel)-star[s]?\b', ' star', text)
    text = re.sub(r'\bliverpoool\b', 'liverpool', text)
    text = re.sub(r'\btowncentre\b', 'town centre', text)
    text = re.sub(r'\bthe folk museum\b', 'the cambridge and county folk museum', text)
    
    text = text.replace('’', '\'')
    # australasian->australian
    # text = re.sub(r"\b(zero|0)(-| ){1,2}star([ts.,? ]|$)", r"0 star\3", text)
    # text = re.sub(r"\b(one|1)(-| ){1,2}star([ts.,? ]|$)", r"1 star\3", text)
    # text = re.sub(r"\b(two|2)(-| ){1,2}star([ts.,? ]|$)", r"2 star\3", text)
    # text = re.sub(r"\b(three|3)(-| ){1,2}star([ts.,? ]|$)", r"3 star\3", text)
    # text = re.sub(r"\b(four|4)(-| ){1,2}star([ts.,? ]|$)", r"4 star\3", text)
    # text = re.sub(r"\b(five|5)(-| ){1,2}star([ts.,? ]|$)", r"5 star\3", text)
    
    text = re.sub(r"\b(zero|0|no)\s+star(ts|s|t)?\b", r"0 star", text)
    text = re.sub(r"\b(one|1)\s+star(ts|s|t)?\b", r"1 star", text)
    text = re.sub(r"\b(two|2)\s+star(ts|s|t)?\b", r"2 star", text)
    text = re.sub(r"\b(three|3)\s+star(ts|s|t)?\b", r"3 star", text)
    text = re.sub(r"\b(four|4)\s+star(ts|s|t)?\b", r"4 star", text)
    text = re.sub(r"\b(five|5)\s+star(ts|s|t)?\b", r"5 star", text)
    text = re.sub(r"\bstar (of )?zero\b", r"star of 0", text)
    text = re.sub(r"\bstar (of )?two\b", r"star of 1", text)
    text = re.sub(r"\bstar (of )?two\b", r"star of 2", text)
    text = re.sub(r"\bstar (of )?three\b", r"star of 3", text)
    text = re.sub(r"\bstar (of )?four\b", r"star of 4", text)
    text = re.sub(r"\bstar (of )?five\b", r"star of 5", text)
    text = re.sub(r"\bunrated\b", r"0 star", text)
    text = re.sub(r'\b0star\b', '0 star', text)

    text = re.sub(r"\bone\s+(people|person)?", r"1 \1", text)
    text = re.sub(r"\btwo\s+(people|person)?", r"2 \1", text)
    text = re.sub(r"\bthree\s+(people|person)?", r"3 \1", text)
    text = re.sub(r"\bfour\s+(people|person)?", r"4 \1", text)
    text = re.sub(r"\bfive\s+(people|person)?", r"5 \1", text)
    text = re.sub(r"\bsix\s+(people|person)?", r"6 \1", text)
    text = re.sub(r"\bseven\s+(people|person)?", r"7 \1", text)
    text = re.sub(r"\beight\s+(people|person)?", r"8 \1", text)
    text = re.sub(r"\bnine\s+(people|person)?", r"9 \1", text)
    
    text = re.sub(r"\b(one|1)\s+night(s)?\b", r"1 night", text)
    text = re.sub(r"\b(two|2)\s+night(s)?\b", r"2 night", text)
    text = re.sub(r"\b(three|3)\s+night(s)?\b", r"3 night", text)
    text = re.sub(r"\b(four|4)\s+night(s)?\b", r"4 night", text)
    text = re.sub(r"\b(five|5)\s+night(s)?\b", r"5 night", text)
    text = re.sub(r"\b(six|6)\s+night(s)?\b", r"6 night", text)
    text = re.sub(r"\b(seven|7)\s+night(s)?\b", r"7 night", text)
    text = re.sub(r"\b(eight|8)\s+night(s)?\b", r"8 night", text)
    text = re.sub(r"\b(nine|9)\s+night(s)?\b", r"9 night", text)
    text = re.sub(r"\bfor a week\b", r"for 7 night", text)
    text = re.sub(r"(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text)  # Normalization
    text = re.sub(r"\.+", ".", text)
    text = re.sub(r"\s?([.?!,();])\s?", r" \1 ", text)
    
    
    
    return text


def normalize_time(text):
    text = re.sub(r"\b(\d{2})(\d{2})(hrs)?\b", r"\1:\2", text)
    text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text)  # am/pm without space
    text = re.sub("(^| |\()(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text)  # am/pm short to long form
    text = re.sub(
        "(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$) (?!minu|star|hour|diff|opti|choi|town|<add|poun|nigh|in f|to .)",
        r"\1\2 \3:\4\5", text)  # Missing separator
    text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text)  # Wrong separator
    text = re.sub(
        "(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)(?!minu|star|hour|diff|opti|choi|town|<add|poun|nigh|in f|to .)",
        r"\1\2 \3:00\4", text)  # normalize simple full hour time
    text = re.sub("(^| |\()(\d{1}:\d{2})", r"\g<1>0\2", text)  # Add missing leading 0
    
    # Map 12 hour times to 24 hour times
    text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?",
                  lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) +
                            x.groups()[1], text)
    text = re.sub("(\d{2})(:\d{2}) ?a\.?m\.?(?! or pm)", r"\1\2", text)
    text = re.sub("(^| |\()24:(\d{2})", r"\g<1>00:\2", text)  # Correct times that use 24 as hour
    return text


def time_diff(time_str, diff=15):
    # time_str - diff
    # 15 或 5 分钟
    hrs, mis = time_str.split(':')
    hrs = int(hrs)
    mis = int(mis)
    if mis > diff:
        new_mis = mis - diff
        new_hrs = hrs
    elif hrs >= 1:
        new_mis = mis + 60 - diff
        new_hrs = hrs - 1
    else:
        new_mis = mis + 60 - diff
        new_hrs = 23
    return '%02d:%02d' % (new_hrs, new_mis)


def time_in_range(time_str1, time_str2, range=30, info=None):
    # assert re.match(r"^\d{2}:\d{2}$", time_str1), "%s"%time_str1
    # assert re.match(r"^\d{2}:\d{2}$", time_str2), "%s"%time_str2
    if not re.match(r"^\d{2}:\d{2}$", time_str1):
        logger.warning('bad time value for time_str %s' % time_str1)
        if info: print(info)
        if re.match(r"^\d{2}:\d{2}", time_str1):
            time_str1 = time_str1[:5]
        else:
            return False
    if not re.match(r"^\d{2}:\d{2}$", time_str2):
        logger.warning('bad time value for time_str %s' % time_str2)
        if info: print(info)
        if re.match(r"^\d{2}:\d{2}", time_str2):
            time_str2 = time_str2[:5]
        else: return False
    
    hrs1, mis1 = time_str1.split(':')
    hrs2, mis2 = time_str2.split(':')
    hrs1, mis1 = int(hrs1), int(mis1)
    hrs2, mis2 = int(hrs2), int(mis2)
    
    if hrs2 < hrs1:  # 保证 2 大
        hrs1, hrs2 = hrs2, hrs1
        mis1, mis2 = mis2, mis1
    
    # 有时标注会错误 相隔 12 小时...
    if abs(hrs1-hrs2) == 12 and mis1 ==mis2:
        return True
    
    if hrs2 == hrs1:
        if abs(mis1 - mis2) <= range: return True
    elif hrs2 - hrs1 == 1:
        if mis2 + 60 - mis1 <= range: return True
    elif hrs2 - hrs1 == 24:
        if mis1 + 60 - mis2 <= range: return True
    return False

# def contain(phrase, string):
#     # 包含短语
#     string = re.sub(r"<.+?>", "",string)
#     str_tok = [tok for tok in map(str.strip, re.split("(\W+)", string)) if len(tok) > 0]
#     phr_tok = [tok for tok in map(str.strip, re.split("(\W+)", phrase)) if len(tok) > 0]

def find_pos(value, string):
    # 返回 start pos
    if string[:len(value) + 1] == '%s ' % value:
        return 0
    if string[-len(value) - 1:] == ' %s' % value:
        return len(string) - len(value)
    if ' %s ' % value in string:
        return string.find(' %s ' % value) + 1
    return -1


def val_in_text(variants, sent):
    for vv in variants:
        start = find_pos(vv, sent)
        if start >= 0:
            # assert sent[start:start+len(vv)] == vv
            return vv, (start, start + len(vv))
    return None, None


def simple_delex(text, spans):
    seg_list = []
    if spans:
        last_end = 0
        for s, v, (begin, end) in sorted(spans, key=lambda x: x[2][0]):
            if begin >= last_end:
                seg_list.append(text[last_end:begin])
                # assert text[begin:end] == v
                seg_list.append(s)
                last_end = end
        seg_list.append(text[last_end:])
        output_s = ' '.join(' '.join(seg_list).split()).lower()
        return output_s
    return text


def change_vocab(vocab_file, add_tokens):
    assert len(add_tokens) < 90
    vocabs = []
    with open(vocab_file, "r", encoding="utf-8") as reader:
        tokens = reader.readlines()
    for index, token in enumerate(tokens):
        token = token.rstrip('\n')
        vocabs.append(token)
    
    for idx, add_token in enumerate(add_tokens):
        vocabs[idx + 1] = add_token
    
    with open(vocab_file, "w", encoding="utf-8") as f:
        for tok in vocabs:
            f.write('%s\n' % tok)  # 30522
            
            

def get_token_pos(tok_list, value_list):
    find_pos = []
    found = False
    len_label = len(value_list)
    for i in range(len(tok_list) + 1 - len_label):
        if tok_list[i:i + len_label] == value_list:
            find_pos.append((i, i + len_label)) # start, exclusive_end
            found = True
    return found, find_pos


class UtterOp:
    # 给定 delex text 和 dic， 能够 1. 转成 natural tokenized text  和 span label 2. 替换 dic
    def __init__(self, added_spec_tokens,  tokenizer=None):
        self.spec_tokens = added_spec_tokens.copy()
        # self.SynonymAug = naw.SynonymAug(aug_src='wordnet', aug_min=1, aug_max=1, aug_p=0.1)
        # self.KeyboardAug = nac.KeyboardAug(aug_char_min=1, aug_char_max=3, aug_char_p=0.1, aug_word_p=0.1)
        # self.SpellingAug = naw.SpellingAug(aug_min=1, aug_max=3, aug_p=0.1)
        
        if tokenizer:
            self.tokenizer = tokenizer  # 已经加入了 special tokens
        else:
            print('使用 basetokenizer split')
            BaseTokenizer = type('BaseTokenizer', (), dict(tokenize=lambda self, string: str.split(string)))
            self.tokenizer = BaseTokenizer()
    
    @staticmethod
    def split_on_token(tok, text):
        result = []
        split_text = text.split(tok)
        for i, sub_text in enumerate(split_text):
            sub_text = sub_text.strip()
            if i == 0 and not sub_text:
                result += [tok]
            elif i == len(split_text) - 1:
                if sub_text:
                    result += [sub_text]
                else:
                    pass
            else:
                if sub_text:
                    result += [sub_text]
                result += [tok]
        return result
    
    def segmentation(self, text):
        # 返回 segmentation list
        if not text.strip():
            return []
        text_list = [text]
        for tok in self.spec_tokens + list(REQUEST_SLOTS.values()):
            tokenized_text = []
            for sub_text in text_list:
                if sub_text not in self.spec_tokens + list(REQUEST_SLOTS.values()):
                    tokenized_text += self.split_on_token(tok, sub_text)
                else:
                    tokenized_text += [sub_text]
            text_list = tokenized_text
        return text_list
    
    # def aug(self, sent, type='syn'):
    #     if type == 'error':  # 噪声
    #         if random.random() > 0.3:
    #             return self.SpellingAug.augment(sent)
    #         else:
    #             return self.KeyboardAug.augment(sent)
    #     else:
    #         return self.SynonymAug.augment(sent)
    
    def lex(self, delex_text, _span_dic, onto, use_aug=False):
        span_dic = _span_dic.copy()
        if use_aug:
            for k, vs in span_dic.items():
                for vid, v in enumerate(vs):
                    new_v = onto.recommend_value(k.strip('<').strip('>'), v)
                    span_dic[k][vid] = new_v
        
        # return tokenized text
        seg_list = self.segmentation(delex_text)
        token_list = []
        span_dic_count = {k: 0 for k in span_dic}
        span_dic_pos = {k: [] for k in span_dic}
        for seg in seg_list:
            if seg not in self.spec_tokens:
                # if use_aug and len(seg.split()) > 2:
                #     if random.random() > 0.9:
                #         seg = self.aug(seg, type='syn')
                #     # elif random.random() < 0.1:
                #     #     seg = self.aug(seg, type='error')
                # seg = normalize_text(seg)
                tokens = self.tokenizer.tokenize(seg)
                
            elif seg in REQUEST_SLOTS.values():
                tokens = [seg]
            else:
                value = span_dic[seg][span_dic_count[seg]]
                # if use_aug and len(value.split())>2:
                #     if random.random() > 0.9:
                #         value = self.aug(value, type='syn')
                    # elif random.random() > 0.8:
                    #     value = self.aug(value, type='error')
                span_dic_count[seg] = min(span_dic_count[seg] + 1, len(span_dic[seg]) - 1)
                value = normalize_text(value)
                tokens = self.tokenizer.tokenize(value)
                span_dic_pos[seg].append((len(token_list), len(token_list) + len(tokens)))

                # 针对  attraction type 特殊处理
                if '<attraction-type>' in span_dic and seg == '<attraction-name>' and \
                        span_dic['<attraction-type>'][0] in value:
                    type_v = normalize_text(span_dic['<attraction-type>'][0])
                    type_tok = self.tokenizer.tokenize(type_v)
                    found, pos = get_token_pos(tokens, type_tok)
                    if found:
                        span_dic_pos['<attraction-type>'].append(
                            (len(token_list) + pos[0][0], len(token_list) + pos[0][1]))

            token_list.extend(tokens)
        return token_list, span_dic_pos, ' '.join(token_list)


from torch.utils.data import Dataset


class TensorListDataset(Dataset):
    r"""Dataset wrapping tensors, tensor dicts and tensor lists.

    Arguments:
        *data (Tensor or dict or list of Tensors): tensors that have the same size
        of the first dimension.
    """

    def __init__(self, *data):
        if isinstance(data[0], dict):
            size = list(data[0].values())[0].size(0)
        elif isinstance(data[0], list):
            size = data[0][0].size(0)
        else:
            size = data[0].size(0)
        for element in data:
            if isinstance(element, dict):
                assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors
            elif isinstance(element, list):
                assert all(size == tensor.size(0) for tensor in element) # list of tensors
            else:
                assert size == element.size(0) # tensor
        self.size = size
        self.data = data

    def __getitem__(self, index):
        result = []
        for element in self.data:
            if isinstance(element, dict):
                result.append({k: v[index] for k, v in element.items()})
            elif isinstance(element, list):
                result.append(v[index] for v in element)
            else:
                result.append(element[index])
        return tuple(result)

    def __len__(self):
        return self.size




if __name__ == '__main__':
    # pass
    print(len(ALL_SLOTS))
    # print(normalize_time('the total fee is 35.21 GBP payable'))
    from transformers import BertTokenizer as tokenizer_class
    from transformers import BertModel

    # model_path = '../download_models/dialoglue/bert'
    # special_tokens_dict = {'additional_special_tokens': SPEC_TOKENS + list(REQUEST_SLOTS.values())}
    # tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=True, do_basic_tokenize=False)
    # tokenizer.add_special_tokens(special_tokens_dict)  # 记得 TODO model.resize
    #
    # text = "am looking for a place to to stay that has <hotel-pricerange> price range it should be in a type of hotel"
    # text_dic = {'<hotel-pricerange>': ['cheap']}
    #
    # from ontology import Ontology
    #
    # onto = Ontology()
    # uttop = UtterOp(SPEC_TOKENS, tokenizer)
    #
    # tokens, dic, text = uttop.lex(text, text_dic, onto, False)
    # print(text)

    # tokens, dic, text = uttop.lex(text, text_dic, onto, True)
    # print(text)
    
    # new_tokens = []
    # for tok in text.split():
    #     if random.random() > 0: new_tokens.append(aug.augment(tok))
    #     else: new_tokens.append(tok)
    # print(' '.join(new_tokens))
