from prefect import Task
from loguru import logger
from tqdm import tqdm


class PreparePairsTask(Task):
    @staticmethod
    def get_bert_positions(content, vector_to_find, tokenizer):
        content = [c.replace("$", "") for c in content]
        vector_to_find = [c.replace("$", "") for c in vector_to_find]

        output_positions = list()

        num_var = dict()
        all_words = list()
        for c in content:
            size_word = len(tokenizer.tokenize(c))
            all_words.extend(tokenizer.tokenize(c))
            if c in vector_to_find:
                if c not in num_var:
                    num_var[c] = len(num_var) + 1
                output_positions.extend([num_var[c]] * size_word)
            else:
                output_positions.extend([0] * size_word)

       

        return output_positions

    def run(self, dataset, type_data, tokenizer):
        logger.info(f"*** (✦ ‿ ✦)  Prepare Pairs Task - DATASET {type_data} (✦ ‿ ✦) ***")
        output = dict()
        positive_pairs = 0
        negative_pairs = 0

        for id_s, content in tqdm(dataset.items()):
            content["sentence"] = [str(x) for x in content["sentence"]]
            terms = sorted(
                set(content["types"]), key=lambda x: content["types"].index(x)
            )

            count = 0

            for exp, mapped_term in content["variable_types"].items():
                for t in terms:
                    entailment_sentence = f" Then {exp} has type {t}.".split(" ")
                    variable_bert_positions = self.get_bert_positions(
                        content["sentence"] + entailment_sentence,
                        content["variables"],
                        tokenizer,
                    )

                    expressions_bert_positions = self.get_bert_positions(
                        content["sentence"] + entailment_sentence,
                        [
                            c
                            for c in content["expressions"]
                            if c not in content["variables"]
                        ],
                        tokenizer,
                    )

                    all_maths_bert_positions = self.get_bert_positions(
                        content["sentence"] + entailment_sentence,
                        content["expressions"],
                        tokenizer,
                    )
                    
                    
                    if mapped_term == t:
                        output[f"{id_s}_{count}"] = {
                            "sentence": " ".join(content["sentence"]).replace("$", ""),
                            "type_sentence": f" Then {exp.replace('$', '')} has type {t}.",
                            "label": 1,
                            "bert_variables_pos": variable_bert_positions,
                            "bert_expressions_pos": expressions_bert_positions,
                            "bert_all_pos": all_maths_bert_positions,
                        }

                        positive_pairs += 1

                    else:
                        output[f"{id_s}_{count}"] = {
                            "sentence": " ".join(content["sentence"]).replace("$", ""),
                            "type_sentence": f" Then {exp.replace('$', '')} has type {t}.",
                            "label": 0,
                            "bert_variables_pos": variable_bert_positions,
                            "bert_expressions_pos": expressions_bert_positions,
                            "bert_all_pos": all_maths_bert_positions,
                        }
                        negative_pairs += 1


                    
                    count = count + 1

        logger.info(f"POSITVE: {positive_pairs}")
        logger.info(f"NEGATIVE: {negative_pairs}")
        logger.info(f"TOTAL: {positive_pairs+negative_pairs}")

        return output
