import collections
import copy
import os
import tempfile
from abc import abstractmethod

import regex
import numpy as np 
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from tqdm import tqdm

from algo import mst
from data import SoftSequenceLengthSampler, SequenceLengthBatchSampler, SoftDatasetCollection
from models import Parser
from utils import SoftConfigs, Logger, Meter, get_header, get_rng_state, set_random_seed, set_rng_state, \
    SoftCrossEntropyLoss, normalize_row
from main import Executor as BaseExecutor

EPS = 1e-10

""" base class for trainer and evaluator """ 
class Executor(object):
    def __init__(self):
        self.configs = None

    # TODO(template): define data state
    def __data_state__(self):
        return self.dataset_collection.state_dict()

    @abstractmethod
    def __dataset_setup__(self):
        pass

    # TODO(template): setup data loaders
    def __dataloader_setup__(self):
        configs = self.configs
        self.dataloader = collections.defaultdict()
        for split in self.dataset_collection.splits:
            if split == 'train':
                sampler = SoftSequenceLengthSampler(self.dataset_collection[split])
                batch_sampler = SequenceLengthBatchSampler(sampler, configs.batch_size)
                self.dataloader[split] = DataLoader(
                    self.dataset_collection[split], batch_sampler=batch_sampler, collate_fn=self.dataset_collection._collator[split]
                )
            else:
                self.dataloader[split] = DataLoader(
                    self.dataset_collection[split], batch_size=16, collate_fn=self.dataset_collection._collator[split]
                )

    # TODO(template): setup models
    def __model_setup__(self):
        configs = self.configs
        self.model = Parser(configs)
        self.best_model = None
        self.best_performance = None

    # TODO(template): load state dict for the best model
    def __load_best_model__(self):
        if self.best_model is not None:
            self.model.load_state_dict(self.best_model)
        else:
            self.logger.warning('Loading from empty best model.')
        
    # TODO(template): forward batch implementation
    def __forward_batch__(self, sentences, masks):
        return self.model(sentences.to(self.configs.device), masks.to(self.configs.device))

    # TODO(template): dev metric computation
    def __evaluate_batch__(self, batch):
        sentences, masks, arcs, labels, info = batch
        with torch.no_grad():
            arc_scores, label_scores, word_masks = self.__forward_batch__(sentences, masks)
        word_masks = word_masks.cpu()
        word_masks[:, 0] = False
        # arcs
        trees = mst(arc_scores, word_masks)
        # labels
        n_examples = arc_scores.shape[0] * arc_scores.shape[1]
        label_scores = label_scores.view(n_examples, -1, label_scores.shape[-1])
        label_preds = label_scores[torch.arange(n_examples), (trees * word_masks).reshape(-1)].reshape(*labels.shape, -1).argmax(-1).cpu()
        info = copy.deepcopy(info)
        for i, sent in enumerate(info):
            for j, _ in enumerate(sent):
                info[i][j]['head'] = '_'
                info[i][j]['deprel'] = '_'
            if masks[i].sum() == 2:     # sentence too long, use trivial chain parsing
                for j, w in enumerate(sent):
                    if isinstance(w['id'], int):
                        info[i][j]['head'] = w['id'] - 1
                        info[i][j]['deprel'] = '_'
            else:
                for j, w in enumerate(sent):
                    if isinstance(w['id'], int):
                        info[i][j]['head'] = trees[i][w['id']].item()
                        info[i][j]['deprel'] = self.dataset_collection.collator.id2label.get(label_preds[i][w['id']].item(), '_')
        return info

    # TODO(template): standard evaluation procedure
    def evaluate(self, dataloader):
        self.model.eval()
        outputs = list()
        ground_truths = list()
        bar = tqdm(dataloader)
        bar.set_description('Evaluation:')
        for batch in bar:
            ground_truths.extend([tree.serialize() for tree in batch[-1]])
            batch_info = self.__evaluate_batch__(batch)
            outputs.extend([tree.serialize() for tree in batch_info])
        tmpdir = tempfile.TemporaryDirectory(prefix='ud-eval')
        with open(f'{tmpdir.name}/gt.trees', 'w', encoding=self.configs.encoding) as fout:
            print(''.join(ground_truths).strip() + '\n', file=fout)
            fout.close()
        with open(f'{tmpdir.name}/pr.trees', 'w', encoding=self.configs.encoding) as fout:
            print(''.join(outputs).strip() + '\n', file=fout)
            fout.close()
        try:
            os.system(f'python {self.configs.eval_script} {tmpdir.name}/gt.trees {tmpdir.name}/pr.trees -v > {tmpdir.name}/eval.res')
            info = ' '.join(open(f'{tmpdir.name}/eval.res').readlines()).replace('\n', '#')
            uas = regex.match('.*UAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
            las = regex.match('.*LAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
        except:
            os.system(f'python -m snippets.partial_eval {tmpdir.name}/gt.trees {tmpdir.name}/pr.trees -v > {tmpdir.name}/eval.res')
            info = ' '.join(open(f'{tmpdir.name}/eval.res').readlines()).replace('\n', '#')
            uas = regex.match('.*UAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
            las = regex.match('.*LAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
        tmpdir.cleanup()
        return float(las), float(uas)
    
    # TODO(template): standard final test procedure
    def test(self):
        self.__load_best_model__()
        las, uas = self.evaluate(self.dataloader['test'])
        self.logger.info(f'Test: LAS={las:.1f}, UAS={uas:.1f}')
        return las, uas


""" trainer: training a parser """
class Trainer(Executor):
    def __init__(self, configs):
        super(Trainer, self).__init__()
        self.configs = configs
        self.model_path = f'{configs.model_path}/model.ckpt'
        self.log_path = f'{configs.model_path}/train.log'
        os.system(f'mkdir -p {configs.model_path}')
        self.logger = Logger(self.log_path)
        self.logger.info(get_header())
        self.logger.info('Configs:' + repr(configs))
        self.__load_base_model__(configs.base_model_path)
        if os.path.exists(self.model_path):
            # TODO(F): check and update the dataset loading process
            self.__load__()
        else:
            self.__dataset_setup__(self.base_model.collator_state_dict)
            self.__dataloader_setup__()
            self.__model_setup__()
            self.__optimizer_setup__()
            if configs.start_model is not None:
                try:
                    state_dict = torch.load(f'{configs.start_model}/model.ckpt')
                    self.model.load_state_dict(state_dict['best_state']['best_model'])
                    self.logger.info(f'Loaded model from {configs.start_model}.')
                except:
                    self.logger.warning(f'Model initialization specified as {configs.start_model}, but not loaded.')
    
    # TODO(template): define current model state, include optimizer etc. 
    def __current_model_state__(self):
        return {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'epoch_id': self.epoch_id, 
            'rng_state': get_rng_state()
        }
    
    # TODO(template): define best model state
    def __best_model_state__(self):
        return {
            'best_model': self.best_model,
            'best_performance': self.best_performance
        }
    
    def __state_dict__(self):
        state_dict = {
            'configs': self.configs,
            'data_state': self.__data_state__(), 
            'current_state': self.__current_model_state__(), 
            'best_state': self.__best_model_state__()
        }
        return state_dict

    def __check_config__(self, configs):
        warn_flag = False
        for key in configs:
            if key not in self.configs or configs[key] != self.configs[key]:
                self.logger.warning(f'Different configs: {key} - (old) {configs.get(key, None)}; (new) {self.configs.get(key, None)}')
                warn_flag = True
        for key in self.configs:
            if key not in configs:
                self.logger.warning(f'Different configs: {key} - (old) {configs.get(key, None)}; (new) {self.configs.get(key, None)}')
                warn_flag = True
        if not warn_flag:
            self.logger.info('Config check passed.')

    # TODO(template): load state dict from pretrained model
    def __load__(self):
        state_dict = torch.load(self.model_path)
        # data 
        data_state = state_dict['data_state']
        self.dataset_collection = SoftDatasetCollection.from_state_dict(data_state)
        self.configs.n_labels = len(self.dataset_collection.collator.label_dict)
        self.__dataloader_setup__()
        # check configs 
        self.__check_config__(state_dict['configs'])
        # current state
        current_state = state_dict['current_state']
        self.__model_setup__()
        self.__optimizer_setup__()
        self.model.load_state_dict(current_state['model'])
        self.optimizer.load_state_dict(current_state['optimizer'])
        self.lr_scheduler.load_state_dict(current_state['lr_scheduler'])
        self.epoch_id = current_state['epoch_id'] 
        # best state 
        best_state = state_dict['best_state']
        self.best_model = best_state['best_model']
        self.best_performance = best_state['best_performance']
        # random state 
        set_rng_state(current_state['rng_state'])
    
    def __load_base_model__(self, model_path):
        self.base_model = BasePredictor(model_path)

    def __save__(self):
        torch.save(self.__state_dict__(), self.model_path)

    # TODO(template): setup datasets and data loaders
    def __dataset_setup__(self, collator_state_dict):
        configs = self.configs
        template = configs.data_path + '/*{split}*'
        self.dataset_collection = SoftDatasetCollection(
            template, configs.pretrain_name, collator_state_dict, configs.encoding, reverse_lang=configs.reverse_lang, 
            n_examples_train=configs.n_examples, shuffle=configs.shuffle
        )
        configs.n_labels = len(self.dataset_collection.collator.label_dict)

    # TODO(template): setup optimizers, learning rate schedulers and loss funcs
    def __optimizer_setup__(self):
        configs = self.configs
        self.params = filter(lambda x: x.requires_grad, self.model.parameters())
        self.optimizer = getattr(torch.optim, configs.optimizer)(self.params, lr=configs.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', configs.lr_reduce_ratio, configs.lr_patience)
        self.arc_loss = SoftCrossEntropyLoss()
        self.label_loss = SoftCrossEntropyLoss()
        self.epoch_id = 0
    
    # TODO(template): loss computation 
    def __loss__(self, arc_scores, label_scores, arcs, labels, arc_masks, is_logits=True):
        arc_loss = self.arc_loss(arc_scores, arcs, arc_masks, is_logits)
        label_loss = self.label_loss(label_scores, labels, arc_masks.unsqueeze(-1), is_logits)
        return arc_loss, label_loss
        
    # TODO(template): train for one epoch 
    def train(self):
        self.epoch_id += 1
        self.model.train()
        bar = tqdm(self.dataloader['train'])
        train_arc_loss, train_label_loss, train_reg_loss = Meter(), Meter(), Meter()
        for token_ids_s, masks_s, token_ids_p, masks_p, alignments in bar:
            # TODO(F): operating on the log domain
            with torch.no_grad():
                alignments[:, 0, 0] = 1
                arc_scores_s, label_scores_s, word_masks_s = self.base_model.__forward_batch__(token_ids_s, masks_s)
                arc_probs_s = arc_scores_s.softmax(-1)
                alignments = alignments.to(arc_scores_s.device).float()
                alignments_transpose = torch.einsum('bij->bji', alignments)
                masks_s = masks_s.to(alignments.device)
                masks_p = masks_p.to(alignments.device)
                alignments_p2s, masks_p2s = normalize_row(alignments_transpose, masks_p, masks_s)
                alignments_s2p, _ = normalize_row(alignments, masks_s, masks_p)
                p2p_probs_groundtruth = torch.einsum('bpi,bij,bjk->bpk', alignments_p2s, arc_probs_s, alignments_s2p)
                label_probs_s = label_scores_s.softmax(-1)
                label_probs_groundtruth = torch.einsum('bpi,bijd,bkj->bpkd', alignments_p2s, label_probs_s, alignments_p2s)
            if configs.base_regularizer > 0:
                arc_scores, label_scores, word_masks = self.__forward_batch__(token_ids_s, masks_s)
                word_masks_2d = torch.einsum('bi,bj->bij', word_masks, word_masks)
                word_masks_2d[:, 0, :] = 0
                reg_arc_loss, reg_label_loss = self.__loss__(arc_scores, label_scores, arc_scores_s, label_scores_s, word_masks_2d)
                regularizer_loss = reg_arc_loss + reg_label_loss
            else:
                regularizer_loss = torch.zeros(1).to(label_probs_groundtruth.device)
            arc_scores_p, label_scores_p, word_masks_p = self.__forward_batch__(token_ids_p, masks_p)
            arc_probs_p = arc_scores_p.softmax(-1)
            label_probs_p = label_scores_p.softmax(-1)
            arc_masks_p = torch.einsum('bi,bj->bij', word_masks_p, word_masks_p) & masks_p2s
            arc_loss, label_loss = self.__loss__(arc_probs_p + EPS, label_probs_p + EPS, p2p_probs_groundtruth, label_probs_groundtruth, arc_masks_p, False)
            # optimization
            self.optimizer.zero_grad()
            loss = arc_loss + label_loss + configs.base_regularizer * regularizer_loss
            loss.backward()
            clip_grad_norm_(self.params, self.configs.clip)
            self.optimizer.step()
            current_examples = arc_masks_p.any(1).sum()
            train_arc_loss.update(arc_loss.item() / current_examples, current_examples)
            train_label_loss.update(label_loss.item() / current_examples, current_examples)
            current_examples_s = word_masks_s.sum()
            train_reg_loss.update(regularizer_loss.item() / current_examples_s, current_examples_s)
            if not loss < 0 and not loss > 0:
                raise Exception('Loss NaN.')
            bar.set_description(
                f'Epoch {self.epoch_id}, accu. loss = {train_arc_loss.average + train_label_loss.average:.4f}, '
                f'accu. arc loss = {train_arc_loss.average:.4f}, accu. label loss = {train_label_loss.average:.4f}, '
                f'accu. regularizer loss = {train_reg_loss.average:.4f}'
            )
        # evaluate
        dev_loss = self.dataset_loss(self.dataloader['dev'])
        self.lr_scheduler.step(dev_loss)
        self.logger.info(f'Dev: Epoch {self.epoch_id}, loss={dev_loss:.2f}')
        if (self.best_performance is None) or (dev_loss < self.best_performance):
            self.best_performance = dev_loss
            self.best_model = copy.deepcopy(self.model.state_dict())
        self.__save__()

    def dataset_loss(self, dataloader):
        self.model.eval()
        dataset_loss = Meter()
        bar = tqdm(dataloader)
        bar.set_description('Evaluation:')
        for token_ids_s, masks_s, token_ids_p, masks_p, alignments in bar:
            with torch.no_grad():
                alignments[:, 0, 0] = 1
                arc_scores_s, label_scores_s, _ = self.base_model.__forward_batch__(token_ids_s, masks_s)
                arc_probs_s = arc_scores_s.softmax(-1)
                alignments = alignments.to(arc_scores_s.device).float()
                alignments_transpose = torch.einsum('bij->bji', alignments)
                masks_s = masks_s.to(alignments.device)
                masks_p = masks_p.to(alignments.device)
                alignments_p2s, masks_p2s = normalize_row(alignments_transpose, masks_p, masks_s)
                alignments_s2p, _ = normalize_row(alignments, masks_s, masks_p)
                p2p_probs_groundtruth = torch.einsum('bpi,bij,bjk->bpk', alignments_p2s, arc_probs_s, alignments_s2p)
                label_probs_s = label_scores_s.softmax(-1)
                label_probs_groundtruth = torch.einsum('bpi,bijd,bkj->bpkd', alignments_p2s, label_probs_s, alignments_p2s)
                arc_scores_p, label_scores_p, word_masks_p = self.__forward_batch__(token_ids_p, masks_p)
                arc_probs_p = arc_scores_p.softmax(-1)
                label_probs_p = label_scores_p.softmax(-1)
                arc_masks_p = torch.einsum('bi,bj->bij', word_masks_p, word_masks_p) & masks_p2s
                arc_loss, label_loss = self.__loss__(arc_probs_p + EPS, label_probs_p + EPS, p2p_probs_groundtruth, label_probs_groundtruth, arc_masks_p, False)
                loss = arc_loss + label_loss
                current_examples = arc_masks_p.any(1).sum()
                dataset_loss.update(loss / current_examples, current_examples)

        return dataset_loss.average


""" evalutor: evaluate a trained model """
class BasePredictor(BaseExecutor):
    def __init__(self, model_path):
        super(BasePredictor, self).__init__()
        self.model_path = f'{model_path}/model.ckpt'
        self.log_path = f'{model_path}/predict-as-base-for-soft.log'
        self.logger = Logger(self.log_path)
        self.__load__()

    def __load__(self):
        self.state_dict = torch.load(self.model_path)
        # configs 
        self.collator_state_dict = self.state_dict['data_state']['collator_state_dict']
        self.configs = self.state_dict['configs']
        # model 
        self.__model_setup__()
        self.best_model = self.state_dict['best_state']['best_model']
        self.best_performance = self.state_dict['best_state']['best_performance']
        self.__load_best_model__()
        self.model.eval()


def train_eval(configs):
    set_random_seed(configs.seed)
    trainer = Trainer(configs)
    if trainer.epoch_id == configs.epochs:
        return
    while trainer.epoch_id < configs.epochs:
        trainer.train()
    return trainer.test()
    

# TODO(F): update evaluate and predict for softly-trained models
def evaluate(configs):
    evaluator = Evaluator(configs.model_path, configs.data_path)
    evaluator.test()


def predict(configs):
    predictor = Predictor(
        model_path=configs.model_path, 
        data_path=configs.data_path, 
        input_format=configs.input_format, 
        output_path=configs.output_path, 
        n_examples=configs.n_examples,
        encoding=configs.encoding
    )
    predictor.predict()


if __name__ == '__main__':
    configs = SoftConfigs.get_configs()
    if configs.mode == 'train':
        train_eval(configs)
    elif configs.mode == 'evaluate':
        evaluate(configs)
    elif configs.mode == 'predict':
        predict(configs)
    else:
        raise Exception(f'Mode {configs.mode} is not supported.')
