import os
import yaml
import json
import argparse
import importlib
from tqdm import tqdm
from collections import defaultdict

from utils import file_utils
from utils.eva import evaluate
from utils.data.MQuAKEDatasetHandler import MQuAKEDatasetHandler
from utils.data.CounterfactDatasetHandler import CounterfactDatasetHandler


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--editor_config', '-e', required=True)
    parser.add_argument('--dataset', '-d', type=str, required=True)
    parser.add_argument('--fact_type', type=str, choices=['Struct', 'Unstruct', 'Unstruct-triplets'])
    parser.add_argument('--test_index', type=int, default=1)
    args = parser.parse_args()
    return args


def build_args():
    args = parse_args()
    file_utils.update_args(args, path=f'./configs/{args.editor_config}.yaml')

    args.device = 'cuda'

    return args


def build_dataset_handler(dataset_name, after_editing):
    DATASET_DIR = './data'
    if dataset_name.startswith('MQuAKE'):
        dataset_handler = MQuAKEDatasetHandler(DATASET_DIR, dataset_name, after_editing)
    else:
        dataset_handler = CounterfactDatasetHandler(DATASET_DIR, dataset_name, after_editing)

    return dataset_handler


def build_editor(args):
    editor = getattr(importlib.import_module(f'editors.{args.editor_module}'), args.editor_module)(args)
    return editor


def evaluate_bymode(args, editor, dataset_handler, dataset_chunk, eva_mode):
    eva_data = dataset_handler.get_eva_data(dataset_chunk, eva_mode, fact_type=args.fact_type)

    eva_params = getattr(args.eva_mode, eva_mode)
    batch_size_test = eva_params.batch_size_test

    total_num_correct = 0
    total = 0
    answers_pred = []
    chunk_size = len(eva_data['answers'])

    for i in tqdm(range(0, chunk_size, batch_size_test), desc=f'chunk eva {eva_mode=}', leave=False):
        batch_eva_data = {}
        for key in eva_data:
            batch_eva_data[key] = eva_data[key][i: i + batch_size_test]

        batch_answers_pred = editor.batch_test(batch_eva_data, eva_mode, eva_params)

        answers_pred.extend(batch_answers_pred)

        total_num_correct += evaluate.compute_num_correct(batch_eva_data['answers'], batch_answers_pred, eva_mode)
        total += len(batch_eva_data['answers'])

        print(f"chunk acc {eva_mode=}: {(total_num_correct / total):.5f}  {total_num_correct} / {total}")

    results = {
        'total_num_correct': total_num_correct,
        'total': chunk_size
    }

    dataset_chunk = dataset_handler.merge_answers_pred(dataset_chunk, answers_pred, eva_mode=eva_mode)

    return dataset_chunk, results


def evaluate_editor(args, editor, dataset_handler, dataset_chunk, eva_mode_list):
    results = {}
    for eva_mode in eva_mode_list:
        dataset_chunk, results_mode = evaluate_bymode(args, editor, dataset_handler, dataset_chunk, eva_mode)
        results[eva_mode] = results_mode

    return results


def execute_editing(args, editor, dataset_handler, dataset_chunk):
    edit_data = dataset_handler.get_edit(dataset_chunk, args.fact_type)
    editor.edit(edit_data, args.fact_type)


def summarize_results(results):
    sum_results = defaultdict(dict)
    for eva_mode in results:
        sum_results[eva_mode]['total_num_correct'] = 0
        sum_results[eva_mode]['total'] = 0
        for chunk in results[eva_mode]:
            sum_results[eva_mode]['total_num_correct'] += chunk['total_num_correct']
            sum_results[eva_mode]['total'] += chunk['total']
        sum_results[eva_mode]['acc'] = round(sum_results[eva_mode]['total_num_correct'] / sum_results[eva_mode]['total'], 5)

    return sum_results


def run(args):
    args.output_prefix = f'./output/{args.dataset}/{args.editor_config}_Fact{args.fact_type}_{args.test_index}th'

    file_utils.make_dir(os.path.dirname(args.output_prefix))

    print('\n=========== arguments ============\n' + yaml.dump(vars(args), default_flow_style=False))

    print(f"===>{args.output_prefix=}")
    print(f'fact_type: {args.fact_type}')

    if args.dataset.startswith('counterfact') or args.dataset.startswith('WikiUpdate'):
        eva_mode_list = ['edit']
    elif args.dataset.startswith('MQuAKE'):
        eva_mode_list = list(vars(args.eva_mode).keys())
        print(f"{eva_mode_list=}")

    dataset_handler = build_dataset_handler(args.dataset, args.after_editing)

    # chunk dataset
    dataset_save = []
    dataset = dataset_handler.get_dataset()

    if args.batch_size_edit == 'full':
        args.batch_size_edit = len(dataset)

    batch_size_edit = args.batch_size_edit
    dataset_chunks = [dataset[i:i + batch_size_edit] for i in range(0, len(dataset), batch_size_edit)]
    print(f"===>dataset_size: {len(dataset)}")
    print(f"===>batch_size_edit: {batch_size_edit}")

    print('===>building editor...')
    editor = build_editor(args)

    results = defaultdict(list)
    for dataset_chunk in tqdm(dataset_chunks, desc='chunk iter'):
        execute_editing(args, editor, dataset_handler, dataset_chunk)
        results_chunk = evaluate_editor(args, editor, dataset_handler, dataset_chunk, eva_mode_list)

        dataset_save.extend(dataset_chunk)

        for eva_mode in results_chunk:
            results[eva_mode].append(results_chunk[eva_mode])
        sum_results = summarize_results(results)

        print(f"\nsum results: {json.dumps(dict(sum_results), indent=4)}")

    # save results
    sum_results = summarize_results(results)
    output_metric = f'---------- {args.dataset} {args.editor_config} {args.test_index}th {file_utils.get_timestamp()} ----------\n'
    output_metric += json.dumps(sum_results, indent=4) + '\n'

    file_utils.save_json(dataset_save, f'{args.output_prefix}_rst.json')
    file_utils.save_texts([output_metric], f'{args.output_prefix}_metrics.txt', mode='a')


def main():
    args = build_args()
    run(args)


if __name__ == '__main__':
    main()
