# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math

import os
import numpy as np
import torch
import random
import logging
import itertools

from fairseq.data import LanguagePairDataset
from .noise_util import apply_span_mask, apply_random_mask, apply_entity_mask_for_mlm, apply_entity_mask_for_clm
from fairseq.data import data_utils
from fairseq.data.encoders.gpt2_bpe import get_encoder

logger = logging.getLogger(__name__)

bpe_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
bpe = get_encoder(bpe_dir + '/data/vocab/gpt2/encoder.json', bpe_dir + '/data/vocab/gpt2/vocab.bpe')


def collate(
        samples,
        pad_idx,
        eos_idx,
        left_pad_source=False,
        left_pad_target=False,
        input_feeding=True,
        pad_to_length=None,
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx, eos_idx, left_pad, move_eos_to_beginning,
            pad_to_length=pad_to_length,
        )

    # sort by descending source length
    src_lengths = torch.LongTensor([s['source'].ne(pad_idx).long().sum() for s in samples])
    src_lengths, sort_order = src_lengths.sort(descending=True)

    id = torch.LongTensor([s['id'] for s in samples]).index_select(0, sort_order)
    src_tokens = merge('source', left_pad=left_pad_source).index_select(0, sort_order)

    # causal language model
    prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target).index_select(0, sort_order)
    prev_output_positions = merge('prev_output_positions', left_pad=left_pad_target).index_select(0, sort_order)
    clm_target = merge('clm_target', left_pad=left_pad_target).index_select(0, sort_order)

    ntokens = src_lengths.sum().item()

    batch = {
        'id': id,
        'nsentences': len(samples),
        'ntokens': ntokens,
        'net_input': {
            'src_tokens': src_tokens,
            'src_lengths': src_lengths,
            'prev_output_tokens': prev_output_tokens,
            'prev_output_positions': prev_output_positions,
        },
        'clm_target': clm_target,
    }
    return batch


