import re
import json
import torch.utils.data as data


def splitCamelCase(string):
    string = re.sub('(.)([A-Z][a-z]+)', r'\1 \2', string)
    string = re.sub('(.)([0-9]+)', r'\1 \2', string)
    return re.sub('([a-z0-9])([A-Z])', r'\1 \2', string).lower()


class IODataset(data.Dataset):
    def __init__(self, config, path):
        self.model_name = 'RotatE'
        self.is_use_property = True

        self.global_entity_path = '/DATA_PATH/ImRL/KGE/ckpts/' + self.model_name + '/entity.npy'
        self.global_relation_path = '/DATA_PATH/ImRL/KGE/ckpts/' + self.model_name + '/relation.npy'

        self.config = config

        with open(path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)

        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # KGE

        self.relation2id = {}
        with open('/DATA_PATH/ImRL/new_relation2id.txt', 'r', encoding='utf-8') as f:
            data = f.readlines()
            for line in data:
                r = line.replace('\n', '').split('\t')[0]
                id = line.replace('\n', '').split('\t')[1]
                self.relation2id[r] = int(id)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        relation_mention = self.data[index][0]
        question = self.data[index][1]
        relations = self.data[index][2]
        label = self.data[index][3]

        candi_relation_local_names = []
        for rel in relations:
            candi_relation_local_names.append(splitCamelCase(rel.split('/')[-1].split('#')[-1]))
        candi_relation_local_names = ' '.join(candi_relation_local_names)

        candi_relation_kge_ids = []
        for rel in relations:
            if rel in self.relation2id:
                relation_id = self.relation2id[rel]
            else:
                # print(rel, label)
                relation_id = 1
            candi_relation_kge_ids.append(relation_id)

        question = '[R] ' + relation_mention + ' [Q] ' + question
        # relation phrase
        # question
        # candidate relation path name
        # candidate relation path id
        # label
        return relation_mention, question, candi_relation_local_names, candi_relation_kge_ids, label