import torch
import os
from torch.utils.data import Dataset
#import cv2
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
#from albumentations import (Compose, Normalize, HorizontalFlip,
 #                           ShiftScaleRotate, Transpose
  #                          )
#from albumentations.pytorch import ToTensor
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class JGWDataset(Dataset):
    def __init__(self, df,SIZE = 512,transform=None):
        self.data = df
        self.size = SIZE
        self.transform=transform
        self.loader = transforms.Compose([
            transforms.ToTensor()])

    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img_name = self.data.loc[idx, 'name']
        img_path = '../data/oracle_images2/' + img_name#jiagu_trans_50mix

        try:
            img = Image.open(img_path)
            img = img.resize((self.size, self.size))
            img = self.loader(img)

            labels = torch.tensor(self.data.loc[idx,'label_number'])


        except:
            print('img_name',img_name)
            idx = 0
            img_name = self.data.loc[idx, 'name']
            img_path = '../data/oracle_images2/' + img_name
            # img = cv2.imread(img_path)
            img = Image.open(img_path)
            img = torch.tensor(img.resize((self.size, self.size)))
            img = self.loader(img)

            labels = torch.tensor(self.data.loc[idx, 'label_number'])

        return {'image': img, 'labels': labels}

def datasets(trn_df,test_df):

    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        transforms.Normalize(mean=(0.5), std=(0.5))
    ])
    # no val
    trndataset = JGWDataset(trn_df, SIZE=100,transform = transform)
    tstdataset = JGWDataset(test_df,SIZE=100,transform = transform)
    num_workers = 0
    trnloader = DataLoader(trndataset, batch_size=64, num_workers=num_workers,shuffle=False,pin_memory=True)#,测试batch=15174
    tstloader = DataLoader(tstdataset, batch_size=64, shuffle=False, num_workers=num_workers,pin_memory=True)
    return trnloader,tstloader
def datasets_test(test_df):
    tstdataset = JGWDataset(test_df,SIZE=100)
    num_workers = 0
    tstloader = DataLoader(tstdataset, batch_size=512, shuffle=False, num_workers=num_workers,
                           pin_memory=True)
    return tstloader
