import transformers
import torch
import numpy as np
import re
import copy
import json

from preprocess import keyword_list, keywords
from fuzzywuzzy import fuzz


EPSILON = float(np.finfo(float).eps)
HUGE_INT = 1e31

NUM_MAPPING = {
    'half': 0.5,
    'one': 1,
    'two': 2,
    'three': 3,
    'four': 4,
    'five': 5,
    'six': 6,
    'seven': 7,
    'eight': 8,
    'nine': 9,
    'ten': 10,
    'eleven': 11,
    'twelve': 12,
    'twenty': 20,
    'thirty': 30,
    'once': 1,
    'twice': 2,
    'first': 1,
    'second': 2,
    'third': 3,
    'fourth': 4,
    'fifth': 5,
    'sixth': 6,
    'seventh': 7,
    'eighth': 8,
    'ninth': 9,
    'tenth': 10,
    'hundred': 100,
    'thousand': 1000,
    'million': 1000000,
    'jan': 1,
    'feb': 2,
    'mar': 3,
    'apr': 4,
    'may': 5,
    'jun': 6,
    'jul': 7,
    'aug': 8,
    'sep': 9,
    'oct': 10,
    'nov': 11,
    'dec': 12,
    'january': 1,
    'february': 2,
    'march': 3,
    'april': 4,
    'june': 6,
    'july': 7,
    'august': 8,
    'september': 9,
    'october': 10,
    'november': 11,
    'december': 12,
}


def safe_log(x):
    return torch.log(x + EPSILON)


def decode_sql(target_ids, q_base, t_base, c_base, q, t, c, table=None):
    decoded = []
    for id in target_ids:
        if id < q_base:
            decoded.append(keyword_list[id])
        elif id < t_base:
            decoded.append(q[id-q_base])
        elif id < c_base:
            decoded.append(t[id-t_base])
        else:
            col_idx = id - c_base
            PERSIST_COL = 3

            if col_idx < PERSIST_COL:
                decoded.append(c[col_idx])
            else:
                decoded.append(f'c{col_idx+1-PERSIST_COL}')
    return decoded


def get_value_map(ex):
    val_map = {}
    for a, b in ex['align']:
        a.sort()
        b.sort()

        q_tok = ex['nl']
        sql_tok = ex['sql']

        for idx in b:
            if sql_tok[idx][0] == "Literal.String" or \
               sql_tok[idx][0] == "Literal.Number":
                key = [q_tok[idx] for idx in a]
                key = [tok[1:] if len(tok) > 1 and tok[0] == ',' else tok for tok in key]
                key = ' '.join(key)

                key = re.sub('\-lrb\-', '\(', key)
                key = re.sub('\-rrb\-', '\)', key)

                key = re.sub('``', '\'', key)
                key = re.sub('\'\'', '\'', key)

                key = find_and_sub(key, '\S\'', '\'', ' \'')
                key = find_and_sub(key, '\'\S', '\'', '\' ')

                value = sql_tok[idx][1]
                val_map[key] = value
                break

    return val_map


def find_and_sub(sql, match_p, sub_p, sub_str):
    while True:
        match = re.search(match_p, sql)
        if match:
            group = match.group()
            sql = sql.replace(group, group.replace(sub_p, sub_str))
        else:
            break
    return sql


def get_cells(tbl_name):
    json_file = "tables/json/{}.json".format(tbl_name)
    with open(json_file, "r") as f:
        table = json.load(f)

    ret = set()
    for content in table["contents"][2:]:
        for col in content:
            if col["type"] == "LIST TEXT":
                for lst in col["data"]:
                    for x in lst:
                        ret.add((col["col"], str(x)))
            else:
                for x in col["data"]:
                    ret.add((col["col"], str(x)))

    return ret


def is_column(tok):
    if tok[0] != 'c':
        return False

    if tok[1] >= '0' and tok[1] <= '9':
        return True
    else:
        return False


def best_match(candidates, query, col=None):
    return max(candidates, key=lambda x: (fuzz.ratio(x[1], query), col==x[0]))


def get_column_type(col_name, table):
    for content in table["contents"][2:]:
        for col in content:
            if col['col'] == col_name:
                return col['type']

    return None


def parse_number(s):
    if s in NUM_MAPPING:
        return NUM_MAPPING[s]

    s = s.replace(',', '')
    # https://stackoverflow.com/questions/4289331/python-extract-numbers-from-a-string
    ret = re.findall(r"[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", s)

    if len(ret) > 0:
        ret = str(ret[0])

        if len(ret) > 3 and '.' in ret and ret[-1] == '0' \
            and (ret[-2] > '0' and ret[-2] <= '9'):
            ret = ret[:-1]
        elif len(ret) > 2 and ret[-2:] == '.0':
            ret = ret[:-2]
        elif len(ret) > 1 and ret[0] == '.':
            ret = '0' + ret
        return ret

    return None


