import numpy as np
import torch
from loguru import logger
from overrides import overrides
from prefect import Task
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from sklearn.metrics import f1_score
import random
import os

class PreparePairsTransformersTask(Task):
    @staticmethod
    def calculate_metrics(preds, gold):
        preds = np.argmax(preds, axis=1)
        return f1_score(preds, gold, average="binary")

    @staticmethod
    def process_data(dataset, tokenizer, max_length=256):
        processed_dataset = {}

        for id_, data in tqdm(dataset.items(), "Preparing transformer data"):

            sentence_type = f'Then {data["expression"]} is of type {data["type"]}'
         
            output = tokenizer.encode_plus(
                data["sentence"],
                sentence_type,
                padding="max_length",
                max_length=max_length,
                truncation=True,
            )

            processed_dataset[id_] = {
                "input_ids": output["input_ids"],
                "attention_mask": output["attention_mask"],
                "label": data["label"],
            }

            if "token_type_ids" in output:
                processed_dataset[id_]["token_type_ids"] = output["token_type_ids"]
            else:
                processed_dataset[id_]["token_type_ids"] = 0

        return processed_dataset


    @overrides
    def run(self, dataset,  bert_model="bert-base-uncased"):
        logger.info("** Preparing binary torch dataset (｡-_-｡ )人( ｡-_-｡) **")

        def seed_everything(seed):
            random.seed(seed)
            os.environ["PYTHONHASHSEED"] = str(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True


        seed_everything(42)

     
        processed_dataset = self.process_data(
                dataset, AutoTokenizer.from_pretrained(bert_model)
        )
       
        return TorchDataset(processed_dataset)


class TorchDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = list(dataset.values())
        self.keys = list(dataset.keys())

    def __getitem__(self, index):
        instance = self.dataset[index]
        return (
            torch.LongTensor(instance["input_ids"]),
            torch.LongTensor(instance["attention_mask"]),
            torch.LongTensor(instance["token_type_ids"]),
            instance["label"],
            index,
        )

    def get_id(self, index):
        return self.keys[index]

    def __len__(self):
        return len(self.dataset)
