import os
import json
import csv
import glob
import numpy as np
from tqdm import tqdm
from .utils import DataProcessor
from .utils import InputPairWiseExample, InputHeadExample, InputAbductiveExample
from .img_utils import read_img_from_filename, resize_img_skimage, skimage_save


FRAMES_ROOT = "jpg_resized_256"  # "jpg"


class MPIIMoviePairWiseProcessor(DataProcessor):
    """Processor for MPIIMovie Steps Dataset, pair-wise data.
    Args:
        data_dir: string. Root directory for the dataset.
        order_criteria: The criteria of determining if a pair is ordered or not.
            "tight" means only strictly consecutive pairs are considered as
            ordered, "loose" means ancestors also count.
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
        annotation_type: choose from ["original", "someone"]
        allow_time_interval: int, seconds of the time intervals.
        split_ratio: list, splitting the data according to movies
    """

    def __init__(self, data_dir=None, order_criteria="tight",
                 paired_with_image=True,
                 min_story_length=5, max_story_length=5,
                 annotation_type="original", allow_time_interval=300,
                 split_ratio=[80, 10, 10]):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/mpii_movie"
        assert order_criteria in ["tight", "loose"]
        assert annotation_type in ["original", "someone"]
        self.annotation_type = annotation_type
        self.order_criteria = order_criteria
        self.paired_with_image = paired_with_image
        self.split_ratio = split_ratio
        self.allow_time_interval = allow_time_interval

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.split_movies()

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def split_movies(self):
        self.mpii_root = self.data_dir
        self.imgs_root = os.path.join(self.mpii_root, "frames", FRAMES_ROOT)
        self.annotations_csv = os.path.join(self.mpii_root,
            "annotations-{}.csv".format(self.annotation_type))

        self.annotations_raw = {}
        with open(self.annotations_csv, "r") as csv_file:
            csv_reader = csv.reader(csv_file, delimiter="\t")
            line_cnt = 0
            for row in csv_reader:
                movie_id, script = row[0], row[1]
                line_cnt += 1
                movie_name = movie_id.split(".")[0][:-3]
                if movie_name not in self.annotations_raw:
                    self.annotations_raw[movie_name] = []
                self.annotations_raw[movie_name].append((movie_id.strip(),
                                                         script.strip()))
            print("There are total {} lines".format(line_cnt))
        pass

        all_movie_names = sorted(list(self.annotations_raw.keys()))
        print("There are in total {} movies.".format(len(all_movie_names)))

        num_movies = len(all_movie_names)

        assert sum(self.split_ratio) == 100, "Ratio needs to sum to 100!"
        train_num = int(num_movies * self.split_ratio[0] / 100)
        dev_num = int(num_movies * self.split_ratio[1] / 100)
        train_movies = all_movie_names[:train_num]
        dev_movies = all_movie_names[train_num:train_num+dev_num]
        test_movies = all_movie_names[train_num+dev_num:]

        self.annotations = {}
        self.annotations["train"] = {m: self.annotations_raw[m]
                                     for m in train_movies}
        self.annotations["val"] = {m: self.annotations_raw[m]
                                     for m in dev_movies}
        self.annotations["test"] = {m: self.annotations_raw[m]
                                     for m in test_movies}
        print("There are {} movies in train".format(
            len(self.annotations["train"])))
        print("There are {} movies in dev".format(
            len(self.annotations["val"])))
        print("There are {} movies in test".format(
            len(self.annotations["test"])))
        pass

    def _read_image_paths(self, split="train"):
        img_paths_dict = {}
        for movie in self.annotations[split]:
            movie_scripts = self.annotations[split][movie]
            for info in movie_scripts:
                movie_id, script = info
                img_dir = os.path.join(self.imgs_root, movie, movie_id)
                img_paths = glob.glob(os.path.join(img_dir, "*"))
                img_paths_dict[movie_id] = sorted(img_paths)
        return img_paths_dict
        
    def _read_csv(self, data_dir=None, split="train"):
        """Reads in json lines to create the dataset."""
        if data_dir is None:
            data_dir = self.data_dir
        else:
            self.data_dir = data_dir
            self.split_movies()

        image_paths = self._read_image_paths(split=split)
        data = self.annotations[split]
        
        story_seqs = []

        def get_time_interval(movie_id):
            time_int = movie_id.split("_")[-1].split("-")
            time_start, time_end = time_int
            # We use time end.
            time_ends = time_end.split(".")
            end_hh, end_mm, end_ss, end_ms = (int(time_ends[0]),
                int(time_ends[1]), int(time_ends[2]), int(time_ends[3]))
            seconds = end_hh * 3600 + end_mm * 60 + end_ss + end_ms / 100
            return seconds

        # Each element in a story seq is (text, image) tuple.
        for movie in sorted(data):
            for i in range(len(data[movie])-self.max_story_length):
                movie_id, script_info = data[movie][i]
                curr_secs = get_time_interval(movie_id)
                story_seq = [movie_id+"_{}".format(i)]
                if self.paired_with_image:
                    # TODO: Currently only takes the 1st image.
                    if len(image_paths[movie_id]) > 0:
                        image_path = image_paths[movie_id][0]
                    else:
                        image_path = ""
                    script_info = (script_info, image_path)
                else:
                    script_info = (script_info, )
                story_seq.append(script_info)
                for j in range(i+1, len(data[movie])):
                    if j - i >= self.max_story_length:
                        break
                    movie_id, script_info = data[movie][j]
                    next_secs = get_time_interval(movie_id)
                    if next_secs - curr_secs >= self.allow_time_interval:
                        break
                    if self.paired_with_image:
                        # TODO: Currently only takes the 1st image.
                        if len(image_paths[movie_id]) > 0:
                            image_path = image_paths[movie_id][0]
                        else:
                            image_path = ""
                        script_info = (script_info, image_path)
                    else:
                        script_info = (script_info, )
                    story_seq.append(script_info)

                if (len(story_seq) >= self.min_story_length + 1
                    and len(story_seq) <= self.max_story_length + 1):
                    story_seqs.append(story_seq)

        print("There are {} valid story sequences in {}".format(
              len(story_seqs), split))

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        paired_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq):
                for j in range(0, len_seq):
                    if i == j:
                        continue
                    if self.order_criteria == "tight":
                        if j == i + 1:
                            label = "ordered"
                        else:
                            label = "unordered"
                    elif self.order_criteria == "loose":
                        if j > i:
                            label = "ordered"
                        else:
                            label = "unordered"
                    guid = "{}_{}{}".format(story_id, i+1, j+1)
                    text_a = story_seq[i][0]
                    text_b = story_seq[j][0]
                    img_path_a = story_seq[i][1]
                    img_path_b = story_seq[j][1]
                    distance = abs(j - i)
                    example = InputPairWiseExample(guid=guid, text_a=text_a,
                                                   text_b=text_b, label=label,
                                                   img_path_a=img_path_a,
                                                   img_path_b=img_path_b,
                                                   distance=distance)
                    paired_examples.append(example)
        return paired_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="test")
        return self._create_examples(lines)


