# turn dic data into tensor data
import logging
import json

from utils import ALL_SLOTS, SPEC_TOKENS, REQUEST_SLOTS, UtterOp

from ontology import Ontology


# logging.basicConfig(filename='input_data.log', level=logging.INFO, filemode='w')
logger = logging.getLogger(__name__)


class DSTExample:
    """
    A single training/test example for the DST dataset.
    """
    
    def __init__(self,
                 guid,
                 text_a,
                 text_b,
                 history,
                 text_a_label=None,
                 text_b_label=None,
                 history_label=None,
                 values=None,
                 inform_label=None,
                 inform_slot_label=None,
                 refer_label=None,
                 diag_state=None,
                 class_label=None,
                 score_dic=None,
                 avg_score=None,  # binary
                 noise=None   # binary
                 ):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.history = history
        self.text_a_label = text_a_label
        self.text_b_label = text_b_label
        self.history_label = history_label
        self.values = values
        self.inform_label = inform_label
        self.inform_slot_label = inform_slot_label
        self.refer_label = refer_label
        self.diag_state = diag_state
        self.class_label = class_label
        self.score_dic = score_dic
        self.avg_score = avg_score
        self.noise = noise
    
    def __str__(self):
        return self.__repr__()
    
    def __repr__(self):
        s = ""
        s += "guid: %s\n" % (self.guid)
        s += "text_a: %s\n" % str(self.text_a)
        s += "text_b: %s\n" % str(self.text_b)
        s += "history: %s\n" % str(self.history)
        if self.text_a_label:
            s += "text_a_label: %s\n" % str(self.simple(self.text_a_label))
        if self.text_b_label:
            s += "text_b_label: %s\n" % str(self.simple(self.text_b_label))
        if self.history_label:
            s += "history_label: %s\n" % str(self.simple(self.history_label))
        if self.values:
            s += "values: %s\n" % str(self.simple(self.values))
        if self.inform_label:
            s += "inform_label: %s\n" % str(self.simple(self.inform_label))
        if self.inform_slot_label:
            s += "inform_slot_label: %s\n" % str(self.simple(self.inform_slot_label))
        if self.refer_label:
            s += "refer_label: %s\n" % str(self.simple(self.refer_label))
        if self.diag_state:
            s += "diag_state: %s\n" % str(self.simple(self.diag_state))
        if self.class_label:
            s += "class_label: %s\n\n" % str(self.simple(self.class_label))
        return s
    
    def simple(self, dic):
        output_dic = {}
        for k in dic:
            if isinstance(dic[k], list):
                if dic[k] and isinstance(dic[k][0], int):
                    if sum(dic[k]) > 0:
                        output_dic[k] = dic[k]
                elif dic[k]:
                    output_dic[k] = dic[k]
            elif isinstance(dic[k], str):
                if dic[k] != 'none':
                    output_dic[k] = dic[k]
            elif isinstance(dic[k], int):
                if dic[k] != 0:
                    output_dic[k] = dic[k]
        return output_dic

class InputFeatures:
    """A single set of features of data."""

    def __init__(self,
                 input_ids,
                 input_mask,
                 segment_ids,
                 start_pos=None,
                 end_pos=None,
                 values=None,
                 inform=None,
                 inform_slot=None,
                 refer_id=None,
                 diag_state=None,
                 class_label_id=None,
                 guid="NONE"):
        self.guid = guid
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_pos = start_pos
        self.end_pos = end_pos
        self.values = values
        self.inform = inform
        self.inform_slot = inform_slot
        self.refer_id = refer_id
        self.diag_state = diag_state
        self.class_label_id = class_label_id

def fill(slot_list, in_dic, is_binary=False):
    out_dic = {}
    for s in slot_list:
        if s in in_dic:
            if is_binary: out_dic[s] = 1
            else: out_dic[s] = in_dic[s]
        else:
            if is_binary: out_dic[s] = 0
            else: out_dic[s] = 'none'
    return out_dic


