from transformers import MT5ForConditionalGeneration
import torch
import transformers
import pathlib
from absl import app, flags
import logging
from datasets import concatenate_datasets
import random
import numpy as np

from typing import Mapping


def tokenize_sentence(tokenizer, record):
    """Tokenize the inputs and targets of a record"""
    tokenized_inputs = tokenizer(
        record["inputs_pretokenized"],
        return_tensors="pt",
        max_length=2048,
        # padding='max_length',
        return_offsets_mapping=True,
        truncation=True,
    )

    inputs = tokenized_inputs.input_ids
    gt = record['targets_pretokenized'].replace("<extra_id_0>", "").strip()
    target_string = record["inputs_pretokenized"].replace("<extra_id_0>", gt)

    tokenized_targets = tokenizer(
        target_string,
        return_tensors="pt",
        return_offsets_mapping=True,
    )
    targets = tokenized_targets.input_ids
#     targets[targets == tokenizer.pad_token_id] = -100

#     if not FLAGS.include_eos:
#         targets = targets[:, :-1]

    # Fixme: it seems like we discard more than eos
    # But this was the scores in the paper
    output = {"inputs": inputs, "targets": targets[:, :-1]} # last token is eos </s>

    # if FLAGS.use_entity_locations and "obj_surface" in record:
    #     entity_locations = find_entity_locations(record)
    #     entity_indices = find_entity_indices(
    #         tokenizer, entity_locations, tokenized_inputs, tokenized_targets
    #     )
    #     output["entity_indices"] = entity_indices
    return output

def tokenize(tokenizer, record):
    """Tokenize the inputs and targets of a record"""
    tokenized_inputs = tokenizer(
        record["inputs_pretokenized"],
        return_tensors="pt",
        max_length=2048,
        padding='max_length',
        return_offsets_mapping=True,
        truncation=True,
    )

    inputs = tokenized_inputs.input_ids

    tokenized_targets = tokenizer(
        record["targets_pretokenized"],
        return_tensors="pt",
        return_offsets_mapping=True,
    )
    targets = tokenized_targets.input_ids
#     targets[targets == tokenizer.pad_token_id] = -100

#     if not FLAGS.include_eos:
#         targets = targets[:, :-1]

    # Fixme: it seems like we discard more than eos
    # But this was the scores in the paper
    output = {"inputs": inputs, "targets": targets[:, :-1]} # last token is eos </s>

    # if FLAGS.use_entity_locations and "obj_surface" in record:
    #     entity_locations = find_entity_locations(record)
    #     entity_indices = find_entity_indices(
    #         tokenizer, entity_locations, tokenized_inputs, tokenized_targets
    #     )
    #     output["entity_indices"] = entity_indices
    return output

LOAD_ACCUMS = True
def load_mt5(checkpoint_folder="/home/xxx/data/LLM/tracing/finetune/checkpoint-30000/"):
    # checkpoint_folder = "/home/xxx/data/LLM/tracing/finetune/checkpoint-30000/"#FLAGS.checkpoint_folder
    checkpoint_name = pathlib.PurePath(checkpoint_folder).name

    if not LOAD_ACCUMS:
        model = MT5ForConditionalGeneration.from_pretrained(
            checkpoint_folder, local_files_only=True
        ).cuda()
    elif LOAD_ACCUMS and "finetune" in checkpoint_folder:
        logging.info("loading accumulators")
        model = load_model_with_accum(
            checkpoint_folder
        )
    else:
        model = MT5ForConditionalGeneration.from_pretrained(
            checkpoint_folder, local_files_only=True
        ).cuda()
        logging.info("loading accumulators")
        accum = MT5ForConditionalGeneration.from_pretrained(
            checkpoint_folder.replace("_model_", "_accum_"),
            local_files_only=True,
        )
        model.accums = {}
        for (k, v) in accum.named_parameters():
            model.accums[k] = (
                (torch.sqrt(v.data) + 1e-7).flatten().cuda()
            )

    model.eval()
    
    return model


def load_model_with_accum(checkpoint_folder):
    model = MT5ForConditionalGeneration.from_pretrained(
        checkpoint_folder, local_files_only=True
    ).cuda()

    model.accums = {}

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=4e-5)
    optim_dict = torch.load(f"{checkpoint_folder}/optimizer.pt")

    optimizer.load_state_dict(optim_dict)

    for group in optimizer.param_groups:
        for p in group["params"]:
            state = optimizer.state[p]
            exp_avg_sq = state["exp_avg_sq"]
            # print(group["eps"])
            denom = exp_avg_sq.sqrt().add_(1e-7)
            p.accum = denom.cuda()

    for name, p in model.named_parameters():
        model.accums[name] = p.accum.flatten()
        p.accum = None

    return model

