import numpy as np
import torch
from loguru import logger
from overrides import overrides
from prefect import Task
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from sklearn.metrics import f1_score
import random
import os



class CreatePairsTransformersTask(Task):
    @staticmethod
    def create_pairs(dataset):
        
        output = dict()
        positive_pairs = 0
        negative_pairs = 0

        for id_s, content in tqdm(dataset.items()):
            terms = sorted(
                set(content["types"]), key=lambda x: content["types"].index(x)
            )
            count = 0
            for exp, mapped_term in content["variable_types"].items():
                joined_sentence = (" ".join(str(x) for x in content["sentence"]).replace("$", ""))
                for t in terms:
                    if mapped_term == t:
                        output[f"{id_s}_{count}"] = {
                            "sentence": joined_sentence,
                            "expression": exp.replace("$", ""),
                            "type": t,
                            "label": 1,
                        }
                        positive_pairs += 1
                        count = count + 1
                    else:

                        output[f"{id_s}_{count}"] = {
                            "sentence": joined_sentence,
                            "expression": exp.replace("$", ""),
                            "type": t,
                            "label": 0,
                        }
                        negative_pairs += 1
                        count = count + 1

        logger.info(f"POSITVE: {positive_pairs}")
        logger.info(f"NEGATIVE: {negative_pairs}")
        logger.info(f"TOTAL: {positive_pairs+negative_pairs} ٩(＾◡＾)۶")

        return output

    @overrides
    def run(self, dataset):
        def seed_everything(seed):
            random.seed(seed)
            os.environ["PYTHONHASHSEED"] = str(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True
            
        seed_everything(42)
        logger.info("*** Create Pairs For Transformers ლ(◉‿◉ ლ) ***")
        dataset = self.create_pairs(dataset)
        return dataset
