import pdb
import re
import unicodedata
import json
import random

from dataset import pad


keyword_list = [
    "[PAD]", "[CLS]", "[SEP]", "select", "distinct", "from", "where",
    "join", "on", "as", "group", "order", "by", "asc", "desc", "having", "limit",
    "intersect", "except", "union", "count", "sum", "min", "max", "avg", "(", ")",
    ",", "<", "<=", "==", ">=", ">", "!=", "between", "in", "and" ,
    "or", "not", "like", "+", "-", "*", "/", "0", "1", "2",
    "3", "4", "5", "6", "7", "8", "9", "10", "time_now", "\"", "value",
    "select2", "row", "cluster", "是", "否", "@", "all", "inf",
    "is", "null", "abs", "present_ref",
    "julianday", "length", "@2", "@3", "@4", "@5", "@6", "@7", "@8", "@9"
]
keywords = {k: i for i, k in enumerate(keyword_list)}

columnt = json.load(open("data/columnt.json"))
columnt_rev_idx = {i: ct for ct, i in columnt.items()}

col_types = json.load(open('data/vocab.col_type'))
col_types = {c: i+10 for i, c in enumerate(col_types)}


def get_col_type(ct):
    ct = ct.strip()
    ct = ct = re.sub(r"(\w),(\w)", r"\1, \2", ct)
    if not ct:
        return col_types["*"]
    if ct in col_types:
        return col_types[ct]
    if ct[0] == "<" and ct[-1] == ">":
        return col_types["expr"]
    return col_types["*"]


def phrase_norm(phrase):
    if isinstance(phrase, str):
        phrase = phrase.strip().split()
    out = []

    for i, tok in enumerate(phrase):
        if tok in ("``", "''"):
            tok = "\""
        tok = re.sub(r"[‘’´`]", "'", tok)
        # tok = re.sub(r"[‘’]", "'", tok)
        tok = re.sub(r"[“”]", "\"", tok)
        tok = re.sub(r"[‐‑‒–—−]", "-", tok)
        tok = re.sub(r'[\\_]', ' ', tok)
        tok = re.sub(r'\xa0', ':', tok)
        if re.search(r"\d+(,\d+)*", tok):
            tok = re.sub(r',', '', tok)
        if tok == "-lrb-":
            tok = "("
        if tok == "-rrb-":
            tok = ")"
        out.append(tok)
    return out


def normalize(x):
    if not isinstance(x, str):
        x = x.decode('utf8', errors='ignore')

    # Remove diacritics
    x = ''.join(c for c in unicodedata.normalize('NFKD', x)
                if unicodedata.category(c) != 'Mn')

    # Normalize quotes and dashes
    x = re.sub(r"[‘’´`]", "'", x)
    x = re.sub(r"[“”]", "\"", x)
    x = re.sub(r"[‐‑‒–—−]", "-", x)

    while True:
        old_x = x
        # Remove citations
        x = re.sub(r"((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$", "", x.strip())
        # Remove details in parenthesis
        x = re.sub(r"(?<!^)( \([^)]*\))*$", "", x.strip())
        # Remove outermost quotation mark
        x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
        if x == old_x:
            break

    # Remove final '.'
    if x and x[-1] == '.':
        x = x[:-1]
    # Collapse whitespaces and convert to lower case
    x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
    return x


def is_continuous(a):
    s = sum(a)
    if s == a[0] * len(a) + sum([i for i in range(len(a))]):
        return True
    return False


def get_align_map(ex):
    amap = {}
    for a, b in ex['align']:
        a.sort()
        b.sort()
        if not is_continuous(a):
            continue

        nl = [ex['nl'][i] for i in a]
        sql = [ex['sql'][i][1] for i in b]
        nl = {'span': nl, 'sql': sql, 'flag': False, 'is_cont': is_continuous(a), 'a': a}

        for i in b:
            amap[i] = nl
    return amap


def is_number(s):
    try:
        int(s)
        return True
    except:
        try:
            float(s)
            return True
        except:
            return False


