''' Attribute style evals '''

import os
import sys
import numpy as np
import torch
import argparse
import json
import math
from tqdm import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    RobertaTokenizer, RobertaForSequenceClassification
)

from evaluate import load
from mutual_implication_score import MIS

MIS_MODEL = None 
COLA_TOKENIZER = None
COLA_MODEL = None
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batch_pairs(lists, batch_size=64):
    lengths = [len(l) for l in lists]
    assert len(set(lengths)) == 1, lengths

    num_elements = len(lists[0])
    batches = []
    idx = 0
    for i in range(math.ceil(num_elements / batch_size)):
        max_idx = idx + batch_size
        batches.append([l[idx:max_idx] for l in lists])
        idx = max_idx
    assert sum([len(x[0]) for x in batches]) == num_elements, (
        sum([len(x[0]) for x in batches]),
        num_elements,
    )
    return batches

def get_raw_mis_score(*, references, candidates, targets, aggregate=True):
    global MIS_MODEL
    if MIS_MODEL is None:
        MIS_MODEL = MIS(device=DEVICE)
    mis_score_transfer_source = []
    batched = batch_pairs([references, candidates, targets], batch_size=64)
    for b in tqdm(batched):
        refs, cands, targs = b
        mis_score_transfer_source.extend(MIS_MODEL.compute(refs, cands))

    if aggregate:
        return np.mean(mis_score_transfer_source)
    return mis_score_transfer_source

def get_perplexity_score(candidates, aggregate=True):
    perplexity = load("perplexity", module_type="metric")

    try:
        results = perplexity.compute(predictions=candidates, model_id='gpt2')
        if aggregate:
            return results['mean_perplexity']
    except IndexError:
        print('Perplexity failed, returning None')
        print(candidates)
        return None
        
    return results['perplexities']

def get_cola_score(candidates, aggregate=True):
    global COLA_TOKENIZER
    global COLA_MODEL
    global DEVICE

    if COLA_MODEL is None:
        COLA_TOKENIZER = RobertaTokenizer.from_pretrained('textattack/roberta-base-CoLA')
        COLA_MODEL = RobertaForSequenceClassification.from_pretrained(
             'textattack/roberta-base-CoLA'
        )
        COLA_MODEL.to(DEVICE)

    # transferred
    prbs = []
    batched = batch_pairs([candidates], batch_size=64)
    for b in batched:
        inputs = COLA_TOKENIZER(*b, return_tensors="pt", padding=True)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        outputs = COLA_MODEL(**inputs)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        prbs.extend(probs[:, 1].detach().cpu().tolist())
        # prbs.extend(torch.argmax(probs,-1).detach().cpu().tolist())

    if aggregate:
        return np.mean(prbs)
    return prbs

def load_internal_formality_model(path='/mnt/swordfish-pool2//emnlp_formality/models/roberta-base/roberta-base_5e-05_128_Entertainment_Music_train_formal-Entertainment_Music_train_informal/42/2024-05-19-19_53_30/checkpoint-1200'):
    tokenizer = AutoTokenizer.from_pretrained('roberta-base')
    model = AutoModelForSequenceClassification.from_pretrained(path)
    label_map = {'informal':0, 'formal':1}
    return model, tokenizer, label_map

def load_external_formality_model():
    tokenizer = AutoTokenizer.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier')
    model = AutoModelForSequenceClassification.from_pretrained('SkolkovoInstitute/xlmr_formality_classifier')
    label_map = {'formal':0, 'informal':1} # flipped
    return model, tokenizer, label_map

def get_attribute_acc(
    *, model, texts, target_idx, tokenizer, device='cuda', aggregate=True
):
    scores = []
    for text in texts:
        inputs = tokenizer([text], return_tensors="pt", padding=True)
        inputs.to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        scores.append(float(preds[0].item() == target_idx))
        # probs = torch.nn.functional.softmax(logits, dim=-1)
        # scores.append(probs[0][target_idx].item())
    if aggregate:
        return np.mean(scores)

    return scores

