import os

import torch
from prettytable import PrettyTable
from utils.utils import set_seed, \
    get_table_stat, model_prefix, get_relation_args, \
    load_file, partition, calculate_p, cal_mean_and_std, get_pair, \
    read_prompts, store_json_dic, load_json_dic
import numpy as np
from models import build_model_wrapper, build_debias_model_wrapper, model_wrapper
from utils.read_data import LamaDataset
import random
from tqdm import tqdm
from transformers import GPT2Tokenizer
import argparse
import copy


def eval_relation_subset(args):
    sample_times = args.sample_times
    sample_num = args.sample_num

    fout = open("causal_doc/relation_subset/relation_subset_variance_{}_{}".format(
        sample_times, sample_num
    ), "w")

    set_seed(0)

    args = get_relation_args(args)

    model_names = [
        "bert-large-cased",
        "gpt2-xl",
        "roberta-large",
        "bart-large",
        "bert-base-cased",
        "roberta-base",
        "gpt2-medium",
        "gpt2-large",
        "bart-base",
    ]

    model_wrappers = []

    for model_name in model_names:
        model_wrappers.append(
            build_model_wrapper(model_name, device=args.cuda_device, args=args)
        )

    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, id2samples = lama_data.get_samples()

    results_file = "causal_doc/relation_subset/new_results.json"
    model2precision = {}

    prompt2precision_file = "causal_doc/relation_subset/new_prompt2precision.json"
    prompt2precision = {}
    if os.path.isfile(prompt2precision_file):
        prompt2precision = load_json_dic(prompt2precision_file)
        print("{} loaded".format(prompt2precision_file))
    else:

        for model_name, model_wrapper in zip(model_names, model_wrappers):

            relation2precision_single_prompt = {}
            for relation_id in tqdm(id2relation):
                relation_prompts = read_prompts(relation_id)
                samples = id2samples[relation_id]
                relation = id2relation[relation_id]
                relation_label = relation["label"]
                if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                    batch_size = 32
                else:
                    batch_size = 4

                for prompt in relation_prompts:

                    res, single_p = model_wrapper.eval_sample_with_multi_prompts(
                        [prompt], samples,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words
                    )
                    relation2precision_single_prompt[prompt] = single_p

            prompt2precision[model_name] = relation2precision_single_prompt

        store_json_dic(prompt2precision_file, prompt2precision)

    if os.path.isfile(results_file):
        model2precision = load_json_dic(results_file)
        print("{} loaded".format(results_file))
    else:
        for model_name, model_wrapper in zip(model_names, model_wrappers):
            table = PrettyTable(
                field_names=["id", "label", "original", "single", "multi"]
            )
            table.title = "{} overall precision".format(model_name)

            relation2precision_original_prompt = {}
            relation2precision_single_prompt = {}
            relation2precision_multiple_prompt = {}
            for relation_id in tqdm(id2relation):
                relation_prompts = read_prompts(relation_id)
                samples = id2samples[relation_id]
                relation = id2relation[relation_id]
                relation_label = relation["label"]

                if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                    batch_size = 128
                else:
                    batch_size = 16

                with torch.no_grad():
                    res, original_p = model_wrapper.evaluate_samples(
                        relation, samples,
                        max_len=args.max_len,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words
                    )
                    relation2precision_original_prompt[relation_id] = original_p
                    
                    if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                        batch_size = 128
                    else:
                        batch_size = 16

                    single_prompt = random.choice(relation_prompts)
                    res, single_p = model_wrapper.eval_sample_with_multi_prompts(
                        [single_prompt], samples,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words
                    )
                    relation2precision_single_prompt[relation_id] = single_p

                    if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                        batch_size = 8
                    else:
                        batch_size = 1

                    res, multi_p = model_wrapper.eval_sample_with_multi_prompts(
                        relation_prompts, samples,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words
                    )
                    relation2precision_multiple_prompt[relation_id] = multi_p

                table.add_row([relation_id, relation_label, original_p, single_p, multi_p])
            table = get_table_stat(table)
            print(table)
            fout.write(table.get_string() + "\n")

            model2precision[model_name] = {
                "ori": relation2precision_original_prompt,
                "single": relation2precision_single_prompt,
                "multi": relation2precision_multiple_prompt
            }

        store_json_dic(results_file, model2precision)

    for prompt_method in ["ori", "single", "multi"]:
        main_table = PrettyTable(field_names=[
            "sample_time"] + model_names + ["best", "worst"]
        )
        main_table.title = "{}  main table".format(prompt_method)

        relation_ids = [relation_id for relation_id in id2relation]
        for sample_time in range(sample_times):
            relations = random.sample(relation_ids, sample_num)
            results = []
            best_model = {"name": "", "p": 0}
            worst_model = {"name": "", "p": 100}
            for model_name in model_names:
                if prompt_method == "single":
                    p = cal_performance_stored(relations, prompt2precision[model_name])
                else:
                    p = get_performance(
                        model2precision[model_name][prompt_method], relations
                    )
                if best_model["p"] == 0:
                    best_model["p"] = p
                    best_model["name"] = model_name
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                elif p > best_model["p"]:
                    best_model["p"] = p
                    best_model["name"] = model_name
                elif p < worst_model["p"]:
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                results.append(p)
            main_table.add_row(
                [str(sample_time)] + results + [best_model["name"], worst_model["name"]]
            )
        main_table = get_table_stat(main_table)
        print(main_table)
        fout.write(main_table.get_string() + "\n")


