from collections import defaultdict
import random
import numpy as np
import json

from structures import Function, Question
from descriptions import by_entity, by_concept, by_attribute, by_relation, by_two_conditions_intersect, by_two_conditions_union, by_two_hop_relation
from conditions import sample_by_kv
from utils.misc import and_two_descriptions, or_two_descriptions, hop_two_descriptions, replace_duplicate_variables
from utils.check import check_attr, check_sparql, check_valid
from sparqlEngine import SparqlEngine, legal
from choice import ChoiceGen
import conf

def what_is_entity(data):
    target = random.choice(list(data.entities.keys()))
    methods = [by_attribute, by_relation, by_two_conditions_intersect, by_two_hop_relation]
    method = np.random.choice(methods)
    condition = method(data, target)
    if condition is None:
        return None

    if method == by_two_conditions_intersect:
        condition1, condition2 = condition
        desc = and_two_descriptions(condition1, condition2)
    elif method == by_two_hop_relation:
        condition, hop_condition = condition
        desc = hop_two_descriptions(condition, hop_condition)
    else:
        desc = condition.description()

    text, program, desc_sparql = desc
    if data.is_human(target):
        text = 'Who is {}'.format(text)
    else:
        text = 'What is {}'.format(text)
    program = program + [ Function('What', [program[-1]], []) ]
    sparql = 'SELECT DISTINCT ?e WHERE {{ {} }}'.format(desc_sparql)
    answer = data.get_name(target)
    choices = ChoiceGen.for_entity(data, target, desc_sparql)
    info = {'id': target, 'parse_type': 'name'}
    return Question(text, program, sparql, answer, choices, info)



def how_many_entities(data):
    methods = [by_attribute, by_relation, by_two_conditions_intersect, by_two_conditions_union, by_two_hop_relation]
    method = np.random.choice(methods)
    condition = method(data)
    if condition is None:
        return None

    if method == by_two_conditions_intersect:
        condition1, condition2 = condition
        desc = and_two_descriptions(condition1, condition2)
        results1 = condition1.filter_facts()
        results2 = condition2.filter_facts()
        legal_entities = list(set(results1) & set(results2))
    elif method == by_two_conditions_union:
        condition1, condition2 = condition
        desc = or_two_descriptions(condition1, condition2)
        results1 = condition1.filter_facts()
        results2 = condition2.filter_facts()
        legal_entities = list(set(results1) | set(results2))
    elif method == by_two_hop_relation:
        condition, hop_condition = condition
        desc = hop_two_descriptions(condition, hop_condition)
        legal_entities = condition.filter_facts()
    else:
        desc = condition.description()
        legal_entities = condition.filter_facts()

    text, program, sparql = desc
    if text.split()[0] == 'the':
        text = ' '.join(text.split()[1:])
    text = 'How many %s' % text
    program = program + [ Function('Count', [program[-1]], []) ]
    sparql = 'SELECT (COUNT(DISTINCT ?e) AS ?count) WHERE {{ {} }}'.format(sparql)
    answer = len(legal_entities)
    choices = ChoiceGen.for_count(answer)
    info = {'legal_entities': legal_entities, 'parse_type': 'count'}
    return Question(text, program, sparql, answer, choices, info)




def which_is_most_among(data):
    methods = [by_concept, by_attribute, by_relation]
    method = np.random.choice(methods)
    condition = method(data)
    if condition is None:
        return None

    entities = condition.filter_facts()
    if len(entities) <= 1:
        return None
    key_units = []
    for e in entities:
        ku = set()
        for (k, v, _) in data.get_attribute_facts(e):
            if v.type == 'quantity':
                ku.add((k, v.unit))
        key_units.append(ku)
    key_unit = key_units[0]
    for ku in key_units:
        key_unit = key_unit & ku
    if len(key_unit) == 0:
        return None

    key, unit = random.choice(list(key_unit))
    values = []
    for e in entities:
        facts = data.get_attribute_facts(e, key)
        # if there are multiple values for one entity (maybe different units), we will skip this condition
        if len(facts) != 1:
            return None
        k, v, _ = facts[0]
        values.append(v)
    text, program, sparql = condition.description()
    op = 'largest' if random.random() < 0.5 else 'smallest'
    text = 'Which one has the {} {} among {}'.format(op, data.describe_attribute_key(key), text)
    f = Function('SelectAmong', [program[-1]], [key, op])
    program = program + [f]
    sparql = SparqlEngine.append_attribute_value_query(sparql, key, 'quantity')
    if op == 'largest':
        j, i = np.argsort(values)[-2:]
        sparql = 'SELECT ?e WHERE {{ {} }} ORDER BY DESC(?v) LIMIT 1'.format(sparql)
    else:
        i, j = np.argsort(values)[0:2]
        sparql = 'SELECT ?e WHERE {{ {} }} ORDER BY ?v LIMIT 1'.format(sparql)
    if values[i] == values[j]: # if there are more than one largest/smallest values
        return None
    answer = data.get_name(entities[i])
    choices = ChoiceGen.for_entity_selection(data, entities[i], entities)
    info = { 'parse_type': 'name', 'id':entities[i], 'value':str(values[i]), 'candidates': { e: str(values[j]) for j, e in enumerate(entities) } }
    return Question(text, program, sparql, answer, choices, info)