def convert_readable_cache_to_examples(cache_readable_data_file, slot_list, tokenizer, uttop,
                                       swap_utterances=True,
                                       use_history_labels=True):
    # cache_readable_data 包含了 delex sentences
    with open(cache_readable_data_file) as f:
        dialogs = json.load(f)

    examples = []
    
    for dial_id, dialog in dialogs.items():
        sys_utt_tok = []
        usr_utt_tok = []
        hst_utt_tok = []
        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
        
        for turn_id, turn in enumerate(dialog):
            if swap_utterances:
                hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
            else:
                hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
                
            guid = turn['dial_loc']
            delex_usr_norm = turn['delex_usr_norm']
            delex_sys_norm = turn['delex_sys_norm']
            delex_sys_sp_dic = turn['delex_sys_sp_dic']
            delex_usr_sp_dic = turn['delex_usr_sp_dic']
            belief_state = turn['belief_state']
            # informed_names = turn['informed_names']  # 记得对应地加上
            
            turn_label = turn['turn_label']
            class_label = turn_label['class_label']
            refer_label = turn_label['refer_label']
            inform_label = turn_label['inform_label']
            dial_state_aux = turn_label['dial_state_aux']
            inform_slot_axu = {k.strip('<').strip('>'):v for k,v in turn_label['inform_slot_axu'].items()}
            for k in class_label:
                if 'name' in k and class_label[k] == 'inform':
                    inform_slot_axu[k] = 1

            class_label = fill(slot_list, class_label)
            refer_label = fill(slot_list, refer_label)
            inform_label = fill(slot_list, inform_label)
            inform_slot_axu = fill(slot_list, inform_slot_axu, True)
            dial_state_aux = fill(slot_list, dial_state_aux, True)
            
            if delex_sys_norm: sys_utt_tok = tokenizer.tokenize(delex_sys_norm)
            else: sys_utt_tok = []
            
            usr_utt_tok, usr_span_dic, _ = uttop.lex(delex_usr_norm, delex_usr_sp_dic, onto=None, use_aug=False)
            usr_utt_tok_label_dict = {slot: [0]*len(usr_utt_tok) for slot in slot_list}
            sys_utt_tok_label_dict = {slot: [0]*len(sys_utt_tok) for slot in slot_list}
            for k, pos in usr_span_dic.items():
                if not pos: continue
                b, e = pos[-1]
                vec = [0] * len(usr_utt_tok)
                vec[b: e] = [1] * (e-b)
                usr_utt_tok_label_dict[k.strip('<').strip('>')] = vec.copy()
            
            if swap_utterances:
                txt_a = usr_utt_tok
                txt_b = sys_utt_tok
                txt_a_lbl = usr_utt_tok_label_dict
                txt_b_lbl = sys_utt_tok_label_dict
            else:
                txt_a = sys_utt_tok
                txt_b = usr_utt_tok
                txt_a_lbl = sys_utt_tok_label_dict
                txt_b_lbl = usr_utt_tok_label_dict
            
            for k in usr_utt_tok_label_dict:
                assert len(usr_utt_tok_label_dict[k]) == len(usr_utt_tok)

            new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
            
            for slot in slot_list:
                if use_history_labels:
                    if swap_utterances:
                        new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label_dict[slot] + [0]*len(sys_utt_tok)  + \
                                                           new_hst_utt_tok_label_dict[slot]
                    else:
                        new_hst_utt_tok_label_dict[slot] = [0]*len(sys_utt_tok) + usr_utt_tok_label_dict[slot] + \
                                                           new_hst_utt_tok_label_dict[slot]
                else:
                    new_hst_utt_tok_label_dict[slot] = [0 for _ in
                                                sys_utt_tok + usr_utt_tok + new_hst_utt_tok_label_dict[slot]]
            
            
            if 'only_last' in turn and turn_id < len(dialog)-1:
                pass
            else:
                examples.append(DSTExample(
                guid=guid,
                text_a=txt_a,
                text_b=txt_b,
                history=hst_utt_tok,
                text_a_label=txt_a_lbl,
                text_b_label=txt_b_lbl,
                history_label=hst_utt_tok_label_dict.copy(),
                values=belief_state,
                inform_label=inform_label,  # sys utter  slot-value     e.g. {'hotel-name': 'finches bed and breakfast'}
                inform_slot_label=inform_slot_axu,  # auxiliary label inform   e.g. {'attraction-area': 1}
                refer_label=refer_label,  # DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
                diag_state=dial_state_aux,
                class_label=class_label,
                ))
            if len(examples) > 0 and 'dontcare'in dial_id:
                print(examples[-1])
            # print(input())
            hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
    return examples
    