def random_eval_relation_subset_by_mention(args):
    sample_times = args.sample_times
    sample_num = args.sample_num

    set_seed(0)

    args = get_relation_args(args)

    model_names = [
        "bert-large-cased",
        "gpt2-xl",
        "roberta-large",
        "bart-large",
        "bert-base-cased",
        "roberta-base",
        "gpt2-medium",
        "gpt2-large",
        "bart-base",
    ]

    model_wrappers = []

    for model_name in model_names:
        model_wrappers.append(
            build_model_wrapper(model_name, device=args.cuda_device, args=args)
        )

    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, id2samples = lama_data.get_samples()

    prompt2precision_file = "causal_doc/relation_subset/random_mention_prompt2precision.json"
    prompt2precision = {}
    for model_name, model_wrapper in zip(model_names, model_wrappers):
        relation2precision_single_prompt = {}
        for relation_id in tqdm(id2relation):
            relation_prompts = read_prompts(relation_id)
            samples = id2samples[relation_id]
            relation = id2relation[relation_id]
            relation_label = relation["label"]

            if relation_id not in relation2precision_single_prompt:
                relation2precision_single_prompt[relation_id] = {}

            if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                batch_size = int(96 / (3 * len(relation_prompts)))
            else:
                batch_size = int(16 / (3 * len(relation_prompts)))
                if batch_size < 1:
                    batch_size = 1
            # mention随机，prompt ensemble
            res, single_p = model_wrapper.eval_sample_with_multi_prompts_and_mention(
                relation_prompts, samples,
                batch_size=batch_size,
                ignore_stop_word=args.ignore_stop_words, mentions=-1
            )
            relation2precision_single_prompt[relation_id]["multi_random"] = single_p
        prompt2precision[model_name] = relation2precision_single_prompt

    store_json_dic(prompt2precision_file, prompt2precision)