# presist col: * , id, agg
def process_sql(tokenizer, sql, q, tbl, col, ex=None):
    Q_BASE = 1000
    T_BASE = 2000
    C_BASE = 3000

    target = []
    target_ids = []
    target_ids_columnt = []

    key_set = keywords
    tok_set = {tok: i for i, tok in enumerate(q)}
    tbl_set = {tbl: i for i, tbl in enumerate(tbl)}
    col_set = {col: i for i, col in enumerate(col)}

    amap = get_align_map(ex)

    for i, tok in enumerate(sql):
        cm = re.search(r"^c(\d+)($|\_)", tok)
        if tok == '=':
            tok = '=='

        if tok in key_set:
            target.append(tok)
            target_ids.append(key_set[tok])
            target_ids_columnt.append(-1)
        elif tok in col_set:
            col_idx = col_set[tok]
            target.append(f"`{col[col_idx]}`")
            target_ids.append(C_BASE + col_idx)
            target_ids_columnt.append(columnt["[KEY]"])
        elif cm is not None:
            col_idx = int(cm.group(1)) - 1 + 3  # minus persist col.
            target.append(f"`{col[col_idx]}`")
            target_ids.append(C_BASE + col_idx)
            cns = tok.strip().split("_", 1)

            if len(cns) > 1:
                assert cns[0] == f"c{cm.group(1)}", f"{cns}, {cm.group()}"
                col_columnt = cns[1]
                target_ids_columnt.append(columnt[col_columnt])
            else:
                target_ids_columnt.append(columnt["[KEY]"])
        elif tok in tbl_set:
            target.append(f"`{tok}`")
            target_ids.append(T_BASE + tbl_set[tok])
            target_ids_columnt.append(-1)
        elif (tok[0] == "\'" and tok[-1] == "\'") or \
             (tok[0] == '\"' and tok[-1] == '\"') or \
             tok in tok_set or is_number(tok):
            target.append(tok)

            if tok[0] == "\'":
                tok = tok[1:-1]

            target_ids.append(keywords['\"'])
            target_ids_columnt.append(-1)

            '''
            regard multi-token value as a single token
            (for copy mechanism)
            '''
            if i in amap:
                nl = amap[i]
                if nl['flag']:
                    continue
                tok = nl['span']
                nl['flag'] = True

            q_idx = -1
            tok = phrase_norm(tok)
            tok = tokenizer.tokenize(' '.join(tok))

            for i in range(len(q)):
                if tok == q[i:i+len(tok)]:
                    q_idx = i
                    break
            assert q_idx != -1, f"{tok}, {q}"

            for i in range(q_idx, q_idx+len(tok)):
                target_ids.append(Q_BASE + i)
                target_ids_columnt.append(-1)

            target_ids.append(keywords['\"'])
            target_ids_columnt.append(-1)
        else:
            print(f"unknown token: {tok}")

    return {
        "target": target,
        "target_ids": [1] + target_ids + [2],
        "target_ids_columnt": [-1] + target_ids_columnt + [-1]
    }


def process_ex(ex, tokenizer):
    q = phrase_norm(ex['nl'])
    ex['q'] = q
    ex['q_tokenize'] = tokenizer.tokenize(' '.join(q))
    col = ["*", "id", "agg"]

    for col_info in ex['columns']:
        col_ = ' '.join(col_info[1])
        if col_ in ["", "#"]:
            col_ = col_info[-2].strip().split("(")[0]

        col_ = re.sub("\n", "", col_)
        col_ = re.sub("\\\\n", "", col_)
        col_ = col_.replace('\\', '')
        col_ = col_.replace("-lrb-", "(")
        col_ = col_.replace("-rrb-", ")")
        col_ = normalize(col_)

        if col_ == "":
            col_ = "#"
        col.append(col_)

    tbl = ["w"]
    if 'tbl_name' not in ex:
        ex['tbl_name'] = ex['tbl']

    ex['tbl'] = tbl
    ex['col'] = col

    if 'sql' in ex:
        sql = [i[1] for i in ex['sql']]
        encoding_dict = process_sql(tokenizer, sql, ex['q_tokenize'], tbl, col, ex=ex)

        ex['encoding_dict']['target_ids'] = encoding_dict['target_ids']
        ex['encoding_dict']['target_ids_columnt'] = encoding_dict['target_ids_columnt']

        ex['target'] = encoding_dict['target']
        ex['target_ids'] = encoding_dict['target_ids']
        ex['target_ids_columnt'] = encoding_dict['target_ids_columnt']

    return ex


