import os

from torch._C import dtype
from parser.cmds.cmd import CMD
from parser.models import CRFAE
from parser.utils.data import Dataset, Field, CoNLL
from parser.utils.logging import get_logger, progress_bar

import torch

import matplotlib.pyplot as plt
import seaborn as sns

logger = get_logger(__name__)


class Predict(CMD):
    def create_subparser(self, parser, mode):
        """
        create a subparser and add arguments in this subparser
        Args:
            parser: parent parser
            mode (str): mode of this subparser

        Returns:

        """
        subparser = parser.add_parser(mode, help='Train model.')
        subparser.add_argument('--data',
                               default='data/ptb/total.conll',
                               help='path to train file')
        return subparser

    def __call__(self, args):
        super(Predict, self).__call__(args)

        # create dir for files
        if not os.path.exists(args.save):
            os.mkdir(args.save)

        assert os.path.exists(args.fields)

        self.fields = torch.load(args.fields)
        (
            self.word_field, self.char_field
        ), self.label_field, self.pred_field = self.fields.FORM, self.fields.POS, self.fields.FEATS
        self.reper_field = Field('repers', is_placeholder=True)
        self.fields = CoNLL(FORM=(self.word_field, self.char_field),
                            POS=self.label_field,
                            FEATS=self.pred_field,
                            HEAD=self.reper_field)
        args.update({
            'n_words': self.word_field.vocab.n_init,
            'n_labels': len(self.label_field.vocab),
            'pad_index': self.word_field.pad_index,
            'unk_index': self.word_field.unk_index,
        })

        if not args.without_feature:
            args.update({
                'n_word_features':
                self.word_field.n_word_features,
                'n_unigram_features':
                self.word_field.n_unigram_features,
                'n_bigram_features':
                self.word_field.n_bigram_features,
                'n_trigram_features':
                self.word_field.n_trigram_features,
                'n_morphology_features':
                self.word_field.n_morphology_features
            })

        logger.info('\n' + str(args))

        data = Dataset(self.fields, args.data)
        data.build(args.batch_size, n_buckets=args.n_buckets)

        logger.info(f"Dateset {data}")

        crf_ae = CRFAE.load(args.crf_model).to(self.args.device)
        preds = self.predict_crf_ae(crf_ae, data.loader)
        for name, value in preds.items():
            setattr(data, name, value)
        self.fields.save(f"{args.crf_model}-{self.timestamp}.predict.conllx",
                         data.sentences)

    @torch.no_grad()
    def predict_crf_ae(self, model, loader):
        """

        Args:
            model:
            loader:

        Returns:

        """
        model.eval()

        preds = {}
        labels = []
        repers = []

        sns.set(style="white")

        # Set up the matplotlib figure
        f, ax = plt.subplots(figsize=(20, 20))

        rs = []
        ts = []

        for words, chars, tags in progress_bar(loader):
            mask = words.ne(self.args.pad_index)
            lens = mask.sum(1).tolist()
            # ignore the first token of each sentence
            encoder_emits, r = model(words, chars)
            predicts = model.predict(words, encoder_emits, mask)
            labels.extend(predicts[mask].split(lens))
            repers.extend(r[mask].split(lens))
            rs.append(r[mask].cpu())
            ts.extend([str(t) for t in predicts[mask].cpu().tolist()])

        rs = torch.cat(rs, dim=0)
        sns.scatterplot(
            x=rs[..., 0],
            y=rs[..., 1],
            hue=ts,
            ax=ax,
            alpha=0.1,
        )

        plt.margins(0, 0)
        plt.subplots_adjust(left=0.04, bottom=0.04, right=0.96, top=0.96)
        plt.savefig(f'.repr.png')
        plt.close()

        return preds
