import json
import logging
import os
import subprocess
from configargparse import Namespace

import torch
import numpy as np
from tqdm import tqdm
import time

from fairseq.data.iterators import EpochBatchIterator
from fairseq.data.data_utils import batch_by_size
from fairseq.data.language_pair_dataset import LanguagePairDataset
import rouge

logger = logging.getLogger(__name__)

def _parse_str_to_bool(args, key):
    # Assuming keys are specified in such a way that ommiting equal False
    return args.get(key) == 'True'

def _validate_summarization_with_rouge(params, out_save_file):
    param_list = params.split(',')
    parsed_args = dict([(p.split('=')[0], p.split('=')[1]) for p in param_list])
    use_stemming = _parse_str_to_bool(parsed_args, 'use_stemming')
    apply_avg = True
    apply_best = False
    evaluator = rouge.Rouge(metrics=['rouge-n'],
                            max_n=1,
                            limit_length=True,
                            length_limit=2000,
                            length_limit_type='words',
                            apply_avg=apply_avg,
                            apply_best=apply_best,
                            alpha=0.5,  # Default F1_score
                            weight_factor=1.2,
                            stemming=use_stemming)
    try:
        with open(parsed_args['gold']) as gold, open(out_save_file) as predicted:
            all_golds = gold.read().strip().split('\n')
            all_preds = predicted.read()
            if _parse_str_to_bool(parsed_args, 'remove_bpe'):
                all_preds = all_preds.replace(' ', '').replace('▁', ' ')
            all_preds = all_preds.strip().split('\n')
            if len(all_preds) < len(all_golds) and _parse_str_to_bool(parsed_args, 'shorter_ok'):
                all_golds = all_golds[:len(all_preds)]
            if _parse_str_to_bool(parsed_args, 'postprocess_repeats'):    # convert to set and back to naively remove repetitions
                all_preds = [' '.join(set(hypo.split(' '))) for hypo in all_preds]
            score = evaluator.get_scores(all_preds, all_golds)
        return score['rouge-1']['f']
    except ValueError as err:
        logger.warning(f'Training will be continued despite the fact that ROGUE metric failed'
                       f'with this Error:\n {err}\n')
        return 0.0


def _run_external_script_by_path(script_path, out_save_file):
    """Use external scorer called in a subprocess to get metric on already prepared files."""

    other_call_str = f'{script_path} {out_save_file}'
    try:
        process = subprocess.run(other_call_str, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                 encoding='utf-8', shell=True)
        output = float(str(process.stdout).strip())
        if process.stderr:
            logger.warning('external-validator stderr: %s', process.stderr)
    except ValueError as err:
        logger.warning(f'Subprocess cannot convert output to float when called with : '
                       f'{other_call_str}. Score will be set to 0.0 '
                       f'Raised error will be omitted: {err}')
        output = 0.0
    except Exception as err:
        logger.warning(f'Subprocess cannot run the validation script. '
                       f'Raised error will be omitted: {err}')
        output = 0.0
    return output


class ExternalValidator:
    """Class aimed at providing scores from external valiation tools or scripts on all validation
    subsets.

    Keep results in the validation subdirectory. """

    def __init__(self, trainer, task, args):
        self.model = trainer.model
        self.criterion = trainer.criterion
        self.args = args
        self.trainer = trainer
        self.task = task
        self.datasets = self._load_datasets(self.args.valid_subset)
        self.external_validator = self.args.external_validator
        self.output_dir = self.args.external_validator_out_dir
        self.gen_args = json.loads(getattr(args, 'external_validator_args', '{}') or '{}')
        os.makedirs(self.output_dir, exist_ok=True)
        self.summarization = self.args.task == 'translation'  # FIXME: temporary hack for development of rouge scorer

    def get_score(self, subset, epoch=0, num_updates=0):
        """Generate output to the file and score with external tool.
        Run model in evaluation mode."""
        dataset = self.datasets[subset]

        with UnsortedDataset(dataset):
            out_save_file = os.path.join(self.output_dir,
                                         f'out_{epoch}_{self.trainer.get_num_updates()}.tsv')
            logger.debug('Out save file: %s', out_save_file)

            itr = self._get_ordered_batch_iterator(dataset, epoch=epoch)

            with open(out_save_file, 'w+') as out_f_handle:
                times = np.array([])
                for el_i, sample in enumerate(tqdm(itr)):
                    pre_t = time.time()
                    sample_prediction = self._get_batch_predictions(sample)
                    post_t = time.time()
                    out_f_handle.write('\n'.join(sample_prediction) + '\n')
                    times = np.append(times, [post_t - pre_t])
                    if el_i % self.args.log_time_every_n_steps == 0:
                        time_logs = self.args.my_config.rstrip('.yaml') + '.log'
                        with open(time_logs, 'a') as tlog:
                            tlog.write(f'INFERENCE @ num_evaluated = {el_i}: '
                                       f'time_mean = {times[1:].mean()} | σ = {times[1:].std()}\n')
                    if el_i == self.args.max_update:
                        break

            times.dump(self.args.my_config.rstrip('.yaml') + '-infer.np')
            # evaluate with external installed tool(like e.g. arnold) or path
            if self.summarization:
                score = _validate_summarization_with_rouge(self.external_validator, out_save_file)
            else:
                score = _run_external_script_by_path(self.external_validator, out_save_file)

            logger.debug(f'Finished scoring with score {score} by ({self.external_validator}), '
                         f'output saved to {out_save_file}')

        return score

    def _get_ordered_batch_iterator(self, dataset, epoch=0):
        """Return a reusable, sharded iterator

        Create mini-batches with given size constraints. Do not shuffle.
        """
        max_tokens = self.args.max_tokens_valid
        max_sentences = self.args.max_sentences_valid
        required_batch_size_multiple = self.args.required_batch_size_multiple

        indices = np.arange(len(dataset))
        batch_sampler = batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        epoch_iter = EpochBatchIterator(
            dataset=dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
            seed=None,
            num_shards=1,
            shard_id=0,
            num_workers=0,
            epoch=epoch,
        )
        return epoch_iter.next_epoch_itr(shuffle=False)

    def _get_batch_predictions(self, sample):
        """Run inference on a batch, and return strings from best hypotheses"""
        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self.trainer._prepare_sample(sample)
            if sample is None:
                sample = self.trainer._prepare_sample(self.trainer._dummy_batch)

            generator = self.task.build_generator(Namespace(**self.gen_args))

            hypos = self.task.inference_step(generator, [self.model], sample)

            # convert to str
            best_hypo_str = list(map(lambda x: self.task.src_dict.string(x[0]['tokens']),
                                     hypos))

        return best_hypo_str

    def _load_datasets(self, subsets):
        """Load all validation subsets into dict"""
        return {split: self.task.dataset(split) for split in subsets.split(',')}


class UnsortedDataset:
    def __init__(self, dataset):
        self.dataset = dataset

    def __enter__(self):
        if isinstance(self.dataset, LanguagePairDataset):
            self.dataset.sort_examples = False
            self.dataset.shuffle = False
        return self.dataset

    def __exit__(self, type, value, traceback):
        if isinstance(self.dataset, LanguagePairDataset):
            self.dataset.sort_examples = True
            self.dataset.shuffle = True
