"""Test with a specified evaluation metrics."""

import torch
from experiment.utils.tokenizer import prepare_tokenizer
from experiment.utils.data import get_testset, special_tokens_map
from experiment.utils.model import ModelTestDict
from experiment.utils.metrics import bleu_scorer, parent


# %% beam search

def tokenize_sample_test(sample, tokenizer, args, verbose=False):
    """Tokenize on the sample source text, while testing."""

    if verbose:
        print(f"[utils >> tknz_sample] has table {sample['table_id']} & subsent [{sample['sub_sent_id']}]")

    cls_id = special_tokens_map[args.experiment_name]['cls']
    sep_id = special_tokens_map[args.experiment_name]['sep']

    input_ids = [cls_id]
    position_ids = [0]
    for text_span in sample['source']:
        span_tokens = tokenizer.tokenize(text_span)
        span_token_ids = tokenizer.convert_tokens_to_ids(span_tokens)
        input_ids.extend(span_token_ids)
        input_ids.append(sep_id)
        position_ids.extend([i for i in range(len(span_token_ids) + 1)])
    input_ids = input_ids[:args.input_maxlen]
    position_ids = position_ids[:args.input_maxlen]
    attention_mask = [1 for _ in input_ids]
    
    input_ids = torch.LongTensor([input_ids])
    attention_mask = torch.LongTensor([attention_mask])
    position_ids = torch.LongTensor([position_ids])
    input_features = {
        'input_ids': input_ids.to(args.device), 
        'attention_mask': attention_mask.to(args.device), 
        'position_ids': position_ids.to(args.device)
    }
    return input_features


def clear_tokens(token_list, tokenizer):
    """Clean a token sequence by remove <pad>s. 
    Skip special tokens noted as f'<{}>'.
    """
    valid_token_list = [
        token for token in token_list
        if token not in tokenizer.all_special_tokens
    ]
    return valid_token_list


def beam_generate(sample, tokenizer, model, args, verbose=False):
    """Generate outputs from a model with beam search decoding.

    args:
        sample: {'table_id', 'sub_sent_id', 'source', 'target'}
    rets:
        generation: List[str]
    """

    # generate vocab ids
    sample_features = tokenize_sample_test(sample, tokenizer, args)
    if args.experiment_name == 'b2b':
        gen_ids = model.generate(
            input_ids=sample_features['input_ids'], 
            attention_mask=sample_features['attention_mask'], 
            position_ids=sample_features['position_ids'], 
            max_length=args.decode_maxlen, 
            num_beams=args.num_beams, 
            num_return_sequences=args.num_return_sequences 
        )
    else:
        gen_ids = model.generate(
            input_ids=sample_features['input_ids'], 
            attention_mask=sample_features['attention_mask'], 
            max_length=args.decode_maxlen, 
            num_beams=args.num_beams, 
            num_return_sequences=args.num_return_sequences 
        )
    if verbose == True:
        print(f'[beam_gen] has GEN-IDS with size {gen_ids.size()}')

    gen_features = dict()
    for iret, gen_ids in enumerate(gen_ids):
        gen_tokens = tokenizer.convert_ids_to_tokens(gen_ids)
        gen_tokens_clear = clear_tokens(gen_tokens, tokenizer)
        gen_sentence = tokenizer.convert_tokens_to_string(gen_tokens_clear)
        
        gen_features[iret] = {
            'ids': gen_ids, 
            'tokens': gen_tokens, 
            'tokens_clear': gen_tokens_clear, 
            'sentence': gen_sentence
        }

    return gen_features



# %% select optimal set

def select_prediction_set_by_bleu(
    prediction_dicts, references, return_index=False):
    """Select sequence-wise-ly from predictions the best predset against references."""
    predictions = []
    indices = []

    for sample_pred_dict, ref_list in zip(prediction_dicts, references):
        max_idx = 0
        max_score = 0.0

        for idx, d in sample_pred_dict.items():
            res = bleu_scorer.compute(
                predictions=[d['tokens_clear']], 
                references=[ref_list]
            )
            score = res['bleu']
            
            if score > max_score:
                max_idx = idx
                max_score = score

        # print(f'[utils >> select_predset] sample max score: [{max_score}]')
        predictions.append(sample_pred_dict[max_idx]['tokens_clear'])
        indices.append(max_idx)

    if return_index: return predictions, indices
    return predictions