def eval_relation_subset_by_mention(args):
    sample_times = args.sample_times
    sample_num = args.sample_num

    fout = open("causal_doc/relation_subset/mention_relation_subset_variance_{}_{}".format(
        sample_times, sample_num
    ), "w")

    set_seed(0)

    args = get_relation_args(args)

    model_names = [
        "bert-large-cased",
        "gpt2-xl",
        "roberta-large",
        "bart-large",
        "bert-base-cased",
        "roberta-base",
        "gpt2-medium",
        "gpt2-large",
        "bart-base",
    ]

    model_wrappers = []

    for model_name in model_names:
        model_wrappers.append(
            build_model_wrapper(model_name, device=args.cuda_device, args=args)
        )

    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, id2samples = lama_data.get_samples()

    prompt2precision_file = "causal_doc/relation_subset/mention_prompt2precision.json"
    prompt2precision = {}
    if os.path.isfile(prompt2precision_file):
        prompt2precision = load_json_dic(prompt2precision_file)
        print("{} loaded".format(prompt2precision_file))
    else:
        for model_name, model_wrapper in zip(model_names, model_wrappers):
            relation2precision_single_prompt = {}
            for relation_id in tqdm(id2relation):
                relation_prompts = read_prompts(relation_id)
                samples = id2samples[relation_id]
                relation = id2relation[relation_id]
                relation_label = relation["label"]
                if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                    batch_size = int(96/3)
                else:
                    batch_size = int(16/3)

                if relation_id not in relation2precision_single_prompt:
                    relation2precision_single_prompt[relation_id] = {}
                for prompt in relation_prompts:
                    res, single_p = model_wrapper.eval_sample_with_multi_prompts_and_mention(
                        [prompt], samples,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words, mentions=-1
                    )
                    relation2precision_single_prompt[relation_id][prompt] = single_p

                    res, single_p = model_wrapper.eval_sample_with_multi_prompts_and_mention(
                        [prompt], samples,
                        batch_size=batch_size,
                        ignore_stop_word=args.ignore_stop_words, mentions=3
                    )
                    relation2precision_single_prompt[relation_id]["causal {}".format(prompt)] = single_p

                if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
                    batch_size = int(96 / (3 * len(relation_prompts)))
                else:
                    batch_size = int(16 / (3 * len(relation_prompts)))
                    if batch_size < 1:
                        batch_size = 1

                res, single_p = model_wrapper.eval_sample_with_multi_prompts_and_mention(
                    relation_prompts, samples,
                    batch_size=batch_size,
                    ignore_stop_word=args.ignore_stop_words, mentions=3
                )
                relation2precision_single_prompt[relation_id]["multi"] = single_p
            prompt2precision[model_name] = relation2precision_single_prompt

        store_json_dic(prompt2precision_file, prompt2precision)

    for prompt_method in ["single", "causal_single", "multi"]:
        main_table = PrettyTable(field_names=[
            "sample_time"] + model_names + ["best", "worst"]
        )
        main_table.title = "{}  main table".format(prompt_method)

        relation_ids = [relation_id for relation_id in id2relation]
        for sample_time in range(sample_times):
            relations = random.sample(relation_ids, sample_num)
            results = []
            best_model = {"name": "", "p": 0}
            worst_model = {"name": "", "p": 100}
            for model_name in model_names:
                if prompt_method == "single":
                    p = cal_performance_stored(relations, prompt2precision[model_name])
                elif prompt_method == "causal_single":
                    p = cal_performance_stored(relations, prompt2precision[model_name], causal=True)
                else:
                    p = cal_performance_stored(relations, prompt2precision[model_name], causal=False, multi=True)
                if best_model["p"] == 0:
                    best_model["p"] = p
                    best_model["name"] = model_name
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                elif p > best_model["p"]:
                    best_model["p"] = p
                    best_model["name"] = model_name
                elif p < worst_model["p"]:
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                results.append(p)
            main_table.add_row(
                [str(sample_time)] + results + [best_model["name"], worst_model["name"]]
            )
        main_table = get_table_stat(main_table)
        print(main_table)
        fout.write(main_table.get_string() + "\n")


def get_performance(p_dic, relations):
    ps = [p_dic[relation_id] for relation_id in relations]
    mean_p = np.mean(ps)
    mean_p = round(float(mean_p), 2)
    return mean_p


def cal_performance_stored(relations, promp2precision, causal=False, multi=False, random_var=False):
    ps = []
    for relation_id in relations:
        relation_prompts = read_prompts(relation_id)
        # print(relation_prompts)
        single_prompt = random.choice(relation_prompts)
        if multi:
            if not random_var:
                p = promp2precision[relation_id]["multi"]
            else:
                p = promp2precision[relation_id]["multi_random"]
        else:
            if causal:
                p = promp2precision[relation_id]["causal {}".format(single_prompt)]
            else:
                p = promp2precision[relation_id][single_prompt]
        ps.append(p)
    mean_p = np.mean(ps)
    mean_p = round(float(mean_p), 2)
    return mean_p


def cal_performance(relations, id2relation, id2samples, model_wrapper, args):
    ps = []
    for relation_id in tqdm(relations):
        relation_prompts = read_prompts(relation_id)
        samples = id2samples[relation_id]
        relation = id2relation[relation_id]
        relation_label = relation["label"]

        if not isinstance(model_wrapper.tokenizer, GPT2Tokenizer):
            batch_size = 32
        else:
            batch_size = 4

        single_prompt = random.choice(relation_prompts)
        # res, single_p = model_wrapper.eval_sample_with_multi_prompts(
        #     [single_prompt], samples,
        #     batch_size=batch_size,
        #     ignore_stop_word=args.ignore_stop_words
        # )
        relation["template"] = single_prompt
        res, single_p = model_wrapper.evaluate_samples(
                relation, samples,
                max_len=args.max_len,
                batch_size=batch_size,
                ignore_stop_word=args.ignore_stop_words
        )
        ps.append(single_p)
    mean_p = np.mean(ps)
    mean_p = round(float(mean_p), 2)
    return mean_p


