import numpy as np
import random
from collections import defaultdict

import conf
from structures import Function, ValueClass
from sparqlEngine import SparqlEngine
from utils.misc import get_concept_name, reverse_dir
from utils.check import check_attr, check_pred


def sample_by_kv(data, concept, key, tgt_value=None, prob_op=conf.PROB_OP):
    """
        concept : limit possible values
        tgt_value :  should satisfy the sampled condition
    Return:
        condition (key, '</=/>', value)
    """
    def string_condition():
        if tgt_value:
            value = tgt_value
        else:
            if concept and len(data.concept_key_values[concept][key]) > 0:
                possible_values = data.concept_key_values[concept][key]
            else:
                possible_values = data.key_values[key]
            value = random.choice(list(possible_values))
        condition = (key, '=', value)
        return condition

    def quantity_condition():
        op = np.random.choice(conf.OPS, p=prob_op)
        if tgt_value:
            unit = tgt_value.unit
            possible_values = [ValueClass('quantity', v, unit) for v in data.ku2grid[key][unit]]
            if op == '<': # tgt < value
                value = random.choice([v for v in possible_values if v > tgt_value])
            elif op == '=':
                value = tgt_value
            elif op == '>': # tgt > value
                value = random.choice([v for v in possible_values if v < tgt_value])
            elif op == '!=': # tgt != value
                value = random.choice([v for v in possible_values if v != tgt_value])
        else:
            unit = random.choice(list(data.ku2grid[key].keys()))
            possible_values = [ValueClass('quantity', v, unit) for v in data.ku2grid[key][unit]]
            value = random.choice(possible_values)
        condition = (key, op, value)
        return condition

    def date_condition():
        op = np.random.choice(conf.OPS, p=prob_op)
        if tgt_value:
            # not equal condition only for year
            possible_values = [v.convert_to_year() for v in data.key_values[key]]
            if op == '<':
                value = random.choice([v for v in possible_values if v > tgt_value])
            elif op == '=':
                if random.random() < 0.5:
                    value = tgt_value
                else:
                    value = tgt_value.convert_to_year()
            elif op == '>':
                value = random.choice([v for v in possible_values if v < tgt_value])
            elif op == '!=':
                value = random.choice([v for v in possible_values if v != tgt_value])
        else:
            use_year = False
            if op == '<' or op == '>' or op == '!=':
                use_year = True
            elif random.random() < 0.5:
                use_year = True
            if use_year:
                if concept and len(data.concept_key_values[concept][key]) > 0:
                    possible_values = [v.convert_to_year() for v in data.concept_key_values[concept][key]]
                else:
                    possible_values = [v.convert_to_year() for v in data.key_values[key]]
            else:
                possible_values = []
                if concept and len(data.concept_key_values[concept][key]) > 0:
                    # as we filter out some v whose type != 'date', possible_values may be empty
                    possible_values = [v for v in data.concept_key_values[concept][key] if v.type == 'date']
                if not possible_values:
                    possible_values = [v for v in data.key_values[key] if v.type == 'date']
            value = random.choice(possible_values)
        condition = (key, op, value)
        return condition


    def year_condition():
        op = np.random.choice(conf.OPS, p=prob_op)
        if concept and len(data.concept_key_values[concept][key]) > 0:
            possible_values = [v.convert_to_year() for v in data.concept_key_values[concept][key]]
        else:
            possible_values = [v.convert_to_year() for v in data.key_values[key]]
        if tgt_value:
            if op == '<':
                value = random.choice([v for v in possible_values if v > tgt_value])
            elif op == '=':
                value = tgt_value
            elif op == '>':
                value = random.choice([v for v in possible_values if v < tgt_value])
            elif op == '!=':
                value = random.choice([v for v in possible_values if v != tgt_value])
        else:
            value = random.choice(possible_values)
        condition = (key, op, value)
        return condition

    if tgt_value:
        key_type = tgt_value.type
    else:
        key_type = data.key_type[key]

    if key_type == 'string':
        condition = string_condition()
    elif key_type == 'quantity':
        condition = quantity_condition()
    elif key_type == 'date':
        condition = date_condition()
    elif key_type == 'year':
        condition = year_condition()

    return condition


def append_qualifier_description(data, qualifier_condition, fact_repr, text, program, sparql):
    qk, qop, qv = qualifier_condition
    assert qop == '='
    if qv.type == 'string':
        program += [ Function('QFilterStr', [program[-1]], [qk, qv]) ]
    elif qv.type == 'quantity':
        program += [ Function('QFilterNum', [program[-1]], [qk, qv, qop]) ]
    elif qv.type == 'year':
        program += [ Function('QFilterYear', [program[-1]], [qk, qv, qop]) ]
    elif qv.type == 'date':
        program += [ Function('QFilterDate', [program[-1]], [qk, qv, qop]) ]

    text += ' ({})'.format(data.describe_attribute_fact('this statement', qk, qv))
    sparql += SparqlEngine.gen_attribute_query(qk, qv, qop, e=fact_repr, in_qualifier=True)
    return text, program, sparql



