# Copyright (c) <anonymized for review>
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
from scripts.batch_eval_KB_completion import (
    load_file,
    init_logging,
    lowercase_samples,
    filter_samples)
import logging
import logging.config
import pickle
import json
from multiprocessing.pool import ThreadPool
import multiprocessing
from tqdm import tqdm
from argparse import Namespace
import copy
import datetime

from lama.utils import create_input_text, batchify
from lama.generation import GenerationImpl
import lama.evaluation_metrics as metrics

logger = logging.getLogger(__name__)


DEFAULT_DATASET_CONFIG = "config/dataset_config.json"

SUBJ_SYMBOL = "[X]"
OBJ_SYMBOL = "[Y]"

# Moved and modified from scripts.batch_eval_KB_completion
def run_thread(arguments):

    msg = ""

    # 1. compute the ranking metrics on the filtered log_probs tensor
    sample_MRR, sample_P, experiment_result, return_msg = metrics.get_ranking(
        arguments["filtered_log_probs"],
        arguments["masked_indices"],
        arguments["vocab"],
        label_index=arguments["label_index"],
        index_list=arguments["index_list"],
        topk=arguments.get("topk") or 10000,
        ppl_reranking=arguments.get("ppl_reranking"),
        token_ids=arguments.get("token_ids"),
        model=arguments.get("model"),
    )
    msg += "\n" + return_msg

    return experiment_result, sample_MRR, sample_P, msg


def get_model_name(args):
    [model_type_name] = args.models_names
    if model_type_name == "fairseq":
        model_name = "fairseq_{}".format(args.fairseq_model_name)
    elif model_type_name == "bert":
        model_name = "BERT_{}".format(args.bert_model_name)
    elif model_type_name == "elmo":
        model_name = "ELMo_{}".format(args.elmo_model_name)
    else:
        model_name = model_type_name.title()

    return model_name


def prepare_input_samples(
    dataset_filename,
    template,
    model,
    vocab_subset,
    max_sentence_length,
    lowercase=False,
    logger=None
):

    data = load_file(dataset_filename)
    print(len(data))

    if lowercase:
        # lowercase all samples
        logger.info("lowercasing all samples...")
        all_samples = lowercase_samples(data)
        if template and template != "":
            template = template.lower()
            template = template.replace(SUBJ_SYMBOL.lower(), SUBJ_SYMBOL)
            template = template.replace(OBJ_SYMBOL.lower(), OBJ_SYMBOL)
    else:
        # keep samples as they are
        all_samples = data

    all_samples, ret_msg = filter_samples(
        model, data, vocab_subset, max_sentence_length, template
    )

    logger.info("\n" + ret_msg + "\n")

    print(len(all_samples))

    if template and template != "":
        all_samples = create_input_text(
            all_samples,
            template,
            logger=logger)

    # create uuid if not present
    i = 0
    for sample in all_samples:
        if "uuid" not in sample:
            sample["uuid"] = i
        i += 1

    return all_samples


# Moved and modified from scripts.batch_eval_KB_completion
def run_prediction(args):
    # initialize logging
    log_directory = args.full_logdir
    logger = init_logging(log_directory)

    gen = GenerationImpl(args, logger=logger)

    init_msg = ""
    model_name = get_model_name(args)
    init_msg += "model name: {}\n".format(model_name)
    init_msg += "args: {}\n".format(args)
    init_msg += "common vocabulary size: {}\n".format(len(gen.vocab_subset))
    logger.info("\n" + init_msg + "\n")

    # dump arguments on file for log
    with open("{}/args.json".format(log_directory), "w") as outfile:
        json.dump(vars(args), outfile)

    all_samples = prepare_input_samples(
        args.dataset_filename,
        args.template,
        gen.model,
        gen.vocab_subset,
        args.max_sentence_length,
        args.lowercase,
        logger)

    (
        samples_batches,
        sentences_batches,
        ret_msg
    ) = batchify(all_samples, args.batch_size)

    logger.info("\n" + ret_msg + "\n")

    # ThreadPool
    num_threads = args.threads
    if num_threads <= 0:
        # use all available threads
        num_threads = multiprocessing.cpu_count()
    pool = ThreadPool(num_threads)
    list_of_results = []

    for i in tqdm(range(len(samples_batches))):

        samples_b = samples_batches[i]
        sentences_b = sentences_batches[i]

        genout = gen.get_batch_generation(sentences_b, logger)

        label_index_list = []
        for sample in samples_b:
            # MAKE SURE THAT obj_label IS IN VOCABULARIES
            if not gen.is_in_model_vocabulary(sample["obj_label"]):
                raise ValueError(
                    "object label {} not in model vocabulary".format(
                        sample["obj_label"]
                    )
                )

            obj_label_id = gen.model.get_id(sample["obj_label"])
            label_index_list.append(obj_label_id)

        arguments = [
            {
                "original_log_probs": original_log_probs,
                "filtered_log_probs": filtered_log_probs,
                "token_ids": token_ids,
                "vocab": gen.model.vocab,
                "label_index": label_index[0],
                "masked_indices": masked_indices,
                "index_list": gen.index_list,
                "sample": sample,
                "topk": args.topk,
                "ppl_reranking": args.ppl_reranking,
                "model": gen.model,  # TODO: is it safe for multi-processing?
            }
            for (
                original_log_probs,
                filtered_log_probs,
                token_ids,
                masked_indices,
                label_index,
                sample
            ) in zip(
                genout.original_log_probs_list,
                genout.filtered_log_probs_list,
                genout.token_ids_list,
                genout.masked_indices_list,
                label_index_list,
                samples_b,
            )
        ]
        # single thread for debug
        # for isx,a in enumerate(arguments):
        #     print(samples_b[isx])
        #     run_thread(a)

        # multithread
        res = pool.map(run_thread, arguments)

        for idx, result in enumerate(res):

            (
                result_masked_topk,
                sample_MRR,
                sample_P,
                msg
             ) = result

            logger.info("\n" + msg + "\n")

            sample = samples_b[idx]

            element = {}
            element["sample"] = sample
            element["uuid"] = sample["uuid"]
            element["token_ids"] = genout.token_ids_list[idx]
            element["masked_indices"] = genout.masked_indices_list[idx]
            element["label_index"] = label_index_list[idx]
            element["masked_topk"] = result_masked_topk
            element["sample_MRR"] = sample_MRR
            element["sample_Precision"] = sample_P
            element["sample_Precision1"] = result_masked_topk["P_AT_1"]

            list_of_results.append(element)

    pool.close()
    pool.join()

    msg = "all_samples: {}\n".format(len(all_samples))
    msg += "list_of_results: {}\n".format(len(list_of_results))

    logger.info("\n" + msg + "\n")
    print("\n" + msg + "\n")

    # dump pickle with the result of the experiment
    with open("{}/result.pkl".format(log_directory), "wb") as f:
        pickle.dump(list_of_results, f)


