# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This is the Data Loading Pipeline for Sentence Classifier Task adapted from:
    `https://github.com/google-research/bert/blob/master/run_classifier.py`
"""

import csv
import logging
import os

import texar.torch as tx


class InputExample:
    r"""A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        r"""Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence.
                For single sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second
                sequence. Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
                specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures:
    r"""A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor:
    r"""Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        r"""Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        r"""Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        r"""Gets a collection of `InputExample`s for prediction."""
        raise NotImplementedError()

    def get_labels(self):
        r"""See base class."""
        return ["0", "1"]

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        r"""Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
        return lines


class SSTProcessor(DataProcessor):
    r"""Processor for the MRPC data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        r"""See base class."""
        return ["0", "1"]

    @staticmethod
    def _create_examples(lines, set_type):
        r"""Creates examples for the training and dev sets."""
        examples = []
        if set_type in ('train', 'dev'):
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = f"{set_type}-{i}"
                text_a = tx.utils.compat_as_text(line[0])
                # Single sentence classification, text_b doesn't exist
                text_b = None
                label = tx.utils.compat_as_text(line[1])
                examples.append(InputExample(guid=guid, text_a=text_a,
                                             text_b=text_b, label=label))
        if set_type == 'test':
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = f"{set_type}-{i}"
                text_a = tx.utils.compat_as_text(line[1])
                # Single sentence classification, text_b doesn't exist
                text_b = None
                label = '0'  # arbitrary set as 0
                examples.append(InputExample(guid=guid, text_a=text_a,
                                             text_b=text_b, label=label))
        return examples


class XnliProcessor(DataProcessor):
    r"""Processor for the XNLI data set."""

    def __init__(self):
        self.language = "zh"

    def get_train_examples(self, data_dir):
        r"""See base class."""
        lines = self._read_tsv(
            os.path.join(data_dir, "multinli",
                         f"multinli.train.{self.language}.tsv"))
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = f"train-{i}"
            text_a = tx.utils.compat_as_text(line[0])
            text_b = tx.utils.compat_as_text(line[1])
            label = tx.utils.compat_as_text(line[2])
            if label == tx.utils.compat_as_text("contradictory"):
                label = tx.utils.compat_as_text("contradiction")
            examples.append(InputExample(guid=guid, text_a=text_a,
                                         text_b=text_b, label=label))
        return examples

    def get_dev_examples(self, data_dir):
        r"""See base class."""
        lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = f"dev-{i}"
            language = tx.utils.compat_as_text(line[0])
            if language != tx.utils.compat_as_text(self.language):
                continue
            text_a = tx.utils.compat_as_text(line[6])
            text_b = tx.utils.compat_as_text(line[7])
            label = tx.utils.compat_as_text(line[1])
            examples.append(InputExample(guid=guid, text_a=text_a,
                                         text_b=text_b, label=label))
        return examples

    def get_labels(self):
        r"""See base class."""
        return ["contradiction", "entailment", "neutral"]


class MnliProcessor(DataProcessor):
    r"""Processor for the MultiNLI data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
            "dev_matched")

    def get_test_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
            "test")

    def get_labels(self):
        r"""See base class."""
        return ["contradiction", "entailment", "neutral"]

    @staticmethod
    def _create_examples(lines, set_type):
        r"""Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = f"{set_type}-{tx.utils.compat_as_text(line[0])}"
            text_a = tx.utils.compat_as_text(line[8])
            text_b = tx.utils.compat_as_text(line[9])
            if set_type == "test":
                label = "contradiction"
            else:
                label = tx.utils.compat_as_text(line[-1])
            examples.append(InputExample(guid=guid, text_a=text_a,
                                         text_b=text_b, label=label))
        return examples


class MrpcProcessor(DataProcessor):
    r"""Processor for the MRPC data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")),
            "train")

    def get_dev_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")),
            "dev")

    def get_test_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")),
            "test")

    def get_labels(self):
        r"""See base class."""
        return ["0", "1"]

    @staticmethod
    def _create_examples(lines, set_type):
        r"""Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = f"{set_type}-{i}"
            text_a = tx.utils.compat_as_text(line[3])
            text_b = tx.utils.compat_as_text(line[4])
            if set_type == "test":
                label = "0"
            else:
                label = tx.utils.compat_as_text(line[0])
            examples.append(InputExample(guid=guid, text_a=text_a,
                                         text_b=text_b, label=label))
        return examples


class ColaProcessor(DataProcessor):
    r"""Processor for the CoLA data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")),
            "train")

    def get_dev_examples(self, data_dir):
        r"""See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")),
            "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")),
            "test")

    def get_labels(self):
        r"""See base class."""
        return ["0", "1"]

    @staticmethod
    def _create_examples(lines, set_type):
        r"""Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            # Only the test set has a header
            if set_type == "test" and i == 0:
                continue
            guid = f"{set_type}-{i}"
            if set_type == "test":
                text_a = tx.utils.compat_as_text(line[1])
                label = "0"
            else:
                text_a = tx.utils.compat_as_text(line[3])
                label = tx.utils.compat_as_text(line[1])
            examples.append(InputExample(guid=guid, text_a=text_a,
                                         text_b=None, label=label))
        return examples