class MPIIMovieAbductiveProcessor(DataProcessor):
    """Processor for MPIIMovie Steps Dataset, abductive data.
    Args:
        data_dir: string. Root directory for the dataset.
        pred_method: the method of the predictions, can be binary or
            contrastive 
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
        annotation_type: choose from ["original", "someone"]
        allow_time_interval: int, seconds of the time intervals.
        split_ratio: list, splitting the data according to movies
    """

    def __init__(self, data_dir=None, pred_method="binary",
                 paired_with_image=True,
                 min_story_length=5, max_story_length=5,
                 annotation_type="original", allow_time_interval=300,
                 split_ratio=[80, 10, 10]):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/mpii_movie"
        assert pred_method in ["binary", "contrastive"]
        assert annotation_type in ["original", "someone"]
        self.annotation_type = annotation_type
        self.pred_method = pred_method
        self.paired_with_image = paired_with_image
        self.split_ratio = split_ratio
        self.allow_time_interval = allow_time_interval

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.split_movies()

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def split_movies(self):
        self.mpii_root = self.data_dir
        self.imgs_root = os.path.join(self.mpii_root, "frames", FRAMES_ROOT)
        self.annotations_csv = os.path.join(self.mpii_root,
            "annotations-{}.csv".format(self.annotation_type))

        self.annotations_raw = {}
        with open(self.annotations_csv, "r") as csv_file:
            csv_reader = csv.reader(csv_file, delimiter="\t")
            line_cnt = 0
            for row in csv_reader:
                movie_id, script = row[0], row[1]
                line_cnt += 1
                movie_name = movie_id.split(".")[0][:-3]
                if movie_name not in self.annotations_raw:
                    self.annotations_raw[movie_name] = []
                self.annotations_raw[movie_name].append((movie_id.strip(),
                                                         script.strip()))
            print("There are total {} lines".format(line_cnt))
        pass

        all_movie_names = sorted(list(self.annotations_raw.keys()))
        print("There are in total {} movies.".format(len(all_movie_names)))

        num_movies = len(all_movie_names)

        assert sum(self.split_ratio) == 100, "Ratio needs to sum to 100!"
        train_num = int(num_movies * self.split_ratio[0] / 100)
        dev_num = int(num_movies * self.split_ratio[1] / 100)
        train_movies = all_movie_names[:train_num]
        dev_movies = all_movie_names[train_num:train_num+dev_num]
        test_movies = all_movie_names[train_num+dev_num:]

        self.annotations = {}
        self.annotations["train"] = {m: self.annotations_raw[m]
                                     for m in train_movies}
        self.annotations["val"] = {m: self.annotations_raw[m]
                                     for m in dev_movies}
        self.annotations["test"] = {m: self.annotations_raw[m]
                                     for m in test_movies}
        print("There are {} movies in train".format(
            len(self.annotations["train"])))
        print("There are {} movies in dev".format(
            len(self.annotations["val"])))
        print("There are {} movies in test".format(
            len(self.annotations["test"])))
        pass

    def _read_image_paths(self, split="train"):
        img_paths_dict = {}
        for movie in self.annotations[split]:
            movie_scripts = self.annotations[split][movie]
            for info in movie_scripts:
                movie_id, script = info
                img_dir = os.path.join(self.imgs_root, movie, movie_id)
                img_paths = glob.glob(os.path.join(img_dir, "*"))
                img_paths_dict[movie_id] = sorted(img_paths)
        return img_paths_dict
        
    def _read_csv(self, data_dir=None, split="train"):
        """Reads in json lines to create the dataset."""
        if data_dir is None:
            data_dir = self.data_dir
        else:
            self.data_dir = data_dir
            self.split_movies()

        image_paths = self._read_image_paths(split=split)
        data = self.annotations[split]
        
        story_seqs = []

        def get_time_interval(movie_id):
            time_int = movie_id.split("_")[-1].split("-")
            time_start, time_end = time_int
            # We use time end.
            time_ends = time_end.split(".")
            end_hh, end_mm, end_ss, end_ms = (int(time_ends[0]),
                int(time_ends[1]), int(time_ends[2]), int(time_ends[3]))
            seconds = end_hh * 3600 + end_mm * 60 + end_ss + end_ms / 100
            return seconds

        # Each element in a story seq is (text, image) tuple.
        for movie in sorted(data):
            for i in range(len(data[movie])-self.max_story_length):
                movie_id, script_info = data[movie][i]
                curr_secs = get_time_interval(movie_id)
                story_seq = [movie_id+"_{}".format(i)]
                if self.paired_with_image:
                    # TODO: Currently only takes the 1st image.
                    if len(image_paths[movie_id]) > 0:
                        image_path = image_paths[movie_id][0]
                    else:
                        image_path = ""
                    script_info = (script_info, image_path)
                else:
                    script_info = (script_info, )
                story_seq.append(script_info)
                for j in range(i+1, len(data[movie])):
                    if j - i >= self.max_story_length:
                        break
                    movie_id, script_info = data[movie][j]
                    next_secs = get_time_interval(movie_id)
                    if next_secs - curr_secs >= self.allow_time_interval:
                        break
                    if self.paired_with_image:
                        # TODO: Currently only takes the 1st image.
                        if len(image_paths[movie_id]) > 0:
                            image_path = image_paths[movie_id][0]
                        else:
                            image_path = ""
                        script_info = (script_info, image_path)
                    else:
                        script_info = (script_info, )
                    story_seq.append(script_info)

                if (len(story_seq) >= self.min_story_length + 1
                    and len(story_seq) <= self.max_story_length + 1):
                    story_seqs.append(story_seq)

        print("There are {} valid story sequences in {}".format(
              len(story_seqs), split))

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        abd_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq-2):
                all_seq_idx = set(list(range(len_seq)))
                curr_seq_idx = set(list(range(i, i+3)))
                left_seq_idx = list(all_seq_idx - curr_seq_idx)
                curr_seq_idx = list(curr_seq_idx)

                for k in left_seq_idx:
                    abd_idx = [curr_seq_idx[0]] + [k] + [curr_seq_idx[1]]
                    text_h1 = story_seq[abd_idx[0]][0]
                    text_h2 = story_seq[abd_idx[1]][0]
                    text_h3 = story_seq[abd_idx[2]][0]
                    img_path_h1 = story_seq[abd_idx[0]][1]
                    img_path_h2 = story_seq[abd_idx[1]][1]
                    img_path_h3 = story_seq[abd_idx[2]][1]
                    if self.pred_method == "binary":
                        label = "unordered"
                    guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                              abd_idx[1], abd_idx[2])
                    example = InputAbductiveExample(guid=guid, label=label,
                                                    text_h1=text_h1,
                                                    text_h2=text_h2,
                                                    text_h3=text_h3,
                                                    img_path_h1=img_path_h1,
                                                    img_path_h2=img_path_h2,
                                                    img_path_h3=img_path_h3)
                    abd_examples.append(example)

                abd_idx = curr_seq_idx
                text_h1 = story_seq[abd_idx[0]]
                text_h2 = story_seq[abd_idx[1]]
                text_h3 = story_seq[abd_idx[2]]
                img_path_h1 = story_seq[abd_idx[0]][1]
                img_path_h2 = story_seq[abd_idx[1]][1]
                img_path_h3 = story_seq[abd_idx[2]][1]
                if self.pred_method == "binary":
                    label = "ordered"
                guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                          abd_idx[1], abd_idx[2])
                example = InputAbductiveExample(guid=guid, label=label,
                                                text_h1=text_h1,
                                                text_h2=text_h2,
                                                text_h3=text_h3,
                                                img_path_h1=img_path_h1,
                                                img_path_h2=img_path_h2,
                                                img_path_h3=img_path_h3)
                abd_examples.append(example)
        return abd_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="test")
        return self._create_examples(lines)


