import os
import json
import csv
import numpy as np
from .utils import DataProcessor
from .utils import InputPairWiseExample, InputHeadExample, InputAbductiveExample


class ROCPairWiseProcessor(DataProcessor):
    """Processor for ROCStories Dataset, pair-wise data.
    Args:
        data_dir: string. Root directory for the dataset.
        order_criteria: The criteria of determining if a pair is ordered or not.
            "tight" means only strictly consecutive pairs are considered as
            ordered, "loose" means ancestors also count.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, order_criteria="tight",
                 min_story_length=5, max_story_length=5, caption_transforms=None,
                 **kwargs):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/roc"
        assert order_criteria in ["tight", "loose"]
        self.order_criteria = order_criteria

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def _read_csv(self, csv_path, split="train"):
        """Reads in csv lines to create the dataset."""
        story_seqs = []
        csv_file = csv.DictReader(open(csv_path))
        for row in csv_file:
            story_seq = []
            if split == "train":
                story_seq.append(row["storyid"])
                for i in range(1, 6):
                    seq = row["sentence{}".format(i)]
                    story_seq.append(seq)
            elif split == "val" or split == "test":
                story_seq.append(row["InputStoryid"])
                for i in range(1, 5):
                    seq = row["InputSentence{}".format(i)]
                    story_seq.append(seq)
                if split == "val":
                    story_seq.append(row["RandomFifthSentenceQuiz{}"
                                     .format(row["AnswerRightEnding"])])
            else:
                raise ValueError("No such split: {}".format(split))
            assert len(story_seq) >= 4 + 1  # 5 or 4 sentences, 1 for id.
            story_seqs.append(story_seq)

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        paired_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq):
                for j in range(0, len_seq):
                    if i == j:
                        continue
                    if self.order_criteria == "tight":
                        if j == i + 1:
                            label = "ordered"
                        else:
                            label = "unordered"
                    elif self.order_criteria == "loose":
                        if j > i:
                            label = "ordered"
                        else:
                            label = "unordered"
                    guid = "{}_{}{}".format(story_id, i+1, j+1)
                    text_a = story_seq[i]
                    text_b = story_seq[j]
                    distance = abs(j - i)
                    example = InputPairWiseExample(guid=guid, text_a=text_a,
                                                   text_b=text_b, label=label,
                                                   distance=distance)
                    paired_examples.append(example)
        return paired_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "train.csv"),
                                   split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_val.csv"),
                                   split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_test.csv"),
                                   split="test")
        return self._create_examples(lines)


class ROCAbductiveProcessor(DataProcessor):
    """Processor for ROCStories Dataset, abductive data.
    Args:
        data_dir: string. Root directory for the dataset.
        pred_method: the method of the predictions, can be binary or
            contrastive 
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, pred_method="binary",
                 max_story_length=5, min_story_length=5, caption_transforms=None,
                 **kwargs):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/roc"
        assert pred_method in ["binary", "contrastive"]
        self.pred_method = pred_method

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

    def get_labels(self):
        """See base class."""
        return ["unordered", "ordered"]  # 0: unordered, 1: ordered.

    def _read_csv(self, csv_path, split="train"):
        """Reads in csv lines to create the dataset."""
        story_seqs = []
        csv_file = csv.DictReader(open(csv_path))
        for row in csv_file:
            story_seq = []
            if split == "train":
                story_seq.append(row["storyid"])
                for i in range(1, 6):
                    seq = row["sentence{}".format(i)]
                    story_seq.append(seq)
            elif split == "val" or split == "test":
                story_seq.append(row["InputStoryid"])
                for i in range(1, 5):
                    seq = row["InputSentence{}".format(i)]
                    story_seq.append(seq)
                if split == "val":
                    story_seq.append(row["RandomFifthSentenceQuiz{}"
                                     .format(row["AnswerRightEnding"])])
            else:
                raise ValueError("No such split: {}".format(split))
            assert len(story_seq) >= 4 + 1  # 5 or 4 sentences, 1 for id.
            story_seqs.append(story_seq)

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        abd_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            for i in range(0, len_seq-2):
                all_seq_idx = set(list(range(len_seq)))
                curr_seq_idx = set(list(range(i, i+3)))
                left_seq_idx = list(all_seq_idx - curr_seq_idx)
                curr_seq_idx = list(curr_seq_idx)

                for k in left_seq_idx:
                    abd_idx = [curr_seq_idx[0]] + [k] + [curr_seq_idx[1]]
                    text_h1 = story_seq[abd_idx[0]]
                    text_h2 = story_seq[abd_idx[1]]
                    text_h3 = story_seq[abd_idx[2]]
                    if self.pred_method == "binary":
                        label = "unordered"
                    guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                              abd_idx[1], abd_idx[2])
                    example = InputAbductiveExample(guid=guid, label=label,
                                                    text_h1=text_h1,
                                                    text_h2=text_h2,
                                                    text_h3=text_h3)
                    abd_examples.append(example)

                abd_idx = curr_seq_idx
                text_h1 = story_seq[abd_idx[0]]
                text_h2 = story_seq[abd_idx[1]]
                text_h3 = story_seq[abd_idx[2]]
                if self.pred_method == "binary":
                    label = "ordered"
                guid = "{}_{}{}{}".format(story_id, abd_idx[0],
                                          abd_idx[1], abd_idx[2])
                example = InputAbductiveExample(guid=guid, label=label,
                                                text_h1=text_h1,
                                                text_h2=text_h2,
                                                text_h3=text_h3)
                abd_examples.append(example)
        return abd_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "train.csv"),
                                   split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_val.csv"),
                                   split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_test.csv"),
                                   split="test")
        return self._create_examples(lines)


