import argparse
import prefect
from dynaconf import settings
from loguru import logger
import json
from prefect.engine.flow_runner import FlowRunner
from prefect import Flow, Parameter, Task, tags, task
from prefect.engine.results import LocalResult
import os
import pickle
from varslot.tasks.data_preparation import (
    PreparePairsForTrainTask,
    PreparePairsTask,
)
from transformers import AutoTokenizer, AutoModel
import random
import pickle
from varslot.tasks.model_pipeline import VariableSlotTrainer
import statistics

from sklearn.metrics import f1_score
import json


# num_iters=NUM_ITERS
is_baseline = 0

model_name = "sentence-transformers/bert-base-nli-mean-tokens"
# model_name = "gsarti/scibert-nli"
BERT_MODEL = model_name

test_path_subs = settings["test_subs"]
test_path = settings["test"]
train_path = settings["train"]
dev_path = settings["dev"]



CACHE_LOCATION = settings["cache_location"]

cache_args = dict(
    target= model_name + "--{task_name}-{task_tags}.pkl",
    checkpoint=True,
    result=LocalResult(dir=f"{CACHE_LOCATION}"),
)


class OpenFileTask(Task):
    def run(self, file_path):
        with open(file_path, "r") as f:
            input_data = json.load(f)

        return input_data


prepare_pairs_for_train = PreparePairsForTrainTask(**cache_args)
obtain_file_content = OpenFileTask()
prepare_vtds_task = PreparePairsTask(**cache_args)
trainer_task = VariableSlotTrainer()


@task
def pretty_print_output(
    scores, mode="normal"
):
    if mode=="subs":
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("TEST With Renaming")
        logger.info(scores["test"])
    else:
        logger.info("Transformers BASELINE")
        logger.info(f"BERT MODEL: {BERT_MODEL}")

        logger.info("DEV")
        logger.info(scores["dev"])

        logger.info("TEST")
        logger.info(scores["test"])
        


@task
def save_output(data, file_path):
    with open(file_path, "w") as f:
        json.dump(data, f)

 

def run_experiments_variable_slot_baseline(num_iters, is_baseline):
    with Flow("Running Variable Slot experiment") as flow1:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        with tags("train"):
            train_input = obtain_file_content(train_path)
            data_prepared_train = prepare_vtds_task(train_input, "train", tokenizer)
            input_to_model_train = prepare_pairs_for_train(
                dataset=data_prepared_train, tokenizer_model=model_name
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            data_prepared_dev = prepare_vtds_task(dev_input, "dev", tokenizer)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=data_prepared_dev, tokenizer_model=model_name
            )

        with tags(f"test"):
            test_input = obtain_file_content(test_path)
            data_prepared_test = prepare_vtds_task(test_input, "test", tokenizer)
            input_to_model_test = prepare_pairs_for_train(
                dataset=data_prepared_test, tokenizer_model=model_name
            )

        scores_ = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test,
            task_name=f"{model_name}_variable_slot_{is_baseline}_iters_{num_iters}",
            output_dir=f"{CACHE_LOCATION}/models/",
            model_name=model_name,
            is_baseline=is_baseline,
            dimension_size=768,
            num_iters=num_iters,
            mode="train",
            eval_fn=f1_score,
            msg_report="*** NORMAL RESULTS ***",
        )
        pretty_print_output(scores_)



    state = flow1.run()
    

    with Flow("Running Variable Slot experiment with Variable Renaming") as flow2:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        with tags("train"):
            train_input = obtain_file_content(train_path)
            data_prepared_train = prepare_vtds_task(train_input, "train", tokenizer)
            input_to_model_train = prepare_pairs_for_train(
                dataset=data_prepared_train, tokenizer_model=model_name
            )

        with tags("dev"):
            dev_input = obtain_file_content(dev_path)
            data_prepared_dev = prepare_vtds_task(dev_input, "dev", tokenizer)
            input_to_model_dev = prepare_pairs_for_train(
                dataset=data_prepared_dev, tokenizer_model=model_name
            )

        with tags(f"test"):
            test_input = obtain_file_content(test_path)
            data_prepared_test = prepare_vtds_task(test_input, "test", tokenizer)
            input_to_model_test = prepare_pairs_for_train(
                dataset=data_prepared_test, tokenizer_model=model_name
            )
        
        with tags(f"test-subs"):
            test_input_subs = obtain_file_content(test_path_subs)
            data_prepared_test_subs = prepare_vtds_task(test_input_subs, "test", tokenizer)
            input_to_model_test_subs = prepare_pairs_for_train(
                dataset=data_prepared_test_subs, tokenizer_model=model_name
            )

        scores1 = trainer_task(
            train_dataset=input_to_model_train,
            dev_dataset=input_to_model_dev,
            test_dataset=input_to_model_test_subs,
            task_name=f"{model_name}_variable_slot_{is_baseline}_iters_{num_iters}",
            output_dir=f"{CACHE_LOCATION}/models/",
            model_name=model_name,
            is_baseline=is_baseline,
            dimension_size=768,
            num_iters=num_iters,
            mode="test",
            eval_fn=f1_score,
             
        )
        pretty_print_output(scores1, mode="subs")
    state = flow2.run()

    


NUM_ITERS=6
run_experiments_variable_slot_baseline(NUM_ITERS, 0)