# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""

import functools
import json
import os
import shutil
import pandas as pd

# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib

FLAGS = flags.FLAGS

# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum(
    "fine_tuning_task_type", "classification",
    ["classification", "regression", "squad", "retrieval", "tagging"],
    "The name of the BERT fine tuning task for which data "
    "will be generated.")

# BERT classification specific flags.
flags.DEFINE_string(
    "input_data_dir", None,
    "The input data dir. Should contain the .tsv files (or other data files) "
    "for the task.")

flags.DEFINE_enum(
    "classification_task_name", "MNLI", [
        "RACE", "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
        "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X"
    ], "The name of the task to train BERT classifier. The "
    "difference between XTREME-XNLI and XNLI is: 1. the format "
    "of input tsv files; 2. the dev set for XTREME is english "
    "only and for XNLI is all languages combined. Same for "
    "PAWS-X.")

# MNLI task-specific flag.
flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
                  "The type of MNLI dataset.")

# XNLI task-specific flag.
flags.DEFINE_string(
    "xnli_language", "en",
    "Language of training data for XNLI task. If the value is 'all', the data "
    "of all languages will be used for training.")

# PAWS-X task-specific flag.
flags.DEFINE_string(
    "pawsx_language", "en",
    "Language of training data for PAWS-X task. If the value is 'all', the data "
    "of all languages will be used for training.")

# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags.DEFINE_string(
    "translated_input_data_dir", None,
    "The translated input data dir. Should contain the .tsv files (or other "
    "data files) for the task.")

# Retrieval task-specific flags.
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
                  "The name of sentence retrieval task for scoring")

# Tagging task-specific flags.
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
                  "The name of BERT tagging (token classification) task.")

flags.DEFINE_bool("tagging_only_use_en_train", True,
                  "Whether only use english training data in tagging.")

# BERT Squad task-specific flags.
flags.DEFINE_string(
    "squad_data_file", None,
    "The input data file in for generating training data for BERT squad task.")

flags.DEFINE_string(
    "translated_squad_data_folder", None,
    "The translated data folder for generating training data for BERT squad "
    "task.")

flags.DEFINE_integer(
    "doc_stride", 128,
    "When splitting up a long document into chunks, how much stride to "
    "take between chunks.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool(
    "version_2_with_negative", False,
    "If true, the SQuAD examples contain some that do not have an answer.")

flags.DEFINE_bool(
    "xlnet_format", False,
    "If true, then data will be preprocessed in a paragraph, query, class order"
    " instead of the BERT-style class, paragraph, query order.")

# XTREME specific flags.
flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")

# Shared flags across BERT fine-tuning tasks.
flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "train_data_output_path", None,
    "The path in which generated training input data will be written as tf"
    " records.")

flags.DEFINE_string(
    "eval_data_output_path", None,
    "The path in which generated evaluation input data will be written as tf"
    " records.")

flags.DEFINE_string(
    "test_data_output_path", None,
    "The path in which generated test input data will be written as tf"
    " records. If None, do not generate test data. Must be a pattern template"
    " as test_{}.tfrecords if processor has language specific test data.")

flags.DEFINE_string("meta_data_file_path", None,
                    "The path in which input meta data will be written.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")

flags.DEFINE_string("sp_model_file", "",
                    "The path to the model used by sentence piece tokenizer.")

flags.DEFINE_enum(
    "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
    "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
    "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
    "while ALBERT uses SentencePiece tokenizer.")

flags.DEFINE_string(
    "tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
    "generic classfication data import (for more details "
    "see the TfdsProcessor class documentation).")


