import copy
import pickle as pkl
import sys
sys.path.append("/path/to/inputs/DPR")
import tqdm
from transformers import BertTokenizer
from dpr.models import init_biencoder_components
import hydra
import torch
import math
from omegaconf import DictConfig, OmegaConf
import logging
import numpy as np
import os
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    get_encoder_params_state_from_cfg,
    set_cfg_params_from_state,
    setup_logger,
)
from dpr.utils.conf_utils import BiencoderDatasetsCfg
from dpr.utils.model_utils import (
    setup_for_distributed_mode,
    move_to_device,
    get_schedule_linear,
    CheckpointState,
    get_model_file,
    get_model_obj,
    load_states_from_checkpoint,
)
from dpr.models.biencoder import BiEncoderNllLoss
import torch.nn.functional as F
import sys
import warnings
import pickle as pkl
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import logging
logging.set_verbosity(logging.FATAL)
warnings.filterwarnings("ignore")
ceil = math.ceil


def setup(rank, world_size, master_addr, master_port):
    print(f"Setting up rank: {rank}")
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = str(master_port)
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    print(f"Rank {rank} is setup")

def cleanup():
    dist.destroy_process_group()

def scaled_input(emb, batch_size, num_batch):
    # emb: (1, ffn_size)
    baseline = torch.zeros_like(emb)  # (1, ffn_size)

    num_points = batch_size * num_batch
    step = (emb - baseline) / num_points  # (1, ffn_size)

    res = torch.cat([torch.add(baseline, step * i) for i in range(num_points)], dim=0)  # (num_points, ffn_size)
    return res, step[0]


saved_layer_output = []
replace_output = []

def specific_hook_fn(module, input_, output):
    # print(f"Inside {module.__class__.__name__} forward")
    # print(f"Output Shape: {output.shape}\n")
    # first pass for an example we just store the output.
    # the next passes for the input we replace the output with steps to the output
    if not replace_output:
        global saved_layer_output
        saved_layer_output = torch.unbind(output[:, 0, :])
    else:
        # return replace_output.pop(0).view(1, -1)
        try:
            output[:, 0, :] = replace_output[0]
        except RuntimeError:
            import ipdb; ipdb.set_trace()
        return output

# def register_specific_hooks(model, specific_string):
#     for name, module in model.named_modules():
#         if specific_string in name:
#             module.register_forward_hook(specific_hook_fn)



def validate_main(rank, machine_rank, world_size, master_addr, master_port, cfg):
    # setup(rank+machine_rank, world_size, master_addr, master_port)
    training_data = False
    baseline = False
    intermediate = True
    qmodel = True
    model_file = get_model_file(cfg, cfg.checkpoint_file_name) if not baseline else None
    ds_cfg = BiencoderDatasetsCfg(cfg)
    ds_cfg = ds_cfg.dev_datasets[0]
    if model_file and not baseline:
        saved_state = load_states_from_checkpoint(model_file)
    else:
        saved_state = None
    if rank == 0:
        print(model_file)

    tensorizer, model, optimizer = init_biencoder_components(cfg.encoder.encoder_model_type, cfg)
    btokenizer = BertTokenizer.from_pretrained(cfg.encoder.pretrained_model_cfg)
    model = model.to(rank)
    model.question_model.encoder.gradient_checkpointing = True
    model.ctx_model.encoder.gradient_checkpointing = True

    if saved_state and not baseline:
        model_to_load = get_model_obj(model)
        model_to_load.load_state(saved_state, strict=True)


    k = 64  # for full batch
    # k=16  # for context model
    num_steps = 5
    batches = pkl.load(open("/path/to/inputs/DPR/all_batches.pkl", "rb")) if not training_data else pkl.load(open("/path/to/inputs/DPR/all_train_batches.pkl", "rb"))
    batches = batches[rank::world_size]
    for a in range(len(model.question_model.encoder.layer if qmodel else model.ctx_model.encoder.layer)):
        hook = (model.question_model.encoder.layer[a].intermediate.dense.register_forward_hook(specific_hook_fn) if intermediate else
                model.question_model.encoder.layer[a].output.dense.register_forward_hook(specific_hook_fn)) if qmodel else (model.ctx_model.encoder.layer[a].intermediate.dense.register_forward_hook(specific_hook_fn) if intermediate else model.ctx_model.encoder.layer[a].output.dense.register_forward_hook(specific_hook_fn))
        for name, param in model.named_parameters():
            if name != f"{'question_model' if qmodel else 'ctx_model'}.encoder.layer.{a}.{'intermediate' if intermediate else 'output'}.dense.weight":  # ctx_model
                param.requires_grad = False

        per_example_layer_sensitivity = []
        question = []
        with tqdm.tqdm(total=len(batches), position=rank) as pbar:
            for batch in batches:
                for j in range(ceil(len(batch[0])/k)):
                    encoded_questions = btokenizer.batch_encode_plus(batch[0][j*k:(j+1)*k], return_tensors="pt", padding="max_length", max_length=256)
                    encoded_ctx = [btokenizer.encode_plus(batch[1][i][0], text_pair=batch[1][i][1], return_tensors="pt", padding="max_length", truncation=True, max_length=256) for i in range(len(batch[1][j*k*13:(j+1)*k*13]))]
                    encoded_ctx_ids = torch.vstack([i["input_ids"] for i in encoded_ctx])
                    encoded_ctx_masks = torch.vstack([i["attention_mask"] for i in encoded_ctx])
                    encoded_questions["input_ids"][:, -1] = btokenizer.sep_token_id
                    encoded_questions["attention_mask"][:, -1] = 1
                    encoded_ctx_ids[:, -1] = btokenizer.sep_token_id
                    encoded_ctx_masks[:, -1] = 1
                    rep_positions = ds_cfg.selector.get_positions(encoded_questions["input_ids"], tensorizer)
                    positive_idx_per_question = batch[3]
                    encoder_type = ds_cfg.encoder_type
                    with torch.no_grad():
                        _ = model(
                            encoded_questions["input_ids"].to(rank),
                            batch[-2][j*k:(j+1)*k].to(rank),
                            encoded_questions["attention_mask"].to(rank).bool(),
                            encoded_ctx_ids.to(rank),
                            batch[-1][j*k*13:(j+1)*k*13].to(rank),
                            encoded_ctx_masks.to(rank).bool(),
                            encoder_type=encoder_type,
                            representation_token_pos=rep_positions,
                            output_attentions=False
                        )
                    global saved_layer_output
                    weight_changes = [scaled_input(saved_layer_output[i].view(1, -1), num_steps, 1) for i in range(len(saved_layer_output))]
                    batch_weight_steps = [i[1] for i in weight_changes]
                    unbound_weights = [i[0].unbind(dim=0) for i in weight_changes]
                    batch_scaled_weights = [torch.vstack([row[i] for row in unbound_weights]).requires_grad_(True) for i in range(len(unbound_weights[0]))]
                    replace_output.extend(batch_scaled_weights)
                    all_grads = []
                    for i in range(num_steps):
                        model_out = model(
                            encoded_questions["input_ids"].to(rank),
                            batch[-2][j * k:(j + 1) * k].to(rank),
                            encoded_questions["attention_mask"].to(rank).bool(),
                            encoded_ctx_ids.to(rank),
                            batch[-1][j * k * 13:(j + 1) * k * 13].to(rank),
                            encoded_ctx_masks.to(rank).bool(),
                            encoder_type=encoder_type,
                            representation_token_pos=rep_positions,
                            output_attentions=False
                        )

                        # have to compare generated attention makes and if that doesn't align local_q_vectors
                        local_q_vector, local_ctx_vectors, _, _, _, _ = model_out
                        loss, score_agreement = vector_agreements(local_q_vector, local_ctx_vectors, positive_idx_per_question, len(batch[0][j*k:(j+1)*k]))
                        all_grads.append(torch.autograd.grad(loss, replace_output.pop(0))[0].unbind(dim=0))
                        model.zero_grad(set_to_none=True)
                    all_grads = [torch.stack([all_grads[a][i] for a in range(len(all_grads))]) for i in range(len(all_grads[0]))]
                    per_example_layer_sensitivity.extend([(all_grads[i].sum(dim=0) * batch_weight_steps[i]).cpu() for i in range(len(all_grads))])
                    saved_layer_output = []
                    question.extend(batch[0][j * k:(j + 1) * k])
                pbar.update(1)
        hook.remove()
        pkl.dump([per_example_layer_sensitivity, question], open(f"/path/to/inputs/layer_sensitivities/{'baseline' if baseline else 'dpr_full_train_edited_removed'}/{'intermediate' if intermediate else 'output'}_knowledge_{'question_model' if qmodel else 'context_model'}_{a}{'_train_data' if training_data else ''}_rank_{rank+machine_rank}.pkl", "wb"))


