import sys
import re
import json
import torch
import pdb
import os

from torch import optim
from transformers import BertModel
from dataset import load_data
from trainer import App, Checkpoint, TQDM, TensorBoard
from model import BertSQL
from utils import decode_sql, BertSQLLRScheduler, postprocess

# torch.autograd.set_detect_anomaly(True)


if __name__ == "__main__":
    config = json.load(open(sys.argv[1]))
    is_train = bool(sys.argv[2])
    tag = sys.argv[4]

    name = f"{sys.argv[3]}-{tag}"
    ptm_name = config.get("ptm_name", 'bert-base-uncased')
    data = load_data({
        'dataset': {
            'train': config['data']['train'],
            'dev': config['data']['dev'],
            'test': config['data']['test']
        },
        'ptm_name': ptm_name,
        'bs': config['bs'],
        'batch_first': True,
        'shuffle': True,
        'device': 'cuda',
        'fields': [
            'nl', 'align', 'columns', 'sql', 'tbl',
            'tbl_name', 'nt', 'nl_typebio', 'nl_typebio_col'
        ],
        # 'fields': [
        #     'nl', 'amap', 'align', 'columns', 'sql',
        #     'tbl', 'tbl_name', 'nt'
        # ],
    })

    tokenizer = data['tokenizer']
    model = BertSQL(bert=BertModel.from_pretrained(ptm_name),
                    tokenizer=tokenizer,
                    config=config)
    app = App(model, name=f'BertSQL-{name}')

    app.extend(Checkpoint())
    app.extend(TQDM())
    app.extend(TensorBoard())


    @app.on("evaluate")
    def bert_eval(e):
        # try:
        y = e.model.tokenize(
            e.batch,
            is_train=False,
            with_align=config['model']['with_align'],
            with_reduce=config['model']['with_reduce'],
            with_key=config['model']['with_key'],
            with_col=config['model']['with_col'],
            with_val=config['model']['with_val']
        )
        y = e.model(
            y, is_train=False,
            beam_size=config['model']['beam_size'],
            max_seq=config['model']['max_len']
        )
        # except Exception as e:
        #     pdb.set_trace()

        if 'predict' in y:
            if not y['predict']:
                pred = [1, 2]
            else:
                pred = y['predict'][0][0]
            if 'target_ids' in y:
                gold = y['target_ids'][0]
                if isinstance(gold, torch.Tensor):
                    gold = gold.view(-1).tolist()
            else:
                gold = [1, 2]
        else:
            pred = [1, 2]
            gold = [1, 2]

        if 'predict_column_type' in y:
            pred_ct = y['predict_column_type']
            gold_ct = []

            if 'target_ids_columnt' in y:
                for ct, tok in zip(y['target_ids_columnt'].view(-1).tolist(), y['target_ids'].view(-1).tolist()):
                    if ct != -1:
                        if ct in e.model.column_type_rev_ids:
                            gold_ct.append([tok, e.model.column_type_rev_ids[ct]])
        else:
            pred_ct = []
            gold_ct = []

        if pred:
            col = e.batch[0]['col']
            col = col[:3] + [f"`{c}`" for c in col[3:]]
            pred_sql = decode_sql(pred, y['q_base'], y['t_base'], y['c_base'], e.batch[0]['q_tokenize'], ["`w`"], col)
            pred_sql = postprocess(pred, pred_sql, pred_ct)
        else:
            pred_sql = ''

        return {
            "nt": e.batch[0]['nt'],
            "nl": e.batch[0]['nl'],
            "loss": 0,
            "q": e.batch[0]['q'],
            "q_tokenize": e.batch[0]['q_tokenize'],
            'sql': ' '.join(e.batch[0]['target']),
            'pred_sql': pred_sql,
            'tbl': e.batch[0]['tbl'],
            'col': e.batch[0]['col'],
            'tbl_name': e.batch[0]['tbl_name'],
            "pred": pred,
            "gold": gold,
            "pred_columnt": pred_ct,
            "gold_columnt": gold_ct,
            "q_base": y['q_base'],
            "t_base": y['t_base'],
            "c_base": y['c_base'],
        }


    output_dir = 'output'
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    output_dir = os.path.join(output_dir, name)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)


    results = app.fastforward()    \
       .to("cuda")  \
       .with_seed(config['seed'])  \
       .save_every(epochs=1)    \
       .build() \
       .eval(data['iterator']['test'])

    with open(f"{output_dir}/result.{tag}.test.json", "w+") as fw:
        json.dump(results, fw, ensure_ascii=False, indent=2)
