import argparse
import collections

from data_loader.dataset import Dataset
import data_loader.data_sampler as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
from trainer import *
from data_preprocess import *
from utils import write_json

# fix random seeds for reproducibility
SEED = 18
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def main(config):
    # Logger
    logger = config.get_logger('train')

    # Data
    dataset = Dataset(config['dataset']['load_path'])
    ## Sampler for Step 2 (meta-learning)
    train_y = [sample['y'] for sample in dataset.train_seen]
    data_sampler = config.init_obj('data_sampler', module_data, labels=train_y)

    # Config
    ## save updated config file to the checkpoint dir
    config.deep_update_config(config.config, dataset.cropus.config())
    if not config.debug:
        write_json(config.config, config.save_dir / 'config_{}.json'.format(config['dataset']['name']))

    # Train
    ## Loss function
    loss_fn = getattr(module_loss, config['loss'])

    if config['step'] == 1:
        ## Step 1: pre-training using metric-learning for initialization
        model = config.init_obj('arch_step1', module_arch, config=config, dataset=dataset)

        params = list(model.named_parameters())

        def is_bert(n):
            return 'bert' in n

        trainable_params = [
            {"params": [p for n, p in params if is_bert(n)], 'lr': 1e-5},
            {"params": [p for n, p in params if not is_bert(n)], 'lr': 1e-3},
        ]
        # trainable_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
        lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

        trainer = StepOneTrainer(model=model,
                                 optimizer=optimizer,
                                 config=config,
                                 dataset=dataset,
                                 loss_fn=loss_fn,
                                 lr_scheduler=lr_scheduler)
        trainer.train()

    elif config['step'] == 2:
        ## Step 2: meta-learning procedure
        model = config.init_obj('arch_step2', module_arch, config=config, dataset=dataset)

        ### is init from step1
        if config['arch_step2']['ablation']['init'] == True:
            with torch.no_grad():
                with open('./data/{}/protos.pkl'.format(dataset.cropus.data_name,
                                                        config['arch_step1']['args']['encoder_type']), 'rb') as f:
                    protos = pickle.load(f)
                for i, class_proto in enumerate(model.class_protos):
                    model.class_protos[i] = torch.nn.Parameter(protos[i])

        def is_fast(n):
            return sum([fast_param in n for fast_param in config['arch_step2']['fast']]) > 0

        def is_slow(n):
            return sum([fast_param in n for fast_param in config['arch_step2']['slow']]) > 0

        def is_normal(n):
            return sum([fast_param in n for fast_param in config['arch_step2']['normal']]) > 0

        trainable_params = [{'params': [], 'lr': 1e-5}, {'params': [], 'lr': 1e-2}, {'params': [], 'lr': 1e-4},
                            {'params': [], 'lr': 1e-3}, {'params': [], 'lr': 1e-4}]
        for name, param in model.named_parameters():
            print(name)
            if 'bert' in name:
                param.requires_grad = False
            elif is_fast(name):
                trainable_params[1]['params'].append(param)
            elif is_slow(name):
                trainable_params[2]['params'].append(param)
            elif is_normal(name):
                trainable_params[3]['params'].append(param)
            else:
                trainable_params[4]['params'].append(param)

        optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
        lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

        trainer = StepTwoTrainer(model=model,
                                 optimizer=optimizer,
                                 config=config,
                                 dataset=dataset,
                                 loss_fn=loss_fn,
                                 data_sampler=data_sampler,
                                 lr_scheduler=lr_scheduler)
        trainer.train()


if __name__ == '__main__':
    args = argparse.ArgumentParser(description='GZSL Research')
    # args.add_argument('-c', '--config', default='config_SNIPS.json', type=str,
    # args.add_argument('-c', '--config', default='config_SMP.json', type=str,
    # args.add_argument('-c', '--config', default='config_ATIS.json', type=str,
    args.add_argument('-c', '--config', default='config_Clinc.json', type=str,
                      # args.add_argument('-c', '--config', default='config_Quora.json', type=str,
                      # args.add_argument('-c', '--config', default='config_Samsung.json', type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default='2', type=str,
                      help='indices of GPUs to enable (default: all)')
    args.add_argument('-dg', '--debug', default=True, type=bool,
                      help='debug mode')

    # custom cli options to modify configuration from default values given in json file.
    CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
    options = [
        CustomArgs(['-st', '--step'], type=int, target='step')
    ]

    config = ConfigParser.from_args(args, options)
    main(config)

