import json
import random
import logging
import copy
import torch

from utils import ALL_SLOTS

logging.basicConfig(filename='data_preview_load.log', level=logging.INFO, filemode='w')
logger = logging.getLogger(__name__)

class PreviewExample:
    """
    A single training/test example for the DST dataset.
    """
    
    def __init__(self,
                 guid,
                 text_a,
                 text_b,
                 history,
                 text_a_label=None,
                 text_b_label=None,
                 history_label=None,
                 class_label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.history = history
        self.text_a_label = text_a_label
        self.text_b_label = text_b_label
        self.history_label = history_label
        self.class_label = class_label
    
    def __str__(self):
        return self.__repr__()
    
    def __repr__(self):
        s = ""
        s += "guid: %s\n" % (self.guid)
        s += "text_a: %s\n" % str(self.text_a)
        s += "text_b: %s\n" % str(self.text_b)
        s += "history: %s\n" % str(self.history)
        if self.text_a_label:
            s += "text_a_label: %s\n" % str(self.text_a_label)
        if self.text_b_label:
            s += "text_b_label: %s\n" % str(self.text_b_label)
        if self.history_label:
            s += "history_label: %s\n" % str(self.history_label)
        if self.class_label:
            s += "class_label: %s\n\n" % str(self.class_label)
        return s


class InputFeatures:
    """A single set of features of data."""

    def __init__(self,
                 input_ids,
                 input_mask,
                 segment_ids,
                 span_label_id=None,
                 class_label_id=None,
                 guid="NONE"):
        self.guid = guid
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.span_label_id = span_label_id
        self.class_label_id = class_label_id



def preview_convert_to_examples(slot_list, uttop,
                                swap_utterances=True):
    max_preview_datalen = 0
    # cache_readable_data 包含了 delex sentences
    with open('preview/dontcare_data_ready.json') as f:
        dontcare_data = json.load(f)
    with open('preview/inform_data_ready.json') as f:
        inform_data = json.load(f)
    with open('preview/mention_data_ready.json') as f:
        mention_data = json.load(f)
    with open('preview/refer_data_ready.json') as f:
        refer_data = json.load(f)
    
    all_data = []
    for dialogs in [
        # refer_data,
        # inform_data,
        # mention_data,
        refer_data
    ]:
        for k in dialogs:
            all_data.extend(copy.deepcopy(dialogs[k]))
    
    examples = []
    
    for _ in range(2):
        for turn in all_data:
            guid = turn['dial_loc']
            delex_usr_norm = turn['delex_usr_norm']
            delex_sys_norm = turn['delex_sys_norm']
            history = turn['history']
            delex_sys_sp_dic = turn['delex_sys_sp_dic']
            delex_usr_sp_dic = turn['delex_usr_sp_dic']
    
            turn_label = turn['turn_label']
            class_label = turn_label['class_label']
            refer_label = turn_label['refer_label']
            
            label_ = {}
            for s in slot_list:
                if s in class_label:
                    label_[s] = class_label[s]
                else:
                    label_[s] = 'none'
            for s in refer_label:
                label_[refer_label[s]] = 'refer'
    
            delex_usr_sp_dic_ = {}
            for k, v in delex_usr_sp_dic.items():
                if k.strip('<').strip('>') not in refer_label:
                    delex_usr_sp_dic_[k] = [onto.recommend_value_different(k) for _ in range(len(v))]
                else:
                    delex_usr_sp_dic_[k] = copy.deepcopy(v)
            usr_utt_tok, usr_span_dic, usr_sent = uttop.lex(delex_usr_norm, delex_usr_sp_dic_, onto=None, use_aug=False)
            
            if delex_sys_norm:
                delex_sys_sp_dic_ = {}
                for k, v in delex_sys_sp_dic.items():
                    if len(v) > 1: delex_sys_sp_dic_[k] = copy.deepcopy(v)
                    else: delex_sys_sp_dic_[k] = [onto.recommend_value_different(k)]
                sys_utt_tok, sys_span_dic, sys_sent = uttop.lex(delex_sys_norm, delex_sys_sp_dic_, onto=None, use_aug=False)
            
                history_span_dic = {'<%s>'%s: [onto.recommend_value_different(s)] for s in slot_list}
                hst_utt_tok, hst_span_dic, hst_sent = uttop.lex(history, history_span_dic, onto=None, use_aug=False)
            
            else:
                sys_utt_tok, sys_span_dic = [], {}
                hst_utt_tok, hst_span_dic = [], {}
                
            usr_utt_tok_label_dict = {slot: [0] * len(usr_utt_tok) for slot in slot_list}
            sys_utt_tok_label_dict = {slot: [0] * len(sys_utt_tok) for slot in slot_list}
            hst_utt_tok_label_dict = {slot: [0] * len(hst_utt_tok) for slot in slot_list}
            # usr_utt_tok_label_dict = {}
            # sys_utt_tok_label_dict = {}
            # hst_utt_tok_label_dict = {}
            
            for k, pos in usr_span_dic.items():
                if not pos: continue
                b, e = pos[-1]
                vec = [0] * len(usr_utt_tok)
                vec[b: e] = [1] * (e - b)
                usr_utt_tok_label_dict[k.strip('<').strip('>')] = copy.deepcopy(vec)
            for k, pos in sys_span_dic.items():
                if not pos: continue
                vec = [0] * len(sys_utt_tok)
                for (b, e) in pos:
                    vec[b: e] = [1] * (e - b)
                sys_utt_tok_label_dict[k.strip('<').strip('>')] = copy.deepcopy(vec)
            for k, pos in hst_span_dic.items():
                if not pos: continue
                b, e = pos[-1]
                vec = [0] * len(hst_utt_tok)
                vec[b: e] = [1] * (e - b)
                hst_utt_tok_label_dict[k.strip('<').strip('>')] = copy.deepcopy(vec)
            
            # print(hst_utt_tok)
            # print(sys_utt_tok)
            # print(usr_utt_tok)
            # print(hst_utt_tok_label_dict)
            # print(sys_utt_tok_label_dict)
            # print(usr_utt_tok_label_dict)
            # print(label_)
            # print()
            
            if swap_utterances:
                txt_a = usr_utt_tok
                txt_b = sys_utt_tok
                txt_a_lbl = usr_utt_tok_label_dict
                txt_b_lbl = sys_utt_tok_label_dict
            else:
                txt_a = sys_utt_tok
                txt_b = usr_utt_tok
                txt_a_lbl = sys_utt_tok_label_dict
                txt_b_lbl = usr_utt_tok_label_dict
            
            examples.append(PreviewExample(
                guid=guid,
                text_a=copy.deepcopy(txt_a),
                text_b=copy.deepcopy(txt_b),
                history=copy.deepcopy(hst_utt_tok),
                text_a_label=copy.deepcopy(txt_a_lbl),
                text_b_label=copy.deepcopy(txt_b_lbl),
                history_label=copy.deepcopy(hst_utt_tok_label_dict),
                class_label=copy.deepcopy(label_)))
            max_preview_datalen = max(max_preview_datalen,
                                      len(txt_a)+len(txt_b)+len(hst_utt_tok))
            # print(examples[-1])
            # hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
    random.shuffle(examples)
    return examples


def preview_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length):
    def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
        while True:
            total_length = len(tokens_a) + len(tokens_b) + len(history)
            if total_length <= max_length:
                break
            if len(history) > 0:
                history.pop()
            elif len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()
    
    def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
        if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
            # logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
            input_text_too_long = True
        else:
            input_text_too_long = False
        _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
        return input_text_too_long
    
    def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
        token_label_ids = []
        token_label_ids.append(0)  # [CLS]
        for token_label in token_labels_a:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        for token_label in token_labels_b:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        for token_label in token_labels_history:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        while len(token_label_ids) < max_seq_length:
            token_label_ids.append(0)  # padding
        assert len(token_label_ids) == max_seq_length
        return token_label_ids
    
    def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
        tokens = []
        segment_ids = []
        tokens.append(model_specs['CLS_TOKEN'])
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(0)
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        for token in history:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        return tokens, input_ids, input_mask, segment_ids
    
    if model_type == 'bert':
        model_specs = {'MODEL_TYPE': 'bert',
                       'CLS_TOKEN': '[CLS]',
                       'UNK_TOKEN': '[UNK]',
                       'SEP_TOKEN': '[SEP]',
                       'TOKEN_CORRECTION': 4}
    else:
        exit(1)
    
    total_cnt = 0
    too_long_cnt = 0
    
    features = []
    # Convert single example
    
    for (example_index, example) in enumerate(examples):
        total_cnt += 1
        class_label_dict = {}  # class loss
        span_label_dict = {} # span loss
        
        tokens_a = example.text_a
        tokens_b = example.text_b
        tokens_hst = example.history
        input_text_too_long = _truncate_length_and_warn(
            tokens_a, tokens_b, tokens_hst, max_seq_length, model_specs, example.guid)
        
        for slot in slot_list:
            token_labels_a = example.text_a_label[slot]
            token_labels_b = example.text_b_label[slot]
            token_labels_history = example.history_label[slot]
            if input_text_too_long:
                if example_index < 10:
                    if len(token_labels_a) > len(tokens_a):
                        logger.info('    tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
                    if len(token_labels_b) > len(tokens_b):
                        logger.info('    tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
                    if len(token_labels_history) > len(tokens_hst):
                        logger.info(
                            '    tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_hst):]))
                
                token_labels_a = token_labels_a[:len(tokens_a)]
                token_labels_b = token_labels_b[:len(tokens_b)]
                token_labels_history = token_labels_history[:len(tokens_hst)]
            
            assert len(token_labels_a) == len(tokens_a)
            assert len(token_labels_b) == len(tokens_b)
            assert len(token_labels_history) == len(tokens_hst)
            token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length,
                                                   model_specs)
            
            span_label_dict[slot] = copy.deepcopy(token_label_ids)
            class_label_dict[slot] = class_types.index(example.class_label[slot])
        
        if input_text_too_long:
            too_long_cnt += 1
        
        
        
        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
                                                                            tokens_b,
                                                                            tokens_hst,
                                                                            max_seq_length,
                                                                            tokenizer,
                                                                            model_specs)
        # if example_index < 10:
        logger.info("*** Example ***")
        logger.info("guid: %s" % (example.guid))
        logger.info("tokens: %s" % " ".join(tokens))
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        logger.info("span_label_id: %s" % str(span_label_dict))
        logger.info("class_label_id: %s" % str(class_label_dict))
        
        
        features.append(
            InputFeatures(
                guid=example.guid,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                span_label_id=copy.deepcopy(span_label_dict),
                class_label_id=copy.deepcopy(class_label_dict)
            ))
    logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
    
    return features


def schema_features(slot_list, tokenizer, model_type, max_seq_length):
    if model_type == 'bert':
        model_specs = {'MODEL_TYPE': 'bert',
                       'CLS_TOKEN': '[CLS]',
                       'UNK_TOKEN': '[UNK]',
                       'SEP_TOKEN': '[SEP]',
                       'TOKEN_CORRECTION': 4}
    else:
        exit(1)
    schema_feat_dict = {}
    for slot in slot_list:
        s = slot.replace('-', ' ').replace('leaveat', 'leave time').replace('by', ' time').\
            replace('book', ' booked ').replace('pricerange', 'price range').replace('internet', 'internet wifi')
        s_tok = ['<%s>'%slot] + tokenizer.tokenize(s)
        tokens = []
        segment_ids = []
        tokens.append(model_specs['CLS_TOKEN'])
        segment_ids.append(0)
        for token in s_tok:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(0)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        
        print(tokens)
        print(input_ids)
        print(input_mask)
        print(segment_ids)
        
        schema_feat_dict[slot] = copy.deepcopy({
            'input_ids': input_ids,
            'input_mask':input_mask,
            'segment_ids': segment_ids
        })
    return schema_feat_dict


def mlm_to_examples(uttop, tokenizer, swap_utterances=True, use_nl_sys=True):
    # cache_readable_data 包含了 delex sentences
    examples = []
    for file in [
        'cache_dial_data_train.json',
        'cache_dial_data_val.json',
        'cache_dial_data_train.json']:
        with open(file) as f: dialogs = json.load(f)

        for dial_id, dialog in dialogs.items():
            sys_utt_tok = []
            usr_utt_tok = []
            hst_utt_tok = []
            
            for turn_id, turn in enumerate(dialog):
                if swap_utterances:
                    hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
                else:
                    hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
                
                guid = turn['dial_loc']
                delex_usr_norm = turn['delex_usr_norm']
                delex_sys_norm = turn['delex_sys_norm']
                delex_sys_sp_dic = turn['delex_sys_sp_dic']
                delex_usr_sp_dic = turn['delex_usr_sp_dic']
                if delex_sys_norm:
                    if use_nl_sys:
                        sys_utt_tok, _, _ = uttop.lex(delex_sys_norm, delex_sys_sp_dic, onto=None, use_aug=False)
                    else:
                        sys_utt_tok = tokenizer.tokenize(delex_sys_norm)
                else:
                    sys_utt_tok = []
                
                usr_utt_tok, _, _ = uttop.lex(delex_usr_norm, delex_usr_sp_dic, onto=None, use_aug=False)
                if swap_utterances:
                    txt_a = usr_utt_tok
                    txt_b = sys_utt_tok
                else:
                    txt_a = sys_utt_tok
                    txt_b = usr_utt_tok
                examples.append(PreviewExample(
                    guid=guid,
                    text_a=txt_a,
                    text_b=txt_b,
                    history=hst_utt_tok,
                    text_a_label=None,
                    text_b_label=None,
                    history_label=None,
                    class_label=None))
    random.shuffle(examples)
    return examples


def mlm_examples_to_features(examples, model_type, tokenizer, max_seq_length):
    def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
        while True:
            total_length = len(tokens_a) + len(tokens_b) + len(history)
            if total_length <= max_length:
                break
            if len(history) > 0:
                history.pop()
            elif len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()
    
    def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
        if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
            # logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
            input_text_too_long = True
        else:
            input_text_too_long = False
        _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
        return input_text_too_long
    
    def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
        tokens = []
        segment_ids = []
        tokens.append(model_specs['CLS_TOKEN'])
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(0)
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        for token in history:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        return tokens, input_ids, input_mask, segment_ids
    
    if model_type == 'bert':
        model_specs = {'MODEL_TYPE': 'bert',
                       'CLS_TOKEN': '[CLS]',
                       'UNK_TOKEN': '[UNK]',
                       'SEP_TOKEN': '[SEP]',
                       'TOKEN_CORRECTION': 4}
    else:
        logger.error("Unknown model type (%s). Aborting." % (model_type))
        exit(1)
    
    total_cnt = 0
    too_long_cnt = 0
    
    
    features = []
    # Convert single example
    
    for (example_index, example) in enumerate(examples):
        if example_index % 1000 == 0:
            logger.info("Writing example %d of %d" % (example_index, len(examples)))
        total_cnt += 1
        
        tokens_a = example.text_a
        tokens_b = example.text_b
        tokens_history = example.history
        input_text_too_long = _truncate_length_and_warn(
            tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
        
        
        if input_text_too_long:
            too_long_cnt += 1
        
        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
                                                                            tokens_b,
                                                                            tokens_history,
                                                                            max_seq_length,
                                                                            tokenizer,
                                                                            model_specs)
        if example_index < 10:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(tokens))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        
        # for k in start_pos_dict:
        #     print(k, tokens[start_pos_dict[k]:end_pos_dict[k]+1])
        #     print()
        
        features.append(
            InputFeatures(
                guid=example.guid,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
            ))  # turn-level CLS gate label e.g. {'taxi-leaveAt': 'copy_value', 'taxi-departure': 'refer'}  # TODO 有用的
    
    logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
    
    return features