class EntityCondition():
    def __init__(self, data, desc):
        self.data = data
        self.desc = desc # text, program, sparql

    @classmethod
    def sample(cls, data, ent_id):
        def find_attributes_that_can_disambiguate():
            name = data.get_name(ent_id)
            kv_dict = {}
            for (k, v, q) in data.get_attribute_facts(ent_id):
                kv_dict[(k, v)] = True
            for i in data.get_ids(name):
                if i == ent_id or data.is_concept(i):
                    continue
                for (k_, v_, q) in data.get_attribute_facts(i):
                    for (k, v) in kv_dict: # we must compare v as ValueClass, instead of hash values
                        if k_ == k:
                            if v.type=='date' and v_.type=='year': # year cannot bother a specific date, even though they are 'equal'
                                pass
                            elif v_.can_compare(v) and v_ == v:
                                kv_dict[(k, v)] = False
            return [(k, v) for (k, v), flag in kv_dict.items() if flag]

        def find_relations_that_can_disambiguate():
            name = data.get_name(ent_id)
            pod_dict = {}
            for (p, o, d, q) in data.get_relation_facts(ent_id):
                if data.is_ambiguous(o):
                    continue
                pod_dict[(p, o, d)] = True
            for i in data.get_ids(name):
                if i == ent_id or data.is_concept(i):
                    continue
                for (p, o, d, q) in data.get_relation_facts(i):
                    pod_dict[(p, o, d)] = False
            return [(p, o, d) for (p, o, d), flag in pod_dict.items() if flag]
        
        ent_name = data.get_name(ent_id)
        if data.is_ambiguous(ent_id):
            if data.is_concept(ent_id):
                return None
            att_choices = find_attributes_that_can_disambiguate()
            rel_choices = find_relations_that_can_disambiguate()
            if not att_choices and not rel_choices:
                return None
            i = random.randint(0, len(att_choices)+len(rel_choices)-1)
            if i < len(att_choices):
                k, v = att_choices[i]

                text = '{} ({})'.format(ent_name, data.describe_attribute_subject('the one', k, v))

                f1 = Function('Find', [], [ent_name])
                if v.type == 'string':
                    f2 = Function('FilterStr', [f1], [k, v])
                elif v.type == 'quantity':
                    f2 = Function('FilterNum', [f1], [k, v, '='])
                elif v.type == 'date':
                    f2 = Function('FilterDate', [f1], [k, v, '='])
                elif v.type == 'year':
                    f2 = Function('FilterYear', [f1], [k, v, '='])
                program = [ f1, f2 ]
                sparql = SparqlEngine.gen_name_query(ent_name) + SparqlEngine.gen_attribute_query(k, v)
            else:
                pred, obj, direction = rel_choices[i - len(att_choices)]
                obj_name = data.get_name(obj)

                text = data.describe_relation_subject(ent_name, pred, obj_name, direction, the_one=True)

                f1 = Function('Find', [], [obj_name])
                f2 = Function('Relate', [f1], [pred, reverse_dir(direction)]) # 从obj往回找方向是反的
                f3 = Function('Find', [], [ent_name])
                f4 = Function('And', [f2, f3], [])
                program = [ f1, f2, f3, f4 ]

                obj_sparql, obj_variable = SparqlEngine.replace_variable(SparqlEngine.gen_name_query(obj_name), '?e')
                sparql = SparqlEngine.gen_name_query(ent_name) + SparqlEngine.gen_relation_query(pred, direction, obj_sparql, obj_variable)
                            
        else:
            text = ent_name
            program = [ Function('Find', [], [ent_name]) ]
            sparql = SparqlEngine.gen_name_query(ent_name)
        return cls(data, (text, program, sparql))

    def description(self):
        return self.desc


class ConceptCondition():
    def __init__(self, data, concept):
        self.data = data
        self.concept = concept

    @classmethod
    def sample(cls, data):
        concept = random.choice(data.high_freq_concepts)
        return cls(data, concept)

    def filter_facts(self, candidates=None):
        candidates_ = self.data.concept_to_entity[self.concept]
        if candidates:
            candidates = list(set(candidates_) & set(canddiates))
        else:
            candidates = candidates_
        return candidates

    def description(self):
        concept_name = self.data.get_name(self.concept)

        f1 = Function('FindAll', [], [])
        f2 = Function('FilterConcept', [f1], [concept_name])

        text = concept_name
        program = [f1, f2]
        sparql = SparqlEngine.gen_concept_query(concept_name)
        return text, program, sparql


