from stanza.server import CoreNLPClient
from argparse import ArgumentParser

from qa.datadump.utils import *


client = CoreNLPClient(
    annotators=['tokenize', 'ner'],
    timeout=30000,
    memory='16G')


def set2list(dic):
    for k in dic:
        dic[k] = list(dic[k])
    return dic

def generate_semantic_map():
    dic, rev_dic = {}, {}
    cnt_index_name, cnt_file = 0, 0
    for file in os.listdir(os.path.join(args.root_dir, args.data_dir, 'dataschema')):
        if 'Structure' not in file:
            continue
        cnt_file += 1
        file_path = os.path.join(args.root_dir, args.data_dir, 'dataschema', file)
        with open(file_path, 'r', encoding='utf-8') as f:
            soup = BeautifulSoup(f.read(), 'xml')
        for index_set in soup.find('str:Codelists').find_all('str:Codelist'):
            index_name = normalize(index_set.find("com:Name", attrs={'xml:lang': 'en'}).text)
            cnt_index_name += 1
            for index in index_set.find_all('str:Code'):
                dic.setdefault(index_name, set()).add(normalize(index.find('com:Name', attrs={'xml:lang': 'en'}).text))

    cnt_multi_iname = 0
    for k, v in dic.items():
        for index in v:
            rev_dic.setdefault(index, set()).add(k)
            if len(rev_dic[index]) > 1:
                cnt_multi_iname += 1

    dic, rev_dic = set2list(dic), set2list(rev_dic)

    print(f"num of structure file: {cnt_file}")
    print(f"num of index_name(maybe duplicate): {cnt_index_name}")
    print(f"num of different index_name: {len(dic)}")
    print(f"num of different index: {len(rev_dic)}")
    print(f"num of index with >1 index_name: {cnt_multi_iname}; proportion: {cnt_multi_iname / len(rev_dic)}")
    dic_json = json.dumps({'map': dic, "re_map": rev_dic})
    with open(os.path.join(args.root_dir, args.data_dir, 'semantic_map.json'), 'w') as f:
        f.write(dic_json)


def test_semantic_coverage(table_dict):
    with open(os.path.join(args.root_dir, args.data_dir, 'semantic_map.json')) as f:
        maps = json.loads(f.read())
        semantic_map = maps['map']
        re_semantic_map = maps['re_map']

    success_file = open(os.path.join(args.root_dir, args.data_dir, 'success_file.txt'), 'w')
    failure_file = open(os.path.join(args.root_dir, args.data_dir, 'failure_file.txt'), 'w')
    cnt_header, cnt_header_w_sem, cnt_header_w_single_sem = 0, 0, 0
    for i, table in table_dict.items():
        for x in table.html.find_all('tr'):
            for y in x.find_all(['th', 'td']):
                if y.name == 'th' or (y.name == 'td' and 'class' in y.attrs and y['class'][0] == 'row-heading'):
                    text = normalize(clear_footer(y))
                    if text == '':
                        continue
                    cnt_header += 1
                    if text in re_semantic_map:  # in StatCan knowledge base
                        print(f"{text}: {re_semantic_map[text]}", file=success_file)
                        cnt_header_w_sem += 1
                        if len(re_semantic_map[text]) == 1:
                            cnt_header_w_single_sem += 1
                    else:  # parse index_name by coreNLP. Only "datetime" and "number" supported.
                        _, type = infer_type(text, client)
                        if type == 'datetime' or type == 'number':
                            print(f"{text}: {type}", file=success_file)
                            cnt_header_w_sem += 1
                            cnt_header_w_single_sem += 1
                        else:  # no semantic index_name # TODO: add other sources, like DBpedia
                            print(text, file=failure_file)

    success_file.close()
    failure_file.close()

    print(f"num of header: {cnt_header}")
    print(f"num of header w semantic: {cnt_header_w_sem}")
    print(f"num of header w single semantic: {cnt_header_w_single_sem}")
    print(f"header w semantic proportion: {cnt_header_w_sem/cnt_header}")
    print(f"header w single semantic proportion: {cnt_header_w_single_sem/cnt_header}")


def main():
    # print("---------------------Generate semantic map----------------")
    # generate_semantic_map()
    # print("Done.")

    print("---------------------Loading-----------------")
    table_dict = load_tables(args.root_dir, os.path.join(args.data_dir, 'html/'))
    print("Done.")

    print("---------------------Test semantic coverage---------------")
    test_semantic_coverage(table_dict)
    print("Done.")

    client.stop()


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')

    args = parser.parse_args()
    main()
