import os
import pickle
import torch


def save_model(opt, model, tokenizer, save_path):
    # Save a trained model, configuration and tokenizer
    if hasattr(model, 'module') or hasattr(model, 'core'):
        # print("save model from data-parallel!")
        model_to_save = model.module
    else:
        # print("save a single cuda model!")
        model_to_save = model
    if opt.save_mode == 1 or opt.save_mode == 2:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        f_config = open(save_path + opt.model_name + '.config', mode='wb')
        f_tokenizer = open(save_path + opt.model_name + '.tokenizer', mode='wb')
        pickle.dump(opt, f_config)
        pickle.dump(tokenizer, f_tokenizer)
        f_config.close()
        f_tokenizer.close()
        if opt.save_mode == 1:
            torch.save(model_to_save.state_dict(), save_path + opt.model_name + '.state_dict')  # save the state dict
        elif opt.save_mode == 2:
            torch.save(model.cpu(), save_path + opt.model_name + '.model')  # save the state dict

    elif opt.save_mode == 3:
        # save the fine-tuned bert model
        model_output_dir = save_path + 'fine-tuned-pretrained-model'
        if not os.path.exists(model_output_dir):
            os.makedirs(model_output_dir)
        output_model_file = os.path.join(model_output_dir, 'pytorch_model.bin')
        output_config_file = os.path.join(model_output_dir, 'config.json')

        if hasattr(model_to_save, 'bert4global'):
            model_to_save = model_to_save.bert4global
        elif hasattr(model_to_save, 'bert'):
            model_to_save = model_to_save.bert
        else:
            model_to_save = model_to_save
            # raise RuntimeError('No pretrained model found to save')

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        if hasattr(tokenizer, 'tokenizer'):
            tokenizer.tokenizer.save_pretrained(model_output_dir)
        else:
            tokenizer.save_pretrained(model_output_dir)

    else:
        raise ValueError('Invalid save_mode: {}'.format(opt.save_mode))
    model.to(opt.device)
