""" Example interactive environment"""

import json
import os

from transformers.models.bert.tokenization_bert import BertTokenizer
from qa.table.utils import hmt_score
from qa.nsm.env import QAProgrammingEnv
from qa.nsm.execution.executor import HMTExecutor
from qa.nsm.computer import LispInterpreter
from qa.table_bert.hm_table import *
from qa.examples.example_data import EXAMPLE_DICT_JULY_4, KG_JULY_4


ROOT_PATH = '/home/TableSense/USER/HMT/'
DATA_PATH = 'qa/data/'


def test_interpreter(interpreter: LispInterpreter, program: str):
    program = program.strip()
    tokens = interpreter.tokenize(program)
    for tk in tokens:
        result = interpreter.read_token(tk)
        print('Read in [{}], valid tokens: {}'.format(tk, interpreter.valid_tokens()))
        if result:
            print('Result: ', result)


def create_environment_wrapper(train_dict, table_dict, id, max_mem, max_n_exp):
    score_fn = hmt_score
    process_answer_fn = eval
    executor_fn = HMTExecutor

    example = train_dict[id]
    kg_info = table_dict[example['context']]
    kg_info = HMTable.from_dict(kg_info['kg']).build_kg_info()
    executor = executor_fn(kg_info)
    api = executor.get_api()
    type_hierarchy = api['type_hierarchy']
    func_dict = api['func_dict']
    constant_dict = api['constant_dict']
    interpreter = LispInterpreter(
        op_region=kg_info['kg'].get_init_op_region(),
        type_hierarchy=type_hierarchy,
        max_mem=max_mem,
        max_n_exp=max_n_exp,
        hmt=kg_info['kg'],
        assisted=True)

    for v in func_dict.values():
        interpreter.add_function(**v)

    de_vocab = interpreter.get_vocab()
    env = QAProgrammingEnv(
        question_annotation=example,
        kg=kg_info,
        answer=process_answer_fn(example['answer']),
        constants=constant_dict.values(),
        interpreter=interpreter,
        score_fn=score_fn,
        max_cache_size=50000,
        name=example['id'])
    return env


def main():
    # Load dataset.
    train_shard_file = os.path.join(ROOT_PATH, DATA_PATH, 'processed_input/train_samples.jsonl')
    train_dict = {}
    print('working on shard {}'.format(train_shard_file))
    with open(train_shard_file, 'r') as f:
        for line in f:
            example = json.loads(line)
            train_dict[example['id']] = example
    print('{} examples in training set.'.format(len(train_dict)))

    table_file = os.path.join(ROOT_PATH, DATA_PATH, 'processed_input/tables.jsonl')
    table_dict = {}
    with open(table_file, 'r') as f:
        for line in f:
            table = json.loads(line)
            table_dict[table['name']] = table
    print('{} tables.'.format(len(table_dict)))

    # Interactive environment
    id = '1007'
    env = create_environment_wrapper(train_dict, table_dict, id, 120, 10)
    print("Vocab =>")
    for k, v in env.de_vocab.rev_vocab.items():
        print(f"{k}:{v} | ", end='')
    print('\nNamespace =>')
    for key, val in env.interpreter.namespace.items():
        type_ancesters = env.interpreter.get_type_ancestors(val['type'])
        if 'head' in type_ancesters:
            print('Function: {}'.format(key))
        else:
            print('Entity: {}, Value: {}'.format(key, val))

    env.interactive()

    # print("\ntest#1 pipeline" + '-'*80)
    # test_interpreter(env.interpreter, """
    #                                     (filter_tree_str_contain v9 v0 v1) \
    #                                     (filter_tree_str_contain v9 v2) \
    #                                     (filter_level v11) \
    #                                     (filter_tree_str_contain v12 v5) \
    #                                     (filter_tree_str_contain v12 v7 v8) \
    #                                     (filter_level v14) \
    #                                     (argmax)
    #                                   """)
    # print("\ntest#2 MAX_N_FILTER_TREE_LEFT=1, and filter_level left before top" + '-'*80)
    # env = create_environment_wrapper()
    # test_interpreter(env.interpreter, """
    #                                     (filter_tree_str_contain v9 v0) \
    #                                     (filter_level v11) \
    #                                     (filter_level v14)
    #                                  """)
    # print("\ntest#3 filter_level on line_idx=None non-leaf node; if default filter the current deepest." + '-'*80)
    # env = create_environment_wrapper()
    # test_interpreter(env.interpreter, """
    #                                     (filter_level v11) \
    #                                     (filter_tree_str_contain v12 v5 v6) \
    #                                     (filter_level v14) \
    #                                  """)
    # print("\ntest#4 filter_tree_str_not_contain" + '-'*80)
    # env = create_environment_wrapper()
    # test_interpreter(env.interpreter, """
    #                                     (filter_tree_str_not_contain v9 v0) \
    #                                     (filter_level v10) \
    #                                     (filter_tree_str_not_contain v12 v5) \
    #                                     (filter_level v14) \
    #                                     (sum v11)
    #                                  """)
    # print("\ntest#5 diff" + '-'*80)
    # env = create_environment_wrapper()
    # test_interpreter(env.interpreter, """
    #                                     (filter_tree_str_not_contain v9 v0) \
    #                                     (filter_level v10) \
    #                                     (filter_tree_str_not_contain v12 v6) \
    #                                     (filter_tree_str_not_contain v12 v7) \
    #                                     (filter_level v14) \
    #                                     (
    #                                  """)


if __name__ == "__main__":
    main()
