import os
from datetime import datetime, timedelta
from parser.cmds.cmd import CMD
from parser.models import CRFAE, FeatureHMM
from parser.utils.common import pad, unk
from parser.utils.data import CoNLL, Dataset, ElmoField, FeatureField, Field
from parser.utils.fn import heatmap, replace_digit_fn
from parser.utils.logging import get_logger, progress_bar
from parser.utils.metric import ManyToOneMetric, Metric

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

logger = get_logger(__name__)


class Train(CMD):
    def create_subparser(self, parser, mode):
        """
        create a subparser and add arguments in this subparser
        Args:
            parser: parent parser
            mode (str): mode of this subparser

        Returns:

        """
        subparser = parser.add_parser(mode, help='Train model.')
        subparser.add_argument('--train',
                               default='data/ptb/total.conll',
                               help='path to train file')
        subparser.add_argument('--evaluate',
                               default='data/ptb/total.conll',
                               help='path to evaluate file')
        subparser.add_argument('--layer',
                               default=None,
                               type=int,
                               help='layer of elmo to use in train CRF')
        subparser.add_argument('--feature-hmm-model',
                               default=None,
                               help='path to evaluate file')
        subparser.add_argument('--model-path',
                               default="data/ml_elmo/de",
                               help='path to evaluate file')
        subparser.add_argument('--ignore-capitalized',
                               action="store_true",
                               help='random init CRF-AE model')
        subparser.add_argument('--language-specific-strip',
                               action="store_true",
                               help='random init CRF-AE model')
        subparser.add_argument('--language',
                               default="en",
                               help='path to evaluate file')
        subparser.add_argument('--hmm-epochs',
                               default=50,
                               type=int,
                               help='max num of buckets to use')
        subparser.add_argument('--crf-init-epochs',
                               default=5,
                               type=int,
                               help='max num of buckets to use')
        subparser.add_argument('--crf-epochs',
                               default=50,
                               type=int,
                               help='max num of buckets to use')
        return subparser

    def __call__(self, args):
        super(Train, self).__call__(args)

        # create dir for files
        if not os.path.exists(args.save):
            os.mkdir(args.save)

        if (args.feature_hmm_model is None or not args.use_cached
                or not os.path.exists(args.feature_hmm_model)):
            args.feature_hmm_model = os.path.join(args.save, 'feature_hmm')

        if not args.use_cached or not os.path.exists(args.fields):
            # fields
            # word field
            self.word_field = FeatureField(
                'words',
                pad=pad,
                unk=unk,
                replace_mode=args.language if 'ud' in args.train else None,
                replace_punct='ud' in args.train,
                ud_feature_templates='ud' in args.train,
                language_specific_strip=args.language_specific_strip,
                lower=args.ignore_capitalized)

            # label field
            self.label_field = Field('labels')
            self.pred_field = Field('preds', is_placeholder=True)
            # char field
            self.char_field = ElmoField("chars")
            if 'ud' in args.train:
                self.fields = CoNLL(FORM=(self.word_field, self.char_field),
                                    CPOS=self.label_field,
                                    FEATS=self.pred_field)
            else:
                self.fields = CoNLL(FORM=(self.word_field, self.char_field),
                                    POS=self.label_field,
                                    FEATS=self.pred_field)

            # load dataset
            train = Dataset(self.fields, args.train)
            evaluate = Dataset(self.fields, args.evaluate)

            # build vocab
            self.word_field.build(train, args.min_freq)
            self.label_field.build(train)

            # save fields
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            (self.word_field, self.char_field
             ), self.pred_field = self.fields.FORM, self.fields.FEATS
            if 'ud' in args.train:
                self.label_field = self.fields.CPOS
            else:
                self.label_field = self.fields.POS
            train = Dataset(self.fields, args.train)
            evaluate = Dataset(self.fields, args.evaluate)

        # set the data loaders
        train.build(args.batch_size,
                    n_buckets=args.n_buckets,
                    shuffle=True,
                    seed=self.args.seed)
        evaluate.build(args.batch_size, n_buckets=args.n_buckets)

        logger.info(f"Train Dateset {train}")

        ud_mode = 'ud' in args.train

        args.update({
            'n_words': self.word_field.vocab.n_init,
            'n_word_features': self.word_field.n_word_features,
            'n_unigram_features': self.word_field.n_unigram_features,
            'n_bigram_features': self.word_field.n_bigram_features,
            'n_trigram_features': self.word_field.n_trigram_features,
            'n_morphology_features': self.word_field.n_morphology_features,
            'n_labels': 12 if ud_mode else 45,
            'pad_index': self.word_field.pad_index,
            'unk_index': self.word_field.unk_index,
            'ud_mode': ud_mode
        })
        logger.info(f"\n{args}")
        logger.info("Create the Feature HMM model")

        total_time = timedelta()

        if not args.use_cached or not os.path.exists(args.feature_hmm_model):
            min_loss = float("inf")
            min_restart, min_epoch = 0, 0
            feature_hmm = FeatureHMM(args, self.word_field.features).to(
                self.args.device)
            # optimizer
            optimizer = Adam(feature_hmm.parameters(), args.hmm_lr,
                                (args.hmm_mu, args.hmm_nu), args.hmm_epsilon,
                                args.hmm_weight_decay)
            # scheduler
            decay_steps = args.hmm_decay_epochs * len(train.loader)
            scheduler = ExponentialLR(optimizer,
                                        args.hmm_decay**(1 / decay_steps))

            for epoch in range(1, args.hmm_epochs + 1):
                logger.info(
                    f"Epoch {epoch} / {args.hmm_epochs}:"
                )
                start = datetime.now()
                # train
                self.train_feature_hmm(feature_hmm, train.loader,
                                        optimizer, scheduler)
                # evaluate
                loss, metric = self.evaluate_feature_hmm(
                    feature_hmm, evaluate.loader)
                logger.info(f"{'dev:':6} Loss: {loss:.4f} {metric}")

                time_spent = datetime.now() - start
                total_time += time_spent
                # save the model if it is the best so far
                if loss < min_loss:
                    min_loss = loss
                    min_epoch = epoch
                    feature_hmm.save(args.feature_hmm_model)
                    logger.info(f"{time_spent}s elapsed (saved)\n")
                else:
                    logger.info(f"{time_spent}s elapsed\n")

            train.reset_loader(args.batch_size,
                               shuffle=True,
                               seed=self.args.seed)
            logger.info(
                f"min_loss of Feature HMM is {min_loss:.2f} at epoch {min_epoch}"
            )

        # load best feature hmm model
        feature_hmm = FeatureHMM.load(args.feature_hmm_model).to(
            self.args.device)

        loss, metric = self.evaluate_feature_hmm(feature_hmm, evaluate.loader)
        heatmap(metric.clusters.cpu(),
                list(self.label_field.vocab.stoi.keys()),
                f"{args.feature_hmm_model}.best.clusters",
                match=metric.match)
        logger.info(f"{'Best HMM:':10} Loss: {loss:.4f} {metric}")

        if args.crf_init_epochs < 0 and args.crf_epochs < 0:
            exit()

        # create the crf ae model
        logger.info("Create the CRF AE model")
        crf_ae = CRFAE(args, feature_hmm).to(self.args.device)
        logger.info(crf_ae)

        logger.info("Train CRF AE encoder")
        # optimizer
        optimizer = Adam(crf_ae.parameters(), args.crf_init_lr,
                         (args.crf_init_mu, args.crf_init_nu),
                         args.crf_init_epsilon, args.crf_init_weight_decay)
        # scheduler
        decay_steps = args.crf_init_decay_epochs * len(train.loader)
        scheduler = ExponentialLR(optimizer,
                                  args.crf_init_decay**(1 / decay_steps))

        for epoch in range(1, args.crf_init_epochs + 1):
            logger.info(f"Epoch {epoch} / {args.crf_init_epochs}:")
            start = datetime.now()
            # train
            self.train_crf(crf_ae, train.loader, optimizer, scheduler)
            time_spent = datetime.now() - start
            logger.info(f"{time_spent}s elapsed\n")

        loss, metric = self.evaluate_crf_ae(crf_ae, evaluate.loader)
        heatmap(metric.clusters.cpu(),
                list(self.label_field.vocab.stoi.keys()),
                f"{args.crf_model}.init.clusters",
                match=metric.match)
        logger.info(f"{'CRF init:':10} Loss: {loss:.4f} {metric}")

        logger.info("Train CRF AE")
        optimizer = Adam([{
            "params": crf_ae.feature_hmm.parameters(),
            "lr": args.recons_lr
        }, {
            "params": crf_ae.represent_ln.parameters()
        }, {
            "params": crf_ae.encoder_emit_scorer.parameters()
        }, {
            "params": crf_ae.encoder_emit_ln.parameters()
        }, {
            "params": crf_ae.start
        }, {
            "params": crf_ae.transitions
        }, {
            "params": crf_ae.end
        }], args.crf_lr, (args.crf_mu, args.crf_nu), args.crf_epsilon,
                         args.crf_weight_decay)
        # scheduler
        decay_steps = args.crf_decay_epochs * len(train.loader)
        scheduler = ExponentialLR(optimizer, args.crf_decay**(1 / decay_steps))

        best_e, best_metric = 1, Metric()
        min_loss = float("inf")

        for epoch in range(1, args.crf_epochs + 1):
            logger.info(f"Epoch {epoch} / {args.crf_epochs}:")
            start = datetime.now()
            # train
            self.train_crf_ae(crf_ae, train.loader, optimizer, scheduler)
            # evaluate
            loss, dev_metric = self.evaluate_crf_ae(crf_ae, evaluate.loader)
            clusters = dev_metric.clusters
            heatmap(clusters.cpu(),
                    list(self.label_field.vocab.stoi.keys()),
                    f"{args.crf_model}.clusters",
                    match=dev_metric.match)
            logger.info(f"{'dev:':10} Loss: {loss:.4f} {dev_metric}")

            time_spent = datetime.now() - start
            total_time += time_spent
            # save the model if it is the best so far
            if loss < min_loss:
                best_e, best_metric, min_loss = epoch, dev_metric, loss
                crf_ae.save(args.crf_model)
                logger.info(f"{time_spent}s elapsed (saved)\n")
            else:
                logger.info(f"{time_spent}s elapsed\n")

        logger.info(f"max score of CRF is at epoch {best_e}")
        logger.info(f"{'eval:':10} Loss: {min_loss:.4f} {best_metric}")
        logger.info(f"{total_time}s elapsed")

        crf_ae = CRFAE.load(args.crf_model).to(self.args.device)
        preds = self.predict_crf_ae(crf_ae, evaluate.loader)
        for name, value in preds.items():
            setattr(evaluate, name, value)
        self.fields.save(f"{args.crf_model}-{self.timestamp}.predict.conllx",
                         evaluate.sentences)

    def train_feature_hmm(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, *_ in bar:
            optimizer.zero_grad()
            mask = words.ne(self.args.pad_index)
            emits, start, transitions, end = model(words)
            # compute loss
            loss = model.loss(emits, start, transitions, end, mask)
            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.hmm_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    @torch.no_grad()
    def evaluate_feature_hmm(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()
        metric = ManyToOneMetric(self.args.n_labels, self.args.device)

        total_loss = 0
        sent_count = 0
        for words, *_, labels in progress_bar(loader):
            sent_count += len(words)
            mask = words.ne(self.args.pad_index)
            emits, start, transitions, end = model(words)
            # compute loss
            loss = model.loss(emits, start, transitions, end, mask)
            # predict
            predicts = model.predict(emits, start, transitions, end, mask)
            metric(predicts=predicts[mask], golds=labels[mask])
            total_loss += loss.item()
        total_loss /= sent_count
        return total_loss, metric

    def train_crf(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, chars, _ in bar:
            mask = words.ne(self.args.pad_index)
            # use feature hmm to generate labels
            with torch.no_grad():
                emits, start, transitions, end = model.feature_hmm(words)
                labels = model.feature_hmm.predict(emits, start, transitions,
                                                   end, mask)

            optimizer.zero_grad()

            encoder_emits = model(words, chars)
            # compute loss
            loss = model.crf_loss(encoder_emits, labels, mask)

            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.crf_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    def train_crf_ae(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, chars, _ in bar:
            optimizer.zero_grad()
            mask = words.ne(self.args.pad_index)
            encoder_emits = model(words, chars)
            # compute loss
            loss = model.loss(words, encoder_emits, mask)

            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.crf_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    @torch.no_grad()
    def evaluate_crf_ae(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()

        total_loss = 0
        metric = ManyToOneMetric(self.args.n_labels, self.args.device)
        sent_count = 0
        for words, chars, labels in progress_bar(loader):
            sent_count += len(words)
            mask = words.ne(self.args.pad_index)
            encoder_emits = model(words, chars)
            # compute loss
            loss = model.loss(words, encoder_emits, mask)
            # predict
            predicts = model.predict(words, encoder_emits, mask)
            metric(predicts=predicts[mask], golds=labels[mask])
            total_loss += loss.item()
        total_loss /= sent_count
        return total_loss, metric

    @torch.no_grad()
    def predict_crf_ae(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()

        preds = {}
        labels = []
        for words, chars, _ in progress_bar(loader):
            mask = words.ne(self.args.pad_index)
            lens = mask.sum(1).tolist()
            # ignore the first token of each sentence
            encoder_emits = model(words, chars)
            predicts = model.predict(words, encoder_emits, mask)
            labels.extend(predicts[mask].split(lens))

        labels = [[f"#C{t}#" for t in seq.tolist()] for seq in labels]
        preds = {'preds': labels}

        return preds