class ROCGeneralProcessor(DataProcessor):
    """Processor for ROCStories Dataset, general sorting prediction.
    Args:
        data_dir: string. Root directory for the dataset.
        min_story_length: minimum length of sequence for each.
        max_story_length: maximum length of sequence for each.
    """

    def __init__(self, data_dir=None, max_story_length=5, pure_class=False,
                 min_story_length=5, caption_transforms=None,
                 **kwargs):
        """Init."""
        self.data_dir = data_dir
        if self.data_dir is None:
            self.data_dir = "data/roc"
        self.max_story_length = max_story_length
        self.pure_class = pure_class

        min_story_length = max(1, min_story_length)
        max_story_length = max(1, max_story_length)
        min_story_length = min(min_story_length, max_story_length)
        self.min_story_length = min_story_length
        self.max_story_length = max_story_length

    def get_labels(self):
        """See base class."""
        if self.pure_class:
            n = self.max_story_length
            fact = 1
            for i in range(1, n+1):
                fact = fact * i
            labels = [0 for i in range(fact)]
            return labels

        return list(range(self.max_story_length))

    def _read_csv(self, csv_path, split="train"):
        """Reads in csv lines to create the dataset."""
        story_seqs = []
        csv_file = csv.DictReader(open(csv_path))
        for row in csv_file:
            story_seq = []
            if split == "train":
                story_seq.append(row["storyid"])
                for i in range(1, 6):
                    seq = row["sentence{}".format(i)]
                    story_seq.append(seq)
            elif split == "val" or split == "test":
                story_seq.append(row["InputStoryid"])
                for i in range(1, 5):
                    seq = row["InputSentence{}".format(i)]
                    story_seq.append(seq)
                if split == "val":
                    story_seq.append(row["RandomFifthSentenceQuiz{}"
                                     .format(row["AnswerRightEnding"])])
            else:
                raise ValueError("No such split: {}".format(split))
            assert len(story_seq) >= 4 + 1  # 5 or 4 sentences, 1 for id.
            story_seqs.append(story_seq)

        return story_seqs

    def _create_examples(self, lines):
        """Creates examples for the training, dev and test sets."""
        head_examples = []
        for story_seq in lines:
            story_id = story_seq.pop(0)
            len_seq = len(story_seq)
            guid = story_id
            example = InputHeadExample(guid=guid, text_seq=story_seq)
            head_examples.append(example)
        return head_examples

    def get_train_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "train.csv"),
                                   split="train")
        return self._create_examples(lines)

    def get_dev_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_val.csv"),
                                   split="val")
        return self._create_examples(lines)

    def get_test_examples(self, data_dir=None):
        """See base class."""
        if data_dir is None:
            data_dir = self.data_dir
        lines = self._read_csv(os.path.join(data_dir, "cloze_test_test.csv"),
                                   split="test")
        return self._create_examples(lines)


if __name__ == "__main__":
    proc = ROCPairWiseProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()

    proc = ROCGeneralProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
    
    proc = ROCAbductiveProcessor()
    train_examples = proc.get_train_examples()
    val_examples = proc.get_dev_examples()
    test_examples = proc.get_test_examples()