class DenosingLanguagePairDataset(LanguagePairDataset):
    """
    """

    @classmethod
    def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
        """Return the source and target datasets for masked LM training."""
        return cls(dataset, *args, **kwargs)

    def __init__(
            self, src, src_sizes, src_dict,
            tgt=None, tgt_sizes=None, tgt_dict=None,
            left_pad_source=True, left_pad_target=False,
            max_source_positions=1024, max_target_positions=1024,
            shuffle=True,
            mask_idx=None,
            mask_prob=0.15, leave_unmasked_prob=0.1, random_token_prob=0.1,
            mask_whole_words=None,
            block_size=64,
            sub_task=None,
            apply_mask=True,
            apply_decoder_mask=False,
            only_mask_entity_in_decoder=False,
            decoder_mask_prob=0.15
    ):
        super().__init__(src, src_sizes, src_dict,
                         tgt=tgt, tgt_sizes=tgt_sizes, tgt_dict=tgt_dict,
                         left_pad_source=left_pad_source, left_pad_target=left_pad_target,
                         shuffle=shuffle)
        self.mask_idx = mask_idx
        self.mask_prob = mask_prob
        self.apply_mask = apply_mask
        self.apply_decoder_mask = apply_decoder_mask
        self.only_mask_entity_in_decoder = only_mask_entity_in_decoder
        self.decoder_mask_prob = decoder_mask_prob

        self.sub_task = sub_task
        self.cls_pad = self.src_dict.pad()
        self.block_size = block_size
        self.max_source_positions = max_source_positions
        self.max_target_positions = max_target_positions
        self.replace_probs = torch.FloatTensor(
            [1 - leave_unmasked_prob - random_token_prob, leave_unmasked_prob, random_token_prob])
        self.debug_size_for_mlm = 0
        self.debug_size_for_clm = 0

    def _parse_kg_data(self, kg_item):
        """
        """
        kg_item_np = np.array(kg_item)
        sep_idx = np.where(kg_item_np == self.src_dict.eos())[0]
        sep_idx = sep_idx[:-1]  #
        assert len(sep_idx) % 2 == 0
        entity_pos = (sep_idx - np.arange(len(sep_idx))).reshape(-1, 2)
        src_item = kg_item[[i for i in range(len(kg_item)) if i not in sep_idx]]
        return src_item, entity_pos

    def _create_dummy_data(self, task, **kwargs):
        if task == 'mlm':
            mlm_target = torch.from_numpy(np.full(kwargs['src_sz'], self.src_dict.pad()))
            return mlm_target
        if task == 'clm':
            prev_output_positions = torch.LongTensor([1])
            prev_output_tokens = torch.from_numpy(np.full(1, 1))
            clm_target = torch.from_numpy(np.full(1, self.src_dict.pad()))
            return prev_output_positions, prev_output_tokens, clm_target

    def _entity_mask_pos(self, ent_pos, prev_output_positions):
        """
        # test data
        tensor([   0,    2, 1429,    2, 2156,    5,    2, 5639,    2,    8,   11,    5, 2,  953,  436, 3939,    2,  479,    2])
        ent_pos = np.array([[1,2], [4,5], [8,11]])
        prev_output_positions =  [1, 4, 8, 9, 10]
        return [2,3]
        """
        mask_pos = []
        for ent in ent_pos:
            mask_pos += list(range(ent[0], ent[1]))[:-1]
        dec_mask_pos = [i for i, pos in enumerate(prev_output_positions) if pos in mask_pos]
        return dec_mask_pos

    def _decode_as_token_str(self, tokens):
        return ' '.join((
                 [bpe.decode([int(self.src_dict[ii])]).strip() if self.src_dict[ii].isnumeric() else self.src_dict[ii] for ii in
                  list(tokens)]))

    def _if_print_log(self, print_num, src_sz):
        if print_num < 8:
            self.debug_size_for_clm += 1
            return True
        if print_num < 16 and src_sz > 30:
            self.debug_size_for_clm += 1
            return True
        return False

    def _get_example_for_clm(self, index, kg_item, apply_entity_mask=True):

        src_item, ent_pos = self._parse_kg_data(kg_item)
        assert src_item[0] == self.src_dict.bos()
        assert src_item[-1] == self.src_dict.eos()
        src_sz = len(src_item)

        # build data for CLM in Decoder
        clm_position_list = np.array(apply_entity_mask_for_clm(src_sz, ent_pos) if apply_entity_mask else [])
        prev_output_positions = clm_position_list
        prev_output_tokens = src_item[prev_output_positions - 1].clone()
        dec_mask_pos = []
        if self.apply_decoder_mask and len(src_item) > 6:
            if self.only_mask_entity_in_decoder:  # 1. entity <mask>
                dec_mask_pos = self._entity_mask_pos(ent_pos, prev_output_positions-1)
            else:  # 2. random <mask>
                dec_mask_pos = apply_random_mask(len(prev_output_tokens), self.decoder_mask_prob)

            if len(dec_mask_pos) > 0:
                prev_output_tokens[dec_mask_pos] = self.replace(prev_output_tokens[dec_mask_pos])  # decoder mask

        clm_target = src_item[prev_output_positions].clone()
        prev_output_positions = torch.LongTensor(prev_output_positions)

        if self._if_print_log(self.debug_size_for_clm, src_sz):
            logger.info('========= index: {} ==== CLM ====='.format(str(index)))
            logger.info('src: ' + self._decode_as_token_str(src_item))
            logger.info('src_entity: ' + ' | '.join(
                [bpe.decode(
                    [int(self.src_dict[src_item[ii]]) if ii < src_sz and self.src_dict[src_item[ii]].isnumeric() else ''
                     for ii in range(ent[0], ent[1])]) for ent
                 in ent_pos]))

        # apply encoder mask
        src_item[clm_position_list] = self.replace(src_item[clm_position_list])

        if self._if_print_log(self.debug_size_for_clm, src_sz):
            logger.info('decoder_mask_prob: ' + str(self.decoder_mask_prob))
            logger.info('encoder_input: ' + self._decode_as_token_str(src_item))
            logger.info('decoder_mask_position: ' + ' '.join([str(ii) for ii in prev_output_positions[dec_mask_pos].tolist()]))
            logger.info('decoder_input_position: ' + ' '.join([str(ii) for ii in prev_output_positions.tolist()]))
            logger.info('decoder_input: ' + self._decode_as_token_str(prev_output_tokens))
            logger.info('decoder_target: ' + self._decode_as_token_str(clm_target))

        if prev_output_tokens.numel() == 0:
            prev_output_positions, prev_output_tokens, clm_target = self._create_dummy_data('clm')

        example = {
            'id': index,
            'source': src_item,
            'clm_target': clm_target,
            'prev_output_tokens': prev_output_tokens,
            'prev_output_positions': prev_output_positions,
        }
        return example

    def __getitem__(self, index):
        """
        TODO: dynamic_span_length, dynamic_total_length
        """
        src_item = self.src[index]
        example = self._get_example_for_clm(index, src_item)  # clm
        return example

    def collater(self, samples):
        return collate(samples, self.src_dict.pad(), self.src_dict.eos())

    def replace(self, x):
        """ cannot sample n_sample <= 0 samples """
        if len(x) == 0:
            return x
        _x_real = x
        _x_rand = _x_real.clone().random_(self.src_dict.nspecial, len(self.src_dict))
        _x_mask = _x_real.clone().fill_(self.mask_idx)
        probs = torch.multinomial(self.replace_probs, len(x), replacement=True)
        _x = _x_mask * (probs == 0).long() + \
             _x_real * (probs == 1).long() + \
             _x_rand * (probs == 2).long()
        return _x
