"""
Contains evaluation utilities for pytorch-based rewriting methods.
To use, simply call `compute_rewrite_quality_one_hop` with the
appropriate arguments, which returns a dictionary containing them.
"""

import typing
from itertools import chain

import nltk
import numpy as np
import scipy
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsets import AttributeSnippets
from util.generate import generate_fast
from util.perplexity import perplexity

def slice_list(matrix, start_indices, left):
    if isinstance(matrix[0], list):
        if left:
            return [row[start_index-1:-1] for row, start_index in zip(matrix, start_indices)]
        else:
            return [row[start_index:] for row, start_index in zip(matrix, start_indices)]
    else:
        if left:
            return matrix[start_indices[0]-1:-1]
        else:
            return matrix[start_indices[0]:]

def test_prediction_acc(model, tok, prompts, targets, locality=False):
    if isinstance(prompts, str):
        prompts, targets = [prompts,], [targets,]
    prompt_target = [prompt + ' ' + target for prompt,
                     target in zip(prompts, targets)]
    max_prompt_len = max([len(tok.encode(_)) for _ in prompt_target]) + 1
    prompt_target_tok = tok(
        prompt_target,
        padding=True,
        truncation=True,
        max_length=max(40, max_prompt_len),
        return_tensors="pt",
    ).to(f"cuda:0")
    prompt_tok = tok(
        prompts,
        padding=True,
        truncation=True,
        max_length=max(40, max_prompt_len),
        return_tensors="pt",
    )
    num_prompt_toks = [int((i != tok.pad_token_id).sum())
                       for i in prompt_tok['input_ids']]
    num_pad_toks = [int((i == tok.pad_token_id).sum())
                    for i in prompt_target_tok['input_ids'].cpu()]
    prompt_len = [x+y for x, y in zip(num_pad_toks, num_prompt_toks)]
    with torch.no_grad():
        outputs = model(**prompt_target_tok)
        if type(outputs) is torch.Tensor:
            logits = outputs
        else:
            logits = outputs.logits
        answers = torch.argmax(
            logits, dim=-1).squeeze().detach().cpu().numpy().tolist()
        labels = prompt_target_tok['input_ids'].squeeze(
        ).detach().cpu().numpy().tolist()
        answers = slice_list(answers, prompt_len, left=True)
        labels = slice_list(labels, prompt_len, left=False)
        if locality:
            return answers if type(answers[0]) is list else [answers,]
        if isinstance(answers[0], list):
            res = []
            for ans, label in zip(answers, labels):
                temp_acc = np.mean(np.equal(ans, label))
                if np.isnan(temp_acc):
                    continue
                res.append(temp_acc)
            return res
        else:
            return [np.mean(np.equal(answers, labels))]


def compute_rewrite_quality_one_hop(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
) -> typing.Dict:
    """
    Given a rewritten model, computes generalization and specificity metrics for
    the desired rewrite (passed in via the CounterFact dataset record). Returns a
    dictionary containing those metrics.

    :param model: Rewritten model
    :param tok: Tokenizer
    :param record: CounterFact dataset record
    :paran snips: ???
    :param vec: ???

    :return: Dictionary containing rewriting metrics
    """

    # First, unpack rewrite evaluation record.
    prompt = ["{} The answer to this question, most simply, is".format(
        record["portability"]["New Question"])]
    target = record["portability"]["New Answer"]

    # Structure the restuls as a dictionary.
    print("TESTING_PORTABILITY")
    print("PORTABILITY_PROMPTS:{}".format(prompt))

    gen_texts = generate_fast(
        model,
        tok,
        prompt,
        n_gen_per_prompt=1,
        max_out_len=100,
    )

    print(type(gen_texts[0]))
    print("GEN_TEXTS:{}".format(gen_texts[0]))

    ret = {
        "portabiity_acc": test_prediction_acc(model, tok, prompt, [target]),
        "portability_prompt": prompt,
        "portability_target": target,
        "text": gen_texts[0],
    }

    return ret