def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
    r"""Converts a single `InputExample` into a single `InputFeatures`."""
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    input_ids, segment_ids, input_mask = \
        tokenizer.encode_text(text_a=example.text_a,
                              text_b=example.text_b,
                              max_seq_length=max_seq_length)

    label_id = label_map[example.label]

    # here we disable the verbose printing of the data
    if ex_index < 0:
        logging.info("*** Example ***")
        logging.info("guid: %s", example.guid)
        logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
        logging.info("input_ids length: %d", len(input_ids))
        logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
        logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
        logging.info("label: %s (id = %d)", example.label, label_id)

    feature = InputFeatures(input_ids=input_ids,
                            input_mask=input_mask,
                            segment_ids=segment_ids,
                            label_id=label_id)
    return feature


def convert_examples_to_features_and_output_to_files(
        examples, label_list, max_seq_length, tokenizer, output_file,
        feature_types):
    r"""Convert a set of `InputExample`s to a pickled file."""

    with tx.data.RecordData.writer(output_file, feature_types) as writer:
        for (ex_index, example) in enumerate(examples):
            feature = convert_single_example(ex_index, example, label_list,
                                             max_seq_length, tokenizer)

            features = {
                "input_ids": feature.input_ids,
                "input_mask": feature.input_mask,
                "segment_ids": feature.segment_ids,
                "label_ids": feature.label_id
            }
            writer.write(features)


def prepare_record_data(processor, tokenizer,
                        data_dir, max_seq_length, output_dir,
                        feature_types):
    r"""Prepare record data.
    Args:
        processor: Data Preprocessor, which must have get_labels,
            get_train/dev/test/examples methods defined.
        tokenizer: The Sentence Tokenizer. Generally should be
            SentencePiece Model.
        data_dir: The input data directory.
        max_seq_length: Max sequence length.
        output_dir: The directory to save the pickled file in.
        feature_types: The original type of the feature.
    """
    label_list = processor.get_labels()

    train_examples = processor.get_train_examples(data_dir)
    train_file = os.path.join(output_dir, "train.pkl")
    convert_examples_to_features_and_output_to_files(
        train_examples, label_list, max_seq_length,
        tokenizer, train_file, feature_types)

    eval_examples = processor.get_dev_examples(data_dir)
    eval_file = os.path.join(output_dir, "eval.pkl")
    convert_examples_to_features_and_output_to_files(
        eval_examples, label_list,
        max_seq_length, tokenizer, eval_file, feature_types)

    test_examples = processor.get_test_examples(data_dir)
    test_file = os.path.join(output_dir, "predict.pkl")
    convert_examples_to_features_and_output_to_files(
        test_examples, label_list,
        max_seq_length, tokenizer, test_file, feature_types)
