import os
import copy
import json
import logging

import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset
from transformers.tokenization_utils_base import TruncationStrategy

logger = logging.getLogger(__name__)

OLD_KOLD_HATE_GROUP_LIST = [
    "others",
    "not_offensive",
    "offensive",
    "gender-female",
    "gender-feminist",
    "gender-male",
    "gender-others",
    "gender-queer",
    "others-age",
    "others-disability",
    "others-disease",
    "others-others",
    "others-physical_appearance",
    "others-socioeconomic_status",
    "politics-conservative",
    "politics-others",
    "politics-progressive",
    "race-asian",
    "race-black",
    "race-chinese",
    "race-indian",
    "race-korean_chinese",
    "race-others",
    "race-southeast_asian",
    "race-white",
    "religion-buddhism",
    "religion-catholic",
    "religion-christian",
    "religion-islam",
    "religion-others",
    "sexual_orientation-gay",
    "sexual_orientation-homosexual",
    "sexual_orientation-lesbian",
    "sexual_orientation-others",
]

KOLD_HATE_GROUP_LIST = [
    "gender-LGBTQ+",
    "gender-female",
    "gender-male",
    "gender-others",
    "race-asian",
    "race-black",
    "race-chinese",
    "race-indian",
    "race-korean_chinese",
    "race-southeast_asian",
    "race-white",
    "race-others",
    "politics-conservative",
    "politics-progressive",
    "politics-others",
    "religion-buddhism",
    "religion-catholic",
    "religion-christian",
    "religion-islam",
    "religion-others",
    "others-age",
    "others-disability",
    "others-disease",
    "others-feminist",
    "others-physical_appearance",
    "others-socioeconomic_status",
    "others-others"
]
# KOLD_HATE_GROUP_LIST += ["multi-group", "offensive", "not_offensive"]


class InputExample(object):
    """
    A single training/test example for simple sequence classification.
    """

    def __init__(self, guid, title, title_subwords, title_labels, comment, comment_subwords, pooled_label, labels):
        self.guid = guid
        self.title = title
        self.title_subwords = title_subwords
        self.title_labels = title_labels
        self.comment = comment
        self.comment_subwords = comment_subwords
        self.pooled_label = pooled_label
        self.labels = labels

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, token_type_ids, pooled_label_id, label_ids):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.pooled_label_id = pooled_label_id
        self.label_ids = label_ids

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


def get_labels_for_subword(word_tokens, label):
    labels = [label]
    if label == "O":
        labels.extend([label]*(len(word_tokens)-1))
    elif label == "OFF-B":
        labels.extend(["OFF-I"] * (len(word_tokens) - 1))
    elif label == "OFF-I":
        labels.extend([label] * (len(word_tokens) - 1))
    elif label == "TGT-B":
        labels.extend(["TGT-I"] * (len(word_tokens) - 1))
    elif label == "TGT-I":
        labels.extend([label] * (len(word_tokens) - 1))
    return labels


def ner_convert_examples_to_features(
    args,
    examples,
    tokenizer,
    max_seq_length,
    task,
    pad_token_label_id=-100,
):
    processor = ner_processors[args.data](args, tokenizer)
    pooled_label_list = processor.get_pooled_labels()
    pooled_label_map = {label: i for i, label in enumerate(pooled_label_list)}
    label_lst = processor.get_labels()
    label_map = {label: i for i, label in enumerate(label_lst)}

    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            logger.info("Writing example {} of {}".format(ex_index, len(examples)))

        tokens = []
        label_ids = []

        if args.label_all_tokens:
            encoded_dict = tokenizer.encode_plus(
                example.title,
                example.comment,
                truncation=TruncationStrategy.ONLY_FIRST.value,
                max_length=max_seq_length,
                padding="max_length",
                return_token_type_ids=True,
            )
            input_ids = encoded_dict['input_ids']
            attention_mask = encoded_dict['attention_mask']
            token_type_ids = encoded_dict['token_type_ids']
            title_labels = [label_map["O"] for _ in range(1, token_type_ids.index(1) - 1)]
            comment_labels = [label_map[_label] for _label in example.labels]
            label_ids = [pad_token_label_id] + title_labels + [pad_token_label_id] + comment_labels + [pad_token_label_id]
            label_ids += [pad_token_label_id] * (max_seq_length - len(label_ids))
        else:
            for word, label in zip(example.words, example.labels):
                word_tokens = tokenizer.tokenize(word)
                if not word_tokens:
                    word_tokens = [tokenizer.unk_token]  # For handling the bad-encoded word
                tokens.extend(word_tokens)
                # Use the real label id for the first token of the word, and padding ids for the remaining tokens
                label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))

        special_tokens_count = 2
        if len(tokens) > max_seq_length - special_tokens_count:
            logger.info(f"TRUNCATED {example.guid}: #TOK:{len(tokens)} > MAX_L:{max_seq_length - special_tokens_count}")
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]

        if not args.label_all_tokens:
            # Add [SEP]
            tokens += [tokenizer.sep_token]
            label_ids += [pad_token_label_id]

            # Add [CLS]
            tokens = [tokenizer.cls_token] + tokens
            label_ids = [pad_token_label_id] + label_ids
            token_type_ids = [0] * len(tokens)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            attention_mask = [1] * len(input_ids)

            padding_length = max_seq_length - len(input_ids)
            input_ids += [tokenizer.pad_token_id] * padding_length
            attention_mask += [0] * padding_length
            token_type_ids += [0] * padding_length
            label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(attention_mask) == max_seq_length
        assert len(token_type_ids) == max_seq_length
        assert len(label_ids) == max_seq_length

        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            if args.label_all_tokens:
                logger.info("tokens: %s" % (example.title + example.comment))
            else:
                logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
            logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
            logger.info("label: %s " % " ".join([str(x) for x in label_ids]))

        features.append(
            InputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                pooled_label_id=pooled_label_map[example.pooled_label],
                label_ids=label_ids,
            )
        )
    return features


