import argparse
import copy
import logging
import os
import random
import socket
import IPython

import numpy as np
from numpy.matrixlib.defmatrix import matrix
import torch
from dotdict import DotDict


""" search hyperparam combinations from a predefined group of hyperparams """
def search_hyperparams(dictionary, current_hparams={}):
    if len(dictionary) == 0:
        yield copy.deepcopy(current_hparams)
        return
    current_dictionary = dict()
    current_key = list(dictionary.keys())[0]
    current_flag = dictionary[current_key].get('flag', None)
    if current_flag is not None:
        for key in copy.deepcopy(dictionary):
            if dictionary[key].get('flag', None) == current_flag:
                current_dictionary[key] = copy.deepcopy(dictionary[key])
                del dictionary[key]
    else:
        current_dictionary[current_key] = copy.deepcopy(
            dictionary[current_key]
        )
        del dictionary[current_key]
    num_values = len(current_dictionary[current_key]['values'])
    for key in current_dictionary:
        assert num_values == len(current_dictionary[key]['values']), \
            'hparams with the same flag must have the same #values\n' \
            'check {:s} and {:s}'.format(key, current_key)
    for i in range(num_values):
        for key in current_dictionary:
            current_hparams[key] = current_dictionary[key]['values'][i]
        for item in search_hyperparams(dictionary, current_hparams):
            yield item
    for key in current_dictionary:
        dictionary[key] = copy.deepcopy(current_dictionary[key])


""" configs for d-parsing """
class Configs(DotDict):
    def __init__(self, data_dict, *args, **kwargs):
        super().__init__(data_dict=data_dict, *args, **kwargs)
    
    @classmethod
    def get_configs(cls):
        parser = argparse.ArgumentParser()
        subparsers = parser.add_subparsers(title='commands', dest='mode')
        # train
        subparser = subparsers.add_parser('train', help='training stage')
        subparser.add_argument('--data-path', type=str, required=True, help='folder to training data, stored in [train|dev|test].conll')
        subparser.add_argument('--label-dict', type=str, default='./metadata/ud26.simp.json', help='path to preextracted label dict')
        subparser.add_argument('--model-path', type=str, default='./models/debug', help='path to save the model')
        subparser.add_argument('--pretrain-name', type=str, default='xlm-roberta-base', help='name of pretrained LM')
        subparser.add_argument('--pretrain-grad', type=bool, default=False, help='whether using pretrained gradients')
        subparser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='device for model training')
        subparser.add_argument('--position-pad-id', type=int, default=-1, help='padding ID for position classification (arc)')
        subparser.add_argument('--label-pad-id', type=int, default=-1, help='padding ID for label classification (label)')
        subparser.add_argument('--encoding', type=str, default='utf-8', help='encoding for groundtruth files')
        subparser.add_argument('--hidden-size', type=int, default=256, help='hidden size of the encoder')
        subparser.add_argument('--dropout', type=float, default=0.33, help='dropout ratio for training')
        subparser.add_argument('--lstm-hidden-size', type=int, default=256, help='LSTM hidden size')
        subparser.add_argument('--lstm-dropout', type=float, default=0.33, help='LSTM dropout')
        subparser.add_argument('--lstm-n-layers', type=int, default=3, help='number of LSTM layers for feature extraction')
        subparser.add_argument('--batch-size', type=int, default=4096, help='batch size (token num) for training')
        subparser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
        subparser.add_argument('--clip', type=float, default=5.0, help='clip grad norm for training')
        subparser.add_argument('--epochs', type=int, default=40, help='number of epochs for training')
        subparser.add_argument('--lr', type=float, default=.0005, help='learning rate for training')
        subparser.add_argument('--lr-reduce-ratio', type=float, default=.2, help='reduce ratio for not improved LAS')
        subparser.add_argument('--lr-patience', type=int, default=10, help='patience for not improved LAS')
        subparser.add_argument('--seed', type=int, default=120, help='random seed for reproduction')
        subparser.add_argument('--eval-script', type=str, default='3rdparty/conll17_eval_script/conll17_ud_eval.py', help='eval script')
        subparser.add_argument('--start-model', type=str, default=None, help='path to initialization (saved in the same way)')
        subparser.add_argument('--n-examples', type=int, default=int(1e20), help='only use the first n examples for training')
        subparser.add_argument('--shuffle', action='store_true', default=False, help='whether shuffle the training data (for sample study)')
        # evaluate
        subparser = subparsers.add_parser('evaluate', help='evaluation stage')
        subparser.add_argument('--data-path', type=str, required=True, help='path to evaluation data')
        subparser.add_argument('--model-path', type=str, required=True, help='path to the saved model')
        subparser.add_argument('--eval-script', type=str, default='3rdparty/conll17_eval_script/conll17_ud_eval.py', help='eval script')
        # predict
        subparser = subparsers.add_parser('predict', help='prediction stage')
        subparser.add_argument('--data-path', type=str, required=True, help='path to sentences to be parsed')
        subparser.add_argument('--model-path', type=str, required=True, help='path to the saved model')
        subparser.add_argument('--input-format', type=str, default='conll', choices=['conll', 'txt'], help='input format')
        subparser.add_argument('--output-path', '-o', type=str, required=True, help='path to the output parses')
        subparser.add_argument('--n-examples', type=int, default=int(1e20), help='only predict the first n examples (due to memory lim)')
        subparser.add_argument('--encoding', type=str, default='utf-8', help='encoding for output prediction')
        args = parser.parse_args()
        return cls(vars(args))

    def __repr__(self):
        repr_list = ['-' * 20 + '-+-' + '-' * 50, f'{"Param":20} | {"Value":^50}', '-' * 20 + '-+-' + '-' * 50]
        for key, value in sorted(self.items(), key=lambda x:x[0]):
            repr_list.append(f'{key:20} | {str(value):^50}')
        repr_list.append('-' * 20 + '-+-' + '-' * 50)
        return '\n' + '\n'.join(repr_list)


