import sys
import argparse
import json
import random
import time
import traceback
from tqdm import tqdm
from collections import defaultdict, Counter
from multiprocessing import Process, Lock, Queue, Manager

from questions import funcs
from data import Data
from utils.check import check_sparql, check_valid
import conf

proc_timeout = 30

def worker(data, questions, type_cnt, ans_cnt, ans_max_num, queue):
    while True:
        # q_type = queue.get(True, timeout=proc_timeout)
        q_type = queue.get(True)
        if q_type is None: # use None as a terminal signal
            break
        f = funcs[q_type]
        try:
            question = f(data)
        except Exception as e:
            # traceback.print_exc()
            continue
        if question:
            ans = str(question.answer)
            if not check_valid(question):
                continue
            if conf.virtuoso_validate and not check_sparql(question):
                continue
            if ans not in ans_cnt:
                ans_cnt[ans] = 0
            if ans not in {'yes', 'no'} and ans_cnt[ans] > ans_max_num:
                continue
            questions.append(question)
            type_cnt[q_type] += 1
            ans_cnt[ans] += 1

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--total_num", type=int, required=True)
    parser.add_argument("--freq_high_thresh", type=float, default=0.2, help="For each question type, one answer's ratio cannot exceed this threshold")
    parser.add_argument("--out_json", required=True)
    parser.add_argument("--n_proc", type=int, default=1)
    args = parser.parse_args()

    data = Data()
    per_num = args.total_num // len(funcs)
    print('Start generating. There are {} types, {} questions, each has about {}.'.format(len(funcs), args.total_num, per_num))
    tic = time.time()

    manager = Manager()
    type_cnt = manager.dict()
    ans_cnt = manager.dict()
    questions = manager.list()
    queue = Queue(maxsize=args.n_proc)
    procs = []
    for i in range(args.n_proc):
        p = Process(target=worker, args=(data, questions, type_cnt, ans_cnt, int(per_num*args.freq_high_thresh), queue))
        p.start()
        procs.append(p)
    for q_type in funcs:
        type_cnt[q_type] = 0
        t = tqdm(total=per_num, desc=q_type)
        while type_cnt[q_type] < per_num:
            queue.put(q_type, timeout=proc_timeout)
            t.update(type_cnt[q_type] - t.n)
        t.close()
    for i in range(args.n_proc):
        queue.put(None) # give enough signals
    for i in range(args.n_proc):
        procs[i].join()

    questions = [q.dict() for q in set(questions)]
    random.shuffle(questions)
    print('Finish generation. Take {} minutes. Get {} questions, save into {}'.format((time.time()-tic)//60, len(questions), args.out_json))
    with open(args.out_json, 'w') as f:
        json.dump(questions, f)

    print('Statistics...')
    # distribution of answers
    cnt_ans = defaultdict(int)
    for q in questions:
        cnt_ans[q['answer']] += 1
    print('='*30)
    print('top 10 answer distribution:')
    for k, v in sorted(cnt_ans.items(), key=lambda item: item[1], reverse=True)[:10]:
        print('{}: {:.2f}%'.format(k, v*100/len(questions)))

    # distribution of question type
    map_last_func_to_qtype = {}
    for k, v in {
        'what_is_entity': 'What',
        'how_many_entities': 'Count',
        'which_is_most_among': 'SelectAmong',
        'which_is_more_between': 'SelectBetween',
        'what_is_attribute': ['QueryAttr', 'QueryAttrUnderCondition'],
        'is_attribute_satisfy': ['VerifyStr', 'VerifyNum', 'VerifyDate', 'VerifyYear'],
        'what_is_attribute_qualifier': 'QueryAttrQualifier',
        'what_is_relation': 'QueryRelation',
        'what_is_relation_qualifier': 'QueryRelationQualifier',
    }.items():
        if isinstance(v, list):
            for _ in v:
                map_last_func_to_qtype[_] = k
        else:
            map_last_func_to_qtype[v] = k
    cnt_qtype = defaultdict(int)
    for q in questions:
        f = q['program'][-1]['function']
        qtype = map_last_func_to_qtype[f]
        cnt_qtype[qtype] += 1
    print('='*30)
    print('question type distribution')
    for k, v in cnt_qtype.items():
        print('{}: {:.2f}%'.format(k, 100*v/len(questions)))

    # distribution of hop
    cnt_multi_hop = 0
    for q in questions:
        for f in q['program']:
            if f['function'] == 'Relate':
                cnt_multi_hop += 1
                break
    print('='*30)
    print('multi-hop questions: {}, {:.2f}% of all'.format(
        cnt_multi_hop, 
        cnt_multi_hop*100/len(questions), 
        cnt_multi_hop*100/(cnt_qtype['what_is_entity']+cnt_qtype['how_many_entities'])
        ))

    # distribution of qualifier
    cnt_qual = 0
    for q in questions:
        for f in q['program']:
            if f['function'] in {'QFilterStr', 'QFilterNum', 'QFilterYear', 'QFilterDate', 'QueryAttrUnderCondition', 'QueryAttrQualifier', 'QueryRelationQualifier'}:
                cnt_qual += 1
                break
    print('='*30)
    print('qualifier questions: {}, {:.2f}% of all'.format(
        cnt_qual, 
        cnt_qual*100/len(questions), 
        ))

    # distribution of op
    cnt_op = defaultdict(int)
    for q in questions:
        for f in q['program']:
            op = None
            if f['function'] in {'VerifyNum', 'VerifyDate', 'VerifyYear'}:
                op = f['inputs'][1]
            # elif f['function'] in {'FilterNum', 'FilterYear', 'FilterDate', 'QFilterNum', 'QFilterYear', 'QFilterDate'}:
            elif f['function'] in {'FilterNum', 'FilterYear', 'FilterDate'}:
                op = f['inputs'][2]
            # elif f['function'] in {'FilterStr', 'QFilterStr', 'VerifyStr'}:
            elif f['function'] in {'FilterStr', 'VerifyStr'}:
                op = '='
            if op:
                cnt_op[op] += 1
    total_num = sum(cnt_op.values())
    print('='*30)
    print('op distribution')
    for k, v in cnt_op.items():
        print('{}: {:.2f}%'.format(k, 100*v/total_num))


if __name__ == '__main__':
    main()