class AttributeCondition():
    def __init__(self, data, concept, condition, qualifier_condition):
        self.data = data
        self.concept = concept
        self.condition = condition # key, op, value
        self.qualifier_condition = qualifier_condition # key, op, value of qualifier

    @classmethod
    def sample(cls, data, ent_id=None, concept=None):
        if ent_id:
            if concept is None:
                concept = random.choice(data.get_all_concepts(ent_id)) if random.random() < conf.PROB_USE_CONCEPT else None
            try:
                key, value, qualifiers = random.choice(data.get_attribute_facts(ent_id))
                condition = sample_by_kv(data, concept, key, value)
            except IndexError: # empty possible values
                return None
            except Exception:
                raise

            qualifier_condition = None
            if len(qualifiers) > 0:
                if random.random() < conf.PROB_USE_QUALIFIER:
                    qk, qvs = random.choice(list(qualifiers.items()))
                    qv = random.choice(qvs)
                    qualifier_condition = sample_by_kv(data, None, qk, qv, prob_op=(0,1,0,0))
                
            return cls(data, concept, condition, qualifier_condition)
        else:
            if concept is None:
                concept = random.choice(data.high_freq_concepts) if random.random() < conf.PROB_USE_CONCEPT else None
            if concept:
                keys = []
                for ent_id in data.concept_to_entity[concept]:
                    for (k, v, q) in data.get_attribute_facts(ent_id):
                        keys.append(k)
            else:
                keys = data.attribute_keys
            if len(keys) == 0:
                return None
            key = random.choice(keys)
            condition = sample_by_kv(data, concept, key)
            qualifier_condition = None
            return cls(data, concept, condition, qualifier_condition)

    def filter_facts(self, candidates=None):
        """
        Filter all entities that satisfy the condition
        Args:
            candidates (list) : entity ids
        """
        concept = self.concept
        if concept:
            candidates_ = self.data.concept_to_entity[concept]
        else:
            candidates_ = list(self.data.entities.keys())
        if candidates:
            candidates = list(set(candidates_) & set(canddiates))
        else:
            candidates = candidates_

        results = []
        for ent_id in candidates:
            flag = False
            for (k, v, qualifiers) in self.data.get_attribute_facts(ent_id):
                flag = check_attr(k, v, self.condition)
                if flag and self.qualifier_condition:
                    for qk, qvs in qualifiers.items():
                        for qv in qvs:
                            flag = check_attr(qk, qv, self.qualifier_condition)
                            if flag:
                                break
                        if flag:
                            break
                if flag:
                    break
            if flag:
                results.append(ent_id)
        return results


    def description(self):
        """
        Convert current condition to text, program, and sparql
        """
        key, op, value = self.condition
        concept_name = get_concept_name(self.data, self.concept)

        if self.concept:
            sparql = SparqlEngine.gen_concept_query(concept_name)
        else:
            sparql = ''

        program = [ Function('FindAll', [], []) ]

        if value.type == 'string':
            program += [ Function('FilterStr', [program[-1]], [key, value]) ]
            op_desc = ''

        elif value.type == 'quantity':
            program += [ Function('FilterNum', [program[-1]], [key, value, op]) ]
            if op == '<':
                op_desc = 'less than'
            elif op == '>':
                op_desc = 'greater than'
            elif op == '=':
                op_desc = 'equal to'
            elif op == '!=':
                op_desc = 'not equal to'

        elif value.type == 'date':
            program += [ Function('FilterDate', [program[-1]], [key, value, op]) ]
            if op == '=':
                op_desc = 'on'
            else:
                raise Exception('Not support not-equal opration for date')

        elif value.type == 'year':
            program += [ Function('FilterYear', [program[-1]], [key, value, op]) ]
            if op == '<':
                op_desc = 'before'
            elif op == '>':
                op_desc = 'after'
            elif op == '=':
                op_desc = 'in'
            elif op == '!=':
                op_desc = 'not in'
        
        text = self.data.describe_attribute_subject('the '+concept_name, key, value, op_desc)
        sparql += SparqlEngine.gen_attribute_query(key, value, op=op)

        if self.qualifier_condition:
            fact_repr = SparqlEngine.gen_attr_fact_node(key)
            text, program, sparql = append_qualifier_description(
                     self.data, self.qualifier_condition, fact_repr, text, program, sparql)

        if self.concept:
            program += [ Function('FilterConcept', [program[-1]], [concept_name]) ]

        return text, program, sparql