def vector_agreements(q_vector, ctx_vector, positive_idx_per_question, size=80):
    scores = BiEncoderNllLoss.get_scores(q_vector, ctx_vector)
    if len(q_vector.size()) > 1:
        q_num = q_vector.size(0)
        scores = scores.view(q_num, -1)

    softmax_scores = F.log_softmax(scores, dim=1)
    loss = F.nll_loss(
        softmax_scores,
        torch.tensor(positive_idx_per_question[:size]).to(softmax_scores.device),
        reduction="sum",
    )

    max_score, max_idxs = torch.max(softmax_scores, 1)

    score_agreement = (max_idxs == torch.tensor(positive_idx_per_question[:size]).to(max_idxs.device))
    return loss, score_agreement

@hydra.main(config_path="/path/to/inputs/DPR/conf", config_name="biencoder_train_cfg")
def main(cfg: DictConfig):
    cfg = setup_cfg_gpu(cfg)
    set_seed(cfg)

    # if cfg.local_rank in [-1, 0]:
    #     logger.info("CFG (after gpu  configuration):")
    #     logger.info("%s", OmegaConf.to_yaml(cfg))

    model_files = [
                   # "/path/to/inputs/DPR/outputs/2023-08-29/18-44-07/checkpoints/dpr_biencoder.29",
    ]
    for model_file in model_files:
        cfg.model_file = model_file
    nprocs = torch.cuda.device_count()
    mp.spawn(validate_main,
             args=(nprocs*MACHINE_INDEX, WORLD_SIZE, MASTER_ADDRESS, PORT, cfg),
             nprocs=nprocs,
             join=True)
    # validate_main(0, MACHINE_INDEX, WORLD_SIZE, MASTER_ADDRESS, PORT, cfg)



if __name__ == "__main__":
    # logger.info("Sys.argv: %s", sys.argv)
    hydra_formatted_args = []
    # convert the cli params added by torch.distributed.launch into Hydra format
    for arg in sys.argv:
        if arg.startswith("--"):
            hydra_formatted_args.append(arg[len("--") :])
        else:
            hydra_formatted_args.append(arg)
    # logger.info("Hydra formatted Sys.argv: %s", hydra_formatted_args)
    sys.argv = hydra_formatted_args

    NUM_MACHINES = 1
    MACHINE_INDEX = 0
    WORLD_SIZE = 8
    # MASTER_ADDRESS = "130.207.232.57"
    MASTER_ADDRESS = "localhost"
    PORT = 8106  # + random.randint(-50, 50)

    main()