def extract_masked_sentence(abstract: str, term="<extra_id_0>"):
    term_start = abstract.find(term)
    assert term_start > -1
    term_end = term_start + len(term)
    sentence_start = abstract.rfind(". ", 0, term_start)
    if sentence_start == -1:
        sentence_start = 0
    else:
        sentence_start += 2
    sentence_end = abstract.find(". ", term_end)
    if sentence_end == -1:
        sentence_end = abstract.find(".", term_end)
    sentence_end = min(sentence_end + 1, len(abstract))
    return abstract[sentence_start:sentence_end]


def get_tokenized_query(record, extract=False):
    answer = record["targets_pretokenized"].replace("<extra_id_0> ", "")
    text = record["inputs_pretokenized"]
    if extract:
        text = extract_masked_sentence(text)
    text = text.replace("<extra_id_0>", answer).split(" ")
    return text


def get_target_equivalence_classes(abstracts):
    target_equivariance_indices = {}
    for (i, abstract) in enumerate(abstracts):
        target = (
            abstract["targets_pretokenized"]
            .replace("<extra_id_0> ", "")
            .strip()
            .lower()
        )
        if target in target_equivariance_indices:
            target_equivariance_indices[target].append(i)
        else:
            target_equivariance_indices[target] = [i]
    return target_equivariance_indices


def get_target_ids(target_ids_hashmap, record):
    target = (
        record["targets_pretokenized"]
        .replace("<extra_id_0> ", "")
        .strip()
        .lower()
    )
    return target_ids_hashmap.get(target, [0])


def build_candidate_previous_id(tokenizer, bm25_eval, abstracts, example, fact_to_ids):
    ## include the random
    MAX = 10000
    n = 4000 # size of the subset [100 samples]
    ## include the groud truth [step 1]
    fact = ",".join((example["predicate_id"], example["obj_uri"], example["sub_uri"]))
    fact_ids = list(map(str, fact_to_ids.get(fact, [])))
    # create a filter function to select the samples with the specified IDs
    def id_filter(sample):
        return sample['id'] in fact_ids
    abstracts_filtered = abstracts.filter(id_filter)
    ## include groundtruth labeled samples [step 2]
    # create a filter function to select the samples with the same target
    def target_filter(sample):
        return sample['targets_pretokenized'] in [example['targets_pretokenized']]
    abstracts_target_filtered = abstracts.filter(target_filter)
    if len(abstracts_target_filtered) > n:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), n)
    else:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), len(abstracts_target_filtered))
    abstracts_target_filtered = abstracts_target_filtered.select(subset_indices)
    ## get the neighbor_ids based on the prepared BM25 [step 3]
    ## include NNs from the BM25 [100 samples]
    query_bm25 = get_tokenized_query(example)
    scores = bm25_eval.get_scores(query_bm25)
    idxs = np.argpartition(scores, -n)[-n :] ## the last 100 elements are the largest elements
    nn_idxs = idxs[np.argsort(-scores[idxs])] ## sort the score in the descending order(- sign) and return the args, and rearange
    abstracts_bm25 = abstracts.select(nn_idxs)
    ## include the random [step 4]
    subset_indices = random.sample(range(len(abstracts)), max(1, MAX-len(abstracts_filtered)-len(abstracts_target_filtered)-len(abstracts_bm25)))
    abstracts_random_eval = abstracts.select(subset_indices)
    ## combine ground truth, bm25, and random sets
    abstracts_combined = concatenate_datasets([abstracts_random_eval, abstracts_filtered, abstracts_target_filtered, abstracts_bm25])

    # shuffle
    abstract_ids = random.sample(range(len(abstracts_combined)), len(abstracts_combined))
    abstracts_combined = abstracts_combined.select(abstract_ids)

    abstracts_ours = [tokenize(tokenizer, a) for a in abstracts_combined]
    return abstracts_combined, abstracts_ours # return the built candidates set