def which_is_more_between(data):
    ent_1 = random.choice(list(data.entities.keys()))
    # remove keys that have multiple values (e.g., population is 10,000 in 1990 and 11,000 in 1991)
    cnt = defaultdict(list)
    for (k, v, _) in data.get_attribute_facts(ent_1):
        if v.type == 'quantity':
            cnt[k].append(v)
    key_units = [(k, vs[0].unit) for k, vs in cnt.items() if len(vs) == 1]
    if len(key_units) == 0:
        return None
    
    key, unit = random.choice(key_units)
    candidate_entities = set()
    for c in data.get_all_concepts(ent_1):
        for e in data.concept_to_entity[c]:
            # remove keys that have multiple values (e.g., population is 10,000 in 1990 and 11,000 in 1991)
            cnt = defaultdict(list)
            for (k, v, _) in data.get_attribute_facts(e):
                if v.type == 'quantity':
                    cnt[k].append(v)
            key_units = [(k, vs[0].unit) for k, vs in cnt.items() if len(vs) == 1]
            if (key, unit) in key_units:
                candidate_entities.add(e)
    if len(candidate_entities) == 0:
        return None

    ent_2 = random.choice(list(candidate_entities))
    facts1 = data.get_attribute_facts(ent_1, key, unit)
    facts2 = data.get_attribute_facts(ent_2, key, unit)
    if len(facts1) != 1 or len(facts2) != 1:
        return None
    value1 = facts1[0][1]
    value2 = facts2[0][1]
    if value1.unit != value2.unit or value1 == value2:
        return None

    text1, program1, sparql1 = by_entity(data, ent_1).description()
    text2, program2, sparql2 = by_entity(data, ent_2).description()
    op = 'greater' if random.random() < 0.5 else 'less'
    text = 'Which one has {} {}, {} or {}'.format(op, data.describe_attribute_key(key), text1, text2)
    f = Function('SelectBetween', [program1[-1], program2[-1]], [key, op])
    program = program1 + program2 + [f]
    sparql1, sparql2 = replace_duplicate_variables(sparql1, sparql2)
    sparql = '{{ {} }} UNION {{ {} }} '.format(sparql1, sparql2)
    sparql = SparqlEngine.append_attribute_value_query(sparql, key, 'quantity')
    if op == 'greater':
        target = ent_1 if value1 > value2 else ent_2
        sparql = 'SELECT ?e WHERE {{ {} }} ORDER BY DESC(?v) LIMIT 1'.format(sparql)
    else:
        target = ent_1 if value1 < value2 else ent_2
        sparql = 'SELECT ?e WHERE {{ {} }} ORDER BY ?v LIMIT 1'.format(sparql)
    answer = data.get_name(target)
    choices = ChoiceGen.for_entity_selection(data, target, [ent_1, ent_2])
    info = { 'parse_type': 'name', 'values': [str(value1), str(value2)]}
    return Question(text, program, sparql, answer, choices, info)



