# PYTHONPATH='.' python playground/fisher/compute_fisher_information.py 

import torch

from modules.bart import MyBart

# Difference from https://github.com/tuvuumass/task-transferability/
# - there is no [CLS] for classification, and there is no task-specific heads.
# - this version computes fisher for (1) parameters in the model; (2) outputs after each layer

def compute_fisher(model, batch):
    # computes fisher information of a model w.r.t. the given batch
    # the model should be a BARTForConditionalGeneartion model
    # a batch has four components: input_ids, attention_mask, decoder_input_ids, decoder_attention_mask

    # total number of non <pad> tokens in this batch. used to normalize layer_outputs
    total_tokens = batch[1].sum().item()

    # Step 1: forward and backprop on the model using the batch
    if torch.cuda.is_available():
        batch = [b.to(torch.device("cuda")) for b in batch]

    # pad_token_id = model.config.pad_token_id
    # batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
    # batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

    loss, encoder_hidden_states, decoder_hidden_states = model(input_ids=batch[0], attention_mask=batch[1],
                decoder_input_ids=batch[2], decoder_attention_mask=batch[3], output_hidden_states=True, 
                is_training=True)

    model.zero_grad()
    loss.backward()

    # Step 2.1: compute fisher w.r.t. model parameters
    outputs = {}

    for name, parameter in model.model.named_parameters():
        if parameter.requires_grad:
            score = parameter.grad # defaulting to grads
            if score is not None and name not in outputs:
                score = score ** 2
                outputs[name] = score
            
    
    # Step 2.2: compute fisher w.r.t. layer output
    for i in range(model.config.encoder_layers):
        name = "encoder.layer.{}.{}".format(i, "layer_output")
        model_outputs_i = encoder_hidden_states[i]
        score = torch.einsum("ijk,ij->ijk", [model_outputs_i, batch[1].float()])

        score = score.sum(0).sum(0)
        score = score ** 2
        score = score / total_tokens
        outputs[name] = score

    for i in range(model.config.decoder_layers):
        name = "decoder.layer.{}.{}".format(i, "layer_output")
        model_outputs_i = decoder_hidden_states[i]
        score = torch.einsum("ijk,ij->ijk", [model_outputs_i, batch[1].float()])

        score = score.sum(0).sum(0)
        score = score ** 2
        score = score / total_tokens
        outputs[name] = score

    return outputs

def main():
    model = MyBart.from_pretrained("facebook/bart-base")
    model.train()
    if torch.cuda.is_available():
        model.cuda()

    # some dummy batch
    input_ids = torch.tensor([[1,2,3],[2,3,4]])
    attention_mask = torch.tensor([[1,1,1],[1,1,0]])

    decoder_input_ids = torch.tensor([[4,5,6],[7,8,9]])
    decoder_attention_mask = torch.tensor([[1,1,0],[1,1,1]])

    batch = (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)

    outputs = compute_fisher(model, batch)

if __name__ == "__main__":
    main()