import os
import sys
import json
import csv
import glob
import numpy as np
import random
import argparse
from tqdm import tqdm
from .utils import DataProcessor
from .utils import InputPairWiseExample, InputHeadExample, InputAbductiveExample
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer


WIKIHOW_DATA_ROOT = "data/wikihow"

IMAGE_FIELD_NAMES = [
    "image-large",
    "image-src-1",
]


class WikiHowPairWiseProcessor(DataProcessor):
    """Processor for WikiHow Steps Dataset, pair-wise data.
    Args:
        data_dir: string. Root directory for the dataset.
        order_criteria: The criteria of determining if a pair is ordered or not.
            "tight" means only strictly consecutive pairs are considered as
            ordered, "loose" means ancestors also count.
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, order_criteria="tight",
                 paired_with_image=True,
                 min_story_length=5, max_story_length=5,
                 caption_transforms=None, **kwargs):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = WIKIHOW_DATA_ROOT
        assert order_criteria in ["tight", "loose"]
        self.order_criteria = order_criteria
        self.paired_with_image = paired_with_image

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.caption_transforms = caption_transforms
        
        if "version_text" in kwargs:
            self.version_text = kwargs["version_text"]
        else:
            self.version_text = None

        self.multiref_gt = False

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def _read_json(self, data_dir=None, split="train"):
        """Reads in json lines to create the dataset."""
        if data_dir is None:
            data_dir = self.data_dir

        if self.version_text is not None:
            json_path = os.path.join(data_dir, "wikihow-{}-".format(self.version_text)+split+".json")
            if not os.path.exists(json_path):
                raise ValueError("File: {} not found!".format(json_path))
        else:
            json_path = os.path.join(data_dir, "wikihow-"+split+".json")
        print("Using {}".format(json_path))

        line_cnt = 0
        json_file = open(json_path)
        data = []
        for line in json_file:
            d = json.loads(line.strip())
            line_cnt += 1
            data.append(d)

        story_seqs = []
        missing_images = []

        used_wikihow_ids = {}

        # TODO: consistency of human test sets.
        if self.version_text is not None and self.version_text == "human_annot_only_filtered":
            human_check_dict = {}
            human_json = "../wikihow-scrape/mturk/json_files/production_launches/wikihow_human_studies_picked.jsonl"
            human_f = open(human_json)
            for line in human_f:
                dd = json.loads(line.strip())
                check_key = dd["steps"][0]["text"].split(".")[0]
                human_check_dict[check_key] = True

        # Each element in a story seq is (text, image) tuple.
        for data_raw in tqdm(data, total=len(data)):
            
            # Form the data id.
            wikihow_url = data_raw["url"]
            title_text = data_raw["title"]
            summary_text = data_raw["summary"]
            # print(wikihow_url)

            wikihow_check_id = "###".join([wikihow_url, title_text])
            wikihow_check_id = wikihow_url

            # TODO: GUID using url and title for now.

            # print(wikihow_url, len(data_raw["sections"]))

            # Multi-reference GTs.
            if "multiref_gt" in data_raw:
                if not self.multiref_gt: self.multiref_gt = True

            for section_id in range(len(data_raw["sections"])):

                # if wikihow_check_id in used_wikihow_ids:
                #     continue
                # used_wikihow_ids[wikihow_check_id] = True

                section_curr = data_raw["sections"][section_id]
                wikihow_page_id = "###".join([wikihow_url, title_text, str(section_id)])
                wikihow_page_id = "###".join([wikihow_url, str(section_id)])
                story_seq = [wikihow_page_id]

                # TODO: consistency of human test sets.
                include_data = True
                if self.version_text is not None and self.version_text == "human_annot_only_filtered":
                    include_data = False

                for step_id in range(len(section_curr["steps"])):
                    step_curr = section_curr["steps"][step_id]
                    step_headline = step_curr["step_headline"]
                    step_text = step_curr["step_text"]["text"]
                    bullet_points = step_curr["step_text"]["bullet_points"]
                    # print(step_headline)
                    # print(step_text)
                    # print(bullet_points)
                    # combined_text = " ".join([step_text] + bullet_points)
                    # if step_headline is not None:
                    #     combined_text = " ".join([step_headline, step_text])
                    #     combined_text = step_headline
                    # else:
                    #     combined_text = step_text
                    # combined_text = step_text
                    combined_text = " ".join([step_text] + bullet_points)

                    if self.version_text is not None and self.version_text == "human_annot_only_filtered":
                        check_str = combined_text.split(".")[0]
                        if check_str in human_check_dict:
                            include_data = True

                    if self.caption_transforms is not None:
                        combined_text = self.caption_transforms.transform(combined_text)
                    
                    element = None
                    if self.paired_with_image:
                        # We take the first image for each step.
                        image_path_curr = None
                        for image_field_key in IMAGE_FIELD_NAMES:
                            if image_field_key in step_curr["step_assets"]:
                                image_path_curr = step_curr["step_assets"][
                                    image_field_key]
                                image_path_curr_new = None
                                if image_path_curr is not None and len(image_path_curr) > 0:
                                    image_path_curr = os.path.join(self.data_dir, image_path_curr)
                                    
                                    if "wikihow.com" not in image_path_curr:
                                        image_path_curr_new = image_path_curr.replace(
                                            "/images/", 
                                            "/www.wikihow.com/images/")
                                    else:
                                        image_path_curr_new = image_path_curr
                                    if not os.path.exists(image_path_curr_new):
                                        image_path_curr_new = image_path_curr.replace(
                                            "/images/", 
                                            "/wikihow.com/images/")
                                        if not os.path.exists(image_path_curr_new):
                                            missing_images.append(wikihow_page_id+"###"+str(step_id))
                                            element = None
                                        else:
                                            element = (combined_text, image_path_curr_new)
                                    else:
                                        element = (combined_text, image_path_curr_new)
                                else:
                                    missing_images.append(wikihow_page_id+"###"+str(step_id))
                                    element = None
                                if image_path_curr_new is not None and os.path.exists(image_path_curr_new):
                                    break
                        """
                        if image_path_curr is not None:
                            image_path_curr = os.path.join(self.data_dir, image_path_curr)
                            image_path_curr_new = image_path_curr.replace(
                                "/images/", 
                                "/www.wikihow.com/images/")
                            if not os.path.exists(image_path_curr_new):
                                image_path_curr_new = image_path_curr.replace(
                                    "/images/", 
                                    "/wikihow.com/images/")
                                if not os.path.exists(image_path_curr_new):
                                    missing_images.append(wikihow_page_id+"###"+str(step_id))
                                    element = None
                                else:
                                    element = (combined_text, image_path_curr_new)
                            else:
                                element = (combined_text, image_path_curr_new)
                        else:
                            missing_images.append(wikihow_page_id+"###"+str(step_id))
                            element = None
                        """
                    else:
                        element = (combined_text, None)

                    if element is not None:
                        story_seq.append(element)

                # TODO: Currently different sections are in different
                # sequences for sorting.
                if len(story_seq) < self.min_story_length + 1:
                    pass
                elif not include_data:
                    pass
                else:
                    story_seq = story_seq[:self.max_story_length+1]
                    
                    curr_story_seq_len = len(story_seq)
                    if self.multiref_gt:
                        story_seq = {
                            "story_seq": story_seq,
                            "multiref_gt": data_raw["multiref_gt"]
                        }

                    # story_seqs.append(story_seq)
                    # TODO: relax this.
                    if (curr_story_seq_len >= self.min_story_length + 1
                        and curr_story_seq_len <= self.max_story_length + 1):
                        story_seqs.append(story_seq)

        print("[WARNING] Number of missing images in {}: {}".format(
            split, len(missing_images)))
        missing_image_paths_f = ("data/wikihow/"
            "missing_images_{}.txt".format(split))
        missing_image_paths_file = open(missing_image_paths_f, "w")
        for missing_image_path in missing_images:
            missing_image_paths_file.write(missing_image_path+"\n")
        missing_image_paths_file.close()
        print("          Saves at: {}".format(missing_image_paths_f))

        print("There are {} valid story sequences in {}".format(
              len(story_seqs), json_path))

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        paired_examples = []
        for story_seq in lines:
            if self.multiref_gt:
                multiref_gt = story_seq["multiref_gt"]
                story_seq = story_seq["story_seq"]
            else:
                multiref_gt = None
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq):
                for j in range(0, len_seq):
                    if i == j:
                        continue
                    if self.order_criteria == "tight":
                        if j == i + 1:
                            label = "ordered"
                        else:
                            label = "unordered"
                    elif self.order_criteria == "loose":
                        if j > i:
                            label = "ordered"
                        else:
                            label = "unordered"
                    guid = "{}_{}{}".format(story_id, i+1, j+1)
                    text_a = story_seq[i][0]
                    text_b = story_seq[j][0]
                    img_path_a = story_seq[i][1]
                    img_path_b = story_seq[j][1]
                    distance = abs(j - i)
                    example = InputPairWiseExample(guid=guid, text_a=text_a,
                                                   text_b=text_b, label=label,
                                                   img_path_a=img_path_a,
                                                   img_path_b=img_path_b,
                                                   distance=distance,
                                                   multiref_gt=multiref_gt)
                    paired_examples.append(example)
        return paired_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="dev")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="test")
        return self._create_examples(lines)


class WikiHowAbductiveProcessor(WikiHowPairWiseProcessor):
    """Processor for WikiHow Steps Dataset, abductive data.
    Args:
        data_dir: string. Root directory for the dataset.
        pred_method: the method of the predictions, can be binary or
            contrastive 
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, pred_method="binary",
                 paired_with_image=True,
                 min_story_length=5, max_story_length=5,
                 caption_transforms=None, version_text=None):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = WIKIHOW_DATA_ROOT
        assert pred_method in ["binary", "contrastive"]
        self.pred_method = pred_method
        self.paired_with_image = paired_with_image

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.caption_transforms = caption_transforms
        self.version_text = version_text

        self.multiref_gt = False

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        abd_examples = []
        for story_seq in lines:
            if self.multiref_gt:
                multiref_gt = story_seq["multiref_gt"]
                story_seq = story_seq["story_seq"]
            else:
                multiref_gt = None
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq-2):
                all_seq_idx = set(list(range(len_seq)))
                curr_seq_idx = set(list(range(i, i+3)))
                left_seq_idx = list(all_seq_idx - curr_seq_idx)
                curr_seq_idx = list(curr_seq_idx)

                for k in left_seq_idx:
                    abd_idx = [curr_seq_idx[0]] + [k] + [curr_seq_idx[1]]
                    text_h1 = story_seq[abd_idx[0]][0]
                    text_h2 = story_seq[abd_idx[1]][0]
                    text_h3 = story_seq[abd_idx[2]][0]
                    img_path_h1 = story_seq[abd_idx[0]][1]
                    img_path_h2 = story_seq[abd_idx[1]][1]
                    img_path_h3 = story_seq[abd_idx[2]][1]
                    if self.pred_method == "binary":
                        label = "unordered"
                    guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                              abd_idx[1], abd_idx[2])
                    example = InputAbductiveExample(guid=guid, label=label,
                                                    text_h1=text_h1,
                                                    text_h2=text_h2,
                                                    text_h3=text_h3,
                                                    img_path_h1=img_path_h1,
                                                    img_path_h2=img_path_h2,
                                                    img_path_h3=img_path_h3,
                                                    multiref_gt=multiref_gt)
                    abd_examples.append(example)

                abd_idx = curr_seq_idx
                text_h1 = story_seq[abd_idx[0]]
                text_h2 = story_seq[abd_idx[1]]
                text_h3 = story_seq[abd_idx[2]]
                img_path_h1 = story_seq[abd_idx[0]][1]
                img_path_h2 = story_seq[abd_idx[1]][1]
                img_path_h3 = story_seq[abd_idx[2]][1]
                if self.pred_method == "binary":
                    label = "ordered"
                guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                          abd_idx[1], abd_idx[2])
                example = InputAbductiveExample(guid=guid, label=label,
                                                text_h1=text_h1,
                                                text_h2=text_h2,
                                                text_h3=text_h3,
                                                img_path_h1=img_path_h1,
                                                img_path_h2=img_path_h2,
                                                img_path_h3=img_path_h3,
                                                multiref_gt=multiref_gt)
                abd_examples.append(example)
        return abd_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="dev")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="test")
        return self._create_examples(lines)