def what_is_attribute(data):
    target = random.choice(list(data.entities.keys()))
    if len(data.get_attribute_facts(target)) == 0:
        return None
    methods = [by_entity, by_attribute, by_relation]
    method = np.random.choice(methods)
    condition = method(data, target)
    if condition is None:
        return None

    text, program, sparql = condition.description()
    keys = []
    for (k, v, _) in data.get_attribute_facts(target):
        if k not in text: # avoid repeat
            keys.append(k)
    if len(keys) == 0:
        return None
    key = random.choice(keys)
    facts = data.get_attribute_facts(target, key)
    prefix = 'his/her' if data.is_human(target) else 'its'
    if len(facts) == 1:
        v = facts[0][1]
        text = 'For {}, what is {} {}'.format(text, prefix, data.describe_attribute_key(key))
        f = Function('QueryAttr', [program[-1]], [key])
    else:
        k, v, qualifiers = random.choice(facts)
        if len(qualifiers) == 0:
            return None
        qk, qvs = random.choice(list(qualifiers.items()))
        qv = random.choice(qvs)
        # if multiple values have the same (k, qk, qv), then skip
        cnt = 0
        for (k_, v_, qualifiers_) in facts:
            for qk_, qvs_ in qualifiers.items():
                for qv_ in qvs_:
                    if qk_ == qk and qv_ == qv:
                        cnt += 1
        if cnt > 1:
            return None
        text = 'For {}, what is {} {} ({})'.format(text, prefix, data.describe_attribute_key(key), data.describe_attribute_fact('this statement', qk, qv))
        f = Function('QueryAttrUnderCondition', [program[-1]], [key, qk, qv])
    program = program + [f]
    # use '' as v_type to get ?pv as the target variable
    sparql = SparqlEngine.append_attribute_value_query(sparql, key, '')
    sparql = 'SELECT DISTINCT ?pv WHERE {{ {} }}'.format(sparql)
    answer = v
    choices = ChoiceGen.for_attribute_value(data, key, v)
    info = {'parse_type': 'attr_{}'.format(v.type)}
    return Question(text, program, sparql, answer, choices, info)



def is_attribute_satisfy(data):
    target = random.choice(list(data.entities.keys()))
    if len(data.get_attribute_facts(target)) == 0:
        return None
    methods = [by_entity, by_attribute, by_relation]
    method = np.random.choice(methods)
    condition = method(data, target)
    if condition is None:
        return None

    text, program, sparql = condition.description()
    keys = []
    for (k, v, _) in data.get_attribute_facts(target):
        if k not in text: # avoid repeat
            keys.append(k)
    if len(keys) == 0:
        return None
    key = random.choice(keys)
    facts = data.get_attribute_facts(target, key)
    prefix = 'his/her' if data.is_human(target) else 'its'
    if len(facts) == 1:
        text = 'For {}, is {} {}'.format(text, prefix, data.describe_attribute_key(key))
        qual_text = ''
        f = Function('QueryAttr', [program[-1]], [key])
        program = program + [f]
        real_value = facts[0][1]
    else:
        k, v, qualifiers = random.choice(facts)
        if len(qualifiers) == 0:
            return None
        qk, qvs = random.choice(list(qualifiers.items()))
        qv = random.choice(qvs)
        cnt = 0
        for (k_, v_, qualifiers_) in facts:
            for qk_, qvs_ in qualifiers_.items():
                for qv_ in qvs_:
                    if qk_ == qk and qv_ == qv:
                        cnt += 1
        if cnt > 1:
            return None
        text = 'For {}, is {} {}'.format(text, prefix, data.describe_attribute_key(key))
        qual_text = data.describe_attribute_fact('this statement', qk, qv)
        f = Function('QueryAttrUnderCondition', [program[-1]], [key, qk, qv])
        program = program + [f]
        real_value = v
    # to make yes/no nearly balanced, we sample a yes-condition with a probability 40%
    if random.random() < 0.4:
        condition = sample_by_kv(data, None, key, real_value)
    else:
        condition = sample_by_kv(data, None, key) # TODO: should use the same unit!!!
    if condition is None:
        return None

    key, op, value = condition
    if value.type == 'string':
        op_desc = 'equal to'
        f = Function('VerifyStr', [program[-1]], [value])
    elif value.type == 'quantity':
        if op == '<':
            op_desc = 'less than'
        elif op == '>':
            op_desc = 'greater than'
        elif op == '=':
            op_desc = 'equal to'
        elif op == '!=':
            op_desc = 'not equal to'
        f = Function('VerifyNum', [program[-1]], [value, op])
    elif value.type == 'date':
        f = Function('VerifyDate', [program[-1]], [value, op])
        if op == '=':
            op_desc = 'on'
        else:
            raise Exception('Not support')
    elif value.type == 'year':
        f = Function('VerifyYear', [program[-1]], [value, op])
        if op == '<':
            op_desc = 'before'
        elif op == '>':
            op_desc = 'after'
        elif op == '=':
            op_desc = 'in'
        elif op == '!=':
            op_desc = 'not in'

    if qual_text:
        text = '{} {} {} ({})?'.format(text, op_desc, value, qual_text)
    else:
        text = '{} {} {} ?'.format(text, op_desc, value)
    program = program + [f]

    sparql, _ = SparqlEngine.replace_variable(sparql, '?pv')
    sparql, _ = SparqlEngine.replace_variable(sparql, '?v')
    sparql = sparql + SparqlEngine.gen_attribute_query(key, value, op)
    if len(facts) > 1:
        fact_node = SparqlEngine.gen_attr_fact_node(key)
        sparql, _ = SparqlEngine.replace_variable(sparql, '?qpv')
        sparql, _ = SparqlEngine.replace_variable(sparql, '?qv')
        sparql = sparql + SparqlEngine.gen_attribute_query(qk, qv, '=', e=fact_node, in_qualifier=True)
    sparql = 'ASK {{ {} }}'.format(sparql)

    answer = 'yes' if check_attr(key, real_value, condition) else 'no'
    choices = ChoiceGen.for_binary()
    info = {'parse_type': 'bool', 'id': target, 'name': data.get_name(target), 'real_value': str(real_value)}
    return Question(text, program, sparql, answer, choices, info)