""" configs for soft d-parsing """
class SoftConfigs(Configs):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    @classmethod
    def get_configs(cls):
        parser = argparse.ArgumentParser()
        subparsers = parser.add_subparsers(title='commands', dest='mode')
        # train
        subparser = subparsers.add_parser('train', help='training stage')
        subparser.add_argument('--data-path', type=str, required=True, help='folder to training data, stored in [train|dev|test].conll/.json[l]')
        subparser.add_argument('--reverse-lang', action='store_true', default=False, help='whether to reverse src and trg languages')
        subparser.add_argument('--base-model-path', type=str, required=True, help='path to the base model')
        subparser.add_argument('--model-path', type=str, default='./models/debug', help='path to save the model')
        subparser.add_argument('--pretrain-name', type=str, default='xlm-roberta-base', help='name of pretrained LM')
        subparser.add_argument('--pretrain-grad', type=bool, default=False, help='whether using pretrained gradients')
        subparser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='device for model training')
        subparser.add_argument('--encoding', type=str, default='utf-8', help='encoding for groundtruth files')
        subparser.add_argument('--hidden-size', type=int, default=256, help='hidden size of the encoder')
        subparser.add_argument('--dropout', type=float, default=0.33, help='dropout ratio for training')
        subparser.add_argument('--lstm-hidden-size', type=int, default=256, help='LSTM hidden size')
        subparser.add_argument('--lstm-dropout', type=float, default=0.33, help='LSTM dropout')
        subparser.add_argument('--lstm-n-layers', type=int, default=3, help='number of LSTM layers for feature extraction')
        subparser.add_argument('--batch-size', type=int, default=4096, help='batch size (token num) for training')
        subparser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
        subparser.add_argument('--clip', type=float, default=5.0, help='clip grad norm for training')
        subparser.add_argument('--epochs', type=int, default=40, help='number of epochs for training')
        subparser.add_argument('--lr', type=float, default=.0005, help='learning rate for training')
        subparser.add_argument('--lr-reduce-ratio', type=float, default=.2, help='reduce ratio for not improved LAS')
        subparser.add_argument('--lr-patience', type=int, default=10, help='patience for not improved LAS')
        subparser.add_argument('--seed', type=int, default=120, help='random seed for reproduction')
        subparser.add_argument('--eval-script', type=str, default='3rdparty/conll17_eval_script/conll17_ud_eval.py', help='eval script')
        subparser.add_argument('--start-model', type=str, default=None, help='path to initialization (saved in the same way)')
        subparser.add_argument('--base-regularizer', type=float, default=0, help='add the base regularizer during training')        
        subparser.add_argument('--n-examples', type=int, default=int(1e20), help='only use the first n examples for training')
        subparser.add_argument('--shuffle', action='store_true', default=False, help='whether shuffle the training data (for sample study)')
        # evaluate
        subparser = subparsers.add_parser('evaluate', help='evaluation stage')
        subparser.add_argument('--data-path', type=str, required=True, help='path to evaluation data')
        subparser.add_argument('--model-path', type=str, required=True, help='path to the saved model')
        subparser.add_argument('--eval-script', type=str, default='3rdparty/conll17_eval_script/conll17_ud_eval.py', help='eval script')
        args = parser.parse_args()
        return cls(vars(args))


