import json
import pickle
import pprint, re
import copy, os
import random
from collections import defaultdict

import torch

from utils import  ALL_SLOTS
import logging
from fuzzywuzzy import fuzz

from input_data import DSTExample, fill, InputFeatures

# logging.basicConfig(filename='attend_data.log', level=logging.INFO, filemode='w')
logger = logging.getLogger(__name__)

class DiffcultyMeasurer:
    def __init__(self, slot_list, origin_data, pred_data_list, pred_file_list,
                 tokenizer=None, uttop=None):
        self.T = 35490
        self.alpha = 0.9
        self.weight = 1.0 / len(pred_data_list)
        
        self.slot_list = slot_list
        self.tokenizer = tokenizer
        self.uttop = uttop
        self.origin_data = origin_data
        
        self.save_sorted_examples(pred_data_list, pred_file_list)
        
        features = self.convert_examples_to_features(load_pkl='sorted_EXAMPLES.pkl')
        torch.save(features, 'attend/curriculum_data_mix')
        
        # total_len = len(features) #
        # print(total_len)
        # self.nBucket = 10
        # os.makedirs('attend/curriculum_data_old', exist_ok=True)
        # for n in range(1, self.nBucket+1):
        #     if n != self.nBucket:
        #         feats = features[:len(features)//self.nBucket*n]
        #         print('len feat:', len(feats))
        #         random.shuffle(feats)
        #         torch.save(feats, 'attend/curriculum_data_old/curriculum_feat_train_%d'%(len(features)//self.nBucket*n))
        #     else:
        #         feats = features
        #         print('len feat:', len(feats))
        #         random.shuffle(feats)
        #         torch.save(feats, 'attend/curriculum_data_old/curriculum_feat_train_56767')
        
    def convert_examples_to_features(self, load_pkl, model_type='bert', max_seq_length=180):
        # 课程学习 划分
        def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
            while True:
                total_length = len(tokens_a) + len(tokens_b) + len(history)
                if total_length <= max_length:
                    break
                if len(history) > 0:
                    history.pop()
                elif len(tokens_a) > len(tokens_b):
                    tokens_a.pop()
                else:
                    tokens_b.pop()
    
        def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
            if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
                # logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
                input_text_too_long = True
            else:
                input_text_too_long = False
            _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
            return input_text_too_long
    
        def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
            token_label_ids = []
            token_label_ids.append(0)  # [CLS]
            for token_label in token_labels_a:
                token_label_ids.append(token_label)
            token_label_ids.append(0)  # [SEP]
            for token_label in token_labels_b:
                token_label_ids.append(token_label)
            token_label_ids.append(0)  # [SEP]
            for token_label in token_labels_history:
                token_label_ids.append(token_label)
            token_label_ids.append(0)  # [SEP]
            while len(token_label_ids) < max_seq_length:
                token_label_ids.append(0)  # padding
            assert len(token_label_ids) == max_seq_length
            return token_label_ids
    
        def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
            if class_type == 'copy_value' and 1 not in token_label_ids:
                class_type = 'none'
            start_pos = 0
            end_pos = 0
            if 1 in token_label_ids:
                start_pos = token_label_ids.index(1)
                # Parsing is supposed to find only first location of wanted value
                if 0 not in token_label_ids[start_pos:]:
                    end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
                else:
                    end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
                for i in range(max_seq_length):
                    if i >= start_pos and i <= end_pos:
                        assert token_label_ids[i] == 1
            return class_type, start_pos, end_pos
    
        def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
            tokens = []
            segment_ids = []
            tokens.append(model_specs['CLS_TOKEN'])
            segment_ids.append(0)
            for token in tokens_a:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append(model_specs['SEP_TOKEN'])
            segment_ids.append(0)
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append(model_specs['SEP_TOKEN'])
            segment_ids.append(1)
            for token in history:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append(model_specs['SEP_TOKEN'])
            segment_ids.append(1)
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)
            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            return tokens, input_ids, input_mask, segment_ids

        class_types = ["none", "dontcare", "copy_value", "true", "false", "refer", "inform"]
        
        if model_type == 'bert':
            model_specs = {'MODEL_TYPE': 'bert',
                           'CLS_TOKEN': '[CLS]',
                           'UNK_TOKEN': '[UNK]',
                           'SEP_TOKEN': '[SEP]',
                           'TOKEN_CORRECTION': 4}
        else:
            logger.error("Unknown model type (%s). Aborting." % (model_type))
            exit(1)
        
        with open(load_pkl, 'rb') as f:
            examples = pickle.load(f)
        
        # with open('attend/aug_examples.pkl', 'rb') as f:
        #     aug_example = pickle.load(f)

        # examples += aug_example
        
        # for x in examples:
        #     noise = x.noise
        #     class_label = {k:v for k, v in x.class_label.items() if v != 'none'}
        #     logging.info("%s %0.4f %s" %(
        #         x.guid, x.avg_score, str(class_label)))
        #     # if not class_label:
        #     #     logging.info("%s"%str(x.score_dic))
        # exit(0)
        
        total_cnt = 0
        too_long_cnt = 0

        refer_list = ['none'] + self.slot_list
        features = []
        # Convert single example

        for (example_index, example) in enumerate(examples):
            # if example.guid != 'train-PMUL4080.json-2': continue
            if example_index % 1000 == 0:
                logger.info("Writing example %d of %d" % (example_index, len(examples)))
            total_cnt += 1
            value_dict = {}
            inform_dict = {}  # sys-mentioned inform slot-value
            inform_slot_dict = {}  # for auxiliary inform label
            refer_id_dict = {}  # DS memory label
            diag_state_dict = {}  # for auxiliary ds label
            class_label_id_dict = {}  # slot gate label
            start_pos_dict = {}
            end_pos_dict = {}
            
            tokens_a = example.text_a
            tokens_b = example.text_b
            tokens_history = example.history
            input_text_too_long = _truncate_length_and_warn(
                tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)

            # print(example.guid)
            for slot in self.slot_list:
                token_labels_a = example.text_a_label[slot]
                token_labels_b = example.text_b_label[slot]
                token_labels_history = example.history_label[slot]
                if input_text_too_long:
                    if example_index < 10:
                        if len(token_labels_a) > len(tokens_a):
                            logger.info('    tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
                        if len(token_labels_b) > len(tokens_b):
                            logger.info('    tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
                        if len(token_labels_history) > len(tokens_history):
                            logger.info('    tokens_history truncated labels: %s' % str(
                                token_labels_history[len(tokens_history):]))

                    token_labels_a = token_labels_a[:len(tokens_a)]
                    token_labels_b = token_labels_b[:len(tokens_b)]
                    token_labels_history = token_labels_history[:len(tokens_history)]

                assert len(token_labels_a) == len(tokens_a)
                assert len(token_labels_b) == len(tokens_b)
                assert len(token_labels_history) == len(tokens_history)
                token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history,
                                                       max_seq_length,
                                                       model_specs)

                class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
                    example.class_label[slot], token_label_ids, max_seq_length)
                if class_label_mod != example.class_label[slot]: example.class_label[slot] = class_label_mod
                inform_slot_dict[slot] = example.inform_slot_label[slot]
                refer_id_dict[slot] = refer_list.index(
                    example.refer_label[
                        slot])  # N+1 DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
                diag_state_dict[slot] = example.diag_state[slot]
                class_label_id_dict[slot] = class_types.index(example.class_label[slot])

            if input_text_too_long:
                too_long_cnt += 1

            tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
                                                                                tokens_b,
                                                                                tokens_history,
                                                                                max_seq_length,
                                                                                self.tokenizer,
                                                                                model_specs)
            if example_index < 10:
                logger.info("*** Example ***")
                logger.info("guid: %s" % (example.guid))
                logger.info("tokens: %s" % " ".join(tokens))
                logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
                logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
                logger.info("start_pos: %s" % str(start_pos_dict))
                logger.info("end_pos: %s" % str(end_pos_dict))
                logger.info("inform_slot: %s" % str(inform_slot_dict))
                logger.info("refer_id: %s" % str(refer_id_dict))
                logger.info("diag_state: %s" % str(diag_state_dict))
                logger.info("class_label_id: %s" % str(class_label_id_dict))

            # for k in start_pos_dict:
            #     print(k, tokens[start_pos_dict[k]:end_pos_dict[k]+1])
            #     print()

            features.append(
                InputFeatures(
                    guid=example.guid,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    start_pos=start_pos_dict,
                    end_pos=end_pos_dict,
                    inform_slot=inform_slot_dict,  # auxiliary label inform   e.g. {'attraction-area': 1}
                    refer_id=refer_id_dict,  # DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
                    diag_state=diag_state_dict,
                    # last dialog state （for  auxiliary label ds）  e.g. 'hotel-name': 'copy_value', 'hotel-area': 'copy_value' # TODO 有用的
                    class_label_id=class_label_id_dict))  # turn-level CLS gate label    e.g. {'taxi-leaveAt': 'copy_value', 'taxi-departure': 'refer'}  # TODO 有用的

        logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))

        return features
        

    def save_sorted_examples(self, pred_data_list, pred_file_list):
        self.pred_data_dic = defaultdict(list)
        for pred_data, pred_file in zip(pred_data_list, pred_file_list):
            for item in pred_data:
                guid = item['guid']
                item['pred_file'] = pred_file
                self.pred_data_dic[guid].append(copy.deepcopy(item))
        print('pred files loaded...')
        
        EXAMPLES = []
        
        for dial_id, dialog in self.origin_data.items():
            # if 'PMUL0319.json' not in dial_id: continue  # PMUL0700.json-7
            examples = self.parse_dialog(dialog, self.tokenizer, self.uttop, self.slot_list)
            EXAMPLES.extend(examples)
        # print(input())
        print('EXAMPLES processed...')
        
        # 排序
        print('sorting...')
        EXAMPLES = sorted(EXAMPLES, key=lambda x: x.avg_score, reverse=True)  # easy to hard
        print('sorted...')
        # print([x.guid for x in EXAMPLES[:100]])
        print('len:', len(EXAMPLES)) # 56767
        
        with open('sorted_EXAMPLES.pkl', 'wb') as f:
            pickle.dump(EXAMPLES, f)

        
    def parse_dialog(self, dialog, tokenizer, uttop, slot_list,
                     swap_utterances=True, use_history_labels=True):
        examples = []
        sys_utt_tok = []
        usr_utt_tok = []
        hst_utt_tok = []
        hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
    
        for turn_id, turn in enumerate(dialog):
            
            if swap_utterances:
                hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
            else:
                hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
            guid = turn['dial_loc']
            
            # if guid != 'train-MUL0343.json-9': continue
            
            delex_usr_norm = turn['delex_usr_norm']
            delex_sys_norm = turn['delex_sys_norm']
            delex_sys_sp_dic = turn['delex_sys_sp_dic']
            delex_usr_sp_dic = turn['delex_usr_sp_dic']
            belief_state = turn['belief_state']
            # informed_names = turn['informed_names']  # 记得对应地加上
        
            turn_label = turn['turn_label']
            class_label = turn_label['class_label']
            refer_label = turn_label['refer_label']
            inform_label = turn_label['inform_label']
            dial_state_aux = turn_label['dial_state_aux']
            inform_slot_axu = {k.strip('<').strip('>'): v for k, v in turn_label['inform_slot_axu'].items()}
            for k in class_label:
                if 'name' in k and class_label[k] == 'inform':
                    inform_slot_axu[k] = 1
        
            class_label = fill(slot_list, class_label)
            refer_label = fill(slot_list, refer_label)
            inform_label = fill(slot_list, inform_label)
            inform_slot_axu = fill(slot_list, inform_slot_axu, True)
            dial_state_aux = fill(slot_list, dial_state_aux, True)
        
            if delex_sys_norm:
                sys_utt_tok = tokenizer.tokenize(delex_sys_norm)
            else:
                sys_utt_tok = []
        
            usr_utt_tok, usr_span_dic, _ = uttop.lex(delex_usr_norm, delex_usr_sp_dic, onto=None, use_aug=False)
            usr_utt_tok_label_dict = {slot: [0] * len(usr_utt_tok) for slot in slot_list}
            sys_utt_tok_label_dict = {slot: [0] * len(sys_utt_tok) for slot in slot_list}
            for k, pos in usr_span_dic.items():
                if not pos: continue
                b, e = pos[-1]
                vec = [0] * len(usr_utt_tok)
                vec[b: e] = [1] * (e - b)
                usr_utt_tok_label_dict[k.strip('<').strip('>')] = vec.copy()
        
            if swap_utterances:
                txt_a = usr_utt_tok
                txt_b = sys_utt_tok
                txt_a_lbl = usr_utt_tok_label_dict
                txt_b_lbl = sys_utt_tok_label_dict
            else:
                txt_a = sys_utt_tok
                txt_b = usr_utt_tok
                txt_a_lbl = sys_utt_tok_label_dict
                txt_b_lbl = usr_utt_tok_label_dict
        
            for k in usr_utt_tok_label_dict:
                assert len(usr_utt_tok_label_dict[k]) == len(usr_utt_tok)
        
            new_hst_utt_tok_label_dict = copy.deepcopy(hst_utt_tok_label_dict)
        
            for slot in slot_list:
                if use_history_labels:
                    if swap_utterances:
                        new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label_dict[slot] + [0] * len(sys_utt_tok) + \
                                                           new_hst_utt_tok_label_dict[slot]
                    else:
                        new_hst_utt_tok_label_dict[slot] = [0] * len(sys_utt_tok) + usr_utt_tok_label_dict[slot] + \
                                                           new_hst_utt_tok_label_dict[slot]
                else:
                    new_hst_utt_tok_label_dict[slot] = [0 for _ in
                                                        sys_utt_tok + usr_utt_tok + new_hst_utt_tok_label_dict[slot]]
            
            
            # pprint.pprint(self.pred_data_dic[guid])
            # pprint.pprint(turn)
            
            m_score_dic = self.model_score(turn, self.pred_data_dic[guid])
            r_score = self.rule_score(turn, turn_id in [len(dialog)-1,
                                                        len(dialog)-2,
                                                        len(dialog)-3,
                                                        len(dialog)-4])
            # pprint.pprint(m_score_dic)
            # print(r_score)
            
            
            score_dic = {}
            mean_score_set = set()
            noise_dic = {}
            for s in slot_list:
                score_dic[s] = (r_score, m_score_dic[s], 0.3 * r_score + 0.7 * m_score_dic[s])
                # mean_score_set.add(0.3 * r_score + 0.7 * m_score_dic[s])
                mean_score_set.add(r_score)
                noise_dic[s] = m_score_dic[s] == -1  # all pred wrong
            
            mean_score = sum(mean_score_set) / len(mean_score_set)
            
            examples.append(DSTExample(
                guid=guid,
                text_a=txt_a,
                text_b=txt_b,
                history=hst_utt_tok,
                text_a_label=txt_a_lbl,
                text_b_label=txt_b_lbl,
                history_label=copy.deepcopy(hst_utt_tok_label_dict),
                values=belief_state,
                inform_label=inform_label,  # sys utter  slot-value     e.g. {'hotel-name': 'finches bed and breakfast'}
                inform_slot_label=inform_slot_axu,  # auxiliary label inform   e.g. {'attraction-area': 1}
                refer_label=refer_label,  # DS memory slotA :slotB   e.g. {'taxi-departure': 'attraction-name'}
                diag_state=dial_state_aux,
                class_label=class_label,
                score_dic=score_dic,
                avg_score=mean_score,
                noise=noise_dic))
            
            # print(examples[-1])
            # print(m_score_dic)
            # print(r_score)
            # print(mean_score_set)
            # print(noise_dic)
            # print('---------')

            hst_utt_tok_label_dict = copy.deepcopy(new_hst_utt_tok_label_dict)
            
        return examples
        
    
    @staticmethod
    def wp_recover(value):
        value = re.sub("(^| )##", "", value)
        value = re.sub(r"\s(:|'|-)\s", r"\1", value)
        return value
    
    
    def func(self, t):
        return 1
        # t = int(t)
        # return (1-self.alpha) * min(1.0, t / self.T) + self.alpha
    
    
    def similar(self, str1, str2):
        str1 = str1.replace(' ', '')
        str2 = str2.replace(' ', '')
        if fuzz.WRatio(str1, str2)/100 > 0.9: return True
        else: return False
        
    
    def model_score(self, turn_data, pred_datas):
        output_dic = {}
        gt_class_label = turn_data['turn_label']['class_label']
        delex_usr_norm = turn_data['delex_usr_norm']
        
        
        redundant_class_label = []
        for k, v in gt_class_label.items():
            if v == 'copy_value' and k not in delex_usr_norm:
                redundant_class_label.append(k)
        
        gt_span_label = {k.strip('<').strip('>'): v for k, v in turn_data['delex_usr_sp_dic'].items()}
        
        # print('gt_span_label', gt_span_label)
        gt_refer_label = turn_data['turn_label']['refer_label']
        
        for s in self.slot_list:
            score_list = []
            if s not in gt_class_label:
                for pred in pred_datas:
                    score = self.func(pred['pred_file'].split('.')[-2])
                    if s not in pred['class_label']: score_list.append(score)
                    else: score_list.append(-score)
            elif gt_class_label[s] == 'copy_value':
                for pred in pred_datas:
                    score = self.func(pred['pred_file'].split('.')[-2])
                    if s not in pred['class_label']: score_list.append(-score)
                    elif pred['class_label'][s] != 'copy_value': score_list.append(-score)
                    else:
                        pd_an = onto.normalize_label(s, self.wp_recover(pred['span_label'][s]))
                        gt_ans = [onto.normalize_label(s, item) for item in gt_span_label[s]]
                        # print('ans:', pd_an, gt_ans)
                        flag = False
                        for gt_an in gt_ans:
                            if self.similar(pd_an, gt_an):
                                flag = True
                                break
                        if flag: score_list.append(score)
                        else: score_list.append(-score)
                        
                    # elif not self.similar(self.wp_recover(pred['span_label'][s]), gt_span_label[s][0]):
                    # else: score_list.append(score)
            else:
                for pred in pred_datas:
                    score = self.func(pred['pred_file'].split('.')[-2])
                    if s not in pred['class_label']: score_list.append(-score)
                    elif pred['class_label'][s] != gt_class_label[s]: score_list.append(-score)
                    else: score_list.append(score)

            mean_score = sum(score_list) / len(score_list)
            if s in ['hotel-type', 'hotel-internet', 'hotel-parking'] and mean_score == -1:
                mean_score = random.uniform(0, 1)
            if s in redundant_class_label and mean_score == -1:
                mean_score = random.uniform(0, 1)
            output_dic[s] = mean_score
            
        return output_dic
    
    
    def rule_score(self, turn_data, is_last=False):
        # 单领域 -> 多领域  # 涉及槽位数目
        # 轮次长度 # dontcare
        # 分数越高 越难 (0, 3.5)
        
        MAXTURN, MAXSLOT, MAXVALUE, MAXDOM, MAXLEN = 7, 6, 6, 4, 50
        dial_loc = turn_data['dial_loc']
        turn_num = int(dial_loc.split('-')[-1])

        delex_usr_norm = turn_data['delex_usr_norm']
        delex_sys_norm = turn_data['delex_sys_norm']
        length = len(delex_usr_norm.split()) + len(delex_sys_norm.split())
        
        class_label = turn_data['turn_label']['class_label']
        
        slots = list(class_label.keys())
        values = list(class_label.values())

        dontcare_num = len([v for v in values if v == 'dontcare'])
        refer_num = len([v for v in values if v == 'refer'])
        
        # print(turn_data)
        # print(dontcare_num, refer_num)
        
        slot_num = len(slots)
        dom_num = len(set([s.split('-')[0] for s in slots]))
        
        if 'SNG' in dial_loc: dom_score = 0.0
        else: dom_score = 1.0
        
        # print(0.5 * min(1.0, turn_num/MAXTURN))
        # print(1.0 * min(1.0, slot_num/MAXSLOT))
        # print(0.3 * min(1.0, dom_num/MAXDOM))
        # print(0.2 * min(1.0, length/MAXLEN))
        
        score = 0.5 * min(1.0, turn_num/MAXTURN) + \
                1.0 * min(1.0, slot_num/MAXSLOT) + \
                0.3 * min(1.0, dom_num/MAXDOM) + \
                0.2 * min(1.0, length/MAXLEN) # + \
                # 0.25 * min(1.0, dontcare_num/3) + \
                # 0.25 * min(1.0, refer_num/3)
        
        if slot_num == 0 and (is_last or 'thank' in delex_usr_norm):
            score -= 0.25 * min(1.0, turn_num/MAXTURN)
        score = 1 - 2*score / 2
        return score