def what_is_attribute_qualifier(data):
    target = random.choice(list(data.entities.keys()))
    if len(data.get_attribute_facts(target)) == 0:
        return None
    methods = [by_entity, by_attribute, by_relation]
    method = np.random.choice(methods)
    condition = method(data, target)
    if condition is None:
        return None

    text, program, sparql = condition.description()
    keys = []
    for (k, v, _) in data.get_attribute_facts(target):
        if k not in text: # avoid repeat
            keys.append(k)
    if len(keys) == 0:
        return None
    key = random.choice(keys)
    facts = data.get_attribute_facts(target, key)
    facts_with_q = [f for f in facts if len(f[2]) > 0]
    # filter out facts with the same k, v (e.g., the inflation rate of a country is 1.5 in 1999 and in 2007)
    facts_kv_count = defaultdict(int)
    for k, v, _ in facts_with_q:
        facts_kv_count['{}_{}'.format(k, v)] += 1
    facts_with_q = [(k, v, q) for (k, v, q) in facts_with_q if facts_kv_count['{}_{}'.format(k, v)] == 1]
    # filter out facts whose v contains others 
    # (e.g., there are two facts with the same k, but v are 1990 and 1990-01-01 respectively, then the one of 1990 should be removed)
    facts_with_q_ = []
    for i in range(len(facts_with_q)):
        rm = False
        if facts_with_q[i][1].isTime():
            for j in range(len(facts_with_q)):
                if i!=j and facts_with_q[i][0]==facts_with_q[j][0] and \
                facts_with_q[j][1].isTime() and facts_with_q[i][1].contains(facts_with_q[j][1]):
                    rm = True
                    break
        if not rm:
            facts_with_q_.append(facts_with_q[i])
    facts_with_q = facts_with_q_
    if len(facts_with_q) == 0:
        return None

    key, value, qualifiers = random.choice(facts_with_q)
    qualifiers = [(qk, qvs[0]) for qk, qvs in qualifiers.items() if len(qvs) == 1]
    if len(qualifiers) == 0:
        return None
    qk, qv = random.choice(qualifiers)
    prefix = 'his/her' if data.is_human(target) else 'its'
    text = 'For {}, {} {} is {}, {}'.format(text, prefix, key, value, data.ask_qualifier_key(qk))
    f = Function('QueryAttrQualifier', [program[-1]], [key, str(value), qk])
    program = program + [f]

    sparql, _ = SparqlEngine.replace_variable(sparql, '?pv')
    sparql, _ = SparqlEngine.replace_variable(sparql, '?v')
    sparql = sparql + SparqlEngine.gen_attribute_query(key, value)
    fact_node = SparqlEngine.gen_attr_fact_node(key)
    sparql = SparqlEngine.append_attribute_value_query(sparql, qk, '', e=fact_node, in_qualifier=True)
    sparql = 'SELECT DISTINCT ?qpv WHERE {{ {} }}'.format(sparql)

    answer = qv
    choices = ChoiceGen.for_attribute_value(data, qk, qv)
    info = {'parse_type': 'attr_{}'.format(qv.type)}
    return Question(text, program, sparql, answer, choices, info)



