from structures import ValueClass
from sparqlEngine import SparqlEngine, legal

def cmp(a, b, op):
    if b.isTime():
        # Note: for time, 'a=b' actually means a in b, 'a!=b' means a not in b
        if op == '=':
            return b.contains(a)
        elif op == '!=':
            return not b.contains(a)
    if op == '=':
        return a == b
    elif op == '<':
        return a < b
    elif op == '>':
        return a > b
    elif op == '!=':
        return a != b

def check_attr(k, v, condition):
    """
    condition can be either attribute or qualifier
    Args:
        - k (str)
        - v (ValueClass)
        - condition
    """
    key, op, value = condition
    if k != key:
        return False
    if not v.can_compare(value):
        return False
    return cmp(v, value, op)


def check_pred(pred, obj, direction, condition):
    # condition is target (pred, object, direction)
    for a, b in zip((pred, obj, direction), condition):
        if a != b:
            return False
    return True


def check_sparql(question):
    sparql, given_answer = question.sparql, question.answer
    parse_type = question.info['parse_type']

    res = SparqlEngine.query_virtuoso(sparql)
    if res.vars:
        res = [[binding[v] for v in res.vars] for binding in res.bindings]
        if len(res) != 1:
            return False
    else:
        res = res.askAnswer
        assert parse_type == 'bool'
    
    if parse_type == 'name':
        node = res[0][0]
        sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v .  }}'.format(node, SparqlEngine.PRED_NAME)
        res = SparqlEngine.query_virtuoso(sp)
        res = [[binding[v] for v in res.vars] for binding in res.bindings]
        name = res[0][0].value
        return name == given_answer
    elif parse_type == 'count':
        count = res[0][0].value
        return count == given_answer
    elif parse_type.startswith('attr_'):
        assert isinstance(given_answer, ValueClass)
        node = res[0][0]
        v_type = parse_type.split('_')[1]
        unit = None
        if v_type == 'string':
            sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v .  }}'.format(node, SparqlEngine.PRED_VALUE)
        elif v_type == 'quantity':
            # Note: For those large number, ?v is truncated by virtuoso (e.g., 14756087 to 1.47561e+07)
            # To obtain the accurate ?v, we need to cast it to str
            sp = 'SELECT DISTINCT ?v,?u,(str(?v) as ?sv) WHERE {{ <{}> <{}> ?v ; <{}> ?u .  }}'.format(node, SparqlEngine.PRED_VALUE, SparqlEngine.PRED_UNIT)
        elif v_type == 'year':
            sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v .  }}'.format(node, SparqlEngine.PRED_YEAR)
        elif v_type == 'date':
            sp = 'SELECT DISTINCT ?v WHERE {{ <{}> <{}> ?v .  }}'.format(node, SparqlEngine.PRED_DATE)
        else:
            raise Exception('unsupported parse type')
        res = SparqlEngine.query_virtuoso(sp)
        res = [[binding[v] for v in res.vars] for binding in res.bindings]
        if v_type == 'quantity':
            value = float(res[0][2].value)
            unit = res[0][1].value
        else:
            value = res[0][0].value
        value = ValueClass(v_type, value, unit)
        return value.can_compare(given_answer) and value == given_answer
    elif parse_type == 'bool':
        return res == (given_answer == 'yes')
    elif parse_type == 'pred':
        return str(res[0][0]) == legal(given_answer)


def check_valid(question):
    '''
    check whether there is a duplicate in program
    e.g., what is the university of John (the one whose university is Tsinghua University)
    duplicates may exist between these function pairs:
        - Filter(Str/Num/Year/Date) when op is '=' and QueryAttr(UnderCondition)
        - Relate and Relate
        - To be completed. Seems not easy ...
    '''
    program = question.program
    cnt_qual = 0
    for i, f in enumerate(program):
        if f.function in {'QueryAttr', 'QueryAttrUnderCondition'}:
            g = program[i-1]
            if (g.function == 'FilterStr' or (g.function in {'FilterNum', 'FilterYear', 'FilterDate'} and g.inputs[-1] == '=')) and \
                g.inputs[0] == f.inputs[0]:
                # only consider the case that Query follows Filter
                return False
            
        if i > 0 and f.function == 'Relate' and program[i-1].function == 'Relate' and \
            f.inputs[0] == program[i-1].inputs[0] and f.inputs[1] != program[i-1].inputs[1]:
            # adjacent Relate, same predicate, different direction
            return False

        if f.function.startswith('QFilter') or f.function == 'QueryAttrUnderCondition':
            cnt_qual += 1

    if cnt_qual > 1:
        # there are too many qualifier restrictions
        return False

    return True
