import json
import os
from collections import defaultdict
from queue import Queue
from structures import ValueClass
from datetime import date


class Data(object):
    """
    knowledge json format:
        'concepts':
        {
            'id':
            {
                'name': '',
                'instanceOf': ['<concept_id>'],
            }
        },
        'entities': # exclude concepts
        {
            'id': 
            {
                'name': '<entity_name>',
                'instanceOf': ['<concept_id>'],
                'attributes':
                [
                    {
                        'key': '<key>',
                        'value': 
                        {
                            'type': 'string'/'quantity'/'date'/'year'
                            'value':  # float or int for quantity, int for year, 'yyyy/mm/dd' for date
                            'unit':   # for quantity
                        },
                        'qualifiers':
                        {
                            '<qk>': 
                            [
                                <qv>, # each qv is a dictionary like value, including keys type,value,unit
                            ]
                        }
                    }
                ]
                'relations':
                [
                    {
                        'predicate': '<predicate>',
                        'object': '<object_id>', # NOTE: it may be a concept id
                        'direction': 'forward' or 'backward',
                        'qualifiers':
                        {
                            '<qk>': 
                            [
                                <qv>, # each qv is a dictionary like value
                            ]
                        }
                    }
                ]
            }
        },
    relation description json format:
        {
            'key':
            {
                'forward': '', # contains A/B as placeholder of subject/object
                'backward': '',
            }
        }
    attribute description json format:
        {
            'key': 'desc', # desc can be a noun phrase or a verb phrase started with 'is'
        }
    qualifier ask json format:
        {
            'key': 'desc'
        }
    """
    def __init__(self, data_dir='./data'):
        kb_path = os.path.join(data_dir, 'kb.json')
        ku_grid_path = os.path.join(data_dir, 'ku2grid.json')
        rel_desc_path = os.path.join(data_dir, 'relation_descriptions.json')
        attr_desc_path = os.path.join(data_dir, 'attribute_descriptions.json')
        qual_ask_path = os.path.join(data_dir, 'qualifier_ask.json')

        kb = json.load(open(kb_path))
        self.concepts = kb['concepts']
        self.entities = kb['entities']
        self.ku2grid = json.load(open(ku_grid_path))
        self.relation_descriptions = json.load(open(rel_desc_path))
        self.attribute_descriptions = json.load(open(attr_desc_path))
        self.qualifier_ask = json.load(open(qual_ask_path))

        # replace adjacent space and tab in name, which may cause errors when building sparql query
        for con_id, con_info in self.concepts.items():
            con_info['name'] = ' '.join(con_info['name'].split())
        for ent_id, ent_info in self.entities.items():
            ent_info['name'] = ' '.join(ent_info['name'].split())

        self.name_to_id = defaultdict(list)
        for ent_id, ent_info in self.entities.items():
            self.name_to_id[ent_info['name']].append(ent_id)
        for con_id, con_info in self.concepts.items():
            self.name_to_id[con_info['name']].append(con_id)

        self.concept_to_entity = defaultdict(set)
        for ent_id in self.entities:
            for c in self.get_all_concepts(ent_id): # merge entity into ancestor concept
                self.concept_to_entity[c].add(ent_id)
        self.concept_to_entity = { k:list(v) for k,v in self.concept_to_entity.items() }
        self.high_freq_concepts = [k for k,v in self.concept_to_entity.items() if len(v) > 3]

        # get all attribute keys and predicates
        self.attribute_keys = set()
        self.predicates = set()
        self.key_type = {}
        for ent_id, ent_info in self.entities.items():
            for attr_info in ent_info['attributes']:
                self.attribute_keys.add(attr_info['key'])
                self.key_type[attr_info['key']] = attr_info['value']['type']
                for qk in attr_info['qualifiers']:
                    self.attribute_keys.add(qk)
                    for qv in attr_info['qualifiers'][qk]:
                        self.key_type[qk] = qv['type']
        for ent_id, ent_info in self.entities.items():
            for rel_info in ent_info['relations']:
                self.predicates.add(rel_info['predicate'])
                for qk in rel_info['qualifiers']:
                    self.attribute_keys.add(qk)
                    for qv in rel_info['qualifiers'][qk]:
                        self.key_type[qk] = qv['type']
        self.attribute_keys = list(self.attribute_keys)
        self.predicates = list(self.predicates)
        # Note: key_type is one of string/quantity/date, but date means the key may have values of type year
        self.key_type = { k:v if v!='year' else 'date' for k,v in self.key_type.items() }

        # parse values into ValueClass object
        for ent_id, ent_info in self.entities.items():
            for attr_info in ent_info['attributes']:
                attr_info['value'] = self._parse_value(attr_info['value'])
                for qk, qvs in attr_info['qualifiers'].items():
                    attr_info['qualifiers'][qk] = [self._parse_value(qv) for qv in qvs]
        for ent_id, ent_info in self.entities.items():
            for rel_info in ent_info['relations']:
                for qk, qvs in rel_info['qualifiers'].items():
                    rel_info['qualifiers'][qk] = [self._parse_value(qv) for qv in qvs]

        print('='*50)
        print('number of concepts: %d' % len(self.concepts))
        print('number of entities: %d' % len(self.entities))
        print('number of attribute keys: %d' % len(self.attribute_keys))
        print('number of predicates: %d' % len(self.predicates))

        self.key_values, self.concept_key_values = self._find_seen_values()


    def _find_seen_values(self):
        # find all possible values for each attribute key
        key_values = defaultdict(list)
        concept_key_values = defaultdict(lambda: defaultdict(list)) # not include qualifier values
        for ent_id, ent_info in self.entities.items():
            for attr_info in ent_info['attributes']:
                k, v = attr_info['key'], attr_info['value']
                key_values[k].append(v)
                for c in self.get_all_concepts(ent_id):
                    concept_key_values[c][k].append(v)
                # merge qualifier statistics into attribute
                for qk, qvs in attr_info['qualifiers'].items():
                    for qv in qvs:
                        key_values[qk].append(qv)


        for ent_id, ent_info in self.entities.items():
            for rel_info in ent_info['relations']:
                # merge qualifier statistics into attribute
                for qk, qvs in rel_info['qualifiers'].items():
                    for qv in qvs:
                        key_values[qk].append(qv)

        # remove duplicate
        for k in key_values:
            key_values[k] = list(set(key_values[k]))
        for c in concept_key_values:
            for k in concept_key_values[c]:
                concept_key_values[c][k] = list(set(concept_key_values[c][k]))

        return key_values, concept_key_values



    def _parse_value(self, value):
        if value['type'] == 'date':
            x = value['value']
            p1, p2 = x.find('/'), x.rfind('/')
            y, m, d = int(x[:p1]), int(x[p1+1:p2]), int(x[p2+1:])
            result = ValueClass('date', date(y, m, d))
        elif value['type'] == 'year':
            result = ValueClass('year', value['value'])
        elif value['type'] == 'string':
            result = ValueClass('string', value['value'])
        elif value['type'] == 'quantity':
            result = ValueClass('quantity', value['value'], value['unit'])
        else:
            raise Exception('unsupport value type')
        return result


    def describe_relation_subject(self, sub_text, predicate, obj_text, direction, the_one=False):
        template = self.relation_descriptions[predicate][direction]
        if template.startswith("A's") or template.startswith("B's"):
            template = 'whose' + template[3:]
        elif template.startswith("A") or template.startswith("B"):
            template = 'that' + template[1:]
        else:
            raise Exception('template of {} must begin with A/B'.format(predicate))
        template = template.replace('A', '{}').replace('B', '{}')
        if the_one:
            text = '{} (the one {})'.format(sub_text, template.format(obj_text))
        else:
            text = sub_text + ' ' + template.format(obj_text)
        return text


    def describe_relation_fact(self, sub_text, predicate, obj_text, direction):
        template = self.relation_descriptions[predicate][direction]
        template = template.replace('A', '{}').replace('B', '{}')
        text = template.format(sub_text, obj_text)
        return text


    def describe_attribute_subject(self, sub_text, k, v, op=''):
        key_text = self.attribute_descriptions.get(k, k)
        if key_text.startswith('is '):
            assert op == ''
            text = '{} which {} {}'.format(sub_text, key_text, v)
        else:
            if op:
                text = '{} whose {} is {} {}'.format(sub_text, key_text, op, v)
            else:
                text = '{} whose {} is {}'.format(sub_text, key_text, v)
        return text

    def describe_attribute_fact(self, sub_text, k, v, op=''):
        key_text = self.attribute_descriptions.get(k, k)
        if key_text.startswith('is '):
            assert op == ''
            text = '{} {} {}'.format(sub_text, key_text, v)
        else:
            if op:
                text = 'the {} of {} is {} {}'.format(key_text, sub_text, op, v)
            else:
                text = 'the {} of {} is {}'.format(key_text, sub_text, v)
        return text

    def describe_attribute_key(self, k):
        key_text = self.attribute_descriptions.get(k, k)
        if key_text.startswith('is '):
            key_text = "'{}'".format(key_text)
        return key_text


    def ask_qualifier_key(self, k):
        if k in self.qualifier_ask:
            return self.qualifier_ask[k]
        else:
            return 'what is the {} of this statement'.format(k)


    def get_direct_concepts(self, ent_id):
        """
        return the direct concept id of given entity/concept
        """
        if ent_id in self.entities:
            return self.entities[ent_id]['instanceOf']
        elif ent_id in self.concepts:
            return self.concepts[ent_id]['instanceOf']
        else:
            raise Exception('unknown id')

    def get_all_concepts(self, ent_id):
        """
        return a concept id list
        """
        ancestors = []
        q = Queue()
        for c in self.get_direct_concepts(ent_id):
            q.put(c)
        while not q.empty():
            con_id = q.get()
            ancestors.append(con_id)
            for c in self.concepts[con_id]['instanceOf']:
                q.put(c)

        return ancestors

    def get_name(self, ent_id):
        if ent_id in self.entities:
            return self.entities[ent_id]['name']
        elif ent_id in self.concepts:
            return self.concepts[ent_id]['name']
        else:
            return None

    def get_ids(self, name):
        return self.name_to_id[name]

    def is_ambiguous(self, ent_id):
        name = self.get_name(ent_id)
        return len(self.name_to_id[name]) > 1

    def is_concept(self, ent_id):
        return ent_id in self.concepts

    def is_human(self, ent_id):
        return 'human' in [self.concepts[i]['name'] for i in self.get_all_concepts(ent_id)]

    def get_attribute_facts(self, ent_id, key=None, unit=None):
        if key:
            facts = []
            for attr_info in self.entities[ent_id]['attributes']:
                if attr_info['key'] == key:
                    if unit:
                        if attr_info['value'].unit == unit:
                            facts.append(attr_info)
                    else:
                        facts.append(attr_info)
        else:
            facts = self.entities[ent_id]['attributes']
        facts = [(f['key'], f['value'], f['qualifiers']) for f in facts]
        return facts

    def get_relation_facts(self, ent_id):
        facts = self.entities[ent_id]['relations']
        facts = [(f['predicate'], f['object'], f['direction'], f['qualifiers']) for f in facts]
        return facts

    def print_statistics(self):
        cnt_rel, cnt_attr, cnt_qual = 0, 0, 0
        for ent_id, ent_info in self.entities.items():
            for attr_info in ent_info['attributes']:
                cnt_attr += 1
                for qk in attr_info['qualifiers']:
                    for qv in attr_info['qualifiers'][qk]:
                        cnt_qual += 1
        for ent_id, ent_info in self.entities.items():
            for rel_info in ent_info['relations']:
                cnt_rel += 1
                for qk in rel_info['qualifiers']:
                    for qv in rel_info['qualifiers'][qk]:
                        cnt_qual += 1

        print('number of relation knowledge: %d' % cnt_rel)
        print('number of attribute knowledge: %d' % cnt_attr)
        print('number of qualifier knowledge: %d' % cnt_qual)


if __name__=='__main__':
    data = Data()
    data.print_statistics()