if __name__ == '__main__':
    from utils import ALL_SLOTS, SPEC_TOKENS, REQUEST_SLOTS, UtterOp
    from transformers import BertTokenizer as tokenizer_class
    from ontology import Ontology

    model_path = '../download_models/dialoglue/bert'
    special_tokens_dict = {'additional_special_tokens': SPEC_TOKENS + list(REQUEST_SLOTS.values())}
    tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=True, do_basic_tokenize=True)
    tokenizer.add_special_tokens(special_tokens_dict)  # 记得 TODO model.resize
    uttop = UtterOp(SPEC_TOKENS, tokenizer)
    onto = Ontology()

    
    DIR = './attend'
    with open(os.path.join(DIR, 'cache_dial_data_train.json')) as f:
        cache_dial_data_train = json.load(f)
    
    # take avg of six models results
    pred_file_list = [
        os.path.join(DIR, 'results/bert/pred_res.train.11830.json'),
        os.path.join(DIR, 'results/bert/pred_res.train.21294.json'),
        os.path.join(DIR, 'results/preview_ep3_1/pred_res.train.30758.json'),
        os.path.join(DIR, 'results/preview_ep3_1/pred_res.train.33124.json'),
        os.path.join(DIR, 'results/preview_step12270/pred_res.train.11830.json'),
        os.path.join(DIR, 'results/preview_step12270/pred_res.train.21294.json'),
    ]

    pred_data_list = []
    for file in pred_file_list:
        with open(file) as f:
            pred_data_list.append(json.load(f))

    diff_measurer = DiffcultyMeasurer(
        slot_list=ALL_SLOTS,
        origin_data=cache_dial_data_train,
        pred_data_list=pred_data_list,
        pred_file_list=pred_file_list,
        tokenizer=tokenizer,
        uttop=uttop)

    
    # with open('../cache_dial_data_train.json') as f:
    #     data = json.load(f)
    #
    # for dial_id, dialog in data.items():
    #     if dial_id == 'PMUL3081.json':
    #         pprint.pprint(dialog)




        
    
        
        



