import pickle as pkl
import sys
sys.path.append("/path/to/inputs/DPR")
import tqdm
from transformers import BertTokenizer
from dpr.models import init_biencoder_components
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import logging
import numpy as np
import os
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    get_encoder_params_state_from_cfg,
    set_cfg_params_from_state,
    setup_logger,
)
from dpr.utils.conf_utils import BiencoderDatasetsCfg
from dpr.utils.model_utils import (
    setup_for_distributed_mode,
    move_to_device,
    get_schedule_linear,
    CheckpointState,
    get_model_file,
    get_model_obj,
    load_states_from_checkpoint,
)
from dpr.models.biencoder import BiEncoderNllLoss
import torch.nn.functional as F
import sys
import warnings
import pickle as pkl
from transformers import logging
logging.set_verbosity(logging.FATAL)
warnings.filterwarnings("ignore")



def validate_main(cfg):

    model_file = get_model_file(cfg, cfg.checkpoint_file_name)
    ds_cfg = BiencoderDatasetsCfg(cfg)
    ds_cfg = ds_cfg.dev_datasets[0]
    saved_state = None
    if model_file:
        saved_state = load_states_from_checkpoint(model_file)
    print(model_file)
    tensorizer, model, optimizer = init_biencoder_components(cfg.encoder.encoder_model_type, cfg)
    model.cuda()
    btokenizer = BertTokenizer.from_pretrained(cfg.encoder.pretrained_model_cfg)

    sys.path.append("/path/to/inputs/Transformer-Patcher/")
    state_dict = torch.load(cfg.model_file)
    new_layer_size = state_dict["model_dict"]["ctx_model.encoder.layer.11.output.dense.weight"].shape[1]
    model.question_model.encoder.layer[11].intermediate.dense = torch.nn.Linear(768, new_layer_size).cuda()
    model.question_model.encoder.layer[11].output.dense = torch.nn.Linear(new_layer_size, 768).cuda()
    model.ctx_model.encoder.layer[11].intermediate.dense = torch.nn.Linear(768, new_layer_size).cuda()
    model.ctx_model.encoder.layer[11].output.dense = torch.nn.Linear(new_layer_size, 768).cuda()
    # model.question_model.load_state_dict(state_dict.model.model.model.state_dict(), strict=False)
    # model.ctx_model.load_state_dict(state_dict.model.model.model.state_dict(), strict=False)

    if saved_state:
        model_to_load = get_model_obj(model)
        model_to_load.load_state(saved_state, strict=True)

    batches = pkl.load(open("/path/to/inputs/all_batchesv2.pkl", "rb"))
    all_questions = []
    all_hidden = []
    all_attentions = []
    topk_performances = [[0 for _ in [1, 5, 10, 20, 30, 40]] for _ in range(13)]
    all_score_agreements = []
    with tqdm.tqdm(total=len(batches)) as pbar:
        for batch in batches:
            if "what part of the carrot is the seed" not in batch[0]:
                continue
            encoded_questions = btokenizer.batch_encode_plus(batch[0], return_tensors="pt", padding="max_length", max_length=256)
            encoded_ctx = [btokenizer.encode_plus(batch[1][i][0], text_pair=batch[1][i][1], return_tensors="pt", padding="max_length", truncation=True, max_length=256) for i in range(len(batch[1]))]
            encoded_ctx_ids = torch.vstack([i["input_ids"] for i in encoded_ctx])
            encoded_ctx_masks = torch.vstack([i["attention_mask"] for i in encoded_ctx])
            encoded_questions["input_ids"][:, -1] = btokenizer.sep_token_id
            encoded_questions["attention_mask"][:, -1] = 1
            encoded_ctx_ids[:, -1] = btokenizer.sep_token_id
            encoded_ctx_masks[:, -1] = 1
            rep_positions = ds_cfg.selector.get_positions(encoded_questions["input_ids"], tensorizer)
            positive_idx_per_question = batch[3]
            encoder_type = ds_cfg.encoder_type

            with torch.no_grad():
                model_out = model(
                    encoded_questions["input_ids"].cuda(),
                    batch[-2].cuda(),
                    encoded_questions["attention_mask"].cuda().bool(),
                    encoded_ctx_ids.cuda(),
                    batch[-1].cuda(),
                    encoded_ctx_masks.cuda().bool(),
                    encoder_type=encoder_type,
                    representation_token_pos=rep_positions,
                    output_attentions=True
                )

            # have to comparegenerated attention makss and if that doesn't align local_q_vectors
            local_q_vector, local_ctx_vectors, local_q_hidden, local_ctx_hidden, _, _ = model_out

            score_agreements = []

            for i in range(len(local_q_hidden)):
                score_agreement, topk_performance = vector_agreements(local_q_hidden[i][:, 0, :], local_ctx_hidden[i][:, 0, :], positive_idx_per_question)
                score_agreements.append(score_agreement)
                for j, k in enumerate(topk_performance):
                    topk_performances[i][j] += k

            score_agreement, topk_performance = vector_agreements(local_q_vector, local_ctx_vectors, positive_idx_per_question)
            score_agreements.append(score_agreement)
            for j, k in enumerate(topk_performance):
                topk_performances[-1][j] += k


            local_q_hidden = [i.cpu() for i in local_q_hidden]

            all_hidden.append(local_q_hidden)
            # all_attentions.append([i.cpu() for i in local_q_attentions])
            all_score_agreements.append(score_agreements)
            all_questions.extend(batch[0])
            pbar.update(1)

    # all_hidden,
    os.makedirs(cfg.output_dir, exist_ok=True)
    # pkl.dump([all_score_agreements, all_hidden, all_questions], open(cfg.output_dir + "/cls_embeddings_edited_bert_nq"  ".pkl", "wb"))  # + model_file.split("layers_")[1].split("/")[0] +
    pkl.dump([all_score_agreements, all_questions], open(cfg.output_dir + "/cls_embeddings_edited_bert_nq"  ".pkl", "wb"))  # + model_file.split("layers_")[1].split("/")[0] +


