from prefect import Task
from loguru import logger
from dynaconf import settings
import random
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm


MAX_LENGTH_SENTENCE_BERT = 256


class PreparePairsForTrainTask(Task):
    def run(self, dataset, tokenizer_model, max_len_sentence=256):
        logger.info("Prepare Pairs Task ♫꒰･‿･๑꒱")
        output = dict()
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)

        for id_s, content in tqdm(dataset.items()):
            output[id_s] = dict()
            tokenizer_output = tokenizer(
                content["sentence"]   + content["type_sentence"],
                padding="max_length",
                truncation=True,
                max_length=MAX_LENGTH_SENTENCE_BERT,
            )

            if "token_type_ids" not in tokenizer_output:
                tokenizer_output["token_type_ids"] = [0] * MAX_LENGTH_SENTENCE_BERT

            output[id_s]["sentence_input_ids"] = tokenizer_output["input_ids"]
            output[id_s]["sentence_token_type_ids"] = tokenizer_output["token_type_ids"]
            output[id_s]["sentence_attention_mask"] = tokenizer_output["attention_mask"]

            total_size_sentence = len(output[id_s]["sentence_input_ids"])
            initial_size = len(content["bert_variables_pos"])
            if initial_size >= MAX_LENGTH_SENTENCE_BERT:
                output[id_s]["var_bert_pos"] = [0] + content["bert_variables_pos"]

                output[id_s]["var_bert_pos"] = output[id_s]["var_bert_pos"][
                    :MAX_LENGTH_SENTENCE_BERT
                ]

                output[id_s]["exp_bert_pos"] = [0] + content["bert_expressions_pos"]

                output[id_s]["exp_bert_pos"] = output[id_s]["exp_bert_pos"][
                    :MAX_LENGTH_SENTENCE_BERT
                ]
                output[id_s]["all_bert_pos"] = [0] + content["bert_all_pos"]

                output[id_s]["all_bert_pos"] = output[id_s]["all_bert_pos"][
                    :MAX_LENGTH_SENTENCE_BERT
                ]

            else:
                output[id_s]["var_bert_pos"] = (
                    [0]
                    + content["bert_variables_pos"]
                    + [0] * (total_size_sentence - initial_size - 1)
                )
                output[id_s]["exp_bert_pos"] = (
                    [0]
                    + content["bert_expressions_pos"]
                    + [0] * (total_size_sentence - initial_size - 1)
                )
                output[id_s]["all_bert_pos"] = (
                    [0]
                    + content["bert_all_pos"]
                    + [0] * (total_size_sentence - initial_size - 1)
                )

            output[id_s]["label"] = content["label"]

        return TorchDataset(output)


class TorchDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = list(dataset.values())
        self.keys = list(dataset.keys())

    def __getitem__(self, index):
        instance = self.dataset[index]
        return (
            torch.LongTensor(instance["var_bert_pos"]),
            torch.LongTensor(instance["exp_bert_pos"]),
            torch.LongTensor(instance["sentence_input_ids"]),
            torch.LongTensor(instance["sentence_token_type_ids"]),
            torch.LongTensor(instance["sentence_attention_mask"]),
            instance["label"],
            torch.LongTensor(instance["all_bert_pos"]),
            index,
        )

    def get_id(self, index):
        return self.keys[index]

    def __len__(self):
        return len(self.dataset)