def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length):
    def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
        while True:
            total_length = len(tokens_a) + len(tokens_b) + len(history)
            if total_length <= max_length:
                break
            if len(history) > 0:
                history.pop()
            elif len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()
    
    def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
        if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
            # logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
            input_text_too_long = True
        else:
            input_text_too_long = False
        _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
        return input_text_too_long
    
    def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
        token_label_ids = []
        token_label_ids.append(0)  # [CLS]
        for token_label in token_labels_a:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        for token_label in token_labels_b:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        for token_label in token_labels_history:
            token_label_ids.append(token_label)
        token_label_ids.append(0)  # [SEP]
        while len(token_label_ids) < max_seq_length:
            token_label_ids.append(0)  # padding
        assert len(token_label_ids) == max_seq_length
        return token_label_ids
    
    def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
        if class_type == 'copy_value' and 1 not in token_label_ids:
            class_type = 'none'
        start_pos = 0
        end_pos = 0
        if 1 in token_label_ids:
            start_pos = token_label_ids.index(1)
            # Parsing is supposed to find only first location of wanted value
            if 0 not in token_label_ids[start_pos:]: end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
            else: end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
            for i in range(max_seq_length):
                if i >= start_pos and i <= end_pos:
                    assert token_label_ids[i] == 1
        return class_type, start_pos, end_pos
    
    def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
        tokens = []
        segment_ids = []
        tokens.append(model_specs['CLS_TOKEN'])
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(0)
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        for token in history:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append(model_specs['SEP_TOKEN'])
        segment_ids.append(1)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        return tokens, input_ids, input_mask, segment_ids
    
    
    if model_type == 'bert':
        model_specs = {'MODEL_TYPE': 'bert',
                       'CLS_TOKEN': '[CLS]',
                       'UNK_TOKEN': '[UNK]',
                       'SEP_TOKEN': '[SEP]',
                       'TOKEN_CORRECTION': 4}
    else:
        logger.error("Unknown model type (%s). Aborting." % (model_type))
        exit(1)
    
        
    total_cnt = 0
    too_long_cnt = 0
    
    refer_list = ['none'] + slot_list
    features = []
    # Convert single example
    
    for (example_index, example) in enumerate(examples):
        if example_index % 1000 == 0:
            logger.info("Writing example %d of %d" % (example_index, len(examples)))
        total_cnt += 1
        value_dict = {}
        inform_dict = {}  # sys-mentioned inform slot-value
        inform_slot_dict = {}  # for auxiliary inform label
        refer_id_dict = {}  # DS memory label
        diag_state_dict = {}  # for auxiliary ds label
        class_label_id_dict = {}  # slot gate label
        start_pos_dict = {}
        end_pos_dict = {}

        tokens_a = example.text_a
        tokens_b = example.text_b
        tokens_history = example.history
        input_text_too_long = _truncate_length_and_warn(
            tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)

        # print(example.guid)
        for slot in slot_list:
            token_labels_a = example.text_a_label[slot]
            token_labels_b = example.text_b_label[slot]
            token_labels_history = example.history_label[slot]
            if input_text_too_long:
                if example_index < 10:
                    if len(token_labels_a) > len(tokens_a):
                        logger.info('    tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
                    if len(token_labels_b) > len(tokens_b):
                        logger.info('    tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
                    if len(token_labels_history) > len(tokens_history):
                        logger.info('    tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))

                token_labels_a = token_labels_a[:len(tokens_a)]
                token_labels_b = token_labels_b[:len(tokens_b)]
                token_labels_history = token_labels_history[:len(tokens_history)]
                
            assert len(token_labels_a) == len(tokens_a)
            assert len(token_labels_b) == len(tokens_b)
            # print(tokens_history)
            # print(token_labels_history)
            assert len(token_labels_history) == len(tokens_history)
            token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length,
                                                   model_specs)
            
            class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
                example.class_label[slot], token_label_ids, max_seq_length)
            if class_label_mod != example.class_label[slot]: example.class_label[slot] = class_label_mod
            inform_slot_dict[slot] = example.inform_slot_label[slot]
            refer_id_dict[slot] = refer_list.index(
                example.refer_label[slot])  # N+1 DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
            diag_state_dict[slot] = example.diag_state[slot]
            class_label_id_dict[slot] = class_types.index(example.class_label[slot])
        
        if input_text_too_long:
            too_long_cnt += 1

        tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
                                                                            tokens_b,
                                                                            tokens_history,
                                                                            max_seq_length,
                                                                            tokenizer,
                                                                            model_specs)
        if example_index < 10:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("tokens: %s" % " ".join(tokens))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("start_pos: %s" % str(start_pos_dict))
            logger.info("end_pos: %s" % str(end_pos_dict))
            logger.info("inform_slot: %s" % str(inform_slot_dict))
            logger.info("refer_id: %s" % str(refer_id_dict))
            logger.info("diag_state: %s" % str(diag_state_dict))
            logger.info("class_label_id: %s" % str(class_label_id_dict))
        
        # for k in start_pos_dict:
        #     print(k, tokens[start_pos_dict[k]:end_pos_dict[k]+1])
        #     print()
        
        features.append(
            InputFeatures(
                guid=example.guid,
                input_ids=input_ids,
                input_mask=input_mask,
                segment_ids=segment_ids,
                start_pos=start_pos_dict,
                end_pos=end_pos_dict,
                inform_slot=inform_slot_dict,  # auxiliary label inform   e.g. {'attraction-area': 1}
                refer_id=refer_id_dict, # DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
                diag_state=diag_state_dict, # last dialog state （for  auxiliary label ds）  e.g. 'hotel-name': 'copy_value', 'hotel-area': 'copy_value' # TODO 有用的
                class_label_id=class_label_id_dict))  # turn-level CLS gate label    e.g. {'taxi-leaveAt': 'copy_value', 'taxi-departure': 'refer'}  # TODO 有用的

    logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))

    return features
            
            