def rank_consis(args):
    sample_times = args.sample_times
    sample_num = args.sample_num
    fout = open("causal_doc/relation_subset/mention_rank_consis_{}_{}".format(
        sample_times, sample_num
    ), "w")

    model_names = [
        "bert-base-cased",
        "bert-large-cased",
        "roberta-base",
        "roberta-large",
        "gpt2-medium",
        # "gpt2-large",
        "gpt2-xl",
        "bart-base",
        "bart-large",
    ]

    set_seed(args.seed)

    args = get_relation_args(args)

    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, id2samples = lama_data.get_samples()

    prompt2precision_file = "causal_doc/relation_subset/mention_prompt2precision.json"
    prompt2precision = load_json_dic(prompt2precision_file)

    random_prompt2precision_file = "causal_doc/relation_subset/random_mention_prompt2precision.json"
    random_prompt2precision = load_json_dic(random_prompt2precision_file)

    for prompt_method in ["single", "causal_single", "causal_multi","multi"]:
        main_table = PrettyTable(field_names=[
            "sample_time"] + model_names + ["best", "worst"]
        )
        relation_ids = [relation_id for relation_id in id2relation]
        model_sorts = []
        for sample_time in range(sample_times):
            relations = random.sample(relation_ids, sample_num)
            results = []
            best_model = {"name": "", "p": 0}
            worst_model = {"name": "", "p": 100}
            ps = []
            for model_name in model_names:
                if prompt_method == "single":
                    p = cal_performance_stored(relations, prompt2precision[model_name])
                elif prompt_method == "causal_single":
                    p = cal_performance_stored(relations, prompt2precision[model_name], causal=True)
                elif prompt_method == "causal_multi":
                    p = cal_performance_stored(relations, random_prompt2precision[model_name], causal=False, multi=True, random_var=True)
                else:
                    p = cal_performance_stored(relations, prompt2precision[model_name], causal=False, multi=True)
                if best_model["p"] == 0:
                    best_model["p"] = p
                    best_model["name"] = model_name
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                elif p > best_model["p"]:
                    best_model["p"] = p
                    best_model["name"] = model_name
                elif p < worst_model["p"]:
                    worst_model["p"] = p
                    worst_model["name"] = model_name
                results.append(p)
                ps.append(p)
            p_sort = np.argsort(ps)
            model_sorts.append(p_sort)
            # main_table.add_row(
            #     [str(sample_time)] + results + [best_model["name"], worst_model["name"]]
            # )

        model_rank_consis, model_rank_std = get_rank_consis(model_sorts, model_names)
        all_model_rank = get_all_rank_consis(model_sorts)
        main_table.add_row(["consis"]+ model_rank_consis + ["best", "worst"])
        main_table.add_row(["std"] + model_rank_std + ["best", "worst"])
        # main_table = get_table_stat(main_table)
        main_table.title = "{}  main table, all consis: {}".format(prompt_method, all_model_rank)
        print(main_table)
        fout.write(main_table.get_string() + "\n")


def get_rank_consis(model_sorts, model_names):
    sample_nums = len(model_sorts)
    model_num = len(model_names)
    model_rank_consis = []
    model_rank_std = []
    for model_idx in range(model_num):
        consis = [0 for i in range(model_num)]
        ranks = []
        for i in range(sample_nums):
            
            rank = np.where(model_sorts[i]==model_idx)[0][0]
            consis[rank] += 1
            ranks.append(rank)
        max_consis = max(consis)
        model_rank_consis.append(round(max_consis*100/sample_nums, 2))
        rank_std = np.std(ranks)
        rank_std = round(rank_std, 2)
        model_rank_std.append(rank_std)
    return model_rank_consis, model_rank_std


def get_all_rank_consis(model_sorts):
    sample_num = len(model_sorts)
    consis = [0 for i in range(sample_num)]
    for i in range(sample_num):
        for j in range(sample_num):
            if (model_sorts[i] == model_sorts[j]).all():
                consis[i] += 1
    max_consis = max(consis)
    return round(max_consis*100/sample_num, 2)


