import argparse
import os

import numpy as np
import torch

from parser.cmds import Train, Predict
from parser.utils.config import Config

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Create the Unsupervised POS model.')
    # subparsers
    subparsers = parser.add_subparsers(title='Commands', dest='mode')

    # subcommands
    subcommands = {'train': Train(), 'predict': Predict()}

    for mode, subcommand in subcommands.items():
        subparser = subcommand.create_subparser(subparsers, mode)
        subparser.add_argument('--config',
                               default='config.ini',
                               help='path to config file')
        subparser.add_argument('--save',
                               default='save/master',
                               help='path to saved files')
        subparser.add_argument('--device',
                               default='0',
                               help='ID of GPU to use')
        subparser.add_argument('--seed',
                               default=1,
                               type=int,
                               help='seed for generating random numbers')
        subparser.add_argument('--batch-size',
                               default=5000,
                               type=int,
                               help='batch size')
        subparser.add_argument('--n-bottleneck',
                               default=5,
                               type=int,
                               help='max dim of bottle neck')
        subparser.add_argument('--n-buckets',
                               default=32,
                               type=int,
                               help='max num of buckets to use')
        subparser.add_argument('--min-freq',
                               default=0,
                               type=int,
                               help='max num of buckets to use')
        subparser.add_argument('--threads',
                               default=8,
                               type=int,
                               help='max num of threads')
        subparser.add_argument('--use_cached', '-u', action="store_true")

    # parse args
    args = parser.parse_args()

    print(f"Set the max num of threads to {args.threads}")
    torch.set_num_threads(args.threads)

    print(f"Set the seed for generating random numbers to {args.seed}")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    print(f"Set the device with ID {args.device} visible")
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    args.fields = os.path.join(args.save, 'fields')
    args.crf_model = os.path.join(args.save, 'crf')
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args = Config(args.config).update(vars(args))

    print(f"Run the command in mode {args.mode}")
    cmd = subcommands[args.mode]
    cmd(args)
