import os
import copy
import json
import logging

import torch
from torch.utils.data import TensorDataset

logger = logging.getLogger(__name__)

class InputExample(object):
    """
    A single training/test example for simple sequence classification.
    """

    def __init__(self, guid, text_a, text_b, label,span_label):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
        self.speaker=span_label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, token_type_ids, label,p_mask,spans):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.class_label = label
        self.p_mask = p_mask
        self.span_label = spans

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


def seq_cls_convert_examples_to_features(args, examples, tokenizer, max_length,span_token_id):
    processor = MultiProcessor(args)
    label_list = processor.get_labels()
    span_labels=processor.get_labels_span()
    logger.info("Using label list {}".format(label_list))
    output_mode ='classification'

    label_map = {label: i for i, label in enumerate(label_list)}


    def label_from_example(example):
        if output_mode == "classification":
            return label_map[example.label]
        elif output_mode == "regression":
            return float(example.label)
        raise KeyError(output_mode)

    def label_from_span(sp):
            if sp=='n':
                return 0
            elif sp=='g':
                return 1
            elif sp=='p':
                return 2



    labels = [label_from_example(example) for example in examples]
    hep=[span_token_id]
    batch_encoding = tokenizer.batch_encode_plus(
        [example.text_a for example in examples],
        max_length=max_length,
        padding="max_length",
        add_special_tokens=True,
        truncation=True,
    )

    features = []
    for i in range(len(examples)):
        inputs = {k: batch_encoding[k][i] for k in batch_encoding}
        if "token_type_ids" not in inputs:
            inputs["token_type_ids"] = [0] * len(inputs["input_ids"])  # For xlm-roberta
        seg_id=[]
        seg_flag = 0
        for token in inputs["input_ids"] :
            seg_id.append(seg_flag)
            if token == 3:
                seg_flag = 1 - seg_flag
        p_mask=[0 if p in hep else 1 for p in inputs['input_ids']]
        span_labell=[]
        for inx,ids in enumerate(inputs['input_ids']):
            count=0
            if ids==span_token_id:
                span_labell.append(label_from_span(examples[i].speaker[count]))
                count+=1
            else:
                span_labell.append(0)
        feature = InputFeatures(**inputs, label=labels[i],p_mask=p_mask,spans=span_labell)
        features.append(feature)

    for i, example in enumerate(examples[:5]):
        logger.info("*** Example ***")
        logger.info("guid: {}".format(example.guid))
        logger.info("input_ids: {}".format(" ".join([str(x) for x in features[i].input_ids])))
        logger.info("attention_mask: {}".format(" ".join([str(x) for x in features[i].attention_mask])))
        logger.info("token_type_ids: {}".format(" ".join([str(x) for x in features[i].token_type_ids])))
        logger.info("label: {}".format(features[i].class_label))
        logger.info("span_label: {}".format(features[i].span_label))
        logger.info("p_mask: {}".format(features[i].p_mask))

    return features

class MultiProcessor(object):
    """Processor for the NSMC data set """

    def __init__(self, args):
        self.args = args

    def get_labels(self):
        return ["000001", "020121", "02051", "020811", "020819"]

    def get_labels_span(self):
        return ["n","g","p"]

    @classmethod
    def _read_file(cls, input_file):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            lines = []
            for line in f:
                lines.append(line.strip())
            return lines

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            line = line.strip().split("\t")
            guid = "%s-%s" % (set_type, i)
            #print(line)

            text_a = line[0].strip()
            label = line[1].strip()
            span_label=line[2].strip().split(',')
            if i % 10000 == 0:
                logger.info(line)
            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label,span_label=span_label))
        return examples

    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None
        if mode == "train":
            file_to_read = self.args.train_file
        elif mode == "dev":
            file_to_read = self.args.dev_file
        elif mode == "test":
            file_to_read = self.args.test_file

        logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir,  file_to_read)))
        return self._create_examples(
            self._read_file(os.path.join(self.args.data_dir,  file_to_read)), mode
        )



def seq_cls_load_and_cache_examples(args, tokenizer, mode,span_token_id):
    processor = MultiProcessor(args)
    output_mode ='classification'
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}".format(
            list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_len), mode
        ),
    )
    if os.path.exists(cached_features_file):
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        if mode == "train":
            examples = processor.get_examples("train")
        elif mode == "dev":
            examples = processor.get_examples("dev")
        elif mode == "test":
            examples = processor.get_examples("test")
        else:
            raise ValueError("For mode, only train, dev, test is avaiable")
        features = seq_cls_convert_examples_to_features(
            args, examples, tokenizer, max_length=args.max_seq_len,span_token_id=span_token_id)
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_token_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.long)
    all_token_span_label = torch.tensor([f.span_label for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.class_label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels,all_token_p_mask ,all_token_span_label)
    return dataset