class WikiHowGeneralProcessor(WikiHowPairWiseProcessor):
    """Processor for WikiHow Steps Dataset, general sorting prediction.
    Args:
        data_dir: string. Root directory for the dataset.
        paired_with_image: will only consider sequence that have perfect image
            pairings.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, max_story_length=5, pure_class=False,
                 paired_with_image=True, min_story_length=5,
                 caption_transforms=None, version_text=None):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = WIKIHOW_DATA_ROOT
        self.max_story_length = max_story_length
        self.pure_class = pure_class
        self.paired_with_image = paired_with_image

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

        self.caption_transforms = caption_transforms
        self.version_text = version_text

        self.multiref_gt = False

    def get_labels(self):
        """See base class."""
        if self.pure_class:
            n = self.max_story_length
            fact = 1
            for i in range(1, n+1):
                fact = fact * i
            labels = [0 for i in range(fact)]
            return labels

        return list(range(self.max_story_length))

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        head_examples = []
        for story_seq in lines:
            if self.multiref_gt:
                multiref_gt = story_seq["multiref_gt"]
                story_seq = story_seq["story_seq"]
            else:
                multiref_gt = None
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            guid = story_id
            text_seq = [x[0] for x in story_seq]
            img_path_seq = [x[1] for x in story_seq]
            example = InputHeadExample(guid=guid, text_seq=text_seq,
                                       img_path_seq=img_path_seq,
                                       multiref_gt=multiref_gt)
            head_examples.append(example)
        return head_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="dev")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        lines = self._read_json(data_dir=data_dir, split="test")
        return self._create_examples(lines)


def split_wikihow_json_files(data_dir, out_dir=None, train_ratio=0.9,
                             dev_ratio=0.05, test_ratio=0.05):
    assert train_ratio + dev_ratio + test_ratio == 1.0, ("Split rations do"
        " not sum up to 1!")
    if out_dir is None:
        out_dir = data_dir
    json_paths = glob.glob(os.path.join(
        data_dir, "postprocessed-wikihow-*.json"))
    line_cnt = 0
    data = []
    random.seed(42)
    for json_path in sorted(json_paths):
        fin = open(json_path)
        for line in fin:
            d = json.loads(line.strip())
            line_cnt += 1
            data.append(d)
    print("Total line counts: {}".format(line_cnt))
    print(data[0]["url"])
    random.shuffle(data)
    print(data[0]["url"])
    train_cnt = int(float(line_cnt) * train_ratio)
    dev_cnt = int(float(line_cnt) * dev_ratio)
    train_data = data[:train_cnt]
    dev_data = data[train_cnt:train_cnt+dev_cnt]
    test_data = data[train_cnt+dev_cnt:]
    print("WikiHow Train: {}".format(len(train_data)))
    print("WikiHow Dev:   {}".format(len(dev_data)))
    print("WikiHow Test:  {}".format(len(test_data)))

    with open(os.path.join(out_dir, "wikihow-train.json"), "w") as outf:
        for d in train_data:
            outf.write(json.dumps(d)+"\n")
    with open(os.path.join(out_dir, "wikihow-dev.json"), "w") as outf:
        for d in dev_data:
            outf.write(json.dumps(d)+"\n")
    with open(os.path.join(out_dir, "wikihow-test.json"), "w") as outf:
        for d in test_data:
            outf.write(json.dumps(d)+"\n")

    # Leave.
    exit(-1)


def split_wrt_genres(data_dir, filter_keywords=[], split="train",
                     save_term="physical", out_dir="."):
    """Reads in json lines to create the dataset."""
    if data_dir is None:
        data_dir = WIKIHOW_DATA_ROOT
    json_path = os.path.join(data_dir, "wikihow-"+split+".json")
    json_out_path = os.path.join(out_dir, "wikihow-{}-{}.json".format(save_term, split))

    line_cnt = 0
    json_file = open(json_path)
    data = []
    for line in json_file:
        d = json.loads(line.strip())
        line_cnt += 1
        data.append(d)

    story_seqs = []
    missing_images = []

    # Stemmer.
    ps = PorterStemmer()

    if len(filter_keywords) == 0:
        fi = open("datasets/filter_genres_keywords_wikihow.txt")
        for line in fi:
            filter_keywords.append(line.strip())
    print(filter_keywords)

    # Counters.
    social_context_cnt = 0
    physical_context_cnt = 0
    social_context_data = []
    physical_context_data = []

    # Each element in a story seq is (text, image) tuple.
    for data_raw in tqdm(data, total=len(data)):

        its_social = False
        
        # Form the data id.
        wikihow_url = data_raw["url"]
        title_text = data_raw["title"]
        summary_text = data_raw["summary"]
        title_text_lower = title_text.lower()
        title_tokens = word_tokenize(title_text_lower)
        title_tokens_stemmed = [ps.stem(w) for w in title_tokens]
        title_text_lower_stemmed = " ".join(title_tokens_stemmed)
        for keyword in filter_keywords:
            if keyword in title_tokens or keyword in title_tokens_stemmed:
                its_social = True
                break

        if its_social:
            social_context_cnt += 1
            social_context_data.append(data_raw)
        else:
            physical_context_cnt += 1
            physical_context_data.append(data_raw)

        # TODO: GUID using url and title for now.

        for section_id in range(len(data_raw["sections"])):

            section_curr = data_raw["sections"][section_id]
            wikihow_page_id = "###".join([wikihow_url, title_text, str(section_id)])
            story_seq = [wikihow_page_id]

            for step_id in range(len(section_curr["steps"])):
                step_curr = section_curr["steps"][step_id]
                step_headline = step_curr["step_headline"]
                step_text = step_curr["step_text"]["text"]
                bullet_points = step_curr["step_text"]["bullet_points"]
                combined_text = " ".join([step_text] + bullet_points)

    print_max = 10
    print_cnt = 0
    print("##### Physical: {} #####".format(len(physical_context_data)))
    for data in physical_context_data:
        if print_cnt > print_max:
            break
        print(data["title"])
        print_cnt += 1
    print()

    print_max = 10
    print_cnt = 0
    print("##### Social: {} #####".format(len(social_context_data)))
    for data in social_context_data:
        if print_cnt > print_max:
            break
        print(data["title"])
        print_cnt += 1

    if save_term == "physical":
        save_data = physical_context_data
    else:
        save_data = social_context_data

    with open(json_out_path, "w") as outf:
        for d in save_data:
            outf.write(json.dumps(d)+"\n")

    # Leave.
    exit(-1)


# Get category mappings
def read_in_wikihow_categories(cat_path=None, cat_level=1):
    if cat_path is None:
        json_f = os.path.join(WIKIHOW_DATA_ROOT, "wikihow-categories-output.json")
    else:
        json_f = cat_path
    json_in = open(json_f, "r")
    url2cat = {}
    cat2url = {}
    for line in json_in:
        cat = json.loads(line.strip())
        url = cat["url"]
        categories = cat["categories"]
        if len(categories) - 1 >= cat_level:
            cat_level_desc = categories[cat_level]["category title"]
        elif len(categories) - 1 >= 1:
            cat_level_desc = categories[-1]["category title"]
        else:
            cat_level_desc = "Root"
        url2cat[url] = cat_level_desc
        if cat_level_desc not in cat2url:
            cat2url[cat_level_desc] = []
        cat2url[cat_level_desc].append(url)
    return url2cat, cat2url


def human_annotated_to_test(data_dir, out_dir=None, train_ratio=0.9,
                            dev_ratio=0.05, test_ratio=0.05,
                            categories_to_exclude=None):
    random.seed(42)

    # Arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--human_annotated_json_files',
        type=str,
        default=None,
        nargs="+",
        help='The jsonl files used for human annotations.'
    )
    parser.add_argument(
        '--human_annotated_version',
        type=str,
        default="human_annot",
        help='The name for the output files.'
    )
    args = parser.parse_args()

    # Read in the human jsonl files.
    assert args.human_annotated_json_files is not None
 
    human_annotated_dats = {}
    duplicating_cnt = 0
    for human_annotated_json_file in args.human_annotated_json_files:
        inf = open(human_annotated_json_file, "r")
        for line in inf:
            datum = json.loads(line.strip())
            key = "###".join(datum["guid"].split("###")[0:2])
            if key in human_annotated_dats:
                print("Duplicating key: {}".format(key))
                duplicating_cnt += 1
            human_annotated_dats[key] = datum
        pass
    pass
    print("Duplicating human test count: {}".format(duplicating_cnt))

    # Read in data json files.
    json_paths = glob.glob(os.path.join(
        data_dir, "postprocessed-wikihow-*.json"))
    line_cnt = 0
    human_line_cnt = 0
    data = []
    human_data = []
    for json_path in sorted(json_paths):
        print("Processing json file: {}".format(json_path))
        fin = open(json_path)
        for line in fin:
            d = json.loads(line.strip())
            wikihow_url = d["url"]
            title_text = d["title"]
            curr_id = "###".join([wikihow_url, title_text])
            if curr_id in human_annotated_dats:
                human_line_cnt += 1
                human_data.append(d)
            else:
                line_cnt += 1
                data.append(d)

    print("Total line counts: {}".format(line_cnt))
    print("Total human line counts: {}".format(human_line_cnt))
    print(data[0]["url"])
    if categories_to_exclude is None:
        random.shuffle(data)
    print(data[0]["url"])

    url2cat, cat2url = read_in_wikihow_categories()
    url2cat_split, cat2url_split = read_in_wikihow_categories(cat_level=5)
    total = 0
    for cat in sorted(cat2url):
        total += len(cat2url[cat])
        print("Category: {}  Num of Data: {}".format(cat, len(cat2url[cat])))
    print(total)

    # Data splits.
    assert train_ratio + dev_ratio + test_ratio == 1.0, ("Split rations do"
        " not sum up to 1!")
    if out_dir is None:
        out_dir = data_dir
    train_cnt = int(float(line_cnt) * train_ratio)
    dev_cnt = int(float(line_cnt) * dev_ratio)
    test_cnt = int(float(line_cnt) * test_ratio)

    if categories_to_exclude is not None:
        all_root_cats = list(cat2url.keys())
        human_root_cats = []
        human_sub_cats = []
        for datum in human_data:
            url = datum["url"]
            if url in url2cat:
                print(url2cat[url], "::", url2cat_split[url])
                assert url in url2cat and url in url2cat_split
                assert url2cat[url] not in categories_to_exclude
                if  url2cat[url] not in human_root_cats:
                    human_root_cats.append(url2cat[url])
                if url2cat_split[url] not in human_sub_cats:
                    human_sub_cats.append(url2cat_split[url])
        not_in_human_root_cats = sorted(list(set(all_root_cats)-set(human_root_cats)))
        print("Root Categories in human data:     {}".format(human_root_cats))
        print("Root Categories not in human data: {}".format(not_in_human_root_cats))
        print("Sub-Categories in human data:      {}".format(human_sub_cats))

        exclude_urls_for_humans = []
        if True:
            for cat in human_sub_cats:
                exclude_urls_for_humans += cat2url_split[cat]
            print("Excluding {} URLs for human data.".format(len(exclude_urls_for_humans)))

        excl_data = []
        kept_data = []
        for datum in data:
            url = datum["url"]
            if len(exclude_urls_for_humans) > 0:
                if url in exclude_urls_for_humans:
                    kept_data.append(datum)
                else:
                    if url not in url2cat:
                        url2cat[url] = "None"
                    excl_data.append(datum)
            else:
                if url in url2cat:
                    cat = url2cat[url]
                    if cat in categories_to_exclude:
                        excl_data.append(datum)
                    else:
                        kept_data.append(datum)
                else:
                    url2cat[url] = "None"
                    excl_data.append(datum)

        print("Category excluded data: {}".format(len(excl_data)))
        print("Category kept data: {}".format(len(kept_data)))

        assert len(kept_data) >= dev_cnt + test_cnt

        random.shuffle(kept_data)
        dev_data = kept_data[:dev_cnt]
        test_data = kept_data[dev_cnt:dev_cnt+test_cnt]
        to_train_data = kept_data[dev_cnt+test_cnt:]
        train_data = excl_data + to_train_data
        random.shuffle(train_data)

        dev_cats = sorted(list(set([url2cat[d["url"]] for d in dev_data])))
        test_cats = sorted(list(set([url2cat[d["url"]] for d in test_data])))
        train_cats = sorted(list(set([url2cat[d["url"]] for d in train_data])))

        print("Train Categories: {}".format(train_cats))
        print("Dev   Categories: {}".format(dev_cats))
        print("Test  Categories: {}".format(test_cats))

    else:
      random.shuffle(data)
      train_data = data[:train_cnt]
      dev_data = data[train_cnt:train_cnt+dev_cnt]
      test_data = data[train_cnt+dev_cnt:]
    
    train_urls = [d["url"] for d in train_data]
    dev_urls = [d["url"] for d in dev_data]
    test_urls = [d["url"] for d in test_data]
    human_urls = [d["url"] for d in human_data]

    for sets in [dev_urls, test_urls, human_urls]:
        for url in sets:
            assert url not in train_urls, "url: {} is in train!"

    print("WikiHow Train:  {}".format(len(train_data)))
    print("WikiHow Dev:    {}".format(len(dev_data)))
    print("WikiHow Test:   {}".format(len(test_data)))
    test_data += human_data
    print("WikiHow Test-H: {}".format(len(test_data)))

    with open(os.path.join(out_dir, "wikihow-{}-train.json".format(args.human_annotated_version)), "w") as outf:
        for d in train_data:
            outf.write(json.dumps(d)+"\n")
    with open(os.path.join(out_dir, "wikihow-{}-dev.json".format(args.human_annotated_version)), "w") as outf:
        for d in dev_data:
            outf.write(json.dumps(d)+"\n")
    with open(os.path.join(out_dir, "wikihow-{}-test.json".format(args.human_annotated_version)), "w") as outf:
        for d in test_data:
            outf.write(json.dumps(d)+"\n")
    with open(os.path.join(out_dir, "wikihow-{}_only-test.json".format(args.human_annotated_version)), "w") as outf:
        for d in human_data:
            outf.write(json.dumps(d)+"\n")

    # Leave.
    # exit(-1)


def output_to_tsv(out_dir):
    from trainers.caption_utils import CaptionTransformations
    args, task = None, "wikihow"

    caption_transforms = None
    caption_transforms = CaptionTransformations(args, task,
        caption_transformation_list=["train_max_sentence_5"])
    proc = WikiHowGeneralProcessor(data_dir=WIKIHOW_DATA_ROOT,
                                   version_text="human_annot",
                                   caption_transforms=caption_transforms)
    caption_transforms = None
    caption_transforms = CaptionTransformations(args, task,
        caption_transformation_list=["eval_max_sentence_5"])
    proc_human = WikiHowGeneralProcessor(data_dir=WIKIHOW_DATA_ROOT,
                                         version_text="human_annot_only_filtered",
                                         caption_transforms=caption_transforms)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    train_examples = proc.get_train_examples()
    dev_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    test_human_examples = proc_human.get_test_examples()

    all_examples = [
        ("train", train_examples),
        ("dev", dev_examples),
        ("test", test_examples),
        ("human_test", test_human_examples)
    ]

    for split, examples in all_examples:
        out_tsv = open(os.path.join(out_dir, "{}.tsv".format(split)), "w")
        if "test" in split:
            out_json = open(os.path.join(out_dir, "{}_examples.json".format(split)), "w")
        
        for example in tqdm(examples, desc="idx"):
            text_seq = example.text_seq
            new_sents = []
            for sent in text_seq:
                tokens = word_tokenize(sent.lower())
                new_sent = " ".join(tokens)
                new_sents.append(new_sent)
            new_text_seq = " <eos> ".join(new_sents)
            out_tsv.write(new_text_seq+"\n")

            d = {
                "url": example.guid.split("###")[0],
                "id": example.guid,
            }
            if "test" in split:
                out_json.write(json.dumps(d)+"\n")
            
        out_tsv.close()
        if "test" in split:
            out_json.close()
        print("Writing files to {}".format(os.path.join(
            out_dir, "{}.tsv".format(split))))

    # Leave.
    exit(-1)


if __name__ == "__main__":
    # split_wikihow_json_files(WIKIHOW_DATA_ROOT, "data/wikihow")

    # split_wrt_genres(WIKIHOW_DATA_ROOT, split="train", save_term="physical",
    #                  out_dir="data/wikihow")

    categories_to_exclude = [
        "Relationships",
        "Philosophy and Religion",
        "Screenplays",
        "Root",
        # "Family Life",
        # "Holidays and Traditions",
        # "Work World",
        # "Health",
        # "Youth",
        # "Personal Care and Style",
    ]

    human_annotated_to_test(WIKIHOW_DATA_ROOT, "data/wikihow",
                            categories_to_exclude=categories_to_exclude)

    # output_to_tsv(out_dir="../prior_works/berson_roc/glue_data_new/wikihow")

    proc = WikiHowPairWiseProcessor(data_dir=WIKIHOW_DATA_ROOT)
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    proc = WikiHowPairWiseProcessor(data_dir=WIKIHOW_DATA_ROOT, version_text="human_annot_only")
    test_examples = proc.get_test_examples()
    print(test_examples[0])
    print()

    proc = WikiHowGeneralProcessor(data_dir=WIKIHOW_DATA_ROOT)
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    print(test_examples[0])
    rand_idx = np.random.randint(0, len(test_examples))
    selected_example = test_examples[rand_idx]
    print("\nData: {}".format(rand_idx))
    print("-"*50)
    for i in range(len(selected_example.text_seq)):
        print(selected_example.text_seq[i])
        print("-"*50)
    print()
    
    proc = WikiHowAbductiveProcessor(data_dir=WIKIHOW_DATA_ROOT)
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    print(test_examples[0])
    print()
