import os
import copy
import json
import logging
import pandas as pd

import torch
from torch.utils.data import TensorDataset

logger = logging.getLogger(__name__)


class InputExample(object):
    """
    A single training/test example for simple sequence classification.
    """

    def __init__(self, guid, text_a, text_b, label):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, token_type_ids, label):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label = label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


def seq_cls_convert_examples_to_features(args, examples, tokenizer, max_length, data):
    processor = seq_cls_processors[data](args)
    label_list = processor.get_labels()
    logger.info("Using label list {} for task {}".format(label_list, data))
    output_mode = seq_cls_output_modes[data]
    logger.info("Using output mode {} for task {}".format(output_mode, data))

    label_map = {label: i for i, label in enumerate(label_list)}

    def label_from_example(example):
        if output_mode == "classification":
            return label_map[example.label]
        elif output_mode == "regression":
            return float(example.label)
        raise KeyError(output_mode)

    labels = [label_from_example(example) for example in examples]
    batch_encoding = tokenizer.batch_encode_plus(
        [(example.text_a, example.text_b) for example in examples],
        max_length=max_length,
        padding="max_length",
        add_special_tokens=True,
        truncation=True,
    )

    features = []
    for i in range(len(examples)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}
        if "token_type_ids" not in inputs:
            inputs["token_type_ids"] = [0] * len(inputs["input_ids"])  # For xlm-roberta

        feature = InputFeatures(**inputs, label=labels[i])
        features.append(feature)

    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info("guid: {}".format(example.guid))
        logger.info(
            "input_ids: {}".format(" ".join([str(x) for x in features[i].input_ids]))
        )
        logger.info(
            "attention_mask: {}".format(
                " ".join([str(x) for x in features[i].attention_mask])
            )
        )
        logger.info(
            "token_type_ids: {}".format(
                " ".join([str(x) for x in features[i].token_type_ids])
            )
        )
        logger.info("label: {}".format(features[i].label))

    return features


class KoldProcessor(object):
    """Processor for the KOLD Dataset"""

    def __init__(self, args):
        self.args = args

    def get_num_labels(self):
        return len(self.get_labels())

    def get_labels(self):
        if self.args.task == "off":
            labels = ["NOT", "OFF"]
        elif self.args.task == "tgt":
            labels = ["UNT", "IND", "OTH", "GRP"]
        elif self.args.task == "grp":
            labels = [
                "gender-female",
                "gender-male",
                "gender-LGBTQ+",
                "gender-others",
                "others-age",
                "others-disability",
                "others-disease",
                "others-others",
                "others-physical_appearance",
                "others-socioeconomic_status",
                "others-feminist",
                "politics-conservative",
                "politics-others",
                "politics-progressive",
                "race-asian",
                "race-black",
                "race-chinese",
                "race-indian",
                "race-korean_chinese",
                "race-others",
                "race-southeast_asian",
                "race-white",
                "religion-buddhism",
                "religion-catholic",
                "religion-christian",
                "religion-islam",
                "religion-others",
            ]
        else:
            raise Exception("Unsupported task!")
        return labels

    @classmethod
    def _read_file(cls, input_file):
        """Reads a jsonl file."""
        if input_file.endswith("jsonl"):
            lines = []
            with open(input_file) as f_r:
                json_list = list(f_r)
                for json_str in json_list:
                    lines.append(json.loads(json_str))
        elif input_file.endswith("json"):
            lines = json.load(open(input_file, encoding="utf-8"))
        else:
            print(f"Error : {input_file} should be in 'jsonl, json' format")
            exit(0)
        return lines

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line["title"]
            text_b = line["comment"]

            if self.args.task == "off":
                if line["OFF"]:
                    label = "OFF"
                else:
                    label = "NOT"
            elif self.args.task == "tgt":
                if line["OFF"]:
                    label = line["TGT"]
                else:
                    continue
            elif self.args.task == "grp":
                if (
                    line["OFF"]
                    and line["TGT"] == "group"
                ):
                    if "&" in line["group"]:
                        continue
                    else:
                        label = line["group"]
                else:
                    continue
            else:
                print(self.args.task)
                raise Exception("Unsupported Task!")

            if i % 1000 == 0:
                logger.info(line)
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
            )
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None
        if mode == "train":
            file_to_read = self.args.train_file
        elif mode == "dev":
            file_to_read = self.args.dev_file
        elif mode == "test":
            file_to_read = self.args.test_file

        logger.info(
            "LOOKING AT {}".format(
                os.path.join(self.args.data_dir, self.args.data, file_to_read)
            )
        )
        return self._create_examples(
            self._read_file(
                os.path.join(self.args.data_dir, self.args.data, file_to_read)
            ),
            mode,
        )


seq_cls_processors = {
    "kold": KoldProcessor,
}


seq_cls_output_modes = {
    "kold": "classification",
}

def seq_cls_load_and_cache_examples(args, tokenizer, mode):
    processor = seq_cls_processors[args.data](args)
    output_mode = seq_cls_output_modes[args.data]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            str(args.task),
            list(filter(None, args.model_dir.split("/"))).pop(),
            str(args.max_seq_len),
            mode,
        ),
    )
    if os.path.exists(cached_features_file):
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        if mode == "train":
            examples = processor.get_examples("train")
        elif mode == "dev":
            examples = processor.get_examples("dev")
        elif mode == "test":
            examples = processor.get_examples("test")
        else:
            raise ValueError("For mode, only train, dev, test is avaiable")
        features = seq_cls_convert_examples_to_features(
            args, examples, tokenizer, max_length=args.max_seq_len, data=args.data
        )
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor(
        [f.attention_mask for f in features], dtype=torch.long
    )
    all_token_type_ids = torch.tensor(
        [f.token_type_ids for f in features], dtype=torch.long
    )
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    dataset = TensorDataset(
        all_input_ids, all_attention_mask, all_token_type_ids, all_labels
    )
    return dataset
