
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

def get_data_from_rawdata(tokenizer,triple_data,entity2id,relation2id,modes):
    CLS_id = tokenizer.token_to_id('[CLS]')
    SEP_id = tokenizer.token_to_id('[SEP]')
    MASK_id = tokenizer.token_to_id('[MASK]')
    triple_data = triple_data
    input_ids = []
    entity2id = entity2id
    labels = []
    heads = []
    relations = []
    if modes == 'tail-batch':
        for line in tqdm(triple_data,total=len(triple_data)):
            head, relation, tail = line
            if len(line) != 3:
                print(f"Ignoring invalid line: {line}")
                continue
            label = entity2id[tail]
            labels.append(label)
            head_id = entity2id[head]
            relation_id = relation2id[relation]
            heads.append(head_id)
            relations.append(relation_id)
            head_ids = tokenizer.token_to_id(head)
            relation_ids = tokenizer.token_to_id(relation)
            tail_ids = [MASK_id]
            input_id = [CLS_id] + [head_ids] + [relation_ids] + tail_ids + [SEP_id]

            input_ids.append(input_id)


    return input_ids,heads,relations,labels

class OnlyTailDataset(Dataset):
    def __init__(self,input_ids,heads,relations,labels):
        self.input_ids = input_ids
        self.labels = labels
        self.heads = heads
        self.relations = relations


    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, item):
        input_ids = torch.tensor(self.input_ids[item])
        labels = torch.tensor(self.labels[item])
        heads = torch.tensor(self.heads[item])
        relations = torch.tensor(self.relations[item])
        return input_ids,heads,relations,labels