def set_default_option_if_not_exist(args, key, value):
    if not hasattr(args, key):
        setattr(args, key, value)


def run_experiments(args, dataset_params, lm_input_param):
    args = copy.deepcopy(args)
    for k, v in lm_input_param.items():
        setattr(args, k, v)

    relations = dataset_params["relations"]
    for relation in relations:
        args.dataset_filename = "{}{}{}".format(
                dataset_params["data_path_pre"],
                relation["relation"],
                dataset_params["data_path_post"],
            )
        args.full_logdir = "{}results/{}/{}".format(
                dataset_params["output_path_pre"],
                lm_input_param["label"],
                relation["relation"],
            )
        if "template" in relation:
            args.template = relation["template"]

        # see if file exists
        try:
            load_file(args.dataset_filename)
        except Exception as e:
            print("Relation {} excluded.".format(relation["relation"]))
            print("Exception: {}".format(e))
            continue

        run_prediction(args)


def get_dataset_parameters(
    dataset_name,
    dataset_config="config/dataset_config.json",
    data_path_pre="data/",
    output_path_pre="output_pll/"
):
    with open(dataset_config) as f:
        config_json = json.load(f)

    try:
        config = config_json[dataset_name]
    except KeyError:
        raise RuntimeError(f"Undefined dataset name {dataset_name}.")

    if type(config["relations"]) is str:
        relations = load_file(config["relations"])
    else:
        relations = config["relations"]

    data_path_pre += config["data_path_pre"]
    output_path_pre += config["output_path_pre"]

    return {
        "relations": relations,
        "data_path_pre": data_path_pre,
        "data_path_post": config["data_path_post"],
        "output_path_pre": output_path_pre,
    }


def get_config_from_file(config_file: str, args=None):
    with open(config_file) as f:
        config_json = json.load(f)

    if args is None:
        args = Namespace()

    for k, v in config_json.items():
        if hasattr(args, k) and getattr(args, k) is not None:
            logger.warning(
                f"Argument {k} is specified by the command. "
                "The value in the configuration file will be ignored.")
            continue
        setattr(args, k, v)

    return args


def main(args):
    if args.config is not None:
        get_config_from_file(args.config, args)

    assert hasattr(args, "LMs")

    # Set default options
    set_default_option_if_not_exist(args, "datasets", "all")
    set_default_option_if_not_exist(
        args,
        "output_path_pre",
        f"output_{datetime.date.today()}/")
    set_default_option_if_not_exist(
        args,
        "common_vocab_filename",
        "pre-trained_language_models/common_vocab_cased.txt")
    set_default_option_if_not_exist(args, "template", "")
    set_default_option_if_not_exist(args, "bert_model_dir", None)
    set_default_option_if_not_exist(args, "bert_vocab_name", "vocab.txt")
    set_default_option_if_not_exist(args, "batch_size", 32)
    set_default_option_if_not_exist(args, "logdir", "output")
    set_default_option_if_not_exist(args, "lowercase", False)
    set_default_option_if_not_exist(args, "max_sentence_length", 100)
    set_default_option_if_not_exist(args, "threads", -1)
    set_default_option_if_not_exist(args, "topk", 100)
    set_default_option_if_not_exist(args, "ppl_reranking", False)

    if type(args.datasets) is str:
        if args.datasets == "all":
            args.datasets = ["gre", "trex", "cnet", "squad"]
        else:
            args.datasets = [args.datasets]

    params = {}
    params["output_path_pre"] = args.output_path_pre
    if hasattr(args, "data_path_pre"):
        params["data_path_pre"] = args.data_path_pre

    dataset_config = getattr(args, "dataset_config", DEFAULT_DATASET_CONFIG)

    for dataset in args.datasets:
        dataset_params = get_dataset_parameters(
            dataset, dataset_config, **params)
        for ip in args.LMs:
            run_experiments(args, dataset_params, ip)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c", "--config",
        help="Path to config file.")
    parser.add_argument(
        "--datasets",
        choices=["all", "gre", "trex", "cnet", "squad"],
        help="Dataset to experiment on.")
    parser.add_argument(
        "--output_path_pre",
        help="Output path prefix.")
    return parser


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)

    parser = get_parser()
    args = parser.parse_args()

    main(args)
