import os
import copy
import json
import logging

import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset

logger = logging.getLogger(__name__)


class InputExample(object):
    """
    A single training/test example for simple sequence classification.
    """

    def __init__(self, guid, words, labels):
        self.guid = guid
        self.words = words
        self.labels = labels

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, token_type_ids, label_ids):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label_ids = label_ids

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


def ner_convert_examples_to_features(
    args,
    examples,
    tokenizer,
    max_seq_length,
    task,
    pad_token_label_id=-100,
):
    label_lst = ner_processors[args.data](args).get_labels()
    label_map = {label: i for i, label in enumerate(label_lst)}

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            logger.info("Writing example {} of {}".format(ex_index, len(examples)))

        tokens = []
        label_ids = []

        # Map that sends B-Xxx label to its I-Xxx counterpart
        b_to_i_label = []
        for idx, label in enumerate(label_lst):
            if label.endswith("B") and label.replace("B", "I") in label_lst:
                b_to_i_label.append(label_lst.index(label.replace("B", "I")))
            else:
                b_to_i_label.append(idx)

        for word, label in zip(example.words, example.labels):
            word_tokens = tokenizer.tokenize(word)
            if not word_tokens:
                word_tokens = [tokenizer.unk_token]  # For handling the bad-encoded word
            tokens.extend(word_tokens)

            if args.label_all_tokens:
                # Use the real label id for all tokens
                label_ids.extend(
                    [label_map[label]]
                    + [b_to_i_label[label_map[label]]] * (len(word_tokens) - 1)
                )

            else:
                # Use the real label id for the first token of the word, and padding ids for the remaining tokens
                label_ids.extend(
                    [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1)
                )

        special_tokens_count = 2
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]

        # Add [SEP]
        tokens += [tokenizer.sep_token]
        label_ids += [pad_token_label_id]

        # Add [CLS]
        tokens = [tokenizer.cls_token] + tokens
        label_ids = [pad_token_label_id] + label_ids

        token_type_ids = [0] * len(tokens)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        attention_mask = [1] * len(input_ids)

        padding_length = max_seq_length - len(input_ids)
        input_ids += [tokenizer.pad_token_id] * padding_length
        attention_mask += [0] * padding_length
        token_type_ids += [0] * padding_length
        label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(attention_mask) == max_seq_length
        assert len(token_type_ids) == max_seq_length
        assert len(label_ids) == max_seq_length

        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info(
                "attention_mask: %s" % " ".join([str(x) for x in attention_mask])
            )
            logger.info(
                "token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])
            )
            logger.info("label: %s " % " ".join([str(x) for x in label_ids]))

        features.append(
            InputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                label_ids=label_ids,
            )
        )
    return features


