import numpy as np
import glob
import cv2
import os
import editdistance
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import time
from threading import Thread


PAD = '<pad>'
BOS = '<bos>'
EOS = '<eos>'
# UNK_TAG = '<unk>'


def get_imgs_from_video(video, ext='jpg', RGB=False):
    frames = []
    if os.path.isdir(video):
        frames = sorted(glob.glob(os.path.join(video, '*.{}'.format(ext))),
                        key=lambda x: int(x.split('/')[-1].split('.')[0]))
        frames = [cv2.imread(f) for f in frames]
    else:
        cap = cv2.VideoCapture(video)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

    frames = np.array(frames)
    if RGB:
        return frames[..., ::-1]
    else:
        return frames


# for CTC from https://github.com/arxrean/LipRead-seq2seq
class GridDataset(Dataset):
    letters = [' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
               'U', 'V', 'W', 'X', 'Y', 'Z']

    def __init__(self, video_path, align_path, file_list, vid_pad, txt_pad, phase='train'):
        self.align_path = align_path  # transcription dir path
        self.video_path = video_path  # video frame dir path
        self.vid_pad = vid_pad
        self.txt_pad = txt_pad
        self.phase = phase

        with open(file_list, 'r') as f:
            self.videos = [os.path.join(video_path, line.strip()) for line in f]

        self.data = []
        for vid in self.videos:
            items = vid.split(os.path.sep)
            self.data.append((vid, items[-1]))

        self.transform = self.get_transform(phase)

    def __getitem__(self, idx):
        (vid, name) = self.data[idx]
        vid = self._load_vid(vid)
        align = self._load_align(os.path.join(self.align_path, name + '.align'))

        vid = self.transform(vid)
        vid_len = vid.shape[0]
        align_len = align.shape[0]
        vid = self._padding(vid, self.vid_pad)
        align = self._padding(align, self.txt_pad)

        return {'vid': torch.FloatTensor(vid.transpose(3, 0, 1, 2)),  # C, T, H, W
                'txt': torch.LongTensor(align),
                'txt_len': align_len,
                'vid_len': vid_len}

    def __len__(self):
        return len(self.data)

    def get_transform(self, phase='train'):
        '''
        torchvision.transforms: 常用的数据预处理方法，提升泛化能力
        包括：数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等
        '''
        # return transforms.Compose([
        #     transforms.Grayscale(),
        #     transforms.Resize((96, 96)),
        #     transforms.TenCrop((88, 88)),    # for testing  (bs, ncrops, c, h, w)
        #     # transforms.CenterCrop((88, 88)),  # for testing
        #     # transforms.RandomCrop((88, 88)),  # for training
        #     transforms.RandomHorizontalFlip(p=0.5),
        #     transforms.Lambda(lambda crops: torch.stack(
        #         [transforms.ToTensor()(crop) for crop in crops])),
        # ])
        if phase == 'train':
            # 灰度图
            # return transforms.Compose([
            #     transforms.Grayscale(),
            #     transforms.Resize((96, 96)),
            #     transforms.RandomCrop((88, 88)),  # for training
            #     transforms.RandomHorizontalFlip(p=0.5),
            #     transforms.ToTensor(),
            # ])
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((96, 96)),
                transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
        else:
            # 灰度图
            # return transforms.Compose([
            #     transforms.Grayscale(),
            #     transforms.Resize((96, 96)),
            #     transforms.CenterCrop((88, 88)),  # for testing
            #     transforms.ToTensor(),
            # ])
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((96, 96)),
                transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])

    def _load_vid(self, p):  # 加载唇动图像序列
        files = os.listdir(p)
        files = list(filter(lambda file: file.find('.jpg') != -1, files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [cv2.imread(os.path.join(p, file)) for file in files]
        array = list(filter(lambda im: im is not None, array))
        array = [cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array

    def _load_align(self, name):
        ''' GRID align样例格式:
            command(4) + color(4) + preposition(4) + letter(25) + digit(10) + adverb(4)
            0 14500 sil   # 静音段
            14500 22500 bin
            22500 30750 blue
            30750 39250 at
            39250 40000 sp  # 静音段
            40000 48500 f
            48500 57500 one
            57500 71000 soon
            71000 74500 sil  # 静音段
        '''
        with open(name, 'r') as f:
            lines = [line.strip().split(' ') for line in f]
            txt = [line[2] for line in lines]
            txt = list(filter(lambda s: not s.upper() in ['SIL', 'SP'], txt))
        # BIN BLUE AT F ONE SOON
        return GridDataset.txt2arr(' '.join(txt).upper(), 1)  # 词之间插入空格(blank label)

    def _padding(self, array, length):  # 对齐到length
        # array = [array[_] for _ in range(array.shape[0])]
        # size = array[0].shape
        # for i in range(length - len(array)):
        #     array.append(np.zeros(size))
        # return np.stack(array, axis=0)
        return np.concatenate([array, np.zeros([length - len(array)] + list(array.shape[1:]))])

    @staticmethod
    def txt2arr(txt, start):
        arr = []
        for c in list(txt):
            arr.append(GridDataset.letters.index(c) + start)
        return np.array(arr)

    @staticmethod
    def arr2txt(arr, start):
        txt = []
        for n in arr:
            if n >= start:
                txt.append(GridDataset.letters[n - start])
        return ''.join(txt).strip()

    @staticmethod
    def ctc_arr2txt(arr, start):
        pre = -1
        txt = []
        for n in arr:
            if pre != n and n >= start:
                if len(txt) > 0 and txt[-1] == ' ' and GridDataset.letters[n - start] == ' ':
                    pass
                else:
                    txt.append(GridDataset.letters[n - start])
            pre = n
        return ''.join(txt).strip()

    @staticmethod
    def wer(predict, truth):
        word_pairs = [(p[0].split(' '), p[1].split(' ')) for p in zip(predict, truth)]
        wer = [editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs]
        return wer

    @staticmethod
    def cer(predict, truth):
        cer = [editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth)]
        return cer


def cutout(img, n_holes=1, scale=0.1):
    '''
        img (Tensor): Gray image of size (H, W).
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    Returns:
        Tensor: Image with n_holes of dimension length x length cut out of it.
    '''
    h, w = img.shape[0], img.shape[1]
    length = int(scale * min(h, w))
    mask = np.ones((h, w), np.uint8)
    for n in range(n_holes):
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - length // 2, 0, h)
        y2 = np.clip(y + length // 2, 0, h)
        x1 = np.clip(x - length // 2, 0, w)
        x2 = np.clip(x + length // 2, 0, w)
        mask[y1: y2, x1: x2] = 0
    return img * mask


def random_erasing(img, rv=np.random.rand(), n_holes=1, scale=0.1):
    '''
        img (Tensor): Gray image of size (H, W).
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    Returns:
        Tensor: Image with n_holes of dimension length x length cut out of it.
    '''
    h, w = img.shape[0], img.shape[1]
    length = int(scale * min(h, w))
    for n in range(n_holes):
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - length // 2, 0, h)
        y2 = np.clip(y + length // 2, 0, h)
        x1 = np.clip(x - length // 2, 0, w)
        x2 = np.clip(x + length // 2, 0, w)
        img[y1: y2, x1: x2] = rv
    return img 


def cosine_angle(a, b):
    if a.ndim == 1 or b.ndim == 1:
        cos_val = a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9)
        #cos_val = np.dot(a/np.linalg.norm(a), b/np.linalg.norm(b))
    else:
        cos_val = (a * b).sum(axis=1) / (np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1) + 1e-9)
    cos_val = np.clip(cos_val, -1., 1.)
    return np.arccos(cos_val)  # / np.pi * 180


def random_crop(img, dw, dh):
    h, w = img.shape[0], img.shape[1]
    h = np.random.randint(0, h - dh)
    w = np.random.randint(0, w - dw)
    return img[h: h+dh, w: w+dw]


def HorizontalFlip(img, p=0.5):
    # (H, W) or (H, W, C)
    return img[:, ::-1, ...] if np.random.rand() < p else img


def time_mask(x, T=25*0.4, n_mask=1):
    cloned = x.copy()
    len_raw = cloned.shape[0]
    for i in range(n_mask):
        t = np.random.randint(0, int(T))
        t0 = np.random.choice(range(0, len_raw - t))
        cloned[t0: t0+t] = cloned.mean()   # or 0.
    return cloned


def adjust_img(img):
    # alpha 对比度调整参数  >1对比度增加，<1对比度减小  (不要超过2)
    # beta 亮度调整参数   >0亮度增加，<0亮度减小
    alpha = np.random.rand() + 0.5   # [0.5, 1.5]
    beta = np.random.randint(-30, 30)  # [-30, 30]
    img_adjusted = abs(alpha * img + beta)
    img_adjusted = np.clip(img_adjusted, 0, 255)
    return img_adjusted.astype(np.uint8)


class GridSeq2Seq(Dataset):
    def __init__(self, opt, phase='train'):
        self.phase = phase
        self.opt = opt
        assert phase in ['train', 'val', 'test', 'pretrain'] 
        ## unseen-speaker setting
        self.data = glob.glob(os.path.join(opt.video_root, 's*', '*'))
        self.train_spks = ['s'+str(i) for i in range(1, 35) if i not in [1, 2, 20, 22, 21]]  # 21 is removed 
        if phase == 'train' or phase == 'pretrain':
            # LipData/GRID/faces-small/s1/swbc3a
            self.cur_data = [x for x in self.data if len(os.listdir(x)) > 70 and x.split(os.path.sep)[-2] in self.train_spks]
        elif phase == 'val':
            self.cur_data = [x for x in self.data if len(os.listdir(x)) > 70 and x.split(os.path.sep)[-2] not in self.train_spks]
            np.random.seed(123)
            np.random.shuffle(self.cur_data)
            self.cur_data = self.cur_data[:opt.val_batch * opt.batch_size]
        else:
            self.cur_data = [x for x in self.data if len(os.listdir(x)) > 70 and x.split(os.path.sep)[-2] not in self.train_spks]
        
        '''         
        ## overlap-speaker setting
        self.train_spks = ['s'+str(i) for i in range(1, 35) if i != 21]  # 21 is removed 
        if phase == 'train' or phase == 'pretrain':
            path = 'overlap_train.txt'
        else:
            path = 'overlap_val.txt'
        head_dir = os.path.split(opt.video_root)[0]   # ../LipData/GRID/ faces-small
        with open(os.path.join(head_dir, path), 'r') as fin:
            self.cur_data = [os.path.join(opt.video_root, x.strip()) for x in fin]
        '''

        print(len(self.cur_data))
        # self.cur_data = list(filter(lambda fn: len(os.listdir(fn)) > 5, self.cur_data))
        # self.char_list = ['^', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
                               # 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '*']
        self.char_dict = [PAD] + [' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + [EOS, BOS]   # 30
        #self.char_dict = {ch: idx for idx, ch in enumerate(self.char_list)}
        #self.idx_dict = {idx: ch for idx, ch in enumerate(self.char_list)}
        #self.transform = self.get_transform(phase)
        self.ps = 24
        print(f'patch size: {self.ps}')
        self.scale = 1

    def adjust_scale(self, i):
        self.scale = [0.5, 1, 2][i % 3]
        print('adjust scale:', self.scale)

    def sample_patch(self, i=0):
        if i < 0:
            i = 0
        #self.ps = list(range(24, 14, -2))[i // 10]  # 大到小
        #self.ps = list(range(16, 26, 2))[i // 10]   # 小到大
        #self.ps = list(range(24, 14, -2))[i % 5]  # 大到小
        self.ps = list(range(16, 26, 2))[i % 5]   # 小到大
        #self.ps = np.random.choice(list(range(16, 26, 2)))
        print('sampled patch size:', self.ps)
        return self.ps 

    def __getitem__(self, idx):
        item = self.cur_data[idx]   # E:\\GRID\\LIP_160x80\\lip\\s26\\prim7p
        align_path = item.replace(self.opt.video_root, self.opt.align_root)
        align_txt = self.load_align('{}.align'.format(align_path))
        align_idx = self.align2idx(align_txt)  # appending [BOS] and [EOS]
        align_len = len(align_idx) - 2   # excluding [BOS] and [EOS] 
        padded_align = self.align_pad(align_idx)
        
        if self.phase == 'pretrain':
            lips, vid_len = self.load_pretrain_landmark_face(item)  # landmark path
            return torch.FloatTensor(lips), padded_align, vid_len, align_len

        #video = self.load_video(item)   # img path
        #video = self.fast_load_video(item)  # img path
        lips, video, pts, motion, vid_len = self.load_landmark_face(item, self.ps)  # landmark path
        if self.phase == 'test':
            #return torch.FloatTensor(video), torch.FloatTensor(pts), torch.FloatTensor(motion), padded_align, vid_len, align_txt, item
            return torch.FloatTensor(lips), torch.FloatTensor(video), torch.FloatTensor(pts), torch.FloatTensor(motion), padded_align, vid_len, align_txt, item
        else:
            spk_id = self.train_spks.index(item.split(os.path.sep)[-2])
            #return torch.FloatTensor(video), torch.FloatTensor(pts), torch.FloatTensor(motion), padded_align, vid_len, align_len        
            return spk_id, torch.FloatTensor(lips), torch.FloatTensor(video), torch.FloatTensor(pts), torch.FloatTensor(motion), padded_align, vid_len, align_len        
       
    def __len__(self):
        return len(self.cur_data)

    def load_landmark_face(self, path, patch_size=16):
        def read_img(id_, fn):
            img = cv2.imread(fn, 0)
            faces.append((id_, img))

        files = os.listdir(path)
        files = list(filter(lambda f: f.find('.xy') != -1, files))  # 2D point
        files = sorted(files, key=lambda f: int(os.path.splitext(f)[0]))
        face_lms = [np.loadtxt(os.path.join(path, f)) for f in files]
        faces = [cv2.imread(os.path.join(path, f.replace('.xy', '.jpg')), 0) for f in files]  # 灰度图
        vid_pad = self.opt.max_vid_len
        vid_len = min(vid_pad, len(faces))
        '''
        faces = []
        ths = []
        for i, f in enumerate(files):
            f = os.path.join(path, f.replace('.xy', '.jpg'))
            th = Thread(target=read_img, args=(i, f, ))
            th.start()
            ths.append(th)
        for th in ths:
            th.join()
        faces = sorted(faces, key=lambda x: x[0])    # 升序
        faces = list(map(lambda x: x[1], faces))
        '''
        N = 20 
        #h, w = 48, 96
        h, w = 1, 1
        lip_crops = np.zeros((vid_pad, 1, h, w), dtype=np.float32)  # H x W = 48 x 80
        lm_patches = np.zeros((vid_pad, N, 1, patch_size, patch_size), dtype=np.float32)   # T x 20 x 1 x 32 x 32
        lm_pts = np.zeros((vid_pad, N, (N-1)*2), dtype=np.float32)   # T x 20 x 38
        dig_mask = (np.ones((N, N)) - np.eye(N)).astype(bool)
        lm_motion = np.zeros((vid_pad, 44), dtype=np.float32)   # T x 16
        
        #ps = 32
        ps = np.random.choice(list(range(16, 34, 2))) if self.phase == 'train' and np.random.rand() < 0.5 else patch_size
        #ps = np.random.choice(list(range(16, 34, 4))) if self.phase == 'train' and np.random.rand() < 0.5 else patch_size
        #lst = list(range(16, 34, 2))
        #mu, sigma = len(lst)//2, 1
        #ws = np.exp(-0.5 * ((np.arange(len(lst)) - mu) / sigma) ** 2)
        #ws /= np.sum(ws)
        #ps = np.random.choice(lst, p=ws)
        for i, (face, lms) in enumerate(zip(faces, face_lms)):
            if i >= vid_pad:
                break
            #face = (face - self.video_mean) / self.video_std
            #if self.phase == 'train' and scale != 1 and np.random.rand() < 0.5:
                #face = cv2.resize(face, dsize=None, fx=scale, fy=scale)
                #lms = lms * scale
            #if self.phase == 'train' and np.random.rand() < 0.2: 
                #face = adjust_img(face)  # 亮度对比度增强
                #face = cutout(face, 1, min(face.shape[0], face.shape[1])//4)
            #face = face / 255.
            lip_lms = lms[-N:].copy()
            #lip_lms = lms[-20:-8]  # 外唇
            #lip_lms = lms[-8:]  # 内唇
            #lip_lms = np.vstack([lms[-20:], lms[:17]])

            #center_lm = np.mean(lip_lms, axis=0)
            #lip_crop = face[max(0, round(center_lm[1] - h / 2)): min(round(center_lm[1] + h / 2), face.shape[0]),
            #               max(0, round(center_lm[0] - w / 2)): min(round(center_lm[0] + w / 2), face.shape[1])].copy()
            #if self.phase == 'train':
            #    lip_crop = HorizontalFlip(lip_crop, 0.5)
            #lip_crops[i, :, :lip_crop.shape[0], :lip_crop.shape[1]] = lip_crop[None]

            #rv = np.random.rand()            
            for j, (x, y) in enumerate(lip_lms):
                lx, ly = max(0, int(x - ps/2)), max(0, int(y - ps/2))   # w, h
                rx, ry = min(face.shape[1], lx+ps), min(face.shape[0], ly+ps)
                patch = face[ly: ry, lx: rx].copy()
                if ps != patch_size:
                    patch = cv2.resize(patch, (patch_size, patch_size))
                #if self.phase == 'train' and np.random.rand() < 0.1:
                    #patch = cutout(patch, 1, 0.4)
                    #patch = random_erasing(patch, rv, 1, 0.4)
                #elif self.phase == 'train' and np.random.rand() < 0.1:
                #    patch = np.zeros((patch_size, patch_size), dtype=np.float32)
                lm_patches[i, j, :, :patch.shape[0], :patch.shape[1]] = patch[None] / 255.

            pts = lip_lms[:, :2][:, None] - lip_lms[:, :2]   # 20 x 20 x 2
            lm_pts[i] = pts[dig_mask, :].reshape(N, -1)    # 20 x (19 x 2)

            #dout = np.linalg.norm(lm[50:53, :2] - lm[56:59, :2][::-1], axis=-1)   # outer dist 3  
            #din = np.linalg.norm(lm[61:64, :2] - lm[65:68, :2][::-1], axis=-1)   # inner dist 3
            #hwr = np.array([
            #    np.linalg.norm(lms[51, :2] - lms[57, :2], axis=-1) / np.linalg.norm(lms[48, :2] - lms[54, :2], axis=-1),
            #    np.linalg.norm(lms[62, :2] - lms[66, :2], axis=-1) / np.linalg.norm(lms[60, :2] - lms[64, :2], axis=-1)])
            dist = np.array([
                    np.linalg.norm(lms[51, :2] - lms[57, :2], axis=-1), # height
                    np.linalg.norm(lms[62, :2] - lms[66, :2], axis=-1),
                    np.linalg.norm(lms[48, :2] - lms[54, :2], axis=-1), # width
                    np.linalg.norm(lms[60, :2] - lms[64, :2], axis=-1), 
                ])
            '''
            lip_angle = np.array([
                    cosine_angle(lms[51, :2] - lms[48, :2], lms[57, :2] - lms[48, :2]),     # left outer  1  
                    cosine_angle(lms[62, :2] - lms[60, :2], lms[66, :2] - lms[60, :2]),     # left inner  1
                    cosine_angle(lms[51, :2] - lms[54, :2], lms[57, :2] - lms[54, :2]),     # right outer  1  
                    cosine_angle(lms[62, :2] - lms[64, :2], lms[66, :2] - lms[64, :2]),     # right inner  1
                    cosine_angle(lms[48, :2] - lms[51, :2], lms[54, :2] - lms[51, :2]),     # upper outer 1
                    cosine_angle(lms[60, :2] - lms[62, :2], lms[64, :2] - lms[62, :2]),     # upper inner 1
                    cosine_angle(lms[48, :2] - lms[57, :2], lms[54, :2] - lms[57, :2]),     # lower outer 1
                    cosine_angle(lms[60, :2] - lms[66, :2], lms[64, :2] - lms[66, :2])      # lower inner 1
                ])
            '''
            coord = (lip_lms - lms[33, :2]).reshape(-1)  # 为了缓解头动带来的偏差，以鼻尖坐标为归一化的原点
            motion = np.concatenate((dist, coord))
            #motion = np.concatenate((dist, lip_angle, coord))
            #chin_angle1 = cosine_angle(lms[48, :2] - lms[:17, :2], lms[54, :2] - lms[:17, :2])   # 17
            #chin_angle2 = cosine_angle(lms[60, :2] - lms[:17, :2], lms[64, :2] - lms[:17, :2])   # 17
            #motion = np.concatenate((lip_angle, chin_angle1, chin_angle2))
            #chin_angle1 = cosine_angle(lms[1:17, :2] - lms[48, :2], lms[:16, :2] - lms[48, :2])   # 16
            #chin_angle2 = cosine_angle(lms[1:17, :2] - lms[54, :2], lms[:16, :2] - lms[54, :2])   # 16
            #chin_angle3 = cosine_angle(lms[1:17, :2] - lms[51, :2], lms[:16, :2] - lms[51, :2])   # 16
            #chin_angle4 = cosine_angle(lms[1:17, :2] - lms[57, :2], lms[:16, :2] - lms[57, :2])   # 16
            #motion = np.concatenate((lip_angle, chin_angle1, chin_angle2, chin_angle3, chin_angle4))
            if i != 0:
              lm_motion[i] = motion - pre
            pre = motion

        #if self.phase == 'train' and np.random.rand() < 0.2:
        #    lm_patches = time_mask(lm_patches, 25*0.3)  # 0.3s masking

        #np.savetxt('2.txt', lm_motion, delimiter=',', fmt='%.3f')
        mot_mean = np.mean(lm_motion[:vid_len], axis=0)
        mot_std = np.std(lm_motion[:vid_len], axis=0)
        lm_motion[:vid_len] = (lm_motion[:vid_len] - mot_mean) / (mot_std + 1e-9)
        return lip_crops, lm_patches, lm_pts, lm_motion, vid_len


    def load_pretrain_landmark_face(self, path):
        def read_img(id_, fn):
            img = cv2.imread(fn, 0)
            faces.append((id_, img))
        
        files = os.listdir(path)
        files = list(filter(lambda f: f.find('.xy') != -1, files))  # 2D point
        files = sorted(files, key=lambda f: int(os.path.splitext(f)[0]))
        face_lms = [np.loadtxt(os.path.join(path, f)) for f in files]
        faces = [cv2.imread(os.path.join(path, f.replace('.xy', '.jpg')), 0) for f in files]  # 灰度图
        vid_pad = self.opt.max_vid_len
        vid_len = min(vid_pad, len(faces))

        '''
        faces = []
        ths = []
        for i, f in enumerate(files):
            f = os.path.join(path, f.replace('.xy', '.jpg'))
            th = Thread(target=read_img, args=(i, f, ))
            th.start()
            ths.append(th)
        for th in ths:
            th.join()
        faces = sorted(faces, key=lambda x: x[0])    # 升序
        faces = list(map(lambda x: x[1], faces))
        '''
        N = 20
        h, w = 48, 96
        lip_crops = np.zeros((vid_pad, 1, h, w), dtype=np.float32)  # H x W = 48 x 80
        for i, (face, lms) in enumerate(zip(faces, face_lms)):
            if i >= vid_pad:
                break
            face = face / 255.
            lip_lms = lms[-N:]
            center_lm = np.mean(lip_lms, axis=0)
            lip_crop = face[max(0, round(center_lm[1] - h / 2)): min(round(center_lm[1] + h / 2), face.shape[0]),
                           max(0, round(center_lm[0] - w / 2)): min(round(center_lm[0] + w / 2), face.shape[1])].copy()
            lip_crop = HorizontalFlip(lip_crop, 0.5)
            if np.random.rand() < 0.2:
                lip_crop = random_erasing(lip_crop, np.random.rand(), 1, 0.2)
            lip_crops[i, :, :lip_crop.shape[0], :lip_crop.shape[1]] = lip_crop[None]
        return lip_crops, vid_len


    def load_landmark(self, name):
        def normalize(xyz):
            xyz -= np.mean(xyz, axis=0)
            xyz /= np.std(xyz)
            return xyz
            
        files = os.listdir(name)
        files = list(filter(lambda file: file.find('.txt') != -1, files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [normalize(np.loadtxt(os.path.join(name, file))) for file in files]  # [(68, 3), ....]
        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = np.asarray(array[:vid_pad])
        return array   # (T, 68, 3)    


    def load_video(self, name):
        files = os.listdir(name)
        files = list(filter(lambda file: file.endswith('.jpg'), files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        array = [cv2.imread(os.path.join(name, file)) for file in files]    # rgb三通道
        #array = [cv2.imread(os.path.join(name, file), 0) for file in files]  # 单通道 
        #array = list(filter(lambda im: im is not None, array))
        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = array[:vid_pad]
        # array = [cv2.resize(img, (128, 64), interpolation=cv2.INTER_LANCZOS4) for img in array]
        array = [self.transform(img) for img in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array

    def fast_load_video(self, name):
        frames = []
        def read_img(id_, fn):
            img = cv2.imread(fn)
            frames.append((id_, img))

        files = os.listdir(name)
        files = list(filter(lambda file: file.endswith('.jpg'), files))
        files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
        ths = []
        for i, f in enumerate(files):
            f = os.path.join(name, f)
            th = Thread(target=read_img, args=(i, f, ))
            th.start()
            ths.append(th)
        for th in ths:
            th.join()
        frames = sorted(frames, key=lambda x: x[0]) # 升序
        array = list(map(lambda x: x[1], frames))

        vid_pad = self.opt.max_vid_len
        if len(array) < vid_pad:
            array = np.concatenate([array, np.zeros([vid_pad-len(array)] + list(array[0].shape)).astype(np.uint8)])
        else:
            array = array[:vid_pad]
        # array = [cv2.resize(img, (128, 64), interpolation=cv2.INTER_LANCZOS4) for img in array]
        array = [self.transform(img) for img in array]
        array = np.stack(array, axis=0).astype(np.float32)
        return array


    def load_align(self, name):
        with open(name, 'r') as f:
            txt = [line.strip().split(' ')[2] for line in f]
            txt = list(filter(lambda s: not s.upper() in ['SIL', 'SP'], txt))
        return ' '.join(txt).upper()

    def align2idx(self, text):
        return [self.char_dict.index(BOS)] + [self.char_dict.index(x) for x in text] + [self.char_dict.index(EOS)]

    def align_pad(self, align):
        if len(align) == self.opt.max_dec_len:  
            return align
        pad_align = align + [self.char_dict.index(PAD)] * (self.opt.max_dec_len - len(align))
        return np.asarray(pad_align)

    def get_transform(self, phase='train'):
        '''
        torchvision.transforms: 常用的数据预处理方法，提升泛化能力
        包括：数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换等
        '''
        # return transforms.Compose([
        #     transforms.Grayscale(),
        #     transforms.Resize((96, 96)),
        #     transforms.TenCrop((88, 88)),    # for testing  (bs, ncrops, c, h, w)
        #     # transforms.CenterCrop((88, 88)),  # for testing
        #     # transforms.RandomCrop((88, 88)),  # for training
        #     transforms.RandomHorizontalFlip(p=0.5),
        #     transforms.Lambda(lambda crops: torch.stack(
        #         [transforms.ToTensor()(crop) for crop in crops])),
        # ])
        if phase == 'train':
            # 灰度图
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Grayscale(),
                #transforms.Resize((64, 128)), # lip ROI
                transforms.Resize((72, 90)), # full face
                #transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),   # [0, 1] (C, H, W)
            ])
            '''
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Resize((64, 128)),  # H, W
                # transforms.RandomCrop((88, 88)),  # for training
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
            '''
        else:
            # 灰度图
            return transforms.Compose([
                transforms.ToPILImage(),  # for 2 or 3 dimensional
                transforms.Grayscale(),
                #transforms.Resize((64, 128)),  # lip ROI
                transforms.Resize((72, 90)),  # full face
                #transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
            ])
            '''
            # RGB图 (3通道)
            return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((64, 128)),   # H, W
                # transforms.CenterCrop((88, 88)),  # for testing
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])  # 逐channel的对图像进行标准化(均值变为0，标准差变为1)，可以加快模型的收敛
            ])
            '''

