import os
import torch
import csv
import logging
model_dir = "model"
def read_csv_data(path,label2id,p_id=0,h_id=1, l_id=2):
    if 'entailment' not in label2id.keys():
        print('label2id error, using default mapping config')
        label2id = {'entailment':0,'neutral':1,'contradiction':2}
    csv_reader = csv.reader(open(path))
    premise = []
    hypothesis = []
    labels = []
    for row in csv_reader:
        row = "".join(row)
        text = row.split('\t')
        premise.append(text[p_id])
        hypothesis.append(text[h_id])
        if text[l_id] in label2id.keys():
            label = label2id[text[l_id]]
        else:
            logging.warning(f'No label find in {row}')
            continue
        labels.append(label)
    return premise,hypothesis,labels
def createDir(filePath):
    if os.path.exists(filePath):
        return
    else:
        try:
            os.mkdir(filePath)
        except Exception as e:
            os.makedirs(filePath)
class BindDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=[]):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels !=[]:
            item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)