def build_candidate_previous(bm25_eval, abstracts, example, fact_to_ids):
    ## include the random
    MAX = 10000
    n = 4000 # size of the subset [100 samples]
    ## include the groud truth [step 1]
    fact = ",".join((example["predicate_id"], example["obj_uri"], example["sub_uri"]))
    fact_ids = list(map(str, fact_to_ids.get(fact, [])))
    # create a filter function to select the samples with the specified IDs
    def id_filter(sample):
        return sample['id'] in fact_ids
    abstracts_filtered = abstracts.filter(id_filter)
    ## include groundtruth labeled samples [step 2]
    # create a filter function to select the samples with the same target
    def target_filter(sample):
        return sample['targets_pretokenized'] in [example['targets_pretokenized']]
    abstracts_target_filtered = abstracts.filter(target_filter)
    if len(abstracts_target_filtered) > n:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), n)
    else:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), len(abstracts_target_filtered))
    abstracts_target_filtered = abstracts_target_filtered.select(subset_indices)
    ## get the neighbor_ids based on the prepared BM25 [step 3]
    ## include NNs from the BM25 [100 samples]
    query_bm25 = get_tokenized_query(example)
    scores = bm25_eval.get_scores(query_bm25)
    idxs = np.argpartition(scores, -n)[-n :] ## the last 100 elements are the largest elements
    nn_idxs = idxs[np.argsort(-scores[idxs])] ## sort the score in the descending order(- sign) and return the args, and rearange
    abstracts_bm25 = abstracts.select(nn_idxs)
    ## include the random [step 4]
    subset_indices = random.sample(range(len(abstracts)), max(1, MAX-len(abstracts_filtered)-len(abstracts_target_filtered)-len(abstracts_bm25)))
    abstracts_random_eval = abstracts.select(subset_indices)
    ## combine ground truth, bm25, and random sets
    abstracts_combined = concatenate_datasets([abstracts_random_eval, abstracts_filtered, abstracts_target_filtered, abstracts_bm25])

    # shuffle
    abstract_ids = random.sample(range(len(abstracts_combined)), len(abstracts_combined))
    abstracts_combined = abstracts_combined.select(abstract_ids)

    abstracts_ours = [tokenize(tokenizer, a) for a in abstracts_combined]
    return abstracts_combined, abstracts_ours


def build_candidate_bert(bm25_eval, abstracts, example, fact_to_ids):
    ## include the random
    MAX = 10000
    n = 4000 # size of the subset [100 samples]
    ## include the groud truth [step 1]
    fact = ",".join((example["predicate_id"], example["obj_uri"], example["sub_uri"]))
    fact_ids = list(map(str, fact_to_ids.get(fact, [])))
    # create a filter function to select the samples with the specified IDs
    def id_filter(sample):
        return sample['id'] in fact_ids
    abstracts_filtered = abstracts.filter(id_filter)
    ## include groundtruth labeled samples [step 2]
    # create a filter function to select the samples with the same target
    def target_filter(sample):
        return sample['targets_pretokenized'] in [example['targets_pretokenized']]
    abstracts_target_filtered = abstracts.filter(target_filter)
    if len(abstracts_target_filtered) > n:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), n)
    else:
        subset_indices = random.sample(range(len(abstracts_target_filtered)), len(abstracts_target_filtered))
    abstracts_target_filtered = abstracts_target_filtered.select(subset_indices)
    ## get the neighbor_ids based on the prepared BM25 [step 3]
    ## include NNs from the BM25 [100 samples]
    query_bm25 = get_tokenized_query(example)
    scores = bm25_eval.get_scores(query_bm25)
    idxs = np.argpartition(scores, -n)[-n :] ## the last 100 elements are the largest elements
    nn_idxs = idxs[np.argsort(-scores[idxs])] ## sort the score in the descending order(- sign) and return the args, and rearange
    abstracts_bm25 = abstracts.select(nn_idxs)
    ## include the random [step 4]
    subset_indices = random.sample(range(len(abstracts)), max(1, MAX-len(abstracts_filtered)-len(abstracts_target_filtered)-len(abstracts_bm25)))
    abstracts_random_eval = abstracts.select(subset_indices)
    ## combine ground truth, bm25, and random sets
    abstracts_combined = concatenate_datasets([abstracts_random_eval, abstracts_filtered, abstracts_target_filtered, abstracts_bm25])

    # shuffle
    abstract_ids = random.sample(range(len(abstracts_combined)), len(abstracts_combined))
    abstracts_combined = abstracts_combined.select(abstract_ids)
    return abstracts_combined