import os
import sys
sys.path.append('.')
from os.path import join
from common.utils import read_json, dump_json, load_bin, dump_to_bin
from collections import OrderedDict
from types import SimpleNamespace
from transformers import AutoTokenizer
import torch
from calib_exp.run_tagger import load_cached_dataset
import argparse
from common.index_feature import IndexedFeature, FeatureVocab
import numpy as np
from tqdm import tqdm
import string
import re
from collections import Counter
from calib_exp.run_tagger import load_cached_dataset

def _parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    args = parser.parse_args()
    args.split = 'addsent-dev' if args.dataset == 'squad' else 'dev'
    return args

def main():
    args = _parse_args()
    features, examples = load_cached_dataset(args.dataset)
    tokenizer = AutoTokenizer.from_pretrained('roberta-base',do_lower_case=False,cache_dir='hf_cache')
    tagger_info = load_bin('misc/{}_{}_tag_info.bin'.format(args.dataset, args.split))
    preds_info = read_json('misc/{}_{}_predictions.json'.format(args.split, args.dataset))
    
    for feat in features[:100]:
        qas_id = feat.qas_id
        tags = tagger_info[qas_id]
        print(qas_id)
        print(tags['tags'])

if __name__ == "__main__":
    main()
    