import numpy
from pathlib import Path
from typing import List

from transformers import AutoConfig, AutoModel, AutoTokenizer
from json import load, dump as save
from page.util import filter_dict_by_keys

from page.torch.text_field import TransformerTextField
from page.torch.eq_field import TokenEquationField, OperationEquationField


class ModelConfig:
    def __init__(self, encoder_model: str = 'albert-large-v2', chkpt_path: str = None,
                 model_type: str = 'opat', num_decoder_layers: int = 12,
                 num_pointer_heads: int = 1, beam_size: int = 3,
                 token_vocab_size: int = None, token_nonum_size: int = None,
                 function_word_size: int = None, constant_word_size: int = None, argument_word_size: int = None):
        self.encoder_model = encoder_model
        self.chkpt_path = chkpt_path
        self.model_type = model_type

        self.num_decoder_layers = int(num_decoder_layers)
        self.num_pointer_heads = int(num_pointer_heads)
        self.beam_size = int(beam_size)

        self.token_vocab_size = token_vocab_size
        self.token_nonum_size = token_nonum_size
        self.function_word_size = function_word_size
        self.constant_word_size = constant_word_size
        self.argument_word_size = argument_word_size

        # Assign class names
        self._cls = {'config': AutoConfig.from_pretrained(self.encoder_model).__class__}
        try:
            module = self._cls['config'].__module__
            clsname = self._cls['config'].__name__

            global_dict = {}
            exec('from %s import %s as tokenizer\nfrom %s import %s as model'
                 % (module.replace('configuration', 'tokenization'), clsname.replace('Config', 'Tokenizer'),
                    module.replace('configuration', 'modeling'), clsname.replace('Config', 'Model')),
                 global_dict)
            self._cls.update(global_dict)
        except:
            self._cls['tokenizer'] = AutoTokenizer.from_pretrained(self.encoder_model).__class__
            self._cls['model'] = AutoModel.from_pretrained(self.encoder_model).__class__

        self.transformer_config = self._cls['config'].from_pretrained(self.encoder_path)

    def copy(self, **kwargs):
        base = self.to_kwargs()
        for key in set(base.keys()).intersection(kwargs.keys()):
            base[key] = kwargs[key]

        return ModelConfig(**base)

    @property
    def embedding_dim(self):
        return self.transformer_config.embedding_size

    @property
    def hidden_dim(self):
        return self.transformer_config.hidden_size

    @property
    def intermediate_dim(self):
        return self.transformer_config.intermediate_size

    @property
    def num_decoder_heads(self):
        return self.transformer_config.num_attention_heads

    @property
    def init_factor(self):
        return self.transformer_config.initializer_range

    @property
    def layernorm_eps(self):
        return self.transformer_config.layer_norm_eps

    @property
    def dropout_layer(self):
        return self.transformer_config.hidden_dropout_prob

    @property
    def dropout_attn(self):
        return self.transformer_config.attention_probs_dropout_prob

    @property
    def experiment_name(self):
        return '%s-%s-%sL%sP' % (self.model_type, self.encoder_model.split('-')[1],
                                 self.num_decoder_layers, self.num_pointer_heads)

    def save_pretrained(self, path_to_save: str):
        self.transformer_config.save_pretrained(path_to_save)

        with Path(path_to_save, 'PageConfig.json').open('w+t', encoding='UTF-8') as fp:
            save(self.to_kwargs(), fp)

    @classmethod
    def from_pretrained(cls, path: str, enforce_path: bool = False):
        path = Path(path, 'PageConfig.json') if not enforce_path else Path(path)
        with path.open('r+t', encoding='UTF-8') as fp:
            kwargs = load(fp)

        return ModelConfig(**kwargs)

    def load_encoder(self):
        return self._cls['model'].from_pretrained(self.encoder_path)

    def load_tokenizer(self):
        return self._cls['tokenizer'].from_pretrained(self.encoder_path)

    def set_chkpt_path(self, path):
        self.chkpt_path = str(path) if path is not None else None

    @property
    def encoder_path(self):
        return self.chkpt_path if self.chkpt_path is not None else self.encoder_model

    def to_kwargs(self):
        return {
            'encoder_model': self.encoder_model,
            'chkpt_path': self.chkpt_path,
            'model_type': self.model_type,
            'num_decoder_layers': self.num_decoder_layers,
            'num_pointer_heads': self.num_pointer_heads,
            'beam_size': self.beam_size,
            'token_vocab_size': self.token_vocab_size,
            'token_nonum_size': self.token_nonum_size,
            'function_word_size': self.function_word_size,
            'constant_word_size': self.constant_word_size,
            'argument_word_size': self.argument_word_size
        }


