import json
import argparse
import torch
from experiment_running.utils import retrieve_models, initialize_seeds
from utils import read_file
from typing import Dict
from collections import defaultdict

MASK_TOKEN = "[MASK] "


def get_masked_sentence(sample_sentence: str, cont: str, num_masks: int) -> [str, str]:

    sample_sentence += MASK_TOKEN * num_masks + '.'
    return sample_sentence.replace('  ', ' ').replace('..', '.').replace(' .', '.'), cont


def retrieve_probabilities(model, tokenizer, sentence, cont):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tok_masked = tokenizer(sentence, return_tensors='pt').to(device)
    outputs = model(**tok_masked).logits
    masked_outputs = outputs[tok_masked['input_ids'] == 103].softmax(dim=-1)
    cont_tok = tokenizer.encode(cont)[1:-1]
    return masked_outputs[:, cont_tok].diag().prod().cpu().item() / masked_outputs.max(dim=-1).values.prod().item()


def main(input_path: str, output_path:str, hf_cache: str, models: str, seed_num: int = 42) -> None:

    initialize_seeds(seed_num)
    model_list = models.split(',')

    samples = read_file(input_path)
    sentences = defaultdict(lambda: defaultdict(lambda: dict()))

    with torch.no_grad():
        for model_name in model_list:

            print(model_name)
            model, tokenizer = retrieve_models(model_name, hf_cache, pretraining=False)
            for sample in samples:

                if sample['answer'] == 'true':
                    continue

                sid = sample['sent_id'].split('_')[0]

                if "_1" in sample['sent_id']:
                    num_masks = len(tokenizer.encode(sample['correct_cont'])) - 2
                    masked_sentence, correct_cont = get_masked_sentence(sample['sentence_cont'], sample['correct_cont'],
                                                                        num_masks)
                    sentences[sid][model_name][correct_cont] = retrieve_probabilities(model, tokenizer, masked_sentence,
                                                                                 correct_cont)

                num_masks = len(tokenizer.encode(sample['incorrect_cont'].replace('the ', ''))) - 2
                masked_sentence, incorrect_cont = get_masked_sentence(sample['sentence_cont'], sample['incorrect_cont'],
                                                                      num_masks)
                sentences[sid][model_name][incorrect_cont] = retrieve_probabilities(model, tokenizer, masked_sentence,
                                                                               incorrect_cont)

            with open(output_path, 'w') as f:
                json.dump(sentences, f)


if __name__ == "__main__":

    parser = argparse.ArgumentParser("LLM llm_pretest")
    parser.add_argument('-i', '--input_path', type=str, help="Path to where our llm_data is kept")
    parser.add_argument('-o', '--output_path', type=str, help="Path to where we want to keep the results")
    parser.add_argument('-m', '--models', type=str, help="Comma separated list of the models we want to use")
    parser.add_argument('-c', '--hf_cache', type=str, help="Path to where we want to cache the results")
    parser.add_argument('-s', '--seed_num', type=int, help="Seed to use in our experiments", default=42)
    args = parser.parse_args()
    main(**vars(args))