""" logger for multiple destination """ 
class Logger(logging.Logger):
    def __init__(self, log_filename, level=logging.INFO, formatter='%(asctime)s - %(levelname)s - %(message)s'):
        super(Logger, self).__init__(__name__, level)
        formatter = logging.Formatter(formatter)
        file_handler = logging.FileHandler(log_filename, 'a')
        file_handler.setLevel(level)
        file_handler.setFormatter(formatter)
        self.addHandler(file_handler)
        console = logging.StreamHandler()
        console.setLevel(level)
        console.setFormatter(formatter)
        self.addHandler(console)


""" meter for result recoding """
class Meter(object):
    def __init__(self, init_sum=0, init_n_iter=0):
        self._sum = init_sum
        self._n_iter = init_n_iter
    
    @property
    def average(self):
        return self._sum / self._n_iter if self._n_iter > 0 else 0
    
    def update(self, item, num=1):
        self._sum += item * num
        self._n_iter += num


""" set random seed for reproduction """
def set_random_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)


""" return random state for reproduction """ 
def get_rng_state():
    if torch.cuda.is_available():
        return random.getstate(), np.random.get_state(), torch.get_rng_state(), torch.cuda.get_rng_state()
    else:
        return random.getstate(), np.random.get_state(), torch.get_rng_state()


""" resuem random state for reproduction """
def set_rng_state(rng_state):
    random.setstate(rng_state[0])
    np.random.set_state(rng_state[1])
    torch.set_rng_state(rng_state[2])
    if len(rng_state) == 4:
        assert torch.cuda.is_available()
        torch.cuda.set_rng_state(rng_state[3])
    else:
        assert len(rng_state) == 3


""" get header for training """
def get_header():
    headers = list()
    headers.append(f'hostname: {socket.gethostname()}')
    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        headers.append(f'CUDA_VISIBLE_DEVICES: {os.environ["CUDA_VISIBLE_DEVICES"]}')
    return '\n'.join(headers)


""" soft cross entropy loss"""
class SoftCrossEntropyLoss(object):
    def __init__(self):
        pass

    def __call__(self, x, xhat, mask, is_logits=True):
        if is_logits:
            logits = -xhat.softmax(-1) * x.log_softmax(-1)
        else:
            logits = -xhat * x.log()
        return torch.masked_select(logits, mask).sum()


""" normalize the alignment matrix"""
def normalize_row(matrices, masks_s, masks_t, add_uniform=False):
    new_masks_s = masks_s.max(-1)[0].unsqueeze(-1)
    new_masks_t = masks_t.max(-1)[0].unsqueeze(1)
    matrix_masks = new_masks_s * new_masks_t
    non_zero_mask = ~matrices.sum(-1).eq(0).unsqueeze(-1)
    if add_uniform:
        matrices = matrices + (~non_zero_mask) * matrix_masks
    normalized_matrices = matrices / (matrices.sum(-1).unsqueeze(-1) + 1e-10) * matrix_masks
    return normalized_matrices, non_zero_mask

if __name__ == '__main__':
    logger = Logger('d-parse-test', logging.INFO, log_filename='debug.log')
    configs = Configs.get_configs()
    logger.info(configs)
    from IPython import embed; embed(using=False)