if __name__ == '__main__':
    from transformers import BertTokenizer as tokenizer_class
    import torch
    import pickle

    model_path = '../download_models/dialoglue/bert'
    special_tokens_dict = {'additional_special_tokens': SPEC_TOKENS + list(REQUEST_SLOTS.values())}
    tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=True, do_basic_tokenize=True)
    tokenizer.add_special_tokens(special_tokens_dict)  # 记得 TODO model.resize
    uttop = UtterOp(SPEC_TOKENS, tokenizer)
    onto = Ontology()
    
    examples = convert_readable_cache_to_examples('cache_for_attend.json',
                                                  slot_list=ALL_SLOTS,
                                                  tokenizer=tokenizer,
                                                  uttop=uttop)
    
    class_types = ["none", "dontcare", "copy_value", "true", "false", "refer", "inform"]
    
    with open('attend/aug_examples.pkl', 'wb') as f:
        pickle.dump(examples, f)
    
    # feats = torch.load('attend/curriculum_data_all')
    # print(len(feats))
    # aug_feat = convert_examples_to_features(examples,
    #                              slot_list=ALL_SLOTS,
    #                              class_types=class_types,
    #                              model_type='bert',
    #                              tokenizer=tokenizer,
    #                              max_seq_length=312)
    # feats += aug_feat
    # print(len(feats))
    # torch.save(feats, 'attend/curriculum_data_all_aug')