def get_align_from_amap(amap, with_key=True, with_col=True, with_val=True):
    kmap, cmap, vmap = amap['kmap'], amap['cmap'], amap['vmap']

    if not with_key:
        kmap = []
    if not with_col:
        cmap = []
    if not with_val:
        vmap = []

    map_list = []
    map_type_list = []

    for i, a_map in enumerate([kmap, cmap, vmap]):
        map_type_list.extend([i] * len(a_map))
        for key, value in a_map:
            map_str = f"{key} : {value}"
            map_list.append(map_str)
    return map_list, map_type_list


def encode(tokenizer, ex, with_target="sql", is_train=True, with_align=True,
           with_reduce=False, with_key=True, with_col=True, with_val=True):
    CLS = 101
    SEP = 102
    T = 99
    C = 100
    KEY = 96
    COL = 97
    VAL = 98
    A_PAD = 95

    encoding_dict = {}
    ex['encoding_dict'] = encoding_dict

    ex = process_ex(ex, tokenizer)
    q = ex['q']
    tbl = ex['tbl']
    col = ex['col']

    if with_align:
        ali, ali_type = get_align_from_amap(ex['amap'], with_key=with_key, with_col=with_col, with_val=with_val)
    else:
        ali, ali_type = [], []

    if with_target != 'sql':
        ali, ali_type = [], []

    q_ids = tokenizer(q, add_special_tokens=False, return_token_type_ids=None, return_attention_mask=False)['input_ids']
    tbl_ids = tokenizer(tbl, add_special_tokens=False, return_token_type_ids=None, return_attention_mask=False)['input_ids']
    col_ids = tokenizer(col, add_special_tokens=False, return_token_type_ids=None, return_attention_mask=False)['input_ids']

    if ali:
        ali_ids = tokenizer(ali, add_special_tokens=False, return_token_type_ids=None, return_attention_mask=False)['input_ids']
    else:
        ali_ids = []

    x = [CLS]
    q_start = len(x)

    q_len = 0
    for i, q_sub_toks in enumerate(q_ids):
        if q_sub_toks:
            x += q_sub_toks
            q_len += len(q_sub_toks)

    q_sep = len(x)
    x += [SEP]

    t_span_ids = []
    c_span_ids = []
    tbl_parsed = False

    for i, c_sub_toks in enumerate(col_ids):
        if i < 3:  # ["*", "id", "agg"]
            col_start = len(x)
            x += [C] + c_sub_toks
            col_end = len(x)
            c_span_ids.append([col_start, col_end])
        else:
            if not tbl_parsed:
                tbl_start = len(x)
                x += [T] + tbl_ids[0]   # WikiTableQuestion is a single-table dataset
                tbl_end = len(x)

                t_span_ids.append([tbl_start, tbl_end])
                tbl_parsed = True

            col_start = len(x)
            x += [get_col_type(ex['columns'][i-3][3])] + c_sub_toks
            col_end = len(x)
            c_span_ids.append([col_start, col_end])

    s_sep = len(x)
    x += [SEP]

    a_span_ids = []
    a_sep = s_sep

    if with_align:
        for ali_sub_toks, ali_type in zip(ali_ids, ali_type):
            ali_start = len(x)
            if ali_type == 0:
                x += [KEY]
            elif ali_type == 1:
                x += [COL]
            else:
                x += [VAL]

            x += ali_sub_toks
            ali_end = len(x)
            a_span_ids.append([ali_start, ali_end])

        if not a_span_ids:
            ali_start = len(x)
            x += [A_PAD]
            ali_end = len(x)
            a_span_ids.append([ali_start, ali_end])

        a_sep = len(x)
        x += [SEP]

    # x = f"[CLS] {question} [SEP] {schema} [SEP] {align} [SEP]"
    attention_mask = [1] * len(x)
    token_type_ids = [0] * (s_sep + 1) + [1] * (a_sep - s_sep)
    assert q_sep - q_start == q_len

    encoding_dict['q'] = q
    encoding_dict['q_length'] = q_sep - q_start
    encoding_dict['t_length'] = len(tbl)
    encoding_dict['c_length'] = len(col)
    encoding_dict['q_span_ids'] = [q_start, q_sep]
    encoding_dict['t_span_ids'] = t_span_ids
    encoding_dict['c_span_ids'] = c_span_ids
    encoding_dict['a_span_ids'] = a_span_ids
    encoding_dict['tbl'] = tbl
    encoding_dict['col'] = col
    encoding_dict['ali'] = ali
    encoding_dict['input_ids'] = x
    encoding_dict['attention_mask'] = attention_mask
    encoding_dict['token_type_ids'] = token_type_ids
    ex['encoding_dict'] = encoding_dict

    return encoding_dict


