import argparse
import prefect
from dynaconf import settings
from loguru import logger
import json
from prefect.engine.flow_runner import FlowRunner
from prefect import Flow, Parameter, Task, tags, task
from prefect.engine.results import LocalResult
import os
import pickle
from varslot.tasks.data_preparation import (
    PreparePairsTransformersTask,
    CreatePairsTransformersTask,
)

from varslot.tasks.model_pipeline import SeqTransformerTrainer
import statistics

from sklearn.metrics import f1_score
import json

parser = argparse.ArgumentParser(
    description="Variable typing with baselines."
)
 

test_path_subs = settings["test_subs"]
test_path = settings["test"]
train_path = settings["train"]
dev_path = settings["dev"]

parser.add_argument(
    "--model_name",
    metavar="model_name",
    type=str,
    nargs="?",
    help="Transformers model [ex: bert-base-uncased]",
    default="bert-base-uncased",
)

args = parser.parse_args()

BERT_MODEL = args.model_name
# BERT_MODEL = "allenai/scibert_scivocab_uncased"

CACHE_LOCATION = settings["cache_location"]

cache_args = dict(
    target=f"{BERT_MODEL}"+"--TRANSFORMERS--{task_name}-{task_tags}.pkl",
    checkpoint=True,
    result=LocalResult(dir=f"{CACHE_LOCATION}"),
)

create_pairs_for_train = CreatePairsTransformersTask(**cache_args)
prepare_pairs_for_train = PreparePairsTransformersTask(**cache_args)
trainer_task = SeqTransformerTrainer()


@task
def pretty_print_output(
    scores, mode="normal"
):
    if mode=="subs":
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("TEST With Renaming")
        logger.info(scores["test"])
    else:
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("DEV")
        logger.info(scores["dev"])

        logger.info("TEST")
        logger.info(scores["test"])
        


@task
def obtain_file_content(file_path):
    with open(file_path, "r") as f:
        input_data = json.load(f)

    return input_data


def run_experiments_transformer_baseline():
    with Flow("Transformers for Variable Typing") as flow1:
        with tags("train"):
            train_input = obtain_file_content(train_path)
            train_input = create_pairs_for_train(train_input)
            input_to_model_train = prepare_pairs_for_train(
                dataset=train_input, bert_model=BERT_MODEL 
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            dev_input = create_pairs_for_train(dev_input)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=dev_input, bert_model=BERT_MODEL 
            )

        with tags(f"test-basic"):
            test_input = obtain_file_content(test_path)
            test_input = create_pairs_for_train(test_input)
            input_to_model_test = prepare_pairs_for_train(
                dataset=test_input, bert_model=BERT_MODEL
            )
       

        scores = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test,
            task_name=f"{BERT_MODEL}/seq_classification/",
            output_dir=f"{CACHE_LOCATION}/models/",
            bert_model=BERT_MODEL,
            mode="train",
            eval_fn=f1_score,
        )

        
        pretty_print_output(scores)

    with Flow("Transformers for Variable Typing with Variable Renaming") as flow2:
        with tags("train"):
            train_input = obtain_file_content(train_path)
            train_input = create_pairs_for_train(train_input)
            input_to_model_train = prepare_pairs_for_train(
                dataset=train_input, bert_model=BERT_MODEL 
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            dev_input = create_pairs_for_train(dev_input)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=dev_input, bert_model=BERT_MODEL 
            )

        with tags(f"test-subs"):
            test_input_subs = obtain_file_content(test_path_subs)
            test_input_subs = create_pairs_for_train(test_input_subs)
            input_to_model_test_subs = prepare_pairs_for_train(
                dataset=test_input_subs, bert_model=BERT_MODEL 
            )

      
        scores = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test_subs,
            task_name=f"{BERT_MODEL}/seq_classification/",
            output_dir=f"{CACHE_LOCATION}/models/",
            bert_model=BERT_MODEL,
            mode="test",
            eval_fn=f1_score,
        )
        pretty_print_output(scores, mode="subs")

    state = flow1.run()
    state = flow2.run()


run_experiments_transformer_baseline()