def vector_agreements(q_vector, ctx_vector, positive_idx_per_question):
    scores = BiEncoderNllLoss.get_scores(q_vector, ctx_vector)
    if len(q_vector.size()) > 1:
        q_num = q_vector.size(0)
        scores = scores.view(q_num, -1)

    softmax_scores = F.log_softmax(scores, dim=1)
    max_score, max_idxs = torch.max(softmax_scores, 1)

    topk_performance = []
    for k in [1, 5, 10, 20, 30, 40]:
        _, topk_indices = torch.topk(softmax_scores, k)
        topk_performance.append(sum([positive_idx_per_question[i] in j for i, j in enumerate(topk_indices)]))
    score_agreement = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device))
    return score_agreement, topk_performance

@hydra.main(config_path="/path/to/inputs/DPR/conf", config_name="biencoder_train_cfg")
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    set_seed(cfg)

    # if cfg.local_rank in [-1, 0]:
    #     logger.info("CFG (after gpu  configuration):")
    #     logger.info("%s", OmegaConf.to_yaml(cfg))

    model_files = [
                   # "/path/to/inputs/DPR/outputs/2023-08-29/18-44-07/checkpoints/dpr_biencoder.29",  # Default trained
                   # "/path/to/inputs/DPR/outputs/2023-12-27/23-44-31/checkpoint_edited_bert/dpr_biencoder.29",  knowledge_neuron_edited
                    # "/path/to/inputs/DPR/outputs/2024-01-23/00-30-41/checkpoints_transformer_patch_edited_trained/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-01-23/00-28-48/checkpoints_transformer_patch_edited_pretrained/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-01-25/12-27-46/checkpoints_malmen_edited/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-01-25/12-28-23/checkpoints_mend_edited/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-01-26/00-51-21/checkpoints_transformerpatch_bert_edited_remove_facts/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-01-26/00-53-28/checkpoints_transformerpatch_pretrain_bert_edited_remove_facts/dpr_biencoder.29",
                    # "/path/to/inputs/DPR/outputs/2024-02-07/12-12-14/checkpoints_malmen_pretrain_bert_edited_remove_facts/dpr_biencoder.29",
                    "/path/to/inputs/DPR/outputs/2024-02-07/12-12-22/checkpoints_mend_pretrain_bert_edited_remove_facts/dpr_biencoder.29",
    ]
    for model_file in model_files:
        cfg.model_file = model_file
        cfg.output_dir = model_file.split("/")[-2]
        validate_main(cfg)


if __name__ == "__main__":
    # logger.info("Sys.argv: %s", sys.argv)
    hydra_formatted_args = []
    # convert the cli params added by torch.distributed.launch into Hydra format
    for arg in sys.argv:
        if arg.startswith("--"):
            hydra_formatted_args.append(arg[len("--") :])
        else:
            hydra_formatted_args.append(arg)
    # logger.info("Hydra formatted Sys.argv: %s", hydra_formatted_args)
    sys.argv = hydra_formatted_args

    main()