class MPIIMovieGeneralProcessor(DataProcessor):
    """Processor for MPIIMovie Steps Dataset, general sorting prediction.
    Args:
        data_dir: string. Root directory for the dataset.
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
        annotation_type: choose from ["original", "someone"]
        allow_time_interval: int, seconds of the time intervals.
        split_ratio: list, splitting the data according to movies
    """

    def __init__(self, data_dir=None, max_story_length=5, pure_class=False,
                 paired_with_image=True, min_story_length=5,
                 annotation_type="original", allow_time_interval=300,
                 split_ratio=[80, 10, 10]):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/mpii_movie"
        self.max_story_length = max_story_length
        self.pure_class = pure_class
        self.paired_with_image = paired_with_image
        self.split_ratio = split_ratio
        assert annotation_type in ["original", "someone"]
        self.annotation_type = annotation_type
        self.allow_time_interval = allow_time_interval

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.split_movies()

    def get_labels(self):
        """See base class."""
        if self.pure_class:
            n = self.max_story_length
            fact = 1
            for i in range(1, n+1):
                fact = fact * i
            labels = [0 for i in range(fact)]
            return labels

        return list(range(self.max_story_length))

    def split_movies(self):
        self.mpii_root = self.data_dir
        self.imgs_root = os.path.join(self.mpii_root, "frames", FRAMES_ROOT)
        self.annotations_csv = os.path.join(self.mpii_root,
            "annotations-{}.csv".format(self.annotation_type))

        self.annotations_raw = {}
        with open(self.annotations_csv, "r") as csv_file:
            csv_reader = csv.reader(csv_file, delimiter="\t")
            line_cnt = 0
            for row in csv_reader:
                movie_id, script = row[0], row[1]
                line_cnt += 1
                movie_name = movie_id.split(".")[0][:-3]
                if movie_name not in self.annotations_raw:
                    self.annotations_raw[movie_name] = []
                self.annotations_raw[movie_name].append((movie_id.strip(),
                                                         script.strip()))
            print("There are total {} lines".format(line_cnt))
        pass

        all_movie_names = sorted(list(self.annotations_raw.keys()))
        print("There are in total {} movies.".format(len(all_movie_names)))

        num_movies = len(all_movie_names)

        assert sum(self.split_ratio) == 100, "Ratio needs to sum to 100!"
        train_num = int(num_movies * self.split_ratio[0] / 100)
        dev_num = int(num_movies * self.split_ratio[1] / 100)
        train_movies = all_movie_names[:train_num]
        dev_movies = all_movie_names[train_num:train_num+dev_num]
        test_movies = all_movie_names[train_num+dev_num:]

        self.annotations = {}
        self.annotations["train"] = {m: self.annotations_raw[m]
                                     for m in train_movies}
        self.annotations["val"] = {m: self.annotations_raw[m]
                                     for m in dev_movies}
        self.annotations["test"] = {m: self.annotations_raw[m]
                                     for m in test_movies}
        print("There are {} movies in train".format(
            len(self.annotations["train"])))
        print("There are {} movies in dev".format(
            len(self.annotations["val"])))
        print("There are {} movies in test".format(
            len(self.annotations["test"])))
        pass

    def _read_image_paths(self, split="train"):
        img_paths_dict = {}
        for movie in self.annotations[split]:
            movie_scripts = self.annotations[split][movie]
            for info in movie_scripts:
                movie_id, script = info
                img_dir = os.path.join(self.imgs_root, movie, movie_id)
                img_paths = glob.glob(os.path.join(img_dir, "*"))
                img_paths_dict[movie_id] = sorted(img_paths)
        return img_paths_dict
        
    def _read_csv(self, data_dir=None, split="train"):
        """Reads in json lines to create the dataset."""
        if data_dir is None:
            data_dir = self.data_dir
        else:
            self.data_dir = data_dir
            self.split_movies()

        image_paths = self._read_image_paths(split=split)
        data = self.annotations[split]
        
        story_seqs = []

        def get_time_interval(movie_id):
            time_int = movie_id.split("_")[-1].split("-")
            time_start, time_end = time_int
            # We use time end.
            time_ends = time_end.split(".")
            end_hh, end_mm, end_ss, end_ms = (int(time_ends[0]),
                int(time_ends[1]), int(time_ends[2]), int(time_ends[3]))
            seconds = end_hh * 3600 + end_mm * 60 + end_ss + end_ms / 100
            return seconds

        # Each element in a story seq is (text, image) tuple.
        for movie in sorted(data):
            for i in range(len(data[movie])-self.max_story_length):
                movie_id, script_info = data[movie][i]
                curr_secs = get_time_interval(movie_id)
                story_seq = [movie_id+"_{}".format(i)]
                if self.paired_with_image:
                    # TODO: Currently only takes the 1st image.
                    if len(image_paths[movie_id]) > 0:
                        image_path = image_paths[movie_id][0]
                    else:
                        image_path = ""
                    script_info = (script_info, image_path)
                else:
                    script_info = (script_info, )
                story_seq.append(script_info)
                for j in range(i+1, len(data[movie])):
                    if j - i >= self.max_story_length:
                        break
                    movie_id, script_info = data[movie][j]
                    next_secs = get_time_interval(movie_id)
                    if next_secs - curr_secs >= self.allow_time_interval:
                        break
                    if self.paired_with_image:
                        # TODO: Currently only takes the 1st image.
                        if len(image_paths[movie_id]) > 0:
                            image_path = image_paths[movie_id][0]
                        else:
                            image_path = ""
                        script_info = (script_info, image_path)
                    else:
                        script_info = (script_info, )
                    story_seq.append(script_info)

                if (len(story_seq) >= self.min_story_length + 1
                    and len(story_seq) <= self.max_story_length + 1):
                    story_seqs.append(story_seq)

        print("There are {} valid story sequences in {}".format(
              len(story_seqs), split))

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        head_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            guid = story_id
            text_seq = [x[0] for x in story_seq]
            img_path_seq = [x[1] for x in story_seq]
            example = InputHeadExample(guid=guid, text_seq=text_seq,
                                       img_path_seq=img_path_seq)
            head_examples.append(example)
        return head_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_csv(data_dir=data_dir, split="test")
        return self._create_examples(lines)