class OptimizerConfig:
    def __init__(self, optimizer: str, **kwargs):
        self.optimizer = optimizer.lower()
        kwargs['betas'] = kwargs['beta1'], kwargs['beta2']

        if self.optimizer == 'lamb':
            kwargs = filter_dict_by_keys(kwargs,
                                         'lr', 'betas', 'eps', 'weight_decay', 'clamp_value', 'adam', 'debias')
        elif optimizer == 'radam':
            kwargs = filter_dict_by_keys(kwargs, 'lr', 'betas', 'eps', 'weight_decay')
        elif optimizer == 'adabound':
            kwargs = filter_dict_by_keys(kwargs,
                                         'lr', 'betas', 'eps', 'weight_decay', 'final_lr', 'gamma', 'amsbound')
        elif optimizer == 'yogi':
            kwargs = filter_dict_by_keys(kwargs, 'lr', 'betas', 'eps', 'weight_decay', 'initial_accumulator')
        else:  # AdamW
            kwargs = filter_dict_by_keys(kwargs, 'lr', 'betas', 'eps', 'weight_decay')

        self.kwargs = {}
        for key, value in kwargs.items():
            if not isinstance(key, (int, float, str, bool, list)):
                value = float(value) if 'float' in key.dtype.name else int(value)
            self.kwargs[key] = value

    def copy(self, **kwargs):
        base = self.to_kwargs()
        for key in set(base.keys()).intersection(kwargs.keys()):
            base[key] = kwargs[key]

        return OptimizerConfig(**base)

    def build(self, params):
        if self.optimizer == 'lamb':
            from torch_optimizer import Lamb
            cls = Lamb
        elif self.optimizer == 'radam':
            from torch_optimizer import RAdam
            cls = RAdam
        elif self.optimizer == 'adabound':
            from torch_optimizer import AdaBound
            cls = AdaBound
        elif self.optimizer == 'yogi':
            from torch_optimizer import Yogi
            cls = Yogi
        else:
            from transformers import AdamW
            cls = AdamW

        return cls(params, **self.kwargs)

    def adjust_learning_rate(self, factor):
        self.kwargs['lr'] = self.kwargs.get('lr', 0.00176) * factor

    def to_kwargs(self):
        kwargs = self.kwargs.copy()
        kwargs['optimizer'] = self.optimizer
        kwargs['beta1'], kwargs['beta2'] = kwargs.pop('betas')

        return kwargs

    def save_pretrained(self, path_to_save: str):
        with Path(path_to_save, 'OptConfig.json').open('w+t', encoding='UTF-8') as fp:
            save(self.to_kwargs(), fp)

    @classmethod
    def from_pretrained(cls, path: str, enforce_path: bool = False):
        path = Path(path, 'OptConfig.json') if not enforce_path else Path(path)
        with path.open('r+t', encoding='UTF-8') as fp:
            return OptimizerConfig(**load(fp))