def select_prediction_set_by_parent(prediction_dicts, references, tables, return_index=False):
    """Select sequence-wise-ly from predictions the best predset against references."""
    predictions = []
    indices = []

    for sample_pred_dict, ref_list, table in zip(prediction_dicts, references, tables):
        max_idx = 0
        max_score = 0.0

        for idx, d in sample_pred_dict.items():
            p, r, f1, all_f1 = parent(
                predictions=[d['tokens_clear']], 
                references=[ref_list], 
                tables=[table],
                return_dict=False
            )
            
            if f1 > max_score:
                max_idx = idx
                max_score = f1

        # print(f'[utils >> select_predset] sample max score: [{max_score}]')
        predictions.append(sample_pred_dict[max_idx]['tokens_clear'])
        indices.append(max_idx)

    if return_index: return predictions, indices
    return predictions



# %% sort / rank multiple predictions

def rank_prediction_set_by_bleu(prediction_dicts, references):  # return_scores=True
    """Rank sequence-wise-ly from predictions the best predset against references."""
    from experiment.utils.metrics import bleu_scorer

    sorted_predictions = []
    for sample_pred_dict, ref_list in zip(prediction_dicts, references):
        pred_score_pairs = []
        for idx, d in sample_pred_dict.items():
            res = bleu_scorer.compute(
                predictions=[d['tokens_clear']], 
                references=[ref_list]
            )
            pred_score_pairs.append( (idx, d['sentence'], res['bleu']) )

        pred_score_pairs = sorted(pred_score_pairs, key=lambda x: x[2])
        sorted_predictions.append(pred_score_pairs)

    return sorted_predictions


# %% evaluation

def eval_with_bleu(args, testset, tokenizer, model):
    """Do evaluation on the testset, when BLEU metrics is specified. """

    raw_predictions = [
        beam_generate(sample, tokenizer, model, args)
        for sample in testset
    ]

    references = [
        [tokenizer.tokenize(sample['target'])]
        for sample in testset
    ]

    pred_tokens_dict = {}
    for idx in range(args.num_return_sequences):
        pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]

    for idx, predictions in pred_tokens_dict.items():
        idx_results = bleu_scorer.compute(
            predictions=predictions, 
            references=references,
        )
        print(f"Idx#{idx} - BLEU: {idx_results['bleu']: .3f}")
    
    best_predictions = select_prediction_set_by_bleu(
        raw_predictions, references, bleu_scorer)
    best_results = bleu_scorer.compute(
        predictions=best_predictions, 
        references=references
    )
    print(f"BEST BLEU: {best_results['bleu']: .3f}")

    return



def eval_with_parent(args, testset, tokenizer, model):
    """Do evaluation on the testset, when BLEU metrics is specified. """

    raw_predictions = [ beam_generate(sample, tokenizer, model, args)
        for sample in testset]
    references = [ [tokenizer.tokenize(sample['target'])]
        for sample in testset]
    tokenized_tables = []
    for sample in testset:
        raw_table_parent = sample['table_parent']
        tokenized_table_parent = []
        for attr, value in raw_table_parent:
            value_tokens = tokenizer.tokenize(value)
            tokenized_table_parent.append( ([attr], value_tokens) )
        tokenized_tables.append(tokenized_table_parent)

    pred_tokens_dict = {}
    for idx in range(args.num_return_sequences):
        pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]

    for idx, predictions in pred_tokens_dict.items():
        (idx_p, idx_r, idx_f1, idx_all_f1) = parent(
            predictions=predictions, 
            references=references, 
            tables=tokenized_tables, 
            return_dict=False, 
        )
        print(f"Idx#{idx} - PARENT: {idx_p:.3f}, {idx_r:.3f}, {idx_f1:.3f}")
    
    best_predictions = select_prediction_set_by_parent(
        raw_predictions, references, tokenized_tables)
    (avg_p, avg_r, avg_f, all_f) = parent(
        predictions=best_predictions, 
        references=references, 
        tables=tokenized_tables, 
        return_dict=False
    )
    print(f"BEST PARENT: {avg_p: .3f}, {avg_r:.3f}, {avg_f:.3f}")

    # operations = [sample['operations'] for sample in testset]
    # eval_parent_by_operation(operations, predictions, references, tokenized_tables)
    return



EvalDict = {
    'bleu': eval_with_bleu, 
    'parent': eval_with_parent, 
}



# %% test with logging

def decode_with_bleu(args, testset, tokenizer, model):
    """Decode testset and write out, when BLEU metrics is specified. """

    raw_predictions = [
        beam_generate(sample, tokenizer, model, args)
        for sample in testset
    ]

    references = [
        [tokenizer.tokenize(sample['target'])]
        for sample in testset
    ]

    ranked_predictions = rank_prediction_set_by_bleu(
        raw_predictions, references)

    with open(args.test_decode_path, 'w') as fw:
        for idx, (pred_list, ref) in enumerate(zip(ranked_predictions, references)):
            fw.write(f"#{idx}\n")
            for ii, psent, pscore in pred_list:
                fw.write(f'[{ii}: {pscore:.4f}] {psent}\n')
            fw.write(f'{ref[0]}\n\n')
    print(f'Wrote {len(ranked_predictions)} prediction & reference instances into target file: [{args.test_decode_path}]')

    return


