import random
import rdflib
from conf import NUM_CHOICES, PAD_CHOICE
from sparqlEngine import SparqlEngine

class ChoiceGen():

    def for_entity(data, target, sparql):
        """
        Arguments:
            - target: target entity id
            - sparql: sparql description of target, which can be split into multiple clauses

        We randomly remove a sub-clause of sparql, and extract choices from the query results.
        Repeat it for n times. If the number of choices is less than NUM_CHOICES, pad using children of the concept.
        """
        clauses = SparqlEngine.get_all_clauses(sparql)
        choices = []
        for i in range(len(clauses)):
            # remove the i-th sub-clause
            if clauses[i].strip() in {'{', '}'}:
                continue
            part_clauses = [c for j,c in enumerate(clauses) if j != i]
            part_sparql = SparqlEngine.ensemble_clauses(part_clauses)
            part_sparql = 'SELECT DISTINCT ?e WHERE {{ {} }}'.format(part_sparql)
            res = SparqlEngine.query_virtuoso(part_sparql)
            entities = [str(binding[rdflib.term.Variable('e')]) for binding in res.bindings]
            if len(entities) == 0:
                continue
            # limit the maximum number of a part sparql to 3
            if len(entities) > 3:
                entities = random.sample(entities, 3)
            # add its name into choices
            for e in entities:
                name = data.get_name(e) # e may be a blank node, and name is None
                if name and name != data.get_name(target) and name not in choices:
                    choices.append(name)

        # pad choices using other children of its concept
        if len(choices) < NUM_CHOICES-1:
            concepts = data.get_all_concepts(target)
            entities = []
            for concept in concepts:
                entities += data.concept_to_entity[concept]
            for e in entities:
                name = data.get_name(e)
                if name != data.get_name(target) and name not in choices:
                    choices.append(name)
                if len(choices) > NUM_CHOICES-1:
                    break
        # if still less than required number, then pad using random entities
        if len(choices) < NUM_CHOICES-1:
            while len(choices) < NUM_CHOICES-1:
                e = random.choice(list(data.entities.keys()))
                name = data.get_name(e)
                if name != data.get_name(target) and name not in choices:
                    choices.append(name)

        # add the corrent answer into choices
        choices = random.sample(choices, NUM_CHOICES-1)
        choices.append(data.get_name(target))
        random.shuffle(choices)

        return choices


    def for_entity_selection(data, target, candidates):
        """
        Args:
            - candidates: a list contains entity ids. Note target is included.
        """
        choices = [data.get_name(e) for e in candidates if e != target]
        choices = list(set(choices))

        # pad choices using other children of its concept
        if len(choices) < NUM_CHOICES-1:
            concepts = data.get_all_concepts(target)
            entities = []
            for concept in concepts:
                entities += data.concept_to_entity[concept]
            for e in entities:
                name = data.get_name(e)
                if name != data.get_name(target) and name not in choices:
                    choices.append(name)
                if len(choices) > NUM_CHOICES-1:
                    break
        # if still less than required number, then pad using random entities
        if len(choices) < NUM_CHOICES-1:
            while len(choices) < NUM_CHOICES-1:
                e = random.choice(list(data.entities.keys()))
                name = data.get_name(e)
                if name != data.get_name(target) and name not in choices:
                    choices.append(name)

        # add the corrent answer into choices
        choices = random.sample(choices, NUM_CHOICES-1)
        choices.append(data.get_name(target))
        random.shuffle(choices)

        return choices


    def for_count(answer):
        pos = random.choice(range(10))
        start = max(answer - pos, 0)
        choices = list(range(start, start+10))
        assert answer in choices
        return choices


    def for_attribute_value(data, key, answer):
        """
        Args:
            - key: an attribute or qualifier key
            - answer: target value
        Return:
            values of str format

        Find other values of the same attribute key.
        If not enough, randomly extract some other values for padding.
        """
        choices = []
        for v in data.key_values[key]:
            if str(v) != str(answer) and str(v) not in choices:
                if answer.isTime() and v.isTime() and v.contains(answer):
                    # if answer is 1990-01-01, v is 1990, then v should not be added into choices
                    continue
                choices.append(str(v))
        if len(choices) < NUM_CHOICES-1:
            while len(choices) < NUM_CHOICES-1:
                k = random.choice(data.attribute_keys)
                if len(data.key_values[k]) == 0:
                    continue
                v = random.choice(data.key_values[k])
                if str(v) != str(answer) and str(v) not in choices:
                    choices.append(str(v))
        # add the corrent answer into choices
        choices = random.sample(choices, NUM_CHOICES-1)
        choices.append(str(answer))
        random.shuffle(choices)
        return choices


    def for_binary():
        choices = ['yes', 'no'] + [PAD_CHOICE]*(NUM_CHOICES-2)
        return choices


    def for_relation(data, answer):
        """
        Randomly select from all predicates.
        """
        choices = []
        while len(choices) < NUM_CHOICES-1:
            p = random.choice(data.predicates)
            if p != answer and p not in choices:
                choices.append(p)
        choices.append(answer)
        random.shuffle(choices)
        return choices
