import os
from datetime import datetime, timedelta
from parser.cmds.cmd import CMD
from parser.models import CRFAE, FeatureHMM
from parser.utils.common import pad, unk
from parser.utils.data import CoNLL, Dataset, ElmoField, FeatureField, Field
from parser.utils.fn import heatmap, replace_digit_fn
from parser.utils.logging import get_logger, progress_bar
from parser.utils.metric import UnsupervisedPOSMetric, Metric

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

logger = get_logger(__name__)


class Train(CMD):
    def create_subparser(self, parser, mode):
        """
        create a subparser and add arguments in this subparser
        Args:
            parser: parent parser
            mode (str): mode of this subparser

        Returns:

        """
        subparser = parser.add_parser(mode, help='Train model.')
        subparser.add_argument('--train',
                               default='data/ptb/train.conll',
                               help='path to train file')
        subparser.add_argument('--evaluate',
                               default='data/ptb/dev.conll',
                               help='path to evaluate file')
        subparser.add_argument('--test',
                               default='data/ptb/test.conll',
                               help='path to evaluate file')
        subparser.add_argument('--use_cached',
                               '-u',
                               action="store_true",
                               help='use cached hmm model')
        subparser.add_argument('--rand_init',
                               action="store_true",
                               help='random init CRF-AE model')
        subparser.add_argument('--feature-hmm-model',
                               default=None,
                               help='path to evaluate file')
        subparser.add_argument('--hmm-epochs',
                               default=50,
                               type=int,
                               help='max num of threads')
        subparser.add_argument('--crf-epochs',
                               default=50,
                               type=int,
                               help='max num of threads')
        subparser.add_argument('--crf-init-epochs',
                               default=5,
                               type=int,
                               help='max num of threads')
        return subparser

    def __call__(self, args):
        super(Train, self).__call__(args)

        # create dir for files
        if not os.path.exists(args.save):
            os.mkdir(args.save)

        if (args.feature_hmm_model is None or not args.use_cached
                or not os.path.exists(args.feature_hmm_model)):
            args.feature_hmm_model = os.path.join(args.save, 'feature_hmm')

        if not args.use_cached or not os.path.exists(args.fields):
            # fields
            # word field
            if args.without_feature:
                self.word_field = Field('words',
                                        pad=pad,
                                        unk=unk,
                                        fn=replace_digit_fn)
            else:
                self.word_field = FeatureField('words',
                                               pad=pad,
                                               unk=unk,
                                               fn=replace_digit_fn)
            # label field
            self.label_field = Field('labels')
            self.pred_field = Field('preds', is_placeholder=True)
            # char field
            self.char_field = ElmoField("chars")
            self.fields = CoNLL(FORM=(self.word_field, self.char_field),
                                POS=self.label_field,
                                FEATS=self.pred_field)

            # load dataset
            train = Dataset(self.fields, args.train)
            evaluate = Dataset(self.fields, args.evaluate)
            test = Dataset(self.fields, args.test)

            # build vocab
            self.word_field.build(train, args.min_freq)
            self.label_field.build(train)

            # save fields
            torch.save(self.fields, args.fields)
        else:
            self.fields = torch.load(args.fields)
            (
                self.word_field, self.char_field
            ), self.label_field, self.pred_field = self.fields.FORM, self.fields.POS, self.fields.FEATS
            train = Dataset(self.fields, args.train)
            evaluate = Dataset(self.fields, args.evaluate)
            test = Dataset(self.fields, args.test)

        # set the data loaders
        train.build(args.batch_size,
                    n_buckets=args.n_buckets,
                    shuffle=True,
                    seed=self.args.seed)
        evaluate.build(args.batch_size, n_buckets=args.n_buckets)
        test.build(args.batch_size, n_buckets=args.n_buckets)

        logger.info(f"Train Dateset {train}")
        logger.info(f"Dev   Dateset {evaluate}")
        logger.info(f"Test  Dateset {test}")

        args.update({
            'n_words': self.word_field.vocab.n_init,
            'n_labels': len(self.label_field.vocab),
            'pad_index': self.word_field.pad_index,
            'unk_index': self.word_field.unk_index,
        })

        if not args.without_feature:
            args.update({
                'n_word_features':
                self.word_field.n_word_features,
                'n_unigram_features':
                self.word_field.n_unigram_features,
                'n_bigram_features':
                self.word_field.n_bigram_features,
                'n_trigram_features':
                self.word_field.n_trigram_features,
                'n_morphology_features':
                self.word_field.n_morphology_features
            })

        logger.info(f"\n{args}")
        logger.info(
            f"Create the {'' if args.without_feature else 'Feature'} HMM model"
        )

        total_time = timedelta()

        if not args.rand_init:
            if not args.use_cached or not os.path.exists(
                    args.feature_hmm_model):
                feature_hmm = FeatureHMM(
                    args, None if args.without_feature else
                    self.word_field.features).to(self.args.device)
                logger.info(feature_hmm)
                total_time = self.train_feature_hmm(feature_hmm, args, train,
                                                    evaluate, test, total_time)
            train.reset_loader(args.batch_size,
                               shuffle=True,
                               seed=self.args.seed)
            # load best feature hmm model
            feature_hmm = FeatureHMM.load(args.feature_hmm_model).to(
                self.args.device)
        else:
            feature_hmm = FeatureHMM(
                args,
                None if args.without_feature else self.word_field.features).to(
                    self.args.device)

        # create the crf ae model
        logger.info("Create the CRF AE model")
        crf_ae = CRFAE(args, feature_hmm).to(self.args.device)
        logger.info(crf_ae)

        if not args.rand_init:
            total_time = self.train_crf(crf_ae, args, train, evaluate, test,
                                        total_time)

        total_time = self.train_crf_ae(crf_ae, args, train, evaluate, test,
                                       total_time)

        crf_ae = CRFAE.load(args.crf_model).to(self.args.device)
        preds = self.predict_crf_ae(crf_ae, evaluate.loader)
        for name, value in preds.items():
            setattr(evaluate, name, value)
        self.fields.save(f"{args.crf_model}-{self.timestamp}.predict.conllx",
                         evaluate.sentences)
        logger.info(
            f"prediction have saved to {args.crf_model}-{self.timestamp}.predict.conllx"
        )
        logger.info(f"{total_time}s elapsed")

    def train_crf_ae(self, crf_ae, args, train, evaluate, test, total_time):
        logger.info("Train CRF AE")
        parameters = [{
            "params": crf_ae.feature_hmm.parameters(),
            "lr": args.recons_lr
        }, {
            "params": crf_ae.represent_ln.parameters()
        }, {
            "params": crf_ae.encoder_emit_scorer.parameters()
        }, {
            "params": crf_ae.encoder_emit_ln.parameters()
        }, {
            "params": crf_ae.start
        }, {
            "params": crf_ae.transitions
        }, {
            "params": crf_ae.end
        }]

        if args.rand_init:
            parameters.append({"params": crf_ae.encoder.parameters()})

        optimizer = Adam(parameters, args.crf_lr, (args.crf_mu, args.crf_nu),
                         args.crf_epsilon, args.crf_weight_decay)
        # scheduler
        decay_steps = args.crf_decay_epochs * len(train.loader)
        scheduler = ExponentialLR(optimizer, args.crf_decay**(1 / decay_steps))
        best_e, best_metric, best_test_metric = 1, Metric(), Metric()
        min_loss, min_test_loss = float("inf"), float("inf")
        for epoch in range(1, args.crf_epochs + 1):
            logger.info(f"Epoch {epoch} / {args.crf_epochs}:")
            start = datetime.now()
            # train
            self.train_crf_ae_single(crf_ae, train.loader, optimizer,
                                     scheduler)
            # evaluate
            dev_loss, dev_metric = self.evaluate_crf_ae(
                crf_ae, evaluate.loader)
            logger.info(f"{'dev:':10} Loss: {dev_loss:>8.4f} {dev_metric}")
            dev_m2o_match, dev_o2o_match = dev_metric.match
            test_loss, test_metric = self.evaluate_crf_ae(crf_ae, test.loader)
            # using match from dev, so for test is pure unsupervised
            test_metric.set_match(dev_m2o_match, dev_o2o_match)
            logger.info(f"{'test:':10} Loss: {test_loss:>8.4f} {test_metric}")
            heatmap(dev_metric.clusters.cpu(),
                    list(self.label_field.vocab.stoi.keys()),
                    f"{args.crf_model}.clusters",
                    match=dev_metric.match[-1])
            time_spent = datetime.now() - start
            total_time += time_spent
            now, best = -dev_loss, -min_loss
            # save the model if it is the best so far
            if now > best:
                best_e, best_metric, best_test_metric = epoch, dev_metric, test_metric
                min_loss, min_test_loss = dev_loss, test_loss
                crf_ae.save(args.crf_model)
                logger.info(f"{time_spent}s elapsed (saved)\n")
            else:
                logger.info(f"{time_spent}s elapsed\n")
        logger.info(f"max score of CRF is at epoch {best_e}")
        logger.info(f"{'dev:':10} Loss: {min_loss:>8.4f} {best_metric}")
        logger.info(
            f"{'test:':10} Loss: {min_test_loss:>8.4f} {best_test_metric}")
        return total_time

    def train_crf(self, crf_ae, args, train, evaluate, test, total_time):
        logger.info("Train CRF AE encoder")
        # optimizer
        optimizer = Adam(crf_ae.parameters(), args.crf_init_lr,
                         (args.crf_init_mu, args.crf_init_nu),
                         args.crf_init_epsilon, args.crf_init_weight_decay)
        # scheduler
        decay_steps = args.crf_init_decay_epochs * len(train.loader)
        scheduler = ExponentialLR(optimizer,
                                  args.crf_init_decay**(1 / decay_steps))
        test_loss, test_metric = 0, Metric()
        for epoch in range(1, args.crf_init_epochs + 1):
            logger.info(f"Epoch {epoch} / {args.crf_init_epochs}:")
            start = datetime.now()
            # train
            self.train_crf_single(crf_ae, train.loader, optimizer, scheduler)
            # evaluate
            dev_loss, dev_metric = self.evaluate_crf_ae(
                crf_ae, evaluate.loader)
            logger.info(f"{'dev:':10} Loss: {dev_loss:>8.4f} {dev_metric}")
            dev_m2o_match, dev_o2o_match = dev_metric.match
            test_loss, test_metric = self.evaluate_crf_ae(crf_ae, test.loader)
            # using match from dev, so for test is pure unsupervised
            test_metric.set_match(dev_m2o_match, dev_o2o_match)
            logger.info(f"{'test:':10} Loss: {test_loss:>8.4f} {test_metric}")
            time_spent = datetime.now() - start
            total_time += time_spent
            logger.info(f"{time_spent}s elapsed\n")
        heatmap(test_metric.clusters.cpu(),
                list(self.label_field.vocab.stoi.keys()),
                f"{args.crf_model}.init.clusters",
                match=test_metric.match[-1])
        return total_time

    def train_feature_hmm(self, feature_hmm, args, train, evaluate, test,
                          total_time):
        min_loss, min_test_loss = float("inf"), float("inf")
        best_metric = Metric()
        best_test_metric = Metric()
        min_epoch = 0
        # optimizer
        optimizer = Adam(feature_hmm.parameters(), args.hmm_lr,
                         (args.hmm_mu, args.hmm_nu), args.hmm_epsilon,
                         args.hmm_weight_decay)
        # scheduler
        decay_steps = args.hmm_decay_epochs * len(train.loader)
        scheduler = ExponentialLR(optimizer, args.hmm_decay**(1 / decay_steps))
        for epoch in range(1, args.hmm_epochs + 1):
            logger.info(f"Epoch {epoch} / {args.hmm_epochs}:")
            start = datetime.now()
            # train
            self.train_feature_hmm_single(feature_hmm, train.loader, optimizer,
                                          scheduler)
            # evaluate
            dev_loss, dev_metric = self.evaluate_feature_hmm(
                feature_hmm, evaluate.loader)
            logger.info(f"{'dev:':10} Loss: {dev_loss:>8.4f} {dev_metric}")
            dev_m2o_match, dev_o2o_match = dev_metric.match
            test_loss, test_metric = self.evaluate_feature_hmm(
                feature_hmm, test.loader)
            # using match from dev, so for test is pure unsupervised
            test_metric.set_match(dev_m2o_match, dev_o2o_match)
            logger.info(f"{'test:':10} Loss: {test_loss:>8.4f} {test_metric}")
            time_spent = datetime.now() - start
            total_time += time_spent
            now, best = -dev_loss, -min_loss

            # save the model if it is the best so far
            if now > best:
                min_loss = dev_loss
                min_test_loss = test_loss
                min_epoch = epoch
                best_metric = dev_metric
                best_test_metric = test_metric
                feature_hmm.save(args.feature_hmm_model)
                logger.info(f"{time_spent}s elapsed (saved)\n")
            else:
                logger.info(f"{time_spent}s elapsed\n")
        logger.info(
            f"max score of {'' if args.without_feature else 'Feature'} HMM is at epoch {min_epoch}"
        )
        heatmap(best_test_metric.clusters.cpu(),
                list(self.label_field.vocab.stoi.keys()),
                f"{args.feature_hmm_model}.best.clusters",
                match=best_test_metric.match[-1])
        logger.info(f"{'Best HMM:':10}")
        logger.info(f"{'dev:':10} Loss: {min_loss:>8.4f} {best_metric}")
        logger.info(
            f"{'test:':10} Loss: {min_test_loss:>8.4f} {best_test_metric}")
        return total_time

    def train_feature_hmm_single(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, *_ in bar:
            optimizer.zero_grad()
            mask = words.ne(self.args.pad_index)
            emits, start, transitions, end = model(words)
            # compute loss
            loss = model.loss(emits, start, transitions, end, mask)
            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.hmm_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    @torch.no_grad()
    def evaluate_feature_hmm(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()
        metric = UnsupervisedPOSMetric(
            self.args.n_labels,
            self.args.device)

        total_loss = 0
        sent_count = 0
        for words, *_, labels in loader:
            sent_count += len(words)
            mask = words.ne(self.args.pad_index)
            emits, start, transitions, end = model(words)
            # compute loss
            loss = model.loss(emits, start, transitions, end, mask)
            # predict
            predicts = model.predict(emits, start, transitions, end, mask)
            metric(predicts=predicts[mask], golds=labels[mask])
            total_loss += loss.item()
        total_loss /= sent_count
        return total_loss, metric

    def train_crf_single(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, chars, _ in bar:
            mask = words.ne(self.args.pad_index)
            # use feature hmm to generate labels
            with torch.no_grad():
                emits, start, transitions, end = model.feature_hmm(words)
                labels = model.feature_hmm.predict(emits, start, transitions,
                                                   end, mask)

            optimizer.zero_grad()

            encoder_emits = model(words, chars)
            # compute loss
            loss = model.crf_loss(encoder_emits, labels, mask)

            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.crf_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    def train_crf_ae_single(self, model, loader, optimizer, scheduler):
        """

        Args:
            model:
            loader:
            optimizer:
            scheduler:

        Returns:

        """
        model.train()
        bar = progress_bar(loader)
        for words, chars, _ in bar:
            optimizer.zero_grad()
            mask = words.ne(self.args.pad_index)
            encoder_emits = model(words, chars)
            # compute loss
            loss = model.loss(words, encoder_emits, mask)

            loss.backward()
            #
            nn.utils.clip_grad_norm_(model.parameters(), self.args.crf_clip)
            optimizer.step()
            scheduler.step()
            bar.set_postfix_str(
                f" lr: {scheduler.get_last_lr()[0]:.4e}, loss: {loss.item():.4f}"
            )

    @torch.no_grad()
    def evaluate_crf_ae(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()

        total_loss = 0
        metric = UnsupervisedPOSMetric(
            self.args.n_labels,
            self.args.device)

        sent_count = 0
        for words, chars, labels in loader:
            sent_count += len(words)
            mask = words.ne(self.args.pad_index)
            encoder_emits = model(words, chars)
            # compute loss
            loss = model.loss(words, encoder_emits, mask)
            # predict
            predicts = model.predict(words, encoder_emits, mask)
            metric(predicts=predicts[mask], golds=labels[mask])
            total_loss += loss.item()
        total_loss /= sent_count
        return total_loss, metric

    @torch.no_grad()
    def predict_crf_ae(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()

        preds = {}
        labels = []
        for words, chars, _ in progress_bar(loader):
            mask = words.ne(self.args.pad_index)
            lens = mask.sum(1).tolist()
            # ignore the first token of each sentence
            encoder_emits = model(words, chars)
            predicts = model.predict(words, encoder_emits, mask)
            labels.extend(predicts[mask].split(lens))

        labels = [[f"#C{t}#" for t in seq.tolist()] for seq in labels]
        preds = {'preds': labels}

        return preds
