import argparse
import prefect
from dynaconf import settings
from loguru import logger
import json
from prefect.engine.flow_runner import FlowRunner
from prefect import Flow, Parameter, Task, tags, task
from prefect.engine.results import LocalResult
import os
import pickle
from varslot.tasks.data_preparation import (
    PreparePairsForTrainTask,
    PreparePairsTask,
)
from transformers import AutoTokenizer, AutoModel
import random
import pickle
from varslot.tasks.model_pipeline import VariableSlotTrainer
import statistics

from sklearn.metrics import f1_score
import json


# num_iters=NUM_ITERS
is_baseline = 0
model_name = "gsarti/scibert-nli"
BERT_MODEL = model_name

innoc_location = settings["innoc_location"]
test_path = settings["test"]
train_path = settings["train"]
dev_path = settings["dev"]



CACHE_LOCATION = settings["cache_location"]

cache_args = dict(
    target= model_name + "--{task_name}-{task_tags}.pkl",
    checkpoint=True,
    result=LocalResult(dir=f"{CACHE_LOCATION}"),
)


class OpenFileTask(Task):
    def run(self, file_path):
        with open(file_path, "r") as f:
            input_data = json.load(f)

        return input_data


prepare_pairs_for_train = PreparePairsForTrainTask(**cache_args)
obtain_file_content = OpenFileTask()
prepare_vtds_task = PreparePairsTask(**cache_args)
trainer_task = VariableSlotTrainer()


@task
def pretty_print_output(
    scores, mode="normal"
):
    if mode=="subs":
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("TEST With Renaming")
        logger.info(scores["test"])
    else:
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("DEV")
        logger.info(scores["dev"])

        logger.info("TEST")
        logger.info(scores["test"])
        


@task
def save_output(data, file_path):
    with open(file_path, "w") as f:
        json.dump(data, f)

 

def run_experiments_variable_slot_baseline(num_iters,  input_location, mode, i):
    with Flow("Running Variable Slot experiment") as flow1:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        with tags("train"):
            train_input = obtain_file_content(train_path)
            data_prepared_train = prepare_vtds_task(train_input, "train", tokenizer)
            input_to_model_train = prepare_pairs_for_train(
                dataset=data_prepared_train, tokenizer_model=model_name
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            data_prepared_dev = prepare_vtds_task(dev_input, "dev", tokenizer)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=data_prepared_dev, tokenizer_model=model_name
            )

        with tags(f"test-{i}"):
            test_input = obtain_file_content(input_location)
            data_prepared_test = prepare_vtds_task(test_input, "test", tokenizer)
            input_to_model_test = prepare_pairs_for_train(
                dataset=data_prepared_test, tokenizer_model=model_name
            )

        scores = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test,
            task_name=f"{model_name}_variable_slot_{is_baseline}_iters_{num_iters}",
            output_dir=f"{CACHE_LOCATION}/models/",
            model_name=model_name,
            is_baseline=0,
            dimension_size=768,
            num_iters=num_iters,
            mode=mode,
            eval_fn=f1_score,
            msg_report="*** NORMAL RESULTS ***",
        )

    state = flow1.run()
    return state.result[scores]._result.value

    

NUM_ITERS=2

out_scores = dict()
for i in range(0, 11):
    if i == 0:
        input_location=test_path
        mode="test"
    else:
        input_location = f"{innoc_location}{i}.json"
        mode="test"

    scores = run_experiments_variable_slot_baseline(NUM_ITERS, input_location, mode, i)
    

    out_scores[i] = scores["test"]["f1_score"]

logger.info("************************")
logger.info(f"**** RESULTS FOR MODEL {BERT_MODEL} MASKED WITH ADDED VARS****")
logger.info("keys")
logger.info(f"{list(out_scores.keys())}")
logger.info("values")
logger.info(f"{list(out_scores.values())}")
logger.info("************************")

    
    



