import argparse, os, csv
import pandas as pd

from project_root import join_with_root


from BARTScore import BARTScore
from DSBA import DSBA
from LocalGembaMQM import LocalGembaMQM
from XComet import XComet
from TranslationBleu import SentenceBleu

if __name__ == '__main__':
    os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
    parser = argparse.ArgumentParser(description='Pass allowed models via command line.')
    parser.add_argument('--baseline', help='List of models to be allowed')
    parser.add_argument('--model', help='Model for the DSBA or local MQM baseline', required=False)
    parser.add_argument('--dataset', help='The dataset that should be evaluated. The output will be a single file containing all the scores')
    parser.add_argument('--save_generated', help='Save the generated texts in a file as well')
    parser.add_argument('--src_lang', help='Source language for the local MQM baseline', required=False)
    parser.add_argument('--tgt_lang', help='Target language for the local MQM baseline', required=False)
    parser.add_argument('--lp', required=False)
    parser.add_argument('--to', required=False)
    args = parser.parse_args()

    df = pd.read_csv(args.dataset, sep="\t", quoting=csv.QUOTE_NONE)

    if args.to:
        df = df.head(int(args.to))

    if args.baseline == 'BARTScore':
        metric = BARTScore()
        name = args.dataset.replace("/", "_").split(".")[0] + "___" + args.baseline
        pd.DataFrame(metric.evaluate_df(df), columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + ".json"))

    if args.baseline == 'DSBA':
        metric = DSBA(model=args.model)
        name = args.dataset.replace("/", "_").split(".")[0] + "___" + args.baseline + "___" + args.model.replace("/", "_")
        res = metric.evaluate_df(df)
        scores = res['scores']
        texts = res['texts']
        pd.DataFrame(scores, columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + ".json"))
        if args.save_generated:
            pd.DataFrame(texts, columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + "___generated_texts"+ ".json"))

    if args.baseline == 'localGembaMQM':
        metric = LocalGembaMQM(model=args.model)
        name = args.dataset.replace("/", "_").split(".")[0] + "___" + args.baseline + "___" + args.model.replace("/", "_")
        scores, texts = metric.evaluate_df(df, args.src_lang, args.tgt_lang)
        pd.DataFrame(scores, columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + ".json"))
        if args.save_generated:
            pd.DataFrame(texts, columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + "___generated_texts"+ ".json"))

    if args.baseline == 'XComet':
        metric = XComet("")
        name = args.dataset.replace("/", "_").split(".")[0] + "___" + args.baseline
        pd.DataFrame(metric.evaluate_df(df), columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + ".json"))

    if args.baseline == 'TranslationBleu':
        metric = SentenceBleu(lp=args.lp)
        name = args.dataset.replace("/", "_").split(".")[0] + "___" + args.baseline
        pd.DataFrame(metric(df["SRC"].to_list(), df["HYP"].to_list()), columns=[name]).to_json(join_with_root("outputs/raw_baselines/" + name + ".json"))