def cal_prompt_mention_prediction(args):
    if args.model_type == "bert":
        model_names = [
            "bert-large-cased",
            "bert-base-cased"
        ]
        batch_size = 96
    elif args.model_type == "gpt2":
        model_names = [
            "gpt2-xl",
            "gpt2-medium"
        ]
        batch_size = 16
    elif args.model_type == "roberta":
        model_names = [
            "roberta-large",
            "roberta-base"
        ]
        batch_size = 96
    elif args.model_type == "bart":
        model_names = [
            "bart-large",
            "bart-base"
        ]
        batch_size = 96
    model_wrappers = []
    for model_name in model_names:
        model_wrappers.append(
            build_model_wrapper(model_name, device=args.cuda_device, args=args)
        )
    
    set_seed(0)

    args = get_relation_args(args)
    lama_data = LamaDataset(relation_file=args.relation_file,
                            sample_dir=args.sample_dir,
                            sample_file_type=args.sample_file_type)
    id2relation, id2samples = lama_data.get_samples()

    data_dir = "fact_data/mention2prediction"

    for model_name, model_wrapper in zip(model_names, model_wrappers):
        prediction_path = "{}/{}".format(data_dir, model_name)
        # if not os.path.isdir(prediction_path):
        #     os.makedirs(prediction_path)
        # mention2prediction[relation_id][prompt][sub_id][mention]={"prediction":, "res"}
        mention2prediction = {}
        for relation_id in id2relation:
            mention2prediction[relation_id] = {}
            relation_prompts = read_prompts(relation_id)
            samples = id2samples[relation_id]
            relation = id2relation[relation_id]
            relation_label = relation["label"]
            expand_samples = []
            for sample in samples:
                sub, obj, sub_id = get_pair(sample, return_id=True)
                sub_mentions = sample["sub_mentions"]
                for mention in sub_mentions:
                    new_sample = copy.deepcopy(sample)
                    sample["sub_label"] = mention
                    expand_samples.append(new_sample)
            for prompt in tqdm(relation_prompts):
                results, p, tokens, pre_res = model_wrapper.eval_sample_with_multi_prompts(
                    [prompt], expand_samples,
                    batch_size=batch_size,
                    ignore_stop_word=args.ignore_stop_words, 
                    return_tokens=True
                )
                mention2prediction[relation_id][prompt] = {}
                for token, res, sample in zip(tokens, pre_res, expand_samples):
                    sub, obj, sub_id = get_pair(sample, return_id=True)
                    if sub_id not in mention2prediction[relation_id][prompt]:
                        mention2prediction[relation_id][prompt][sub_id] = {}
                    mention2prediction[relation_id][prompt][sub_id][sub] = {
                        "prediction": token, "res": res
                    }
            prompt = "multi"
            this_batch = int(batch_size / len(relation_prompts))
            if this_batch < 1:
                this_batch = 1
            results, p, tokens, pre_res = model_wrapper.eval_sample_with_multi_prompts(
                    relation_prompts, expand_samples,
                    batch_size=this_batch,
                    ignore_stop_word=args.ignore_stop_words, 
                    return_tokens=True
                )
            mention2prediction[relation_id][prompt] = {}
            for token, res, sample in zip(tokens, pre_res, expand_samples):
                sub, obj, sub_id = get_pair(sample, return_id=True)
                if sub_id not in mention2prediction[relation_id][prompt]:
                    mention2prediction[relation_id][prompt][sub_id] = {}
                mention2prediction[relation_id][prompt][sub_id][sub] = {
                    "prediction": token, "res": res
                }
        store_json_dic(prediction_path, mention2prediction)
                    



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--relation-type", type=str, default="lama_filter")
    parser.add_argument("--model-name", type=str, default="gpt2-medium")
    parser.add_argument("--model-type", type=str, default="bert")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--cuda-device", type=int, default=4)
    parser.add_argument("--max-len", type=int, default=256)
    parser.add_argument("--topk", type=int, default=10)

    parser.add_argument("--gpt-method", type=str, default="next_token")
    parser.add_argument("--generate-len", type=int, default=1)

    parser.add_argument("--sample-method", type=str, default="replace",
                        choices=["replace", "no_replace"])

    parser.add_argument("--dupe", type=int, default=5)
    parser.add_argument("--lr", type=str, default="5e-5")
    parser.add_argument("--model-path", type=str, default=None)

    parser.add_argument("--multi-prompt", type=bool, default=True)
    parser.add_argument("--sample-times", type=int, default=5)
    parser.add_argument("--sample-num", type=int, default=10)

    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--task", type=str,
                        default="cal_prompt_mention_prediction",
                        choices=[
                            "eval_relation_subset",
                            "eval_relation_subset_by_mention",
                            "rank_consis",
                            "random_eval_relation_subset_by_mention",
                            "cal_prompt_mention_prediction"
                        ])

    parser.add_argument("--ignore-stop-words", action="store_false")

    args = parser.parse_args()

    if args.task == "eval_relation_subset":
        eval_relation_subset(args)
    # elif args.task == "eval_relation_subset_by_mention":
    #     eval_relation_subset_by_mention(args)
    elif args.task == "rank_consis":
        rank_consis(args)
    elif args.task == "random_eval_relation_subset_by_mention":
        random_eval_relation_subset_by_mention(args)
    elif args.task == "cal_prompt_mention_prediction":
        cal_prompt_mention_prediction(args)


if __name__ == '__main__':
    main()
