import random
from conditions import AttributeCondition, RelationCondition, EntityCondition, ConceptCondition
conditionClasses = (AttributeCondition, RelationCondition)


def by_entity(data, ent_id):
    return EntityCondition.sample(data, ent_id)

def by_concept(data):
    return ConceptCondition.sample(data)

def by_attribute(data, ent_id=None, max_num=100):
    cnt = 0
    if ent_id and len(data.get_attribute_facts(ent_id)) == 0:
        return None
    while cnt < max_num:
        cnt += 1
        condition = AttributeCondition.sample(data, ent_id)
        if condition:
            if ent_id:
                results = condition.filter_facts()
                if len(results) == 1:
                    return condition
            else:
                return condition
    return None


def by_relation(data, ent_id=None, max_num=100):
    cnt = 0
    if ent_id and len(data.get_relation_facts(ent_id)) == 0:
        return None
    while cnt < max_num:
        cnt += 1
        condition = RelationCondition.sample(data, ent_id)
        if condition:
            if ent_id:
                results = condition.filter_facts()
                if len(results) == 1:
                    return condition
            else:
                return condition
    return None


def by_two_conditions_intersect(data, ent_id=None, not_located_by_single=False, max_num=100):
    cnt = 0
    while cnt < max_num:
        cnt += 1
        class1 = random.choice(conditionClasses)
        condition1 = class1.sample(data, ent_id)
        if condition1 is None:
            continue
        
        class2 = random.choice(conditionClasses)
        # use the same concept with condition1
        condition2 = class2.sample(data, ent_id, concept=condition1.concept)
        if condition2 is None:
            continue
        # avoid repeated conditions
        if class1 == class2 and condition1.condition[0] == condition2.condition[0] and condition1.condition[1] == condition2.condition[1]:
            continue
        
        if ent_id:
            results1 = condition1.filter_facts()
            results2 = condition2.filter_facts()
            if not_located_by_single: # single condition must be ambiguous
                flag = len(results1) > 1 and len(results2) > 1
            else:
                flag = len(results1) >= 1 and len(results2) >= 1

            results = list(set(results1) & set(results2))
            if flag and len(results) == 1:
                return (condition1, condition2)
        else:
            return (condition1, condition2)
    return None


def by_two_conditions_union(data, max_num=100):
    cnt = 0
    while cnt < max_num:
        cnt += 1
        class1 = random.choice(conditionClasses)
        condition1 = class1.sample(data)
        if condition1 is None:
            continue
        
        class2 = random.choice(conditionClasses)
        # use the same concept with condition1
        condition2 = class2.sample(data, concept=condition1.concept)
        if condition2 is None:
            continue
        # avoid repeated conditions
        if class1 == class2 and condition1.condition[0] == condition2.condition[0] and condition1.condition[1] == condition2.condition[1]:
            continue
        
        return (condition1, condition2)
    return None


def by_two_hop_relation(data, ent_id=None, max_num=100, inner_max_num=10):
    cnt = 0
    while cnt < max_num:
        cnt += 1
        condition = RelationCondition.sample(data, ent_id)
        if condition is None:
            continue
        results = condition.filter_facts()
        if ent_id is None or len(results) == 1:
            inner_cnt = 0
            while inner_cnt < inner_max_num:
                inner_cnt += 1
                hop_entity = condition.condition[1]
                # if object is a concept, quit inner circle
                if hop_entity in data.concepts:
                    break
                hop_class = random.choice(conditionClasses)
                hop_condition = hop_class.sample(data, hop_entity)
                if hop_condition is None:
                    continue
                # avoid circle path
                if ent_id and hop_class == RelationCondition and hop_condition.condition[1] == ent_id:
                    continue
                hop_results = hop_condition.filter_facts()
                if len(hop_results) == 1:
                    return (condition, hop_condition)
    return None



def test():
    from data import Data
    from utils import convert_program_list, and_two_descriptions, hop_two_descriptions
    data = Data('./data/kb.json', './data/relation_descriptions.json')
    cnt = 0
    while cnt < 50:
        ent_id = random.choice(list(data.entities.keys()))
        print('-'*10)
        print('> sampled entity: {} {}'.format(ent_id, data.entities[ent_id]['name']))
        # condition = by_entity(data, ent_id)
        # condition = by_attribute(data, ent_id)
        # condition = by_relation(data, ent_id)
        condition = by_two_conditions_intersect(data, ent_id, not_located_by_single=True)
        # condition = by_two_hop_relation(data, ent_id)
        if condition:
            cnt += 1
            # desc = condition.description()
            desc = and_two_descriptions(condition[0], condition[1])
            # desc = hop_two_descriptions(condition[0], condition[1])
            print(desc[0])
            print([str(f) for f in convert_program_list(desc[1])])
            print(desc[2])
            print('='*50)
        

if __name__=='__main__':
    test()
