import os
import shutil

import json
import torch
import os.path as op
from utils.misc import mkdir
from utils.distributed_processing import is_main_process, get_world_size, synchronize, get_rank


def get_predict_file(model_dir, yaml_file, args):
    # make sure it works with/without / in end of the path.
    data = 'vqa'
    split = op.basename(yaml_file)
    assert split.endswith('.yaml')
    split = split[:-5]

    cc = ['pred']
    cc.append(data)
    cc.append(split)

    if args.add_od_labels:
        cc.append('odlabels')

    if args.output_hidden_states:
        cc.append('hidden')

    if args.max_seq_length:
        cc.append(f'seq_len{args.max_seq_length}')

    if args.tag_entire_set is not None:
        cc.append('tag_entire_set')

    if args.tag_sep_token is not None:
        cc.append(f'sep_token_{args.tag_sep_token}'.lower())

    if args.mask_c_t is not None:
        cc.append(f"mask_c_t_{args.mask_c_t}".lower())

    if args.mask_t_i is not None:
        cc.append(f"mask_t_i_{args.mask_t_i}".lower())

    if args.add_prefix is True:
        cc.append(f"prefix_{args.add_prefix}")

    model_type = op.normpath(model_dir)
    model_type = model_type.split(os.sep)[-1]
    output_dir = args.output_dir + "/" + model_type

    if os.path.exists(output_dir) is False and is_main_process():
        os.mkdir(output_dir)

    return op.join(output_dir, '{}.json'.format('.'.join(cc)))

def get_evaluate_file(predict_file, num_tags):
    assert predict_file.endswith('.tsv')
    fpath = op.splitext(predict_file)[0]
    if num_tags is not None:
        return fpath + f'tags_{num_tags}' + '.eval.json'

    return fpath + '.eval.json'

def save_best(logger, output_dir:str, best_ckpt_dir:str):
    output_dir = output_dir + "/best"
    if not op.exists(output_dir):
        shutil.copytree(best_ckpt_dir, output_dir)
    else:
        logger.warning("Overwriting the best folder: best folder already exist.")
        shutil.rmtree(output_dir)
        shutil.copytree(best_ckpt_dir, output_dir)
    logger.info("Best checkpoint saved.")

def save_checkpoint(logger, model, tokenizer, args, epoch, iteration, num_trial=10) -> str:
    """

    :rtype: object
    """
    checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format(
        epoch, iteration))
    if not is_main_process():
        return checkpoint_dir
    mkdir(checkpoint_dir)
    model_to_save = model.module if hasattr(model, 'module') else model
    for i in range(num_trial):
        try:
            model_to_save.save_pretrained(checkpoint_dir)
            torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
            tokenizer.save_pretrained(checkpoint_dir)
            logger.info("Save checkpoint to {}".format(checkpoint_dir))
            break
        except:
            pass
    else:
        logger.info("Failed to save checkpoint after {} trails.".format(num_trial))
    return checkpoint_dir

def save_predict(prediction, filename):
    with open(filename, 'w') as fp:
        json.dump(prediction, fp)

def concat_cache_files(cache_files, predict_file):
    results = []
    for f in cache_files:
        temp = json.load(open(f))
        results += temp
        os.remove(f)

    save_predict(results,predict_file)
    return

def save_predict_ddp(results, predict_file):
    world_size = get_world_size()

    if world_size == 1:
        save_predict(results, predict_file)
        print("Inference file saved")
        return

    else:
        cache_file = op.splitext(predict_file)[0] \
                     + f'_{get_rank()}_{world_size}' \
                     + op.splitext(predict_file)[1]

        save_predict(results, cache_file)
        synchronize()

        if is_main_process():
            cache_files = [op.splitext(predict_file)[0] + '_{}_{}'.format(i, world_size) + \
                           op.splitext(predict_file)[1] for i in range(world_size)]
            concat_cache_files(cache_files, predict_file)
            print("Inference file saved")

    return