def what_is_relation(data):
    ent_1 = random.choice(list(data.entities.keys()))
    if len(data.get_relation_facts(ent_1)) == 0:
        return None
    predicate, ent_2, direction, qualifiers = random.choice(data.get_relation_facts(ent_1))
    # if ent_2 is a concept of ent_1, then skip
    if ent_2 in data.get_all_concepts(ent_1):
        return None
    # if there are multiple relations between ent_1 and ent_2, then skip
    cnt = 0
    for p, e, d, q in data.get_relation_facts(ent_1):
        if e == ent_2 and d == direction:
            cnt += 1
    if cnt > 1:
        return None

    if direction == 'backward': # make sure ent_1 is the subject
        ent_1, ent_2 = ent_2, ent_1
    condition1 = by_entity(data, ent_1)
    condition2 = by_entity(data, ent_2)
    if condition1 is None or condition2 is None:
        return None
    text1, program1, sparql1 = condition1.description()
    text2, program2, sparql2 = condition2.description()
    text = 'What is the relation from {} to {}'.format(text1, text2)
    f = Function('QueryRelation', [program1[-1], program2[-1]], [])
    program = program1 + program2 + [f]

    sparql1, sparql2 = replace_duplicate_variables(sparql1, sparql2, same_sub=False)
    sparql = sparql1 + sparql2 + '?e_1 ?p ?e_2 . '
    sparql = 'SELECT DISTINCT ?p WHERE {{ {} }}'.format(sparql)

    answer = predicate
    choices = ChoiceGen.for_relation(data, answer)
    info = {'parse_type': 'pred'}
    return Question(text, program, sparql, answer, choices, info)



def what_is_relation_qualifier(data):
    ent_1 = random.choice(list(data.entities.keys()))
    facts = [r for r in data.get_relation_facts(ent_1) if len(r[3]) > 0]
    # filter out facts with the same p, o, d (e.g., A wins B multiple times)
    facts_count = defaultdict(int)
    for p, o, d, q in facts:
        facts_count['{}_{}_{}'.format(p, o, d)] += 1
    facts = [(p, o, d, q) for (p,o,d,q) in facts if facts_count['{}_{}_{}'.format(p, o, d)] == 1]
    if len(facts) == 0:
        return None

    predicate, ent_2, direction, qualifiers = random.choice(facts)
    qualifiers = [(qk, qvs[0]) for qk, qvs in qualifiers.items() if len(qvs) == 1]
    if len(qualifiers) == 0:
        return None
    qk, qv = random.choice(qualifiers)

    if direction == 'backward': # make sure ent_1 is the subject
        ent_1, ent_2 = ent_2, ent_1
    condition1 = by_entity(data, ent_1)
    condition2 = by_entity(data, ent_2)
    if condition1 is None or condition2 is None:
        return None
    text1, program1, sparql1 = condition1.description()
    text2, program2, sparql2 = condition2.description()

    text = '{}, {}'.format(data.describe_relation_fact(text1, predicate, text2, 'forward'), data.ask_qualifier_key(qk))
    f = Function('QueryRelationQualifier', [program1[-1], program2[-1]], [predicate, qk])
    program = program1 + program2 + [f]

    sparql1, sparql2 = replace_duplicate_variables(sparql1, sparql2, same_sub=False)
    sparql = sparql1 + sparql2 + '?e_1 <{}> ?e_2 . '.format(legal(predicate))
    fact_node = SparqlEngine.gen_rel_fact_node(predicate, 'forward', '?e_2', '?e_1')
    sparql = SparqlEngine.append_attribute_value_query(sparql, qk, '', e=fact_node, in_qualifier=True)
    sparql = 'SELECT DISTINCT ?qpv WHERE {{ {} }}'.format(sparql)

    answer = qv
    choices = ChoiceGen.for_attribute_value(data, qk, qv)
    info = {'parse_type': 'attr_{}'.format(qv.type)}
    return Question(text, program, sparql, answer, choices, info)



funcs = {
    'what_is_entity': what_is_entity,
    'how_many_entities': how_many_entities,
    'which_is_most_among': which_is_most_among,
    'which_is_more_between': which_is_more_between,
    'what_is_attribute': what_is_attribute,
    'is_attribute_satisfy': is_attribute_satisfy,
    'what_is_attribute_qualifier': what_is_attribute_qualifier,
    'what_is_relation': what_is_relation,
    'what_is_relation_qualifier': what_is_relation_qualifier,
    }



def test_all(times=100, out_json=None):
    from data import Data
    data = Data()
    stderr = open('results/err.txt', 'w')
    questions = []
    
    for desc, f in funcs.items():
        stdout = open('results/{}.txt'.format(desc), 'w')
        cnt = 0
        while cnt < times:
            question = f(data)
            if question:
                if not check_valid(question):
                    continue

                correct = True
                if conf.virtuoso_validate:
                    correct = check_sparql(question)
                
                if correct:
                    stdout.write(str(question))
                    questions.append(question.dict())
                    cnt += 1
                    print('{}: {} / {}'.format(desc, cnt, times))
                else:
                    stderr.write(str(question))
        stdout.close()
    stderr.close()

    if out_json:
        print('get {} valid questions, save into {}'.format(len(questions), out_json))
        with open(out_json, 'w') as f:
            json.dump(questions, f)


if __name__ == '__main__':
    test_all()
    # test_all(250, 'results/sample2k.json')
