import re
import json
import torch.utils.data as data

def splitCamelCase(string):
    string = re.sub('(.)([A-Z][a-z]+)', r'\1 \2', string)
    string = re.sub('(.)([0-9]+)', r'\1 \2', string)
    return re.sub('([a-z0-9])([A-Z])', r'\1 \2', string).lower()


class TestDataset(data.Dataset):
    def __init__(self, config, path):
        self.config = config
        with open(path, 'r', encoding='utf-8') as f:
            items = json.load(f)
        self.data = []
        for item in items:
            question = item['question']
            for triple in item['triples']:
                relation_mention = triple['relation_mention']
                candidate_relation = triple['candidate_relation']
                assert len(candidate_relation) == 30
                for candi in candidate_relation:
                    self.data.append((relation_mention, candi, question))

        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # KGE
        self.relation2id = {}
        with open('/DATA_PATH/ImRL/data/kge/new_relation2id.txt', 'r', encoding='utf-8') as f:
            data = f.readlines()
            for line in data:
                r = line.replace('\n', '').split('\t')[0]
                id = line.replace('\n', '').split('\t')[1]
                self.relation2id[r] = int(id)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        relation_mention = self.data[index][0]
        # candidate_relations = self.data[index][1]
        # relation_ids = []
        # for rel_path in candidate_relations:
        #     rels = []
        #     for rel in rel_path:
        #         if rel in self.relation2id:
        #             relation_id = self.relation2id[rel]
        #         else:
        #             relation_id = 1
        #         rels.append(relation_id)
        #     relation_ids.append(rels)
        #
        # return relation_mention, relation_ids

        relations = self.data[index][1]

        candi_relation_local_names = []
        for rel in relations:
            candi_relation_local_names.append(splitCamelCase(rel.split('/')[-1].split('#')[-1]))
        candi_relation_local_names = ' '.join(candi_relation_local_names)

        candi_relation_kge_ids = []
        for rel in relations:
            if rel in self.relation2id:
                relation_id = self.relation2id[rel]
            else:
                # print(rel)
                relation_id = 1
            candi_relation_kge_ids.append(relation_id)

        question = self.data[index][2]
        question = '[R] ' + relation_mention + ' [Q] ' + question
        # relation phrase
        # question
        # candidate relation path name
        # candidate relation path id
        # label

        return relation_mention, question, candi_relation_local_names, candi_relation_kge_ids, relations