"""Implements different tasks and defines the processors to convert each dataset
to a sequence to sequence format."""
from collections import OrderedDict
import os
import abc
import datasets
import functools
import logging
import numpy as np
import torch
from ..metrics import metrics
from typing import Callable, Dict, Mapping, List
from transformers import T5Tokenizer

logger = logging.getLogger(__name__)


def compute_task_max_decoding_length(word_list):
    """Computes the max decoding length for the given list of words
    Args:
      word_list: A list of stringss.
    Returns:
      maximum length after tokenization of the inputs.
    """
    tokenizer = T5Tokenizer.from_pretrained('google/t5-large')
    max_len = 0
    for word in word_list:
        ids = tokenizer.encode(word)
        max_len = max(max_len, len(ids))
    return max_len


class AbstractTaskDataset(abc.ABC):
    """Defines the abstract class for all the tasks.
    name: the name of the task.
    task_specific_config: specifies the special configuration needs
        to be passed to encoder when decoding each task. Since different
        tasks, have different output space, the maximum decoding length
        varies based on the tasks.
    preprocessor: a processor to convert the given dataset to the sequence
        to sequence format.
    metrics: specifies the metrics to evaluate the task based on them.
    split_to_data_split: since not all the time, different splits of the
        datasets are available, we define a mapping from the wanted split
        to the existing dataset splits.
    small_datasets_without_all_splits: List of strings, defines the name
        of all low-resource tasks in which not all train/test/validation
        splits are available.
    large_data_without_all_splits: List of strings, defines the name of
        all high-resource tasks in which not all train/test/validation
        splits are available.
    """
    name = NotImplemented
    task_specific_config: Dict = NotImplemented
    preprocessor: Callable = NotImplemented
    metrics: List[Callable] = NotImplemented
    split_to_data_split: Mapping[str, str] = \
        {"train": "train", "validation": "validation", "test": "test"}

    small_datasets_without_all_splits = ["cola", "wnli", "rte", "trec", "cb", "sick",
                                         "mrpc", "stsb", "imdb", "commonsense_qa", "boolq"]
    large_data_without_all_splits = ["yelp_polarity", "qqp", "qnli",
                                     "social_i_qa", "cosmos_qa", "winogrande", "hellaswag", "sst2"]

    def __init__(self, seed=42):
        self.seed = seed

    def get_dataset(self, data, add_prefix=True):
        dataset = datasets.Dataset.from_list(data)
        return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
                           remove_columns=dataset.column_names)

    def seq2seq_format(self, src_strs: List[str], tgt_strs: List[str],
                       add_prefix: bool = False, prefix: str = None):
        src_prefix = self.name if prefix is None else prefix
        src_strs = [src_prefix] + src_strs if add_prefix else src_strs
        return {"src_texts": '\n\n'.join(src_strs),
                "tgt_texts": ' '.join(tgt_strs),
                "task": self.name}