class NaverNerProcessor(object):
    """Processor for the Naver NER data set"""

    def __init__(self, args):
        self.args = args

    def get_labels(self):
        return [
            "O",
            "PER-B",
            "PER-I",
            "FLD-B",
            "FLD-I",
            "AFW-B",
            "AFW-I",
            "ORG-B",
            "ORG-I",
            "LOC-B",
            "LOC-I",
            "CVL-B",
            "CVL-I",
            "DAT-B",
            "DAT-I",
            "TIM-B",
            "TIM-I",
            "NUM-B",
            "NUM-I",
            "EVT-B",
            "EVT-I",
            "ANM-B",
            "ANM-I",
            "PLT-B",
            "PLT-I",
            "MAT-B",
            "MAT-I",
            "TRM-B",
            "TRM-I",
        ]

    @classmethod
    def _read_file(cls, input_file):
        """Read tsv file, and return words and label as list"""
        with open(input_file, "r", encoding="utf-8") as f:
            lines = []
            for line in f:
                lines.append(line.strip())
            return lines

    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, data) in enumerate(dataset):
            words, labels = data.split("\t")
            words = words.split()
            labels = labels.split()
            guid = "%s-%s" % (set_type, i)

            assert len(words) == len(labels)

            if i % 10000 == 0:
                logger.info(data)
            examples.append(InputExample(guid=guid, words=words, labels=labels))
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None
        if mode == "train":
            file_to_read = self.args.train_file
        elif mode == "dev":
            file_to_read = self.args.dev_file
        elif mode == "test":
            file_to_read = self.args.test_file

        logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, self.args.task, file_to_read)))
        return self._create_examples(
            self._read_file(os.path.join(self.args.data_dir, self.args.task, file_to_read)),
            mode,
        )