if __name__ == '__main__':
    from transformers import BertTokenizer as tokenizer_class
    from utils import SPEC_TOKENS, REQUEST_SLOTS, UtterOp, CLASS_TYPES
    from ontology import Ontology
    model_path = '../download_models/dialoglue/bert'
    special_tokens_dict = {'additional_special_tokens': SPEC_TOKENS + list(REQUEST_SLOTS.values())}
    tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=True, do_basic_tokenize=True)
    tokenizer.add_special_tokens(special_tokens_dict)  # 记得 TODO model.resize
    uttop = UtterOp(SPEC_TOKENS, tokenizer)
    onto = Ontology()
    
    # A= torch.tensor([[101, 12, 124, 1231, 243,1341, 102, 123,443, 42,251, 0 ,0 ,0 ,0, 0],
    #             [101, 12, 124, 1231, 243,1341, 102, 123,443, 42,251, 0 ,0 ,0 ,0, 0]])
    # mask_tokens(
    #     inputs=A,
    #     tokenizer=tokenizer,
    #     mlm_probability=0.5
    # )
    #
    
    examples = mlm_to_examples(uttop, tokenizer, use_nl_sys=False)
    features = mlm_examples_to_features(examples, 'bert', tokenizer, max_seq_length=384)
    torch.save(features, 'mlm_data_feature_sysdelex')

    # features = torch.load('mlm_data_feature')
    # print('len:', len(features))  # 71513
    
    # examples = preview_convert_to_examples(ALL_SLOTS, uttop)
    # features = preview_examples_to_features(examples,
    #                              slot_list=ALL_SLOTS,
    #                              class_types=CLASS_TYPES,
    #                              model_type='bert',
    #                              tokenizer=tokenizer,
    #                              max_seq_length=100)
    # print('len:', len(features)) # 10000
    # torch.save(features, 'preview_data_feature')

    # schema_feat_dict = schema_features(ALL_SLOTS, tokenizer, model_type='bert', max_seq_length=7)
    # torch.save(schema_feat_dict, 'schema_feat_dict')