class TrainerConfig:
    def __init__(self, model: ModelConfig, optimizer: OptimizerConfig, batch: int = 4096,
                 gradient_accumulation_steps: int = 1, gradient_clip: float = 10.0, gradient_normalize: bool = False,
                 epoch: int = 1000, epoch_warmup: int = 25, epoch_chkpt: int = 10, epoch_report: int = 5,
                 fix_encoder_embedding: bool = True, lr_multiplier_encoder: float = 1.0,
                 seed: int = 1):
        self.model = model
        self.optimizer = optimizer

        self.batch = int(batch)
        self.fix_encoder_embedding = bool(fix_encoder_embedding)
        self.lr_multiplier_encoder = float(lr_multiplier_encoder)
        self.gradient_accumulation_steps = int(gradient_accumulation_steps)
        self.gradient_clip = float(gradient_clip)
        self.gradient_normalize = bool(gradient_normalize)
        self.epoch = int(epoch)
        self.epoch_warmup = float(epoch_warmup)
        self.epoch_chkpt = int(epoch_chkpt)
        self.epoch_report = int(epoch_report)
        self.seed = int(seed)

    def get(self, item):
        if hasattr(self, item):
            return getattr(self, item)
        else:
            modelargs = self.model.to_kwargs()
            if item in modelargs:
                return modelargs[item]
            else:
                return self.optimizer.to_kwargs()[item]

    def copy(self, **kwargs):
        base = dict(batch=self.batch, gradient_accumulation_steps=self.gradient_accumulation_steps,
                    gradient_clip=self.gradient_clip, gradient_normalize=self.gradient_normalize,
                    epoch=self.epoch, epoch_warmup=self.epoch_warmup, epoch_chkpt=self.epoch_chkpt,
                    epoch_report=self.epoch_report, fix_encoder_embedding=self.fix_encoder_embedding,
                    lr_multiplier_encoder=self.lr_multiplier_encoder, seed=self.seed)

        for key in set(base.keys()).intersection(kwargs.keys()):
            base[key] = kwargs[key]

        return TrainerConfig(self.model.copy(**kwargs), self.optimizer.copy(**kwargs), **base)

    def to_kwargs(self):
        return {
            'batch': self.batch,
            'seed': self.seed,
            'fix_encoder_embedding': self.fix_encoder_embedding,
            'lr_multiplier_encoder': self.lr_multiplier_encoder,
            'gradient': {
                'accumulation_steps': self.gradient_accumulation_steps,
                'clip': self.gradient_clip,
                'normalize': self.gradient_normalize,
            },
            'epoch': {
                'total': self.epoch,
                'warmup': self.epoch_warmup,
                'chkpt': self.epoch_chkpt,
                'report': self.epoch_report,
            },
            'model': self.model.to_kwargs(),
            'optimizer': self.optimizer.to_kwargs()
        }

    def save_pretrained(self, path_to_save: str, enforce_path: bool = False):
        path_to_save = Path(path_to_save) if enforce_path else Path(path_to_save, 'TrainConfig.json')
        with path_to_save.open('w+t', encoding='UTF-8') as fp:
            save(self.to_kwargs(), fp)

    @classmethod
    def from_pretrained(cls, path: str):
        path = Path(path)
        if not path.is_file():
            path = Path(path, 'TrainConfig.json')

        parent = path.parent
        with path.open('r+t', encoding='UTF-8') as fp:
            config = load(fp)

        # Read model config
        model = config.pop('model')
        model = ModelConfig.from_pretrained(Path(parent, model), enforce_path=True) if type(model) is str \
            else ModelConfig(**model)

        # Read optimizer config
        optim = config.pop('optimizer')
        optim = OptimizerConfig.from_pretrained(Path(parent, optim), enforce_path=True) if type(optim) is str \
            else OptimizerConfig(**optim)

        # Read default
        default = {}
        if 'default' in config:
            with Path(parent, config.pop('default')).open('r+t', encoding='UTF-8') as fp:
                default = load(fp)

        kwargs = {}
        # Apply default values first, and then overwrite specified keys with specified values
        for key, value in list(default.items()) + list(config.items()):
            if type(value) is dict:
                kwargs.update({key + ('_' + subkey if subkey else ''): val for subkey, val in value.items()})
            else:
                kwargs[key] = value

        if 'epoch_total' in kwargs:
            kwargs['epoch'] = kwargs.pop('epoch_total')

        return TrainerConfig(model, optim, **kwargs)

    def read_datasets(self, train: str, test: str, dev: str = None):
        from page.torch.dataset import TokenBatchIterator

        # Prepare fields
        tokenizer = self.model.load_tokenizer()
        prob_field = TransformerTextField(tokenizer, is_target=False)
        token_gen_field = TokenEquationField(['X_'], ['N_'], 'C_', generate_all=True)
        token_ptr_field = TokenEquationField(['X_'], ['N_'], 'C_', generate_all=False)
        tuple_gen_field = OperationEquationField(['X_'], ['N_'], 'C_', max_arity=2, force_generation=True)
        tuple_ptr_field = OperationEquationField(['X_'], ['N_'], 'C_', max_arity=2, force_generation=False)

        # Load datasets
        trainset = TokenBatchIterator(train, prob_field, token_gen_field, token_ptr_field,
                                      tuple_gen_field, tuple_ptr_field, self.batch, testing_purpose=False)
        evalset = TokenBatchIterator(test, prob_field, token_gen_field, token_ptr_field,
                                     tuple_gen_field, tuple_ptr_field, self.batch, testing_purpose=True)
        if dev is not None and test != dev:
            devset = TokenBatchIterator(dev, prob_field, token_gen_field, token_ptr_field,
                                        tuple_gen_field, tuple_ptr_field, self.batch, testing_purpose=True)
        else:
            devset = evalset

        # Specify vocab size
        self.model.token_vocab_size = len(token_gen_field.token_vocab)
        self.model.token_nonum_size = len(token_ptr_field.token_vocab)
        self.model.function_word_size = len(tuple_gen_field.function_word_vocab)
        self.model.argument_word_size = len(tuple_gen_field.constant_word_vocab)
        self.model.constant_word_size = len(tuple_ptr_field.constant_word_vocab)

        return trainset, devset, evalset


__all__ = ['ModelConfig', 'OptimizerConfig', 'TrainerConfig']
