import json

from transformers import BertTokenizer
from preprocess import encode


tokenizer = BertTokenizer.from_pretrained("../bert-model/bert-base-uncased/")


def is_number(s):
    try:
        int(s)
        return True
    except:
        try:
            float(s)
            return True
        except:
            return False


def process(src_path, tgt_path, align_path=None):
    with open(src_path, 'r', encoding='utf-8') as fr:
        data = json.load(fr)

    if align_path is not None:
        with open(align_path, 'r', encoding='utf-8') as fr:
            align_data = json.load(fr)

        align_data_dict = {}
        for ex in align_data:
            align_data_dict[ex['nt']] = ex

        # with open(align_path, 'r', encoding='utf-8') as fr:
        #     # align_data = json.load(fr)['result']
        #     align_data = json.load(fr)
        # assert(len(data) == len(align_data))

        # align_data = {}
        # with open(align_path, 'r', encoding='utf-8') as fr:
        #     lines = fr.read().strip().split('\n')
        #     for line in lines:
        #         ex = json.loads(line)
        #         align_data[ex['nt']] = ex

    tgt = 'sql'
    err = 0

    with open(tgt_path, 'w', encoding='utf-8') as fw:
        for i, ex in enumerate(data):
            try:
                encode(tokenizer, ex, with_target=tgt, with_align=False)
                if align_path is not None:
                    amap = align_data_dict[ex['nt']]['amap']
                    ex['amap'] = amap

                line = json.dumps(ex, ensure_ascii=False)
                fw.write(f"{line}\n")
            except Exception as e:
                print(i, e)
                err += 1
    print(err, len(data), err/len(data))


if __name__ == "__main__":
    data_dir = 'data/squall/'
    align_dir = 'pred_align/'

    process(data_dir + 'train-0.json', 'pred_align/gold-0-splited/train-0.ali-gen-gold.splited.json', 'pred_align/gold-0-splited/result.ali-gen-gold.splited.train-0.json')

    # n_split = 5
    # for i in range(n_split):
    #     process(data_dir + f'dev-{i}.json', f'pred_align/aug-split-{i}/dev-{i}.aug-amap.processed.json', f'pred_align/aug-split-{i}/result.{i}-split-fixed.dev.json')
    #     # process(data_dir + f'train-{i}.json', f'pred_align/gold-{i}/train-{i}.ali-gen-gold.processed.json', f'pred_align/gold-{i}/result.ali-gen-gold.train-{i}.json')
    #     process(data_dir + 'wtq-test.json', f'pred_align/aug-split-{i}/wtq-test.aug-amap.processed.json', f'pred_align/aug-split-{i}/result.{i}-split-fixed.test.json')