import argparse
import prefect
from dynaconf import settings
from loguru import logger
import json
from prefect.engine.flow_runner import FlowRunner
from prefect import Flow, Parameter, Task, tags, task
from prefect.engine.results import LocalResult
import os
import pickle
from varslot.tasks.data_preparation import (
    PreparePairsTransformersTask,
    CreatePairsTransformersTask,
)

from varslot.tasks.model_pipeline import SeqTransformerTrainer
import statistics

from sklearn.metrics import f1_score
import json

parser = argparse.ArgumentParser(
    description="Variable typing with baselines."
)
 

innoc_location = settings["innoc_location"]
test_path = settings["test"]
train_path = settings["train"]
dev_path = settings["dev"]

parser.add_argument(
    "--model_name",
    metavar="model_name",
    type=str,
    nargs="?",
    help="Transformers model [ex: bert-base-uncased]",
    default="bert-base-uncased",
)

args = parser.parse_args()

BERT_MODEL = args.model_name

CACHE_LOCATION = settings["cache_location"]

cache_args = dict(
    target=f"{BERT_MODEL}"+"--TRANSFORMERS--{task_name}-{task_tags}.pkl",
    checkpoint=True,
    result=LocalResult(dir=f"{CACHE_LOCATION}"),
)

create_pairs_for_train = CreatePairsTransformersTask(**cache_args)
prepare_pairs_for_train = PreparePairsTransformersTask(**cache_args)
trainer_task = SeqTransformerTrainer()


@task
def pretty_print_output(
    scores, mode="normal"
):
    if mode=="subs":
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("TEST With Renaming")
        logger.info(scores["test"])
    else:
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("DEV")
        logger.info(scores["dev"])

        logger.info("TEST")
        logger.info(scores["test"])
        


@task
def obtain_file_content(file_path):
    with open(file_path, "r") as f:
        input_data = json.load(f)

    return input_data


def run_experiments_transformer_baseline(test_path, mode, i):
    with Flow("Transformers for Variable Typing") as flow1:
        with tags("train"):
            train_input = obtain_file_content(train_path)
            train_input = create_pairs_for_train(train_input)
            input_to_model_train = prepare_pairs_for_train(
                dataset=train_input, bert_model=BERT_MODEL 
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            dev_input = create_pairs_for_train(dev_input)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=dev_input, bert_model=BERT_MODEL 
            )

        with tags(f"test_{i}"):
            test_input = obtain_file_content(test_path)
            test_input = create_pairs_for_train(test_input)
            input_to_model_test = prepare_pairs_for_train(
                dataset=test_input, bert_model=BERT_MODEL
            )

        scores = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test,
            task_name=f"{BERT_MODEL}/seq_classification/",
            output_dir=f"{CACHE_LOCATION}/models/",
            bert_model=BERT_MODEL,
            mode=mode,
            eval_fn=f1_score,
        )

    state = flow1.run()
    return state.result[scores]._result.value

out_scores = dict()
for i in range(0, 11):
    if i == 0:
        input_location=test_path
        mode="train"
    else:
        input_location = f"{innoc_location}{i}.json"
        mode="test"

    scores = run_experiments_transformer_baseline(input_location, mode, i)

    out_scores[i] = scores["test"]["f1_score"]

logger.info("************************")
logger.info(f"**** RESULTS FOR MODEL {BERT_MODEL} MASKED WITH ADDED VARS****")
logger.info("keys")
logger.info(f"{list(out_scores.keys())}")
logger.info("values")
logger.info(f"{list(out_scores.values())}")
logger.info("************************")

    
    