class RelationCondition():
    def __init__(self, data, concept, condition, qualifier_condition):
        self.data = data
        self.concept = concept
        self.condition = condition # pred, object, direction
        self.qualifier_condition = qualifier_condition # key, op, value of qualifier

    @classmethod
    def sample(cls, data, ent_id=None, concept=None):
        if ent_id:
            if concept is None:
                concept = random.choice(data.get_all_concepts(ent_id)) if random.random() < conf.PROB_USE_CONCEPT else None
            if len(data.get_relation_facts(ent_id)) == 0:
                return None

            predicate, obj, direction, qualifiers = random.choice(data.get_relation_facts(ent_id))
            condition = (predicate, obj, direction)

            if EntityCondition.sample(data, obj) is None: # 如果宾语无法消歧，那当前的 relation condition 就不成立
                return None

            qualifier_condition = None
            if len(qualifiers) > 0:
                if random.random() < conf.PROB_USE_QUALIFIER:
                    qk, qvs = random.choice(list(qualifiers.items()))
                    qv = random.choice(qvs)
                    qualifier_condition = sample_by_kv(data, None, qk, qv, prob_op=(0,1,0,0))
                
            return cls(data, concept, condition, qualifier_condition)
        else:
            if concept is None:
                concept = random.choice(data.high_freq_concepts) if random.random() < conf.PROB_USE_CONCEPT else None
            
            if concept:
                entities = data.concept_to_entity[concept]
            else:
                entities = data.entities

            pred_dir_to_obj = defaultdict(list)
            for ent_id in entities:
                for (p, o, d, q) in data.get_relation_facts(ent_id):
                    pred_dir_to_obj[(p, d)].append(o)
            
            if len(pred_dir_to_obj) == 0:
                return None
            predicate, direction = random.choice(list(pred_dir_to_obj.keys()))
            obj_candidates = set()
            for o in pred_dir_to_obj[(predicate, direction)]:
                for c in data.get_direct_concepts(o):
                    obj_candidates.update(set(data.concept_to_entity[c]))
            # valid objects + 3 additional entities which are siblings of a valid object
            if len(obj_candidates) > 3:
                additional = random.sample(list(obj_candidates), 3)
            else:
                additional = list(obj_candidates)
            obj_candidates = pred_dir_to_obj[(predicate, direction)] + additional
            if len(obj_candidates) == 0:
                return None
            obj = random.choice(list(obj_candidates))
            if EntityCondition.sample(data, obj) is None: # 如果宾语无法消歧，那当前的 relation condition 就不成立
                return None
            condition = (predicate, obj, direction)

            qualifier_condition = None
            return cls(data, concept, condition, qualifier_condition)


    def filter_facts(self, candidates=None):
        """
        Filter all entities that satisfy the condition
        Args:
            candidates (list) : entity names
        """
        concept = self.concept
        if concept:
            candidates_ = self.data.concept_to_entity[concept]
        else:
            candidates_ = list(self.data.entities.keys())
        if candidates:
            candidates = list(set(candidates_) & set(candidates))
        else:
            candidates = candidates_

        results = []
        for ent_id in candidates:
            flag = False
            for (p, o, d, qualifiers) in self.data.get_relation_facts(ent_id):
                flag = check_pred(p, o, d, self.condition)
                if flag and self.qualifier_condition:
                    for qk, qvs in qualifiers.items():
                        for qv in qvs:
                            flag = check_attr(qk, qv, self.qualifier_condition)
                            if flag:
                                break
                        if flag:
                            break
                if flag:
                    break
            if flag:
                results.append(ent_id)
        return results

    def description(self, obj_desc=None):
        """
        Convert current condition to text, program, and sparql
        Args:
            - obj_desc : if not None, it is the description of object, used for multi-hop description of current condition
        """
        pred, obj, direction = self.condition
        concept_name = get_concept_name(self.data, self.concept)

        if self.concept:
            sparql = SparqlEngine.gen_concept_query(concept_name)
        else:
            sparql = ''

        if obj_desc is None:
            obj_text, obj_program, obj_sparql = EntityCondition.sample(self.data, obj).description()
            obj_sparql, obj_variable = SparqlEngine.replace_variable(obj_sparql, '?e')
        else:
            obj_text, obj_program, obj_sparql, obj_variable = obj_desc

        program = obj_program

        program += [ Function('Relate', [program[-1]], [pred, reverse_dir(direction)]) ]
        text = self.data.describe_relation_subject('the '+concept_name, pred, obj_text, direction)
        
        sparql += SparqlEngine.gen_relation_query(pred, direction, obj_sparql, obj_variable)

        if self.qualifier_condition:
            fact_repr = SparqlEngine.gen_rel_fact_node(pred, direction, obj_variable)
            text, program, sparql = append_qualifier_description(
                    self.data, self.qualifier_condition, fact_repr, text, program, sparql)

        if self.concept:
            program += [ Function('FilterConcept', [program[-1]], [concept_name]) ]

        return text, program, sparql



if __name__=='__main__':
    from data import Data
