# -*- coding: utf-8 -*
import pickle

from model import models_POS
from run_mrc_ssdm_POS import evaluate, MODEL_CLASSES, args_get
import torch

from utils import train_ssdm_helper


def get_result(predict_file):
    args = args_get()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    model = model_class.from_pretrained(args.output_dir)
    tokenizer = tokenizer_class.from_pretrained(args.output_dir)
    model.to(args.device)

    # Evaluate
    save_dict = torch.load(
        args.syntax_model_dir,
        map_location=lambda storage,
                            loc: storage)
    config = save_dict['config']
    checkpoint = save_dict['state_dict']
    config.debug = True
    config.embed_type = "lm"
    with open(args.vocab_file, "rb") as fp:
        W, vocab = pickle.load(fp)
    with train_ssdm_helper.experiment(config, config.save_prefix) as e:
        model_syntax = models_POS.vgvae(
            vocab_size=len(vocab),
            embed_dim=e.config.edim,
            embed_init=W,
            experiment=e)
        model_syntax.eval()
        model_syntax.load(checkpointed_state_dict=checkpoint)

    all_result = []
    em_all = 0
    f1_all = 0
    language_num = len(predict_file)
    for file in predict_file:
        args.predict_file = file
        language = file.split(".")[0].split("-")[-1]
        result = evaluate(args, model, model_syntax, tokenizer, prefix=language)
        # result = dict((k + ('_{}'.format(language)), v) for k, v in result.items())
        # logger.info("Result_{}: {}".format(language, result))
        all_result.append("%.2f/%.2f" % (float(result["HasAns_exact"]), float(result["HasAns_f1"])))
        em_all += result["HasAns_exact"]
        f1_all += result["HasAns_f1"]
    all_result.append("%.2f/%.2f" % (float(em_all / language_num), float(f1_all / language_num)))
    for i in all_result:
        print(i, end=" ")


if __name__ == '__main__':
    get_result(["data/xquad/xquad.en.json", "data/xquad/xquad.ar.json", "data/xquad/xquad.de.json",
                         "data/xquad/xquad.el.json",
                         "data/xquad/xquad.es.json", "data/xquad/xquad.hi.json", "data/xquad/xquad.ro.json",
                         "data/xquad/xquad.ru.json", "data/xquad/xquad.th.json", "data/xquad/xquad.tr.json",
                         "data/xquad/xquad.vi.json", "data/xquad/xquad.zh.json"])
    get_result(["data/mlqa/test/test-context-ar-question-ar.json", "data/mlqa/test/test-context-de-question-de.json",
                "data/mlqa/test/test-context-en-question-en.json", "data/mlqa/test/test-context-es-question-es.json",
                "data/mlqa/test/test-context-hi-question-hi.json", "data/mlqa/test/test-context-vi-question-vi.json",
                "data/mlqa/test/test-context-zh-question-zh.json"])
    get_result(["data/tydiqa-goldp/tydiqa-goldp-dev-arabic.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-bengali.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-english.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-finnish.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-indonesian.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-korean.json",
                         "data/tydiqa-goldp/tydiqa-goldp-dev-russian.json",
                "data/tydiqa-goldp/tydiqa-goldp-dev-swahili.json",
                "data/tydiqa-goldp/tydiqa-goldp-dev-telugu.json"])