def postprocess(pred_ids, pred_toks, pred_columnt, tbl_name, ex=None):
    for i, tok in enumerate(pred_toks):
        if len(tok) > 2 and tok[0] == '`' and tok[-1] == '`':
            pred_toks[i] = pred_toks[i][1:-1]
        if tok == '==':
            pred_toks[i] = '='
        if tok == '\"' and pred_ids[i] != keywords['\"']:  # from question
            pred_toks[i] = repr(tok)

    j = 0
    for item in pred_columnt:
        tok_id, col_suffix = item
        while j < len(pred_ids):
            if pred_ids[j] == tok_id:
                break
            j += 1

        if col_suffix != '[KEY]':
            pred_toks[j] += '_' + col_suffix
        j += 1

    pred_toks = pred_toks[1:-1]
    if not pred_toks:
        return ""

    pred_sql = pred_toks[0]
    in_quote = False

    for i in range(1, len(pred_toks)):
        tok = pred_toks[i]
        if tok[:2] == '##':
            tok = tok[2:]
            pred_sql += tok
        elif tok == '\"' and in_quote:
            pred_sql += tok
            in_quote = False
        elif pred_toks[i-1] == '\"' and in_quote:
            pred_sql += tok
        else:
            if tok == '\"':
                assert in_quote is False
                in_quote = True
            pred_sql += ' ' + tok

    pred_sql = re.sub(" / ", "/", pred_sql)
    # pred_sql = re.sub('\"', '\'', pred_sql)

    pred_sql = find_and_sub(pred_sql, '\d+ \. \d+', ' . ', '.')
    pred_sql = find_and_sub(pred_sql, '\. \d+', '. ', '.')
    pred_sql = find_and_sub(pred_sql, ' \'[^\']+ \! \! [^\']+\'(\ |$)', '! !', '!!')
    pred_sql = find_and_sub(pred_sql, ' \'[^\']+ \- [^\']+\'(\ |$)', ' - ', '-')
    pred_sql = find_and_sub(pred_sql, ' \'[^\']+ \.[^\']*\'(\ |$)', ' .', '.')
    pred_sql = find_and_sub(pred_sql, ' \'[^\']*\. \w\.[^\']*\'(\ |$)', '. ', '.')
    pred_sql = find_and_sub(pred_sql, ' \'[^\']*\. \w( [^\']*)?\'(\ |$)', '. ', '.')

    # Handling mismatch value
    json_file = "tables/json/{}.json".format(tbl_name)
    with open(json_file, "r") as f:
        table = json.load(f)

    pred_toks = pred_sql.split(' ')
    cells = get_cells(tbl_name)
    in_quote = False
    start = -1
    pred_atoms = []

    for i, tok in enumerate(pred_toks):
        if tok[0] == '\"' and not in_quote:
            in_quote = True
            start = i

        if not in_quote:
            pred_atoms.append(tok)
        elif in_quote and tok[-1] == '\"':
            value_str = ' '.join(pred_toks[start:i+1])
            value_str = value_str[1:-1]
            literal = None

            if start >= 2 and is_column(pred_toks[start-2]) and pred_toks[start-1] in ['=', '!=']:
                col, literal = best_match(cells, value_str, pred_toks[start-2])
                if col != pred_toks[start-2]:
                    literal = value_str
            elif len(pred_atoms) >= 2:  # col in ( )
                j = len(pred_atoms) - 1
                while j >= 0:
                    if pred_atoms[j] == 'in':
                        break
                    j -= 1

                if j != -1:
                    if j >= 1 and is_column(pred_atoms[j-1]):
                        corr_col = pred_atoms[j-1]
                    elif j >= 2 and pred_atoms[j-1] == 'not' and is_column(pred_atoms[j-2]):
                        corr_col = pred_atoms[j-2]
                    else:
                        corr_col = None

                    col, literal = best_match(cells, value_str, corr_col)
                    if col != corr_col:
                        literal = value_str

            if not literal:
                literal = value_str

            # Handling numeric value
            j = len(pred_atoms) - 1
            while j >= 0:
                if is_column(pred_atoms[j]):
                    break
                j -= 1

            if j != -1:
                col_type = get_column_type(pred_atoms[j], table)
                if col_type in ["INTEGER", "REAL"]:
                    number_str = parse_number(literal)
                    if number_str is not None:
                        literal = number_str

            pred_atoms.append(repr(literal))
            in_quote = False

    pred_sql = ' '.join(pred_atoms)

    # Delete the quotation marks around the number
    pred_sql = find_and_sub(pred_sql, "\'[\d\.]+\'", '\'', '')

    return pred_sql


def pack_bidirectional_lstm_state(state, num_layers):
    """
    Pack the hidden state of a BiLSTM s.t. the first dimension equals to the number of layers.
    """
    assert (len(state) == 2 * num_layers)
    _, batch_size, hidden_dim = state.size()
    layers = state.view(num_layers, 2, batch_size, hidden_dim).transpose(1, 2).contiguous()
    state = layers.view(num_layers, batch_size, -1)
    return state


class BertSQLLRScheduler(object):
    def __init__(self, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1, verbose=False):
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps
        self.lr_end = lr_end
        self.power = power
        self.last_epoch = last_epoch
        self.verbose = verbose

    def build(self, optimizer):
        return transformers.get_polynomial_decay_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=self.num_training_steps,
            lr_end=self.lr_end,
            power=self.power,
            last_epoch=self.last_epoch,
        )

    def __call__(self, optimizer):
        return self.build(optimizer)