def resize_selected_frames(imgs_read_dir, imgs_write_dir, new_size,
                           selected_frames=None):
    annotations_csv = "data/mpii_movie/annotations-original.csv"
    if selected_frames is not None:
        if type(selected_frames) == int:
            selected_frames = [selected_frames]
        elif type(selected_frames) == list:
            pass
        else:
            raise TypeError("selected_frames"
                " should not be of type {}".format(type(selected_frames)))

    assert type(new_size) == tuple

    # Count lines.
    with open(annotations_csv, "r") as csv_file:
        csv_reader = csv.reader(csv_file, delimiter="\t")
        line_cnt = 0
        for row in csv_reader:
            line_cnt += 1

    with open(annotations_csv, "r") as csv_file:
        csv_reader = csv.reader(csv_file, delimiter="\t")
        for row in tqdm(csv_reader, total=line_cnt):
            movie_id, script = row[0], row[1]
            movie_name = movie_id.split(".")[0][:-3]
            movie_dir = os.path.join(imgs_write_dir, movie_name)
            if not os.path.exists(movie_dir):
                os.makedirs(movie_dir)
            script_dir = os.path.join(movie_dir, movie_id)
            if not os.path.exists(script_dir):
                os.makedirs(script_dir)

            imgs_dir = os.path.join(imgs_read_dir, movie_name, movie_id)
            img_paths = sorted(glob.glob(os.path.join(imgs_dir, "*")))
            if selected_frames is not None:
                selected_frames_len = min(len(selected_frames),
                                          len(img_paths))
                selected_frames_curr = selected_frames[:selected_frames_len]
                img_paths = [img_paths[x] for x in selected_frames_curr]

            for img_path in img_paths:
                img = read_img_from_filename(img_path)
                img = resize_img_skimage(img, new_size[0], new_size[1])
                img_name = img_path.split("/")[-1]
                save_path = os.path.join(script_dir, img_name)
                skimage_save(save_path, img, no_lossy=True)

    print("There are total {} lines".format(line_cnt))


if __name__ == "__main__":

    # imgs_read_dir = "data/mpii_movie/frames/jpg"
    # imgs_write_dir = "data/mpii_movie/frames/jpg_resized_256"
    # new_size = (256, 256)
    # resize_selected_frames(imgs_read_dir, imgs_write_dir, new_size,
    #                        selected_frames=0)

    proc = MPIIMoviePairWiseProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    print(train_examples[0])
    print()

    proc = MPIIMovieGeneralProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    print(train_examples[0])
    print()
    
    proc = MPIIMovieAbductiveProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    print(train_examples[0])
    print()