class HateSpanProcessor(object):
    """Processor for the Hate Span data set"""

    def __init__(self, args, tokenizer):
        self.args = args
        self.min_offensiveness = args.min_offensiveness
        self.tokenizer = tokenizer

    def get_pooled_labels(self):
        if self.args.task == "sc-off":
            labels = ["not_offensive", "offensive"]
        elif self.args.task == "sc-tgt":
            labels = ["not_offensive", "offensive"]
        elif self.args.task == "sp-off":  # if self.args.task == "sc-off":
            labels = ["not_offensive", "offensive"]
        elif self.args.task == "sp-hate":   # elif self.args.task == "sc-hate":
            labels = ["not_offensive", "offensive", "hate"]
        elif self.args.task == "sp-tgt":    # elif self.args.task == "sc-tgt":
            labels = ["UNT", "IND", "OTH", "GRP"]
        elif self.args.task == "sp-group":
            labels = KOLD_HATE_GROUP_LIST
        else:
            raise Exception("Unsupported task!")
        return labels

    def get_labels(self):
        if self.args.task == "sp-all":
            labels = ["O", "OFF-B", "OFF-I", "TGT-B", "TGT-I", "OFF_TGT-B", "OFF_TGT-I"]
        elif self.args.task == "sp-off" or self.args.task == "sc-off":
            labels = ["O", "OFF-B", "OFF-I"]
        elif self.args.task == "sp-tgt" or self.args.task == "sc-tgt" or self.args.task == "sp-group":
            labels = ["O", "TGT-B", "TGT-I"]
        else:
            raise Exception("Unsupported Task!")
        return labels

    @classmethod
    def _read_file(cls, input_file):
        """Read json/jsonl file"""
        if input_file.endswith("jsonl"):
            with open(input_file, "r", encoding="utf-8") as f:
                lines = []
                json_list = list(f)
                for json_str in json_list:
                    lines.append(json.loads(json_str))

        elif input_file.endswith("json"):
            with open(input_file, "r", encoding="utf-8") as f:
                lines = json.load(f)

        else:
            assert ValueError(f"Unsupported file extension: {input_file}")

        return lines

    def _create_examples(self, dataset, set_type):
        """reformat hate span data into NER format

        Input : original format
        {
        "text": "대한민국 페미는 사라져야한다 제정신이아니다. 약탈주의적 페미니즘",
        "title": "“안산이 먼저 남혐” 폭력의 목소리 대변한 ‘국민의힘 대변인’",
        "offensiveness": 3, "target": "group", "target_group": "gender-feminist",
        "off_span_list": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        "tgt_span_list": [0, 0, 0 , 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        "guid": "hatespan_v0.1_09454"
        }

        Output: ner-format
        [대한민국 페미##는 사라져##야##한다 제정신##이##아##니다. 약탈##주의#적 페미니즘]
        [   O       TGT-B        OFF-B            OFF-I               OFF-I      OFF-I  ]

        """
        examples = []
        for (i, data) in enumerate(dataset):
            title = data["title"]
            text = data["text"]
            guid = data["text_id"]
            off_span_list = data["off_span"]
            tgt_span_list = data["tgt_span"]
            pooled_label = "offensive" if data['offensiveness'] >= self.min_offensiveness else "not_offensive"

            # TODO - why all the examples even from trainset are skipped by offensiveness <= 1 ???
            #  do not predict span when not offensive
            # if data["offensiveness"] <= 1:
            #     continue

            labels = []
            comment = text.split()
            cur_idx = 0
            prev_label = "O"
            for word in comment:
                off_flag = False
                tgt_flag = False
                if off_span_list[cur_idx] >= 2:
                    off_flag = True
                if tgt_span_list[cur_idx] >= 2:
                    tgt_flag = True

                if self.args.task == "sp-all":
                    if off_flag and tgt_flag:
                        label = "OFF_TGT"
                    elif off_flag:
                        label = "OFF"
                    elif tgt_flag:
                        label = "TGT"
                    else:
                        label = "O"
                elif self.args.task == "sp-off":
                    if off_flag:
                        label = "OFF"
                    else:
                        label = "O"
                elif self.args.task == "sp-tgt":
                    if tgt_flag:
                        label = "TGT"
                    else:
                        label = "O"

                if label != "O":
                    if label not in prev_label:
                        label += "-B"
                    else:
                        label += "-I"
                cur_idx = cur_idx + len(word) + 1  # add word len + 1(space)
                labels.append(label)
                prev_label = label

            assert len(comment) == len(labels)

            if i % 10000 == 0:
                logger.info(data)
            examples.append(InputExample(guid=guid, title=title, comment=comment, pooled_label=pooled_label, labels=labels))
        return examples

    @staticmethod
    def _find_next_positions_by_offset(curr_idx, offset):
        for start_idx, end_idx in offset:
            if start_idx <= curr_idx < end_idx:
                return end_idx
        return -1

    def _create_subword_examples(self, dataset, set_type):

        tgt_labels = []
        if self.args.task == "sp-group":
            tgt_labels = self.get_pooled_labels()

        examples = []
        for (i, data) in enumerate(dataset):
            title = data["title"]
            comment = data["comment"]
            guid = data["text_id"]
            off_span_list = data["off_span_list"]
            tgt_span_list = data["tgt_span_list"]
            if self.args.task == "sp-off" or self.args.task == "sc-off" or self.args.task == "sc-tgt":
                pooled_label = "offensive" if data['offensive'] else "not_offensive"

            elif self.args.task == "sp-tgt":
                if data['targeted_insult']:
                    if data['target'] == "group":
                        pooled_label = "GRP"
                    elif data['target'] == "individual":
                        pooled_label = "IND"
                    elif data['target'] == "other":
                        pooled_label = "OTH"
                else:
                    pooled_label = "UNT"

            elif self.args.task == "sp-group":
                if self.args.predict_tgt_group:
                    if data['target'] == "group":
                        pooled_label = data['group']
                        if '&' in pooled_label:
                            continue
                        if pooled_label not in tgt_labels:
                            raise ValueError(f"{pooled_label} not in group_labels")
                    else:
                        continue
            else:
                raise ValueError(f"{self.args.task} NOT SUPPORTED")

            labels = []
            title_tokenized = self.tokenizer(title, return_offsets_mapping=True, add_special_tokens=False)
            title_labels = ['O' for _ in title_tokenized['input_ids']]
            title_subwords = self.tokenizer.convert_ids_to_tokens(title_tokenized['input_ids'])

            comment_tokenized = self.tokenizer(comment, return_offsets_mapping=True, add_special_tokens=False)
            comment_subwords = self.tokenizer.convert_ids_to_tokens(comment_tokenized['input_ids'])
            offset_mapping = comment_tokenized['offset_mapping']

            if self.args.task == "sp-off" or self.args.task == "sc-off":
                golds = off_span_list
                prefix_label = "OFF"
            elif self.args.task == "sp-tgt" or self.args.task == "sc-tgt" or self.args.task == "sp-group":
                golds = tgt_span_list
                prefix_label = "TGT"
            else:
                raise ValueError(f"{self.args.task} NOT SUPPORTED (sp-off|sp-tgt)")

            prev_label = "O"
            for start_idx, end_idx in offset_mapping:
                curr_label = ""
                for idx in range(start_idx, end_idx):
                    if golds[idx] >= self.min_offensiveness:
                        curr_label = prefix_label
                        break
                if curr_label == "":
                    curr_label = "O"

                if curr_label != "O":
                    if curr_label not in prev_label:
                        curr_label += "-B"
                    else:
                        curr_label += "-I"
                labels.append(curr_label)
                prev_label = curr_label

            assert len(comment_subwords) == len(labels)

            if i % 10000 == 0:
                logger.info(data)
            examples.append(InputExample(
                guid=guid, pooled_label=pooled_label,
                title=title, title_subwords=title_subwords, title_labels=title_labels,
                comment=comment, comment_subwords=comment_subwords, labels=labels
            ))
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None

        if mode == "train":
            file_to_read = self.args.train_file
        elif mode == "dev":
            file_to_read = self.args.dev_file
        elif mode == "test":
            file_to_read = self.args.test_file

        if self.args.target:
            file_path = os.path.join(self.args.data_dir, self.args.data, self.args.target, file_to_read)
            logger.info("LOOKING AT {}".format(file_path))
            if self.args.label_all_tokens:
                return self._create_subword_examples(self._read_file(file_path), mode)
            else:
                return self._create_examples(self._read_file(file_path), mode)
        else:
            file_path = os.path.join(self.args.data_dir, self.args.data, file_to_read)
            logger.info("LOOKING AT {}".format(file_path))
            if self.args.label_all_tokens:
                return self._create_subword_examples(self._read_file(file_path), mode)
            else:
                return self._create_examples(self._read_file(file_path), mode)