class SuperGLUEBoolQTaskDataset(AbstractTaskDataset):
    name = "boolq"
    label_list = ['no', 'yes']
    task_specific_config = {'max_length': compute_task_max_decoding_length(label_list)}
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metrics = [metrics.accuracy]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'boolq', split=split)

    def preprocessor(self, example, add_prefix=True):
        # src_texts = ["question:", example["question"], "passage:", example["passage"]]    # original
        # src_texts = [example["passage"], example["question"]]
        src_texts = ["Text: {}".format(example["passage"]), "Question: {}".format(example["question"])]  # boolq-best
        if str(example["label"]) == '1':
            tgt_texts = ['yes']
        else:
            tgt_texts = ['no']
        # convert
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class SuperGLUECBTaskDataset(AbstractTaskDataset):
    name = "cb"
    label_list = ['0', '1', '2']
    task_specific_config = {'max_length': compute_task_max_decoding_length(label_list)}
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metrics = [metrics.accuracy]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'cb', split=split)

    def preprocessor(self, example, add_prefix=True):
        # src_texts = ["premise:", example["premise"], "hypothesis:", example["hypothesis"]] # original
        src_texts = ["Here is a premise:", example["premise"], "Here is a hypothesis:", example["hypothesis"], "Is it possible to conclude that if the premise is true, then so is the hypothesis?"] # MNLI
        if str(example["label"]) == '0':
            tgt_texts = ['yes']
        elif str(example["label"]) == '1':
            tgt_texts = ['no']
        else:
            tgt_texts = ['it is not possible to tell']
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class ScitailTaskDataset(AbstractTaskDataset):
    name = "scitail"
    label_list = ["0", "1"]
    task_specific_config = {'max_length': compute_task_max_decoding_length(label_list)}
    metrics = [metrics.accuracy]
    # label_map = {"entailment": 0, "neutral": 1}
    label_map = {"entailment": 'yes', "neutral": 'no'}

    def map_label(self, label):
        return self.label_map[label]

    def load_dataset(self, split):
        return datasets.load_dataset("scitail", "snli_format",
                                     split=split)

    def preprocessor(self, example, add_prefix=True):
        # src_texts = ["sentence1:", example['sentence1'], "sentence2:", example["sentence2"]]  # original
        # src_texts = ["Here is a premise:", example["sentence1"], "Here is a hypothesis:", example["sentence2"], "Is it possible to conclude that if the premise is true, then so is the hypothesis?"] # MNLI
        # src_texts = ["If \"{}\", does it logically follow that \"{}\"".format(example["sentence1"], example["sentence2"])]  # snli-best
        src_texts = ["Premise:", example["sentence1"], "Hypothesis:", example["sentence2"], ".Given the premise, can we conclude the hypothesis?"] # autodemo-best
        # To increase the transfer performance, we modified the targets to be similar to other datasets.
        tgt_texts = [str(self.map_label(example['gold_label']))]
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class WicTaskDataset(AbstractTaskDataset):
    name = "wic"
    label_list = ['yes', 'no']
    task_specific_config = {'max_length': compute_task_max_decoding_length(label_list)}
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metrics = [metrics.accuracy]

    def load_dataset(self, split):
        return datasets.load_dataset('super_glue', 'wic', split=split)

    def preprocessor(self, example, add_prefix=True):
        # src_texts = ["sentence1:", example["sentence1"], "sentence2:", example["sentence2"], 'Does {} mean the same thing in these two sentences?'.format(example['word'])]
        src_texts = ['Does word {} have the same meaning in the following two sentences?'.format(example['word']), example["sentence1"], example["sentence2"]]
        if str(example["label"]) == '1':
            tgt_texts = ['yes']
        else:
            tgt_texts = ['no']
        # convert
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)

class WscTaskDataset(AbstractTaskDataset):
    name = "wsc"
    label_list = ['yes', 'no']
    task_specific_config = {'max_length': compute_task_max_decoding_length(label_list)}
    split_to_data_split = {"train": "train",
                           "validation": "validation",
                           "test": "validation"}
    metrics = [metrics.accuracy]

    def preprocessor(self, example, add_prefix=True):
        # src_texts = [example['text'], 'Do {} and {} have the same meaning?'.format(example['span1_text'], example['span2_text'])]

        src_texts = ['Are {} and {} the same in this sentence?'.format(example['span1_text'], example['span2_text']), example['text'],]
        if str(example["label"]) == '1':
            tgt_texts = ['yes']
        else:
            tgt_texts = ['no']
        # convert
        return self.seq2seq_format(src_texts, tgt_texts, add_prefix)


TASK_MAPPING = OrderedDict([
    ('boolq', SuperGLUEBoolQTaskDataset),
    ('cb', SuperGLUECBTaskDataset),
    ('scitail', ScitailTaskDataset),
    ('wic', WicTaskDataset),
    ('wsc', WscTaskDataset)]
)

class AutoTask:
    @classmethod
    def get(self, task_name, seed=42):
        if task_name in TASK_MAPPING:
            return TASK_MAPPING[task_name](seed)
        raise ValueError(
            "Unrecognized task {} for AutoTask Model: {}.\n"
            "Task name should be one of {}.".format(
                ", ".join(c for c in TASK_MAPPING.keys())
            )
        )