class KoldProcessor(object):
    """Processor for the KOLD data set"""

    def __init__(self, args):
        self.args = args

    def get_labels(self):
        if self.args.task == "sp-all":
            labels = ["O", "OFF-B", "OFF-I", "TGT-B", "TGT-I", "OFF_TGT-B", "OFF_TGT-I"]
        elif self.args.task == "sp-off":
            labels = ["OFF", "NOT"]
        elif self.args.task == "sp-tgt":
            labels = ["TGT", "NOT"]
        else:
            raise Exception("Unsupported Task!")
        return labels

    @classmethod
    def _read_file(cls, input_file):
        """Read json/jsonl file"""
        if input_file.endswith("jsonl"):
            with open(input_file, "r", encoding="utf-8") as f:
                lines = []
                json_list = list(f)
                for json_str in json_list:
                    lines.append(json.loads(json_str))

        elif input_file.endswith("json"):
            with open(input_file, "r", encoding="utf-8") as f:
                lines = json.load(f)

        else:
            assert ValueError(f"Unsupported file extension: {input_file}")

        return lines

    def _create_examples(self, dataset, set_type):
        """reformat hate span data into NER format

        Input : original format
        {
        "text": "대한민국 페미는 사라져야한다 제정신이아니다. 약탈주의적 페미니즘",
        "title": "“안산이 먼저 남혐” 폭력의 목소리 대변한 ‘국민의힘 대변인’",
        "offensiveness": 3, "target": "group", "target_group": "gender-feminist",
        "off_span_list": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        "tgt_span_list": [0, 0, 0 , 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        "guid": "hatespan_v0.1_09454"
        }

        Output: ner-format
        [대한민국 페미##는 사라져##야##한다 제정신##이##아##니다. 약탈##주의#적 페미니즘]
        [   O       TGT-B        OFF-B            OFF-I               OFF-I      OFF-I  ]

        """
        examples = []
        for (i, data) in enumerate(dataset):
            text = data["comment"]
            off_span_list = data["off_span_list"]
            tgt_span_list = data["tgt_span_list"]

            #  do not predict span when not offensive
            if not data["OFF"]:
                continue

            labels = []
            words = text.split()
            cur_idx = 0
            #  prev_label = "O"
            for word in words:
                off_flag = False
                tgt_flag = False
                if off_span_list[cur_idx] >= 2:
                    off_flag = True
                if tgt_span_list[cur_idx] >= 2:
                    tgt_flag = True

                if self.args.task == "sp-off":
                    if off_flag:
                        label = "OFF"
                    else:
                        label = "NOT"
                elif self.args.task == "sp-tgt":
                    if tgt_flag:
                        label = "TGT"
                    else:
                        label = "NOT"
                cur_idx = cur_idx + len(word) + 1  # add word len + 1(space)
                labels.append(label)

            guid = "%s-%s" % (set_type, i)

            assert len(words) == len(labels)

            if i % 10000 == 0:
                logger.info(data)
            examples.append(InputExample(guid=guid, words=words, labels=labels))
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None

        if mode == "train":
            file_to_read = self.args.train_file
        elif mode == "dev":
            file_to_read = self.args.dev_file
        elif mode == "test":
            file_to_read = self.args.test_file

        if self.args.target:
            logger.info(
                "LOOKING AT {}".format(
                    os.path.join(
                        self.args.data_dir,
                        self.args.data,
                        self.args.target,
                        file_to_read,
                    )
                )
            )
            return self._create_examples(
                self._read_file(
                    os.path.join(
                        self.args.data_dir,
                        self.args.data,
                        self.args.target,
                        file_to_read,
                    )
                ),
                mode,
            )
        else:
            logger.info(
                "LOOKING AT {}".format(
                    os.path.join(self.args.data_dir, self.args.data, file_to_read)
                )
            )
            return self._create_examples(
                self._read_file(
                    os.path.join(self.args.data_dir, self.args.data, file_to_read)
                ),
                mode,
            )


ner_processors = {"kold": KoldProcessor}

ner_tasks_num_labels = {"sp-off": 2, "sp-tgt": 2}


def ner_load_and_cache_examples(args, tokenizer, mode):
    processor = ner_processors[args.data](args)
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            str(args.task),
            list(filter(None, args.model_dir.split("/"))).pop(),
            str(args.max_seq_len),
            mode,
        ),
    )
    logger.info("Creating features from dataset file at %s", args.data_dir)
    if mode == "train":
        examples = processor.get_examples("train")
    elif mode == "dev":
        examples = processor.get_examples("dev")
    elif mode == "test":
        examples = processor.get_examples("test")
    else:
        raise ValueError("For mode, only train, dev, test is avaiable")

    pad_token_label_id = CrossEntropyLoss().ignore_index
    features = ner_convert_examples_to_features(
        args,
        examples,
        tokenizer,
        max_seq_length=args.max_seq_len,
        task=args.task,
        pad_token_label_id=pad_token_label_id,
    )
    logger.info("Saving features into cached file %s", cached_features_file)
    torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor(
        [f.attention_mask for f in features], dtype=torch.long
    )
    all_token_type_ids = torch.tensor(
        [f.token_type_ids for f in features], dtype=torch.long
    )
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

    dataset = TensorDataset(
        all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids
    )
    return dataset
