import argparse
from utils.utils import set_seed, get_relation_args, build_model, model_prefix, \
    load_roberta_vocab, filter_samples_by_vocab, load_json_dic, count_distinct_obj, get_table_stat, store_json_dic
from utils.read_data import LamaDataset
from prettytable import PrettyTable
from tqdm import tqdm
from prompt_based.prompt_utils import average_sampling, store_samples
from transformers import RobertaTokenizer


def generate_data(args):
    set_seed(0)
    args = get_relation_args(args)

    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, relation2samples = lama_data.get_samples()

    table = PrettyTable(
        field_names=[
            "relation_id", "relation_label",
            "lama_num", "wiki_uni_num", "wiki_num",
            "lama_obj", "wiki_uni_obj", "wiki_obj"
        ]
    )
    table.title = "data statistics for {}".format(args.model_name)
    tokenizer, model = build_model(args.model_name)

    if model_prefix(args.model_name) == "bert":
        vocab = tokenizer.vocab
    elif model_prefix(args.model_name) == "roberta":
        vocab = load_roberta_vocab()
    else:
        raise RuntimeError("model error")

    for relation_id in tqdm(id2relation):
        relation_label = id2relation[relation_id]["label"]

        lama = relation2samples[relation_id]
        if model_prefix(args.model_name) == "roberta":
            lama = filter_samples_by_vocab(lama, vocab)[0]
            store_samples(relation_id, "data/roberta_data/lama", lama)

        # 要删掉词典外的
        wiki = load_json_dic("data/wiki/{}".format(relation_id))
        wiki = filter_samples_by_vocab(wiki, vocab)[0]

        target = args.target
        if args.method == "lama_sampling":
            target = len(lama)

        wiki_uni = average_sampling(wiki, method=args.method, threshold=args.threshold, target=target)
        store_samples(relation_id, "data/{}_data/wiki_uni".format(model_prefix(args.model_name)), wiki_uni)

        table.add_row([
            relation_id, relation_label,
            len(lama), len(wiki_uni), len(wiki),
            count_distinct_obj(lama), count_distinct_obj(wiki_uni), count_distinct_obj(wiki)
        ])

    table = get_table_stat(table)
    if args.stat_data:
        print(table)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--relation-type", type=str, default="roberta_original",
                        choices=["lama_original", "lama_mine", "lama_auto",
                                 "roberta_original", "roberta_auto", "roberta_mine"])
    parser.add_argument("--model-name", type=str, default="roberta-large",
                        choices=["bert-large-cased", "bert-large-cased-wwm",
                                 "roberta-large"])
    # wiki_uni的采样方法
    parser.add_argument("--method", type=str, default="threshold_sample",
                        choices=[
                            "threshold_sample",  # 设置一个百分比阈值，高于的下采样，低于的舍弃
                            "target_sampling",  # 设置一个数字目标，总数量不低于目标的情况下，要求尾实体分布尽可能广泛
                            "lama_sampling",  # 数字目标就按照lama每个关系的instance数量
                        ])
    parser.add_argument("--threshold", type=float, default=0.5)
    parser.add_argument("--target", type=int, default=1000)

    # 是否输出数据统计
    parser.add_argument("--stat-data", action="store_true")

    args = parser.parse_args("--stat-data".split())
    generate_data(args)


def store_roberta_vocab2idx():
    tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
    vocab = tokenizer.get_vocab()
    vocab2idx = {}
    for token in vocab:
        word = tokenizer.decode(vocab[token]).strip()
        vocab2idx[word] = vocab[token]
    store_json_dic("data/roberta_data/vocab2idx.json", vocab2idx)


if __name__ == '__main__':
    main()