def adjust_target_ids(target_ids, q_base, t_base, c_base):
    Q_BASE = 1000
    T_BASE = 2000
    C_BASE = 3000

    target_ids_adjusted = []
    for id in target_ids:
        if isinstance(id, list):
            _, id, rel = id
        else:
            _, id, rel = -1, id, -2
        if id < Q_BASE:  # keywords
            pass
        elif id < T_BASE:   # q token
            id = id - Q_BASE + q_base
        elif id < C_BASE:   # t token
            id = id - T_BASE + t_base
        else:   # c token
            id = id - C_BASE + c_base

        if rel == -2:
            target_ids_adjusted.append(id)
        else:
            target_ids_adjusted.append([id, rel])

    return target_ids_adjusted


def tokenize(tokenizer, examples, with_target='sql', is_train=True, with_align=True,
             with_reduce=False, with_key=True, with_col=True, with_val=True):
    encoding_dict = {
        "input_ids": [],
        "attention_mask": [],
        "token_type_ids": [],

        "q_span_ids": [],
        "t_span_ids": [],
        "c_span_ids": [],
        "a_span_ids": [],

        "q_length": [],
        "t_length": [],
        "c_length": [],

        # target labels
        "target_ids": [],
        "target_ids_columnt": [],
        "target_ids_debug": []
    }

    bs = len(examples)
    encoding_dict['bs'] = bs
    q_max_len = 0

    for ex in examples:
        # try:
        input_dict = encode(tokenizer, ex, is_train=is_train, with_target=with_target, with_align=with_align,
                            with_reduce=with_reduce, with_key=with_key, with_col=with_col, with_val=with_val)
        # except Exception as e:
        #     print(e)
        #     pdb.set_trace()

        q_max_len = max(q_max_len, input_dict['q_length'])

        encoding_dict['input_ids'].append(input_dict['input_ids'])
        encoding_dict['attention_mask'].append(input_dict['attention_mask'])
        encoding_dict['token_type_ids'].append(input_dict['token_type_ids'])
        encoding_dict['q_span_ids'].append(input_dict['q_span_ids'])
        encoding_dict['t_span_ids'].append(input_dict['t_span_ids'])
        encoding_dict['c_span_ids'].append(input_dict['c_span_ids'])
        encoding_dict['a_span_ids'].append(input_dict['a_span_ids'])
        encoding_dict['q_length'].append(input_dict['q_length'])
        encoding_dict['t_length'].append(input_dict['t_length'])
        encoding_dict['c_length'].append(input_dict['c_length'])

        if 'target_ids' in input_dict:
            encoding_dict['target_ids'].append(input_dict['target_ids'])
            encoding_dict['target_ids_columnt'].append(input_dict['target_ids_columnt'])

    encoding_dict['input_ids'] = pad(encoding_dict['input_ids'], device='cuda')
    encoding_dict['attention_mask'] = pad(encoding_dict['attention_mask'], device='cuda')
    encoding_dict['token_type_ids'] = pad(encoding_dict['token_type_ids'], device='cuda')

    if with_target == 'sql':
        q_base = len(keywords)
    else:
        raise NotImplementedError
    t_base = q_base + max(encoding_dict['q_length'])
    c_base = t_base + max(encoding_dict['t_length'])

    encoding_dict['q_base'] = q_base
    encoding_dict['t_base'] = t_base
    encoding_dict['c_base'] = c_base

    if encoding_dict['target_ids']:
        if with_target != 'sql':
            raise NotImplementedError

        encoding_dict['target_ids'] = [adjust_target_ids(target_ids, q_base, t_base, c_base) for target_ids in encoding_dict['target_ids']]
        encoding_dict['target_length'] = [len(tgt) for tgt in encoding_dict['target_ids']]
        encoding_dict['target_ids'] = pad(encoding_dict['target_ids']).cuda()
        encoding_dict['target_ids_columnt'] = pad(encoding_dict['target_ids_columnt'], padding=-1).cuda()

        encoding_dict['q_base'] = q_base
        encoding_dict['t_base'] = t_base
        encoding_dict['c_base'] = c_base

    encoding_dict['ex'] = ex
    return encoding_dict