def generate_classifier_dataset():
  """Generates classifier dataset and returns input meta data."""
  assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
          FLAGS.tfds_params)

  if FLAGS.tokenization == "WordPiece":
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
    processor_text_fn = tokenization.convert_to_unicode
  else:
    assert FLAGS.tokenization == "SentencePiece"
    tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
    processor_text_fn = functools.partial(
        tokenization.preprocess_text, lower=FLAGS.do_lower_case)

  # URL = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
  # dataset = tf.keras.utils.get_file(
  #     fname="aclImdb_v1.tar.gz",
  #     origin=URL,
  #     untar=True,
  #     cache_dir='.',
  #     cache_subdir='',
  # )
  #
  # print("*** dataset loading is done.")
  # main_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
  #
  # train_dir = os.path.join(main_dir, 'train')
  # remove_dir = os.path.join(train_dir, 'unsup')
  # shutil.rmtree(remove_dir)
    #print(os.listdir(train_dir))

  #train = tf.keras.preprocessing.text_dataset_from_directory(
  #    '/home/ubuntu/aclImdb/train', batch_size=30000, validation_split=0.2, subset='training', seed=123)
  #test = tf.keras.preprocessing.text_dataset_from_directory(
  #    '/home/ubuntu/aclImdb/train', batch_size=30000, validation_split=0.2, subset='validation', seed=123)

  train = tfds.load('race/middle', split='train', shuffle_files=True)
  test = tfds.load('race/middle', split='test', shuffle_files=True)

  def process_df(train):
      res_dict = {
          "article": [],
          "questions_option": [],
          "label": [],
          "weight": [],
      }
      for example in train:
          answers = example["answers"].numpy().tolist()
          questions = example["questions"].numpy().tolist()
          options = example["options"].numpy().tolist()
          article = example["article"].numpy().decode("utf-8")
          for i in range(len(answers)):
              label = ord(answers[i]) - ord("A")
              question = questions[i].decode("utf-8")
              for j in range(4):
                  option = options[i][j].decode("utf-8")
                  if "_" in question:
                      qa_cat = question.replace("  _  ", option)
                  else:
                      qa_cat = " ".join([question, option])
                  label_option = int(label == j)
                  res_dict["article"].append(article)
                  res_dict["questions_option"].append(qa_cat)
                  res_dict['label'].append(label_option)
                  weight_option = 0.75 if label_option == 1 else 0.25
                  res_dict['weight'].append(weight_option)
      return pd.DataFrame.from_dict(res_dict)

  train_df = process_df(train)
  test_df = process_df(test)

  def convert_data_to_examples(train, test, DATA_COLUMN_1, DATA_COLUMN_2, LABEL_COLUMN, WEIGHT):
      train_InputExamples = train.apply(
          lambda x: classifier_data_lib.InputExample(
              guid=None,  # Globally unique ID for bookkeeping, unused in this case
              text_a=x[DATA_COLUMN_1],
              text_b=x[DATA_COLUMN_2],
              label=x[LABEL_COLUMN],
              weight=x[WEIGHT],
          ),
          axis=1
      )

      validation_InputExamples = test.apply(
          lambda x: classifier_data_lib.InputExample(
              guid=None,  # Globally unique ID for bookkeeping, unused in this case
              text_a=x[DATA_COLUMN_1],
              text_b=x[DATA_COLUMN_2],
              label=x[LABEL_COLUMN],
              weight=x[WEIGHT],
          ),
          axis=1
      )
      return train_InputExamples, validation_InputExamples


  train_input_data_examples, eval_input_data_examples = convert_data_to_examples(
        train_df,
        test_df,
        DATA_COLUMN_1='article',
        DATA_COLUMN_2='questions_option',
        LABEL_COLUMN='label',
        WEIGHT='weight',
    )

  return classifier_data_lib.generate_tf_record_from_data_file_RACE_ONLY_2(
        train_input_data_examples,
        eval_input_data_examples,
        tokenizer,
        train_data_output_path=FLAGS.train_data_output_path,
        eval_data_output_path=FLAGS.eval_data_output_path,
        test_data_output_path=FLAGS.test_data_output_path,
        max_seq_length=FLAGS.max_seq_length,
    )

def main(_):
  if FLAGS.tokenization == "WordPiece":
    if not FLAGS.vocab_file:
      raise ValueError(
          "FLAG vocab_file for word-piece tokenizer is not specified.")
  else:
    assert FLAGS.tokenization == "SentencePiece"
    if not FLAGS.sp_model_file:
      raise ValueError(
          "FLAG sp_model_file for sentence-piece tokenizer is not specified.")

  if FLAGS.fine_tuning_task_type != "retrieval":
    flags.mark_flag_as_required("train_data_output_path")

  if FLAGS.fine_tuning_task_type == "classification":
    input_meta_data = generate_classifier_dataset()
  else:
    raise("this is soly for IMDB dataset processing.")

  tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
  with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
    writer.write(json.dumps(input_meta_data, indent=4) + "\n")


if __name__ == "__main__":
  flags.mark_flag_as_required("meta_data_file_path")
  app.run(main)
