import collections
import copy
import os
import tempfile
from abc import abstractmethod

import regex
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from tqdm import tqdm

from algo import mst
from data import CoNLLDatasetCollection, SequenceLengthBatchSampler, SequenceLengthSampler
from models import Parser
from utils import Configs, Logger, Meter, get_header, get_rng_state, set_random_seed, set_rng_state

# TODO(F): 
#   1. (optional) add multi GPU support 


""" base class for trainer and evaluator """ 
class Executor(object):
    def __init__(self):
        self.configs = None

    # TODO(template): define data state
    def __data_state__(self):
        return self.dataset_collection.state_dict()

    @abstractmethod
    def __dataset_setup__(self):
        pass

    # TODO(template): setup data loaders
    def __dataloader_setup__(self):
        configs = self.configs
        self.dataloader = collections.defaultdict()
        for split in self.dataset_collection.splits:
            if split == 'train':
                sampler = SequenceLengthSampler(self.dataset_collection[split])
                batch_sampler = SequenceLengthBatchSampler(sampler, configs.batch_size)
                self.dataloader[split] = DataLoader(
                    self.dataset_collection[split], batch_sampler=batch_sampler, collate_fn=self.dataset_collection.collator
                )
            else:
                self.dataloader[split] = DataLoader(
                    self.dataset_collection[split], batch_size=16, collate_fn=self.dataset_collection.collator
                )

    # TODO(template): setup models
    def __model_setup__(self):
        configs = self.configs
        self.model = Parser(configs)
        self.best_model = None
        self.best_performance = None

    # TODO(template): load state dict for the best model
    def __load_best_model__(self):
        self.model.load_state_dict(self.best_model)
        
    # TODO(template): forward batch implementation
    def __forward_batch__(self, sentences, masks):
        return self.model(sentences.to(self.configs.device), masks.to(self.configs.device))

    # TODO(template): dev metric computation
    def __evaluate_batch__(self, batch):
        sentences, masks, arcs, labels, info = batch
        with torch.no_grad():
            arc_scores, label_scores, word_masks = self.__forward_batch__(sentences, masks)
        word_masks = word_masks.cpu()
        word_masks[:, 0] = False
        # arcs
        trees = mst(arc_scores, word_masks)
        # labels
        n_examples = arc_scores.shape[0] * arc_scores.shape[1]
        label_scores = label_scores.view(n_examples, -1, label_scores.shape[-1])
        label_preds = label_scores[torch.arange(n_examples), (trees * word_masks).reshape(-1)].reshape(*labels.shape, -1).argmax(-1).cpu()
        info = copy.deepcopy(info)
        for i, sent in enumerate(info):
            for j, _ in enumerate(sent):
                info[i][j]['head'] = '_'
                info[i][j]['deprel'] = '_'
            if masks[i].sum() == 2:     # sentence too long, use trivial chain parsing
                for j, w in enumerate(sent):
                    if isinstance(w['id'], int):
                        info[i][j]['head'] = w['id'] - 1
                        info[i][j]['deprel'] = '_'
            else:
                for j, w in enumerate(sent):
                    if isinstance(w['id'], int):
                        info[i][j]['head'] = trees[i][w['id']].item()
                        info[i][j]['deprel'] = self.dataset_collection.collator.id2label.get(label_preds[i][w['id']].item(), '_')
        return info

    # TODO(template): standard evaluation procedure
    def evaluate(self, dataloader):
        self.model.eval()
        outputs = list()
        ground_truths = list()
        bar = tqdm(dataloader)
        bar.set_description('Evaluation:')
        for batch in bar:
            ground_truths.extend([tree.serialize() for tree in batch[-1]])
            batch_info = self.__evaluate_batch__(batch)
            outputs.extend([tree.serialize() for tree in batch_info])
        tmpdir = tempfile.TemporaryDirectory(prefix='ud-eval')
        with open(f'{tmpdir.name}/gt.trees', 'w', encoding=self.configs.encoding) as fout:
            print(''.join(ground_truths).strip() + '\n', file=fout)
            fout.close()
        with open(f'{tmpdir.name}/pr.trees', 'w', encoding=self.configs.encoding) as fout:
            print(''.join(outputs).strip() + '\n', file=fout)
            fout.close()
        try:
            os.system(f'python {self.configs.eval_script} {tmpdir.name}/gt.trees {tmpdir.name}/pr.trees -v > {tmpdir.name}/eval.res')
            info = ' '.join(open(f'{tmpdir.name}/eval.res').readlines()).replace('\n', '#')
            uas = regex.match('.*UAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
            las = regex.match('.*LAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
        except:
            os.system(f'python -m snippets.partial_eval {tmpdir.name}/gt.trees {tmpdir.name}/pr.trees -v > {tmpdir.name}/eval.res')
            info = ' '.join(open(f'{tmpdir.name}/eval.res').readlines()).replace('\n', '#')
            uas = regex.match('.*UAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
            las = regex.match('.*LAS[^#]+[ \|]([0-9]+\.[0-9]+)[^0-9]*#', info).group(1)
        tmpdir.cleanup()
        return float(las), float(uas)
    
    # TODO(template): standard final test procedure
    def test(self, eval_script=None):
        if eval_script is not None:
            self.configs.eval_script = eval_script
        self.__load_best_model__()
        las, uas = self.evaluate(self.dataloader['test'])
        self.logger.info(f'Test: LAS={las:.1f}, UAS={uas:.1f}')
        print(f'Test: LAS={las:.1f}, UAS={uas:.1f}', flush=True)
        return las, uas


""" trainer: training a parser """
class Trainer(Executor):
    def __init__(self, configs):
        super(Trainer, self).__init__()
        self.configs = configs
        self.model_path = f'{configs.model_path}/model.ckpt'
        self.log_path = f'{configs.model_path}/train.log'
        os.system(f'mkdir -p {configs.model_path}')
        self.logger = Logger(self.log_path)
        self.logger.info(get_header())
        self.logger.info('Configs:' + repr(configs))
        if os.path.exists(self.model_path):
            self.__load__()
        else:
            self.__dataset_setup__()
            self.__dataloader_setup__()
            self.__model_setup__()
            self.__optimizer_setup__()
            if configs.start_model is not None:
                try:
                    state_dict = torch.load(f'{configs.start_model}/model.ckpt')
                    self.model.load_state_dict(state_dict['best_state']['best_model'])
                    self.logger.info(f'Loaded model from {configs.start_model}.')
                except:
                    self.logger.warning(f'Model initialization specified as {configs.start_model}, but not loaded.')
                    from IPython import embed; embed(using=False)
    
    # TODO(template): define current model state, include optimizer etc. 
    def __current_model_state__(self):
        return {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'epoch_id': self.epoch_id, 
            'rng_state': get_rng_state()
        }
    
    # TODO(template): define best model state
    def __best_model_state__(self):
        return {
            'best_model': self.best_model,
            'best_performance': self.best_performance
        }
    
    def __state_dict__(self):
        state_dict = {
            'configs': self.configs,
            'data_state': self.__data_state__(), 
            'current_state': self.__current_model_state__(), 
            'best_state': self.__best_model_state__()
        }
        return state_dict

    def __check_config__(self, configs):
        warn_flag = False
        for key in configs:
            if key not in self.configs or configs[key] != self.configs[key]:
                self.logger.warning(f'Different configs: {key} - (old) {configs.get(key, None)}; (new) {self.configs.get(key, None)}')
                warn_flag = True
        for key in self.configs:
            if key not in configs:
                self.logger.warning(f'Different configs: {key} - (old) {configs.get(key, None)}; (new) {self.configs.get(key, None)}')
                warn_flag = True
        if not warn_flag:
            self.logger.info('Config check passed.')

    # TODO(template): load state dict from pretrained model
    def __load__(self):
        state_dict = torch.load(self.model_path)
        # data 
        data_state = state_dict['data_state']
        self.dataset_collection = CoNLLDatasetCollection.from_state_dict(data_state)
        self.configs.n_labels = len(self.dataset_collection.collator.label_dict)
        self.__dataloader_setup__()
        # check configs 
        self.__check_config__(state_dict['configs'])
        # current state 
        current_state = state_dict['current_state']
        self.__model_setup__()
        self.__optimizer_setup__()
        self.model.load_state_dict(current_state['model'])
        self.optimizer.load_state_dict(current_state['optimizer'])
        self.lr_scheduler.load_state_dict(current_state['lr_scheduler'])
        self.epoch_id = current_state['epoch_id'] 
        # best state 
        best_state = state_dict['best_state']
        self.best_model = best_state['best_model']
        self.best_performance = best_state['best_performance']
        # random state 
        set_rng_state(current_state['rng_state'])

    def __save__(self):
        torch.save(self.__state_dict__(), self.model_path)

    # TODO(template): setup datasets and data loaders
    def __dataset_setup__(self):
        configs = self.configs
        template = configs.data_path + '/*{split}.conll*'
        self.dataset_collection = CoNLLDatasetCollection(
            template, configs.pretrain_name, configs.label_dict, configs.encoding, configs.position_pad_id, configs.label_pad_id, 
            n_examples=configs.n_examples, shuffle=configs.shuffle
        )
        configs.n_labels = len(self.dataset_collection.collator.label_dict)

    # TODO(template): setup optimizers, learning rate schedulers and loss funcs
    def __optimizer_setup__(self):
        configs = self.configs
        self.params = filter(lambda x: x.requires_grad, self.model.parameters())
        self.optimizer = getattr(torch.optim, configs.optimizer)(self.params, lr=configs.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'max', configs.lr_reduce_ratio, configs.lr_patience)
        self.arc_loss = torch.nn.CrossEntropyLoss(ignore_index=self.dataset_collection.collator.position_pad_id)
        self.label_loss = torch.nn.CrossEntropyLoss(ignore_index=self.dataset_collection.collator.label_pad_id)
        self.epoch_id = 0
    
    # TODO(template): loss computation 
    def __loss__(self, arc_scores, label_scores, arcs, labels):
        arc_loss = self.arc_loss(arc_scores.view(-1, arc_scores.shape[-1]), arcs.to(self.configs.device).view(-1).long())
        n_examples = arcs.view(-1).shape[0]
        label_scores = label_scores.view(n_examples, -1, label_scores.shape[-1])
        label_loss = self.label_loss(label_scores[torch.arange(n_examples), arcs.view(-1)], labels.to(self.configs.device).view(-1).long())
        return arc_loss, label_loss
        
    # TODO(template): train for one epoch 
    def train(self):
        self.epoch_id += 1
        self.model.train()
        bar = tqdm(self.dataloader['train'])
        train_arc_loss, train_label_loss = Meter(), Meter()
        for sentences, masks, arcs, labels, _ in bar:
            arc_scores, label_scores, word_masks = self.__forward_batch__(sentences, masks)
            self.optimizer.zero_grad()
            arc_loss, label_loss = self.__loss__(arc_scores, label_scores, arcs, labels)
            loss = arc_loss + label_loss
            loss.backward()
            clip_grad_norm_(self.params, self.configs.clip)
            self.optimizer.step()
            current_examples = word_masks.sum().item()
            if not arc_loss < 0 and not arc_loss > 0:
                raise Exception('Arc loss NaN.')
            train_arc_loss.update(arc_loss.item(), current_examples)
            train_label_loss.update(label_loss.item(), current_examples)
            bar.set_description(
                f'Epoch {self.epoch_id}, accu. loss = {train_arc_loss.average + train_label_loss.average:.4f}, '
                f'accu. arc loss = {train_arc_loss.average:.4f}, accu. label loss = {train_label_loss.average:.4f}'
            )
        las, uas = self.evaluate(self.dataloader['dev'])
        self.lr_scheduler.step(las)
        self.logger.info(f'Dev: Epoch {self.epoch_id}, LAS={las:.1f}, UAS={uas:.1f}')
        if (self.best_performance is None) or (las > self.best_performance):
            self.best_performance = las
            self.best_model = copy.deepcopy(self.model.state_dict())
        self.__save__()


""" evalutor: evaluate a trained model """
class Evaluator(Executor):
    def __init__(self, model_path, data_path):
        super(Evaluator, self).__init__()
        self.model_path = f'{model_path}/model.ckpt'
        self.data_path = data_path
        self.log_path = f'{model_path}/evaluate.log'
        self.logger = Logger(self.log_path)
        self.logger.info(self.data_path)
        self.__load__()

    def __load__(self):
        self.state_dict = torch.load(self.model_path)
        # configs 
        self.configs = self.state_dict['configs']
        # datasets, data loaders
        self.__dataset_setup__()
        self.__dataloader_setup__()
        # model 
        self.__model_setup__()
        self.best_model = self.state_dict['best_state']['best_model']
        self.best_performance = self.state_dict['best_state']['best_performance']

    def __dataset_setup__(self):
        data_state = self.state_dict['data_state'] 
        data_state['path_template'] = self.data_path 
        data_state['mode'] = 'evaluate'
        self.dataset_collection = CoNLLDatasetCollection.from_state_dict(data_state)


""" predictor: predict the parse tree given text """
class Predictor(Evaluator):
    def __init__(self, model_path, data_path, input_format, output_path, n_examples, encoding):
        self.model_path = f'{model_path}/model.ckpt'
        self.data_path = data_path
        self.input_format = input_format
        self.output_path = output_path
        self.n_examples = n_examples
        self.encoding = encoding 
        self.log_path = f'{model_path}/predict.log'
        self.logger = Logger(self.log_path)
        self.__load__()
    
    def __dataset_setup__(self):
        data_state = self.state_dict['data_state'] 
        data_state['path_template'] = self.data_path 
        data_state['mode'] = 'predict'
        data_state['input_format'] = self.input_format
        data_state['n_examples'] = self.n_examples
        self.dataset_collection = CoNLLDatasetCollection.from_state_dict(data_state)

    def predict(self):
        self.__load_best_model__()
        self.model.eval()
        outputs = list()
        for batch in tqdm(self.dataloader['test']):
            batch_info = self.__evaluate_batch__(batch)
            outputs.extend([tree.serialize() for tree in batch_info])
        with open(self.output_path, 'w', encoding=self.encoding) as fout:
            print(''.join(outputs).strip() + '\n', file=fout)
            fout.close()


def train_eval(configs):
    set_random_seed(configs.seed)
    trainer = Trainer(configs)
    if trainer.epoch_id == configs.epochs:
        return
    while trainer.epoch_id < configs.epochs:
        trainer.train()
    return trainer.test()
    

def evaluate(configs):
    eval_script = configs.eval_script
    evaluator = Evaluator(configs.model_path, configs.data_path)
    evaluator.test(eval_script)


def predict(configs):
    predictor = Predictor(
        model_path=configs.model_path, 
        data_path=configs.data_path, 
        input_format=configs.input_format, 
        output_path=configs.output_path, 
        n_examples=configs.n_examples,
        encoding=configs.encoding
    )
    predictor.predict()


if __name__ == '__main__':
    configs = Configs.get_configs()
    if configs.mode == 'train':
        train_eval(configs)
    elif configs.mode == 'evaluate':
        evaluate(configs)
    elif configs.mode == 'predict':
        predict(configs)
    else:
        raise Exception(f'Mode {configs.mode} is not supported.')