def run_formality_eval(*, references, candidates, target, device='cuda'):

    ctr_model, tokenizer, label_map = load_internal_formality_model()
    optimizing_label_index = label_map[target]
    ctr_model.to(device)
    ctr_model.eval()

    holdout_model, holdout_tokenizer, holdout_label_map = load_external_formality_model()
    holdout_optimizing_label_index = holdout_label_map[target]
    holdout_model.to(device)
    holdout_model.eval()

    retval = {}
    retval['perplexity'] = get_perplexity_score(candidates, aggregate=False)
    retval['cola'] = get_cola_score(candidates, aggregate=False)
    retval['accuracy'] = get_attribute_acc(
        model=ctr_model,
        texts=candidates,
        target_idx=optimizing_label_index,
        tokenizer=tokenizer,
        device=device,
        aggregate=False,
    )
    retval['holdout_accuracy'] = get_attribute_acc(
        model=holdout_model,
        texts=candidates,
        target_idx=holdout_optimizing_label_index,
        tokenizer=holdout_tokenizer,
        device=device,
        aggregate=False,
    )
    retval['similarity'] = get_raw_mis_score(
        references=references,
        candidates=candidates,
        targets=[None for _ in range(len(references))],
        aggregate=False,
    )
    retval['n'] = len(references)
    retval['holdout_joint_gm'] = np.mean(
        [
            (a * c * m) ** (1 / 3)
            for a, c, m in zip(
                retval['holdout_accuracy'], retval['cola'], retval['similarity']
            )
        ]
    )
    retval['cola'] = np.mean(retval['cola'])
    retval['accuracy'] = np.mean(retval['accuracy'])
    retval['similarity'] = np.mean(retval['similarity'])
    retval['holdout_accuracy'] = np.mean(retval['holdout_accuracy'])
    retval['perplexity_median'] = np.median(retval['perplexity'])
    retval['perplexity'] = np.mean(retval['perplexity'])



    return retval

if __name__ == '__main__':
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--input_path', type=str)
    argparser.add_argument('--target', type=str)
    argparser.add_argument('--is_chatgpt', action='store_true')

    args = argparser.parse_args()
    out_name = args.input_path + f'.{args.target}_eval'

    if os.path.exists(out_name):
        print(f'{out_name} already exists, skipping')
        sys.exit(0)

    with open(args.input_path, 'r') as f:
        lines = f.readlines()
        input_data = [json.loads(line.strip()) for line in lines]


    # references = [d['original_text'] for d in input_data]
    # candidates = [d['decoded'][0] if len(d['decoded']) > 0 else '' for d in input_data]

    if input_data[0].get('source_text') is None:
        for i in range(len(input_data)):
            input_data[i]['source_text'] = input_data[i]['original_text']
    references = [d['source_text'] for d in input_data]

    if args.is_chatgpt:
        for i in range(len(input_data)):
            result = input_data[i]['output']
            while isinstance(result, (dict, list)):
                while isinstance(result, dict):
                    # import pdb; pdb.set_trace()
                    keys = list(result.keys())
                    if len(keys) > 1:
                        assert 'text' in keys
                        result = result['text']
                    else:
                        result = result[keys[0]]
                while isinstance(result, list):
                    result = result[0]
            input_data[i]['output'] = [result]

    candidates = [d['output'] for d in input_data]

    if isinstance(candidates[0], list):
        for i in range(len(candidates)):
            candidates[i] = candidates[i][0]

    for _ in range(3): # deal with nested lists
        for i in range(len(candidates)):
            if isinstance(candidates[i], list):
                candidates[i] = candidates[i][0]
 
    for i in range(len(candidates)):
        if candidates[i].strip() == '' or candidates[i] is None:
            candidates[i] = references[i]

    # import pdb; pdb.set_trace()

    # decoded eval
    if args.target in ['formal', 'informal']:
        eval_results = run_formality_eval(
            references=references, candidates=candidates, target=args.target
        )
    else:
        raise ValueError(f'Unknown target {args.target}')
    print('Decoded eval results:', eval_results)

    with open(out_name, 'w') as f:
        json.dump(
            {
                'input_path': args.input_path,
                'decoded': eval_results,
            },
            f,
        )