def decode_with_parent(args, testset, tokenizer, model):
    """Do evaluation on the testset, when BLEU metrics is specified. """

    raw_predictions = [ beam_generate(sample, tokenizer, model, args)
        for sample in testset]
    references = [ [tokenizer.tokenize(sample['target'])]
        for sample in testset]
    tokenized_tables = []
    for sample in testset:
        raw_table_parent = sample['table_parent']
        tokenized_table_parent = []
        for attr, value in raw_table_parent:
            value_tokens = tokenizer.tokenize(value)
            tokenized_table_parent.append( ([attr], value_tokens) )
        tokenized_tables.append(tokenized_table_parent)

    pred_tokens_dict = {}
    for idx in range(args.num_return_sequences):
        pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]

    for idx, predictions in pred_tokens_dict.items():
        (idx_p, idx_r, idx_f1, idx_all_f1) = parent(
            predictions=predictions, 
            references=references, 
            tables=tokenized_tables, 
            return_dict=False, 
        )
        print(f"Idx#{idx} - PARENT: {idx_p:.3f}, {idx_r:.3f}, {idx_f1:.3f}")
    
    best_predictions = select_prediction_set_by_parent(
        raw_predictions, references, tokenized_tables)
    (avg_p, avg_r, avg_f, all_f) = parent(
        predictions=best_predictions, 
        references=references, 
        tables=tokenized_tables, 
        return_dict=False
    )
    print(f"BEST PARENT: {avg_p: .3f}, {avg_r:.3f}, {avg_f:.3f}")

    with open(args.test_decode_path, 'w') as fw:
        for idx, (pred, ref, tab) in enumerate(zip(best_predictions, references, tokenized_tables)):
            sample_parent = parent(
                predictions=[pred], 
                refereces=[ref], 
                tables=[tab], 
                return_dict=True
            )
            fw.write(f"#{idx} BLEU: [{sample_parent['average_f1']:.4f}]\n")
            fw.write(f'{pred}\n{ref[0]}\n\n')
    print(f'Wrote {len(predictions)} prediction & reference pairs into target file: [{args.test_decode_path}]')

    return



DecodeDict = {
    'bleu': decode_with_bleu, 
    'parent': decode_with_parent, 
}



# %% test main

def run_test(args):
    testset = get_testset(data_files=args.test_outpath)
    tokenizer = prepare_tokenizer(name=args.tokenizer_name)

    model = ModelTestDict[args.experiment_name](
        run_dir=args.run_dir, 
        path=args.model_path, 
        name=args.model_name, 
        device=args.device
    )
    if args.do_test: 
        EvalDict[args.metrics[0]](args, testset, tokenizer, model)
    if args.do_decode:
        args.test_decode_path = os.path.join(args.run_dir, args.test_decode_name)
        DecodeDict[args.metrics[0]](args, testset, tokenizer, model)


import os
from experiment.pointer_generator.decode import BeamSearch

def find_best_pgn_model_index(run_dir, main_metric_key='bleu-4'):
    """Find the best model at testing. """
    detailed_run_dir = os.path.join(run_dir, 'train', 'models')
    decode_dirs = os.listdir(detailed_run_dir)
    decode_metrics = []
    for dd in decode_dirs:
        mfile = os.path.join(run_dir, dd, 'metrics')
        ckpt_metrics = {}
        with open(mfile, 'r') as fr:
            for line in fr:
                mkey, mval = line.strip().split('\t')
                ckpt_metrics[mkey] = float(mval)
        decode_metrics.append(ckpt_metrics)
    
    best_ckpt_idx = -1
    best_ckpt_mval = 0.0
    for idx, mdict in decode_metrics:
        mval = mdict[main_metric_key]
        if mval > best_ckpt_mval:
            best_ckpt_mval = mval
            best_ckpt_idx = idx
    return best_ckpt_idx
    
    

def run_test_pgn(args):
    try:
        # best_ckpt_idx = find_best_pgn_model_index(args.run_dir)
        best_ckpt_idx = 99
        best_ckpt_path = os.path.join(args.run_dir, 'train', 'models', f'model_{best_ckpt_idx}.bin')
    except:
        best_ckpt_path = None
    print(f'<<< Perform the Final Test ... (use model [{best_ckpt_path}]) >>>')
    tester = BeamSearch(args, best_ckpt_path, args.decode_data_path)
    tester.run(args.logging_steps)
    # tester.eval_parent(args.logging_steps)
    print(f'<<< Finished the Final Test ! >>>')



# %% collection

TestFunctionDict = {
    't5': run_test, 
    'bart': run_test, 
    'b2b': run_test, 
    'pg': run_test_pgn, 
}
