import numpy as np
import glob
import cv2
import os
import editdistance
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import time
from threading import Thread


PAD_TAG = '<pad>'
BOS_TAG = '<bos>'
EOS_TAG = '<eos>'
# UNK_TAG = '<unk>'

PAD = 0
BOS = 1
EOS = 2
# UNK = 3


def get_imgs_from_video(video, ext='jpg', RGB=False):
    frames = []
    if os.path.isdir(video):
        frames = sorted(glob.glob(os.path.join(video, '*.{}'.format(ext))),
                        key=lambda x: int(x.split('/')[-1].split('.')[0]))
        frames = [cv2.imread(f) for f in frames]
    else:
        cap = cv2.VideoCapture(video)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

    frames = np.array(frames)
    if RGB:
        return frames[..., ::-1]
    else:
        return frames


# for CTC from https://github.com/arxrean/LipRead-seq2seq
class GridDataset(Dataset):
    letters = [' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
               'U', 'V', 'W', 'X', 'Y', 'Z']

    def __init__(self, video_path, align_path, file_list, vid_pad, txt_pad, phase='train'):
        self.align_path = align_path  # transcription dir path
        self.video_path = video_path  # video frame dir path
        self.vid_pad = vid_pad
        self.txt_pad = txt_pad
        self.phase = phase

        with open(file_list, 'r') as f:
            self.videos = [os.path.join(video_path, line.strip()) for line in f]

        self.data = []
        for vid in self.videos:
            items = vid.split(os.path.sep)
            self.data.append((vid, items[-1]))

        self.transform = self.get_transform(phase)

    def __getitem__(self, idx):
        (vid, name) = self.data[idx]
        vid = self._load_vid(vid)
        align = self._load_align(os.path.join(self.align_path, name + '.align'))

        vid = self.transform(vid)
        vid_len = vid.shape[0]
        align_len = align.shape[0]
        vid = self._padding(vid, self.vid_pad)
        align = self._padding(align, self.txt_pad)

        return {'vid': torch.FloatTensor(vid.transpose(3, 0, 1, 2)),  # C, T, H, W
                'txt': torch.LongTensor(align),
                'txt_len': align_len,
                'vid_len': vid_len}

    def __len__(self):
        return len(self.data)

    def get_transform(self, phase='train'):
        '''
        torchvision.transforms: 常用的数据预处理方法，提升泛化能力
        包括：数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等
        '''
        # return transforms.Compose([
        #     transforms.Grayscale(),
        #     transforms.Resize((96, 96)),
        #     transforms.TenCrop((88, 88)),    # for testing  (bs, ncrops, c, h, w)
        #     # transforms.CenterCrop((88, 88)),  # for testing
        #     # transforms.RandomCrop((88, 88)),  # for training
        #     transforms.RandomHorizontalFlip(p=0.5),
        #     transforms.Lambda(lambda crops: torch.stack(
        #         [transforms.ToTensor()(crop) for crop in crops])),
        # ])
        if phase == 'train':
            # 灰度图
            # return transforms.Compose([
            #     transforms.Grayscale(),
            #     transforms.Resize((96, 96)),
            #     transforms.RandomCrop((88, 88)),  # for training
            #     transforms.RandomHorizontalFlip(p=0.5),
            #     transforms.ToTensor(),
            # ])
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((96, 96)),
                transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
        else:
            # 灰度图
            # return transforms.Compose([
            #     transforms.Grayscale(),
            #     transforms.Resize((96, 96)),
            #     transforms.CenterCrop((88, 88)),  # for testing
            #     transforms.ToTensor(),
            # ])
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((96, 96)),
                transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])

    def _load_vid(self, p):  # 加载唇动图像序列
        files = os.listdir(p)
        files = list(filter(lambda file: file.find('.jpg') != -1, files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [cv2.imread(os.path.join(p, file)) for file in files]
        array = list(filter(lambda im: im is not None, array))
        array = [cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array

    def _load_align(self, name):
        ''' GRID align样例格式:
            command(4) + color(4) + preposition(4) + letter(25) + digit(10) + adverb(4)
            0 14500 sil   # 静音段
            14500 22500 bin
            22500 30750 blue
            30750 39250 at
            39250 40000 sp  # 静音段
            40000 48500 f
            48500 57500 one
            57500 71000 soon
            71000 74500 sil  # 静音段
        '''
        with open(name, 'r') as f:
            lines = [line.strip().split(' ') for line in f]
            txt = [line[2] for line in lines]
            txt = list(filter(lambda s: not s.upper() in ['SIL', 'SP'], txt))
        # BIN BLUE AT F ONE SOON
        return GridDataset.txt2arr(' '.join(txt).upper(), 1)  # 词之间插入空格(blank label)

    def _padding(self, array, length):  # 对齐到length
        # array = [array[_] for _ in range(array.shape[0])]
        # size = array[0].shape
        # for i in range(length - len(array)):
        #     array.append(np.zeros(size))
        # return np.stack(array, axis=0)
        return np.concatenate([array, np.zeros([length - len(array)] + list(array.shape[1:]))])

    @staticmethod
    def txt2arr(txt, start):
        arr = []
        for c in list(txt):
            arr.append(GridDataset.letters.index(c) + start)
        return np.array(arr)

    @staticmethod
    def arr2txt(arr, start):
        txt = []
        for n in arr:
            if n >= start:
                txt.append(GridDataset.letters[n - start])
        return ''.join(txt).strip()

    @staticmethod
    def ctc_arr2txt(arr, start):
        pre = -1
        txt = []
        for n in arr:
            if pre != n and n >= start:
                if len(txt) > 0 and txt[-1] == ' ' and GridDataset.letters[n - start] == ' ':
                    pass
                else:
                    txt.append(GridDataset.letters[n - start])
            pre = n
        return ''.join(txt).strip()

    @staticmethod
    def wer(predict, truth):
        word_pairs = [(p[0].split(' '), p[1].split(' ')) for p in zip(predict, truth)]
        wer = [editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs]
        return wer

    @staticmethod
    def cer(predict, truth):
        cer = [editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth)]
        return cer


def cutout(img, n_holes=1, length=20):
    '''
        img (Tensor): Gray image of size (H, W).
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    Returns:
        Tensor: Image with n_holes of dimension length x length cut out of it.
    '''
    h, w = img.shape[0], img.shape[1]
    mask = np.ones((h, w), np.uint8)
    for n in range(n_holes):
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - length // 2, 0, h)
        y2 = np.clip(y + length // 2, 0, h)
        x1 = np.clip(x - length // 2, 0, w)
        x2 = np.clip(x + length // 2, 0, w)
        mask[y1: y2, x1: x2] = 0
    return img * mask


def cosine_angle(a, b):
    if a.ndim == 1 or b.ndim == 1:
        cos_val = a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9)
        #cos_val = np.dot(a/np.linalg.norm(a), b/np.linalg.norm(b))
    else:
        cos_val = (a * b).sum(axis=1) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-9)
    cos_val = np.clip(cos_val, -1., 1.)
    return np.arccos(cos_val)  # / np.pi * 180


def random_crop(img, dw, dh):
    h, w = img.shape[0], img.shape[1]
    h = np.random.randint(0, h - dh)
    w = np.random.randint(0, w - dw)
    return img[h: h+dh, w: w+dw]


def HorizontalFlip(img, p=0.5):
    # (H, W) or (H, W, C)
    return img if np.random.rand() > p else img[:, ::-1, ...]


def adjust_img(img):
    # alpha 对比度调整参数  >1对比度增加，<1对比度减小  (不要超过2)
    # beta 亮度调整参数   >0亮度增加，<0亮度减小
    alpha = np.random.rand() + 0.5   # [0.5, 1.5]
    beta = np.random.randint(-30, 30)  # [-30, 30]
    img_adjusted = abs(alpha * img + beta)
    img_adjusted = np.clip(img_adjusted, 0, 255)
    return img_adjusted.astype(np.uint8)


class GridSeq2Seq(Dataset):
    def __init__(self, opt, phase='train'):
        self.phase = phase
        self.opt = opt
        assert phase in ['train', 'val', 'test']
        
        self.data = glob.glob(os.path.join(opt.video_root, 's*', '*'))
        #self.data = glob.glob(os.path.join(opt.lm_root, 's*', '*'))
        print('total:', len(self.data))
        test_spks = ['s1', 's2', 's20', 's22']
        
        if phase == 'train':
            self.cur_data = [x for x in self.data if len(os.listdir(x)) >= 40 and x.split(os.path.sep)[-2] not in test_spks]
        elif phase == 'val':
            self.cur_data = [x for x in self.data if len(os.listdir(x)) >= 40 and x.split(os.path.sep)[-2] in test_spks]
            np.random.seed(123)
            np.random.shuffle(self.cur_data)
            self.cur_data = self.cur_data[:opt.val_batch * opt.batch_size]
        else:
            self.cur_data = [x for x in self.data if len(os.listdir(x)) > 0 and x.split(os.path.sep)[-2] in test_spks]

        print(len(self.cur_data))

        # self.cur_data = list(filter(lambda fn: len(os.listdir(fn)) > 5, self.cur_data))
        # self.char_list = ['^', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
                               # 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '*']
        self.char_list = [PAD_TAG, BOS_TAG, EOS_TAG] + [' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
                                                                 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
        self.char_dict = {ch: idx for idx, ch in enumerate(self.char_list)}
        self.idx_dict = {idx: ch for idx, ch in enumerate(self.char_list)}
        self.transform = self.get_transform(phase)
        self.ps = 16
        #self.count = 0
        #self.seed = 123
        print(f'patch size: {self.ps}')
        #self.video_mean = np.load('data/video_mean.npy')
        #self.video_std = np.load('data/video_std.npy')

    def __getitem__(self, index):
        '''
        if self.phase == 'train':
            if self.count % len(self.cur_data) == 0:
                #np.random.seed(self.seed)
                #self.ps = np.random.choice([8, 12, 16, 20, 24])   
                self.ps = np.random.choice(list(range(8, 20, 2)))   
                print(f'new patch size: {self.ps}')
                self.count = 0
                #self.seed = np.random.randint(10, 500)
        else:
            self.ps = 16
        self.count += 1
        '''
        item = self.cur_data[index]   # E:\\GRID\\LIP_160x80\\lip\\s26\\prim7p
        #video = self.load_video(item)   # img path
        #video = self.fast_load_video(item)  # img path
        vids, pts, motion, vid_len = self.load_landmark_face(item, self.ps)  # landmark path
        
        #return torch.FloatTensor(vids), torch.FloatTensor(pts), torch.FloatTensor(motion), vid_len
        return {'org_patch': torch.FloatTensor(vids[0]),
                'small_patch': torch.FloatTensor(vids[1]),
                'large_patch': torch.FloatTensor(vids[2]),
                'org_pts': torch.FloatTensor(pts[0]),
                'small_pts': torch.FloatTensor(pts[1]),
                'large_pts': torch.FloatTensor(pts[2]),
                'motion': torch.FloatTensor(motion),
                'vid_len': vid_len}

    def __len__(self):
        return len(self.cur_data)

    def load_landmark_face(self, path, patch_size=16):
        #def read_img(id_, fn):
        #    img = cv2.imread(fn, 0)
        #    faces.append((id_, img))

        files = os.listdir(path)
        files = list(filter(lambda file: file.find('.xy') != -1, files))  # 2D point
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        face_lms = [np.loadtxt(os.path.join(path, f)) for f in files]
        faces = [cv2.imread(os.path.join(path, file.replace('.xy', '.jpg')), 0) for file in files]  # 灰度图
        vid_pad = self.opt.max_vid_len
        vid_len = min(vid_pad, len(faces))
        '''
        faces = []
        ths = []
        for i, f in enumerate(files):
            f = os.path.join(path, f.replace('.xy', '.jpg'))
            th = Thread(target=read_img, args=(i, f, ))
            th.start()
            ths.append(th)
        for th in ths:
            th.join()
        faces = sorted(faces, key=lambda x: x[0])    # 升序
        faces = list(map(lambda x: x[1], faces))
        '''

        N = 20   
        org_patch = np.zeros((vid_pad, N, 1, patch_size, patch_size), dtype=np.float32)   # T x 20 x 1 x 32 x 32
        small_patch = np.zeros((vid_pad, N, 1, patch_size, patch_size), dtype=np.float32)   # T x 20 x 1 x 32 x 32
        large_patch = np.zeros((vid_pad, N, 1, patch_size, patch_size), dtype=np.float32)   # T x 20 x 1 x 32 x 32
        
        org_pts = np.zeros((vid_pad, N, (N-1)*2), dtype=np.float32)   # T x 20 x 38
        small_pts = np.zeros((vid_pad, N, (N-1)*2), dtype=np.float32)   # T x 20 x 38
        large_pts = np.zeros((vid_pad, N, (N-1)*2), dtype=np.float32)   # T x 20 x 38
        
        lm_motion = np.zeros((vid_pad, 50), dtype=np.float32)   # T x 16
        dlg_mask = (np.ones((N, N)) - np.eye(N)).astype(bool)
        
        for i, (face, lms) in enumerate(zip(faces, face_lms)):
            if i >= vid_pad:
                break
           
            for scale in [1, 0.5, 2]: 
                if self.phase == 'train' and scale != 1:
                    scale_face = cv2.resize(face.copy(), dsize=None, fx=1./scale, fy=1./scale)
                scale_lms = lms / scale
                lip_lms = scale_lms[-N:]
                for j, pt in enumerate(lip_lms):
                    x, y = pt[0], pt[1]
                    lx, ly = max(0, int(x - patch_size/2)), max(0, int(y - patch_size/2))   # w, h
                    rx, ry = min(scale_face.shape[1], lx+patch_size), min(scale_face.shape[0], ly+patch_size)
                    patch = scale_face[ly: ry, lx: rx] / 255. 
                    if scale == 0.5:
                        large_patch[i, j, :, :patch.shape[0], :patch.shape[1]] = patch[None]
                    elif scale == 2:
                        small_patch[i, j, :, :patch.shape[0], :patch.shape[1]] = patch[None]
                    else:
                        org_patch[i, j, :, :patch.shape[0], :patch.shape[1]] = patch[None]

                pts = lip_lms[:, :2][:, None] - lip_lms[:, :2]   # 20 x 20 x 2
                if scale == 0.5:
                    large_pts[i] = pts[dlg_mask, :].reshape(N, -1)    # 20 x (19 x 2)
                elif scale == 2:
                    small_pts[i] = pts[dlg_mask, :].reshape(N, -1)    # 20 x (19 x 2)
                else:
                    org_pts[i] = pts[dlg_mask, :].reshape(N, -1)    # 20 x (19 x 2)
           
            #dout = np.linalg.norm(lm[50:53, :2] - lm[56:59, :2][::-1], axis=-1)   # outer dist 3  
            #din = np.linalg.norm(lm[61:64, :2] - lm[65:68, :2][::-1], axis=-1)   # inner dist 3
            #hwr = np.array([
            #    np.linalg.norm(lms[51, :2] - lms[57, :2], axis=-1) / np.linalg.norm(lms[48, :2] - lms[54, :2], axis=-1),
            #    np.linalg.norm(lms[62, :2] - lms[66, :2], axis=-1) / np.linalg.norm(lms[60, :2] - lms[64, :2], axis=-1)])
            hwr = np.array([np.linalg.norm(lms[51, :2] - lms[57, :2], axis=-1), np.linalg.norm(lms[62, :2] - lms[66, :2], axis=-1)])
            lip_angle = np.array([
                    cosine_angle(lms[51, :2] - lms[48, :2], lms[57, :2] - lms[48, :2]),     # left outer  1  
                    cosine_angle(lms[62, :2] - lms[60, :2], lms[66, :2] - lms[60, :2]),     # left inner  1
                    cosine_angle(lms[51, :2] - lms[54, :2], lms[57, :2] - lms[54, :2]),     # right outer  1  
                    cosine_angle(lms[62, :2] - lms[64, :2], lms[66, :2] - lms[64, :2]),     # right inner  1
                    #cosine_angle(lms[48, :2] - lms[51, :2], lms[54, :2] - lms[51, :2]),     # upper outer 1
                    #cosine_angle(lms[60, :2] - lms[62, :2], lms[64, :2] - lms[62, :2]),     # upper inner 1
                    #cosine_angle(lms[48, :2] - lms[57, :2], lms[54, :2] - lms[57, :2]),     # lower outer 1
                    #cosine_angle(lms[60, :2] - lms[66, :2], lms[64, :2] - lms[66, :2])      # lower inner 1
                ])
            chin_angle1 = cosine_angle(lms[48, :2] - lms[:17, :2], lms[54, :2] - lms[:17, :2])   # 17
            chin_angle2 = cosine_angle(lms[60, :2] - lms[:17, :2], lms[64, :2] - lms[:17, :2])   # 17
            motion = np.concatenate((hwr, lip_angle, chin_angle1, chin_angle2))
            #chin_angle1 = cosine_angle(lms[1:17, :2] - lms[48, :2], lms[:16, :2] - lms[48, :2])   # 16
            #chin_angle2 = cosine_angle(lms[1:17, :2] - lms[54, :2], lms[:16, :2] - lms[54, :2])   # 16
            #chin_angle3 = cosine_angle(lms[1:17, :2] - lms[51, :2], lms[:16, :2] - lms[51, :2])   # 16
            #chin_angle4 = cosine_angle(lms[1:17, :2] - lms[57, :2], lms[:16, :2] - lms[57, :2])   # 16
            #motion = np.concatenate((lip_angle, chin_angle1, chin_angle2, chin_angle3, chin_angle4))
            if i != 0:
              lm_motion[i] = motion - pre
            pre = motion
        
        mot_mean = np.mean(lm_motion[:len(faces)], axis=0)
        mot_std = np.std(lm_motion[:len(faces)], axis=0)
        lm_motion = (lm_motion - mot_mean) / (mot_std + 1e-9)
        return (org_patch, small_patch, large_patch), (org_pts, small_pts, large_pts), lm_motion, vid_len


    def load_landmark(self, name):
        def normalize(xyz):
            xyz -= np.mean(xyz, axis=0)
            xyz /= np.std(xyz)
            return xyz
            
        files = os.listdir(name)
        files = list(filter(lambda file: file.find('.txt') != -1, files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [normalize(np.loadtxt(os.path.join(name, file))) for file in files]  # [(68, 3), ....]
        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = np.asarray(array[:vid_pad])
        return array   # (T, 68, 3)    


    def load_video(self, name):
        files = os.listdir(name)
        files = list(filter(lambda file: file.endswith('.jpg'), files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [cv2.imread(os.path.join(name, file)) for file in files]    # rgb三通道
        #array = [cv2.imread(os.path.join(name, file), 0) for file in files]  # 单通道 
        #array = list(filter(lambda im: im is not None, array))
        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = array[:vid_pad]
        # array = [cv2.resize(img, (128, 64), interpolation=cv2.INTER_LANCZOS4) for img in array]
        array = [self.transform(img) for img in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array

    def fast_load_video(self, name):
        frames = []
        def read_img(id_, fn):
            img = cv2.imread(fn)
            frames.append((id_, img))

        files = os.listdir(name)
        files = list(filter(lambda file: file.endswith('.jpg'), files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        ths = []
        for i, f in enumerate(files):
            f = os.path.join(name, f)
            th = Thread(target=read_img, args=(i, f, ))
            th.start()
            ths.append(th)
        for th in ths:
            th.join()
        frames = sorted(frames, key=lambda x: x[0]) # 升序
        array = list(map(lambda x: x[1], frames))

        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = array[:vid_pad]
        # array = [cv2.resize(img, (128, 64), interpolation=cv2.INTER_LANCZOS4) for img in array]
        array = [self.transform(img) for img in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array


    def load_align(self, name):
        with open(name, 'r') as f:
            txt = [line.strip().split(' ')[2] for line in f.readlines()]
            txt = list(filter(lambda s: not s.upper() in ['SIL', 'SP'], txt))
        return ' '.join(txt).upper()

    def align2idx(self, text):
        return [BOS] + [self.char_dict[x] for x in text] + [EOS]

    def align_pad(self, align):
        if len(align) == self.opt.max_dec_len + 2:  # including <bos> and <eos> token
            return align
        return align + [PAD] * (self.opt.max_dec_len + 2 - len(align))

    def get_transform(self, phase='train'):
        '''
        torchvision.transforms: 常用的数据预处理方法，提升泛化能力
        包括：数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等
        '''
        # return transforms.Compose([
        #     transforms.Grayscale(),
        #     transforms.Resize((96, 96)),
        #     transforms.TenCrop((88, 88)),    # for testing  (bs, ncrops, c, h, w)
        #     # transforms.CenterCrop((88, 88)),  # for testing
        #     # transforms.RandomCrop((88, 88)),  # for training
        #     transforms.RandomHorizontalFlip(p=0.5),
        #     transforms.Lambda(lambda crops: torch.stack(
        #         [transforms.ToTensor()(crop) for crop in crops])),
        # ])
        if phase == 'train':
            # 灰度图
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Grayscale(),
                #transforms.Resize((64, 128)), # lip ROI
                transforms.Resize((72, 90)), # full face
                #transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),   # [0, 1] (C, H, W)
            ])
            '''
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Resize((64, 128)),  # H, W
                # transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
            '''
        else:
            # 灰度图
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Grayscale(),
                #transforms.Resize((64, 128)),  # lip ROI
                transforms.Resize((72, 90)),  # full face
                #transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
            ])
            '''
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((64, 128)),   # H, W
                # transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
            '''