ner_processors = {"naver-ner": NaverNerProcessor, "hatespan": HateSpanProcessor}

ner_tasks_num_labels = {
    "naver-ner": 29, "sp-all": 7, "sp-off": 3, "sp-tgt": 3, "sc-off": 3, "sc-tgt": 3, "sp-group": 3
}


def ner_load_and_cache_examples(args, tokenizer, mode):
    processor = ner_processors[args.data](args, tokenizer)
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            str(args.task),
            list(filter(None, args.model_dir.split("/"))).pop(),
            str(args.max_seq_len),
            mode,
        ),
    )
    # no cache
    #  if os.path.exists(cached_features_file):
    #      logger.info("Loading features from cached file %s", cached_features_file)
    #      features = torch.load(cached_features_file)
    #  else:
    logger.info("Creating features from dataset file at %s", args.data_dir)
    if mode == "train":
        examples = processor.get_examples("train")
    elif mode == "dev":
        examples = processor.get_examples("dev")
    elif mode == "test":
        examples = processor.get_examples("test")
    else:
        raise ValueError("For mode, only train, dev, test is avaiable")

    pad_token_label_id = CrossEntropyLoss().ignore_index
    features = ner_convert_examples_to_features(
        args,
        examples,
        tokenizer,
        max_seq_length=args.max_seq_len,
        task=args.task,
        pad_token_label_id=pad_token_label_id,
    )
    logger.info("Saving features into cached file %s", cached_features_file)
    torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_pooled_label_ids = torch.tensor([f.pooled_label_id for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_pooled_label_ids, all_label_ids)
    return dataset
