import argparse
import torch
import os
import random
import numpy as np

from utils import load_file, load_dir, write_to_csv
from metrics import (
    compute_m2,
    compute_sentm2,
    compute_errant,
    compute_senterrant,
    get_plm_scorer,
)


def main():
    parser = argparse.ArgumentParser("PT-M2")
    parser.add_argument("-b", "--base", choices=["m2", "sentm2", "errant", "senterrant"], default="m2", type=str)
    parser.add_argument("-o", "--output_file", type=str,
                        required=True, help="output_file name")
    parser.add_argument("-m", "--model_type", type=str, help="choose the plm type", default="bert-base-uncased")
    parser.add_argument("-s", "--scorer", choices=["self", "bertscore", "bartscore"],
                        default="self",
                        type=str, help="choose the plm scorer type")
    parser.add_argument("--beta", default=0.5, type=float)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--n_gpu", default=1, type=int)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark=False
    torch.backends.cudnn.deterministic=True
    torch.cuda.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    print(args)

    input_dir = "data/conll14"
    reference_dir = os.path.join(input_dir, "refAB")
    references = load_dir(reference_dir)

    source_file = os.path.join(input_dir, "source")
    sources = load_file(source_file)

    if "m2" in args.base:  # use m2score to extract edits
        m2_file = os.path.join(input_dir, "goldAB.m2")
        system_dir = os.path.join(input_dir, "hyp")
    elif "errant" in args.base:  # use errant to extract edits
        m2_file = os.path.join(input_dir, "refAB.m2")
        system_dir = os.path.join(input_dir, "hyp_m2")

    scorer = None
    if args.scorer in ["bartscore", "bertscore"]:
        scorer = get_plm_scorer(references=[v for kv in references for v in kv.values()], args=args)

    datas, sdatas = [], []
    for f_n in os.listdir(system_dir):
        hyp_n = f_n.strip(".m2")
        hyp_file = os.path.join(system_dir, f_n)
        if args.base == "m2":
            score, score_lst = compute_m2(m2_file=m2_file, hyp_file=hyp_file, references=references,
                                          scorer=scorer, args=args)
        elif args.base == "sentm2":
            score, score_lst = compute_sentm2(m2_file=m2_file, hyp_file=hyp_file, references=references,
                                              scorer=scorer, args=args)
        elif args.base == "errant":
            score, score_lst = compute_errant(m2_file=m2_file, hyp_file=hyp_file, hyp_n=hyp_n, references=references,
                                              scorer=scorer, args=args)
        elif args.base == "senterrant":
            score, score_lst = compute_senterrant(m2_file=m2_file, hyp_file=hyp_file, hyp_n=hyp_n, references=references,
                                                  scorer=scorer, args=args)

        f_n = f_n[:-3] if ".m2" in f_n else f_n
        print(f"Computing {f_n}, score={score:.4f}")
        datas.append((args.output_file, "src-trg", "conll14", f_n, f"{score:.4f}"))
        if score_lst:
            for i, v in enumerate(score_lst):
                sdatas.append((args.output_file, "src-trg", "conll14", f_n, i, v))

    write_to_csv(f"gecmetrics/scores/conll14/system_scores_metrics/{args.output_file}.txt", datas)
    if score_lst:
        write_to_csv(f"gecmetrics/scores/conll14/sentence_scores_metrics/{args.output_file}.txt", sdatas)


if __name__ == "__main__":
    main()
