import math
from typing import NamedTuple, Optional, List

import torch
from torch import nn as nn, Tensor
from torch.nn import functional as F

from fairseq.models.pooler import Pooler
#from fairseq.models.skynet_transformer import SinkhornSelfAttention


def write_to_debug_file(f_handle, dictionary, doc_tokens, selected_indices):
    for bidx in range(doc_tokens.shape[0]):  # This is ugly again and I am sad too
        tokens = dictionary.string(doc_tokens[bidx].index_select(0, selected_indices[
            bidx].squeeze(1))) + '\n'
        f_handle.write(tokens)


def write_to_debug_maybe(handles, encoder, doc_tokens_padded, selected_indices):
    if handles['pool_debug_file'] and selected_indices is not None:
        write_to_debug_file(handles['pool_debug_file'],
                                                   encoder.dictionary,
                                                   doc_tokens_padded, selected_indices)


def fold_back_maybe(pooler, args, encoder_out_doc, old_bs, need_attn):
    """Documents are already chunked, put them back to the old shape (from before pooling)"""
    attn_matrix = None
    is_blockwise = getattr(pooler.args, 'use_sparse_attn', 'none') == 'only_blockwise'
    stop_blockwise = getattr(pooler.args, 'use_sparse_attn', 'none') == 'pooler_no_block'   # 'pooler_no_block' TURNS OFF BLOCKWISE ATTN
    stop_blockwise |= getattr(pooler.args, 'use_sparse_attn', 'none') == 'linformer'
    #stop_blockwise |= getattr(pooler.args, 'encoder_pooling_arch', 'none') == 'power_bert'
    stop_blockwise |= getattr(pooler.args, 'use_sparse_attn', 'none') == 'linformer_no_cross_decoder'
    if (not pooler.is_lambda and not stop_blockwise) or is_blockwise:
        # combine chunks back to original batches,
        # but with shorter length (pooled in the 1st dim)
        # this can be achieved by tiling the tensors
        # TODO: fix pooler, so that transpose is not needed
        hidden_dim = encoder_out_doc[0].shape[2]
        enc = encoder_out_doc[0].shape
        assert enc[0] == args.chunk_size
        assert enc[1] % old_bs == 0
        assert enc[2] == args.encoder_embed_dim

        # The `encoder_out` should return exactly what does this loopy solution:
        # gold_truth = torch.stack([torch.cat([encoder_out_doc[0].transpose(0, 1)[i+j]
        #                                      for j in range(0, enc[1]//old_bs)])
        #                           for i in range(0, enc[1], enc[1]//old_bs)]).transpose(0, 1)
        # assert (encoder_out == gold_truth).all()

        encoder_out = encoder_out_doc.encoder_out \
            .transpose(0, 1) \
            .reshape([old_bs, -1, hidden_dim]) \
            .transpose(0, 1)
        # columnwise sum of attn_matrix
        if need_attn:
            if args.encoder_pooling_arch == 'attn':
                choosen_head = 0  # head to score from
                attn_matrix = encoder_out_doc.encoder_last_attn_matrix[:, choosen_head, :, :] \
                    .sum(dim=1).reshape([old_bs, -1])
            elif args.encoder_pooling_arch == 'power_bert':
                # provide scores
                attn_matrix = encoder_out_doc.encoder_last_attn_matrix.sum([1, 2]).reshape([old_bs, -1])
        else:
            attn_matrix = None
        encoder_out_doc = EncoderOut(
            encoder_out=encoder_out,
            encoder_padding_mask=None,  # Look away,
            encoder_embedding=None,  # it is not altering the recorded data
            encoder_states=None,
            encoder_last_attn_matrix=attn_matrix,
        )  # to fit a preconceived version of universal functionality.

    return encoder_out_doc, attn_matrix


def unfold_maybe(pooler, chunk_size, encoder, doc_lengths, doc_tokens):
    """Prepare docs for encoding (split to 'chunks')"""

    doc_tokens_unfolded = doc_tokens
    old_bs = doc_tokens.shape[0]
    doc_tokens_padded = None
    chunks_num = 1
    is_blockwise = getattr(pooler.args, 'use_sparse_attn', 'none') == 'only_blockwise'
    stop_blockwise = getattr(pooler.args, 'use_sparse_attn', 'none') == 'pooler_no_block'   # 'pooler_no_block' TURNS OFF BLOCKWISE ATTN
    stop_blockwise |= getattr(pooler.args, 'use_sparse_attn', 'none') == 'linformer'
    #stop_blockwise |= getattr(pooler.args, 'encoder_pooling_arch', 'none') == 'power_bert'
    stop_blockwise |= getattr(pooler.args, 'use_sparse_attn', 'none') == 'linformer_no_cross_decoder'
    if (not pooler.is_lambda and not stop_blockwise) or is_blockwise:
        chunks_num = math.ceil(doc_tokens.shape[1] / chunk_size)    # take current batch-document shape

        pad_len = (chunks_num * chunk_size) - doc_tokens.shape[1]
        doc_padder = torch.zeros((old_bs, pad_len), dtype=doc_tokens.dtype,
                                 device=doc_tokens.device).fill_(
            encoder.dictionary.pad_index)
        doc_tokens_padded = torch.cat((doc_tokens, doc_padder,), dim=1)
        doc_tokens_unfolded = doc_tokens_padded.unfold(1, chunk_size,
                                                       chunk_size).flatten(0, 1)

        doc_lengths = (doc_tokens_unfolded != 1).sum(axis=1)
        new_bsz = doc_tokens_unfolded.shape[0]
        ind = (torch.arange(new_bsz, dtype=torch.long, device=doc_lengths.device),
               doc_lengths - 1,)

        doc_tokens_unfolded.index_put_(ind,
                                       torch.full((new_bsz,),
                                                  encoder.dictionary.eos_index,
                                                  device=doc_lengths.device,
                                                  dtype=torch.long))

    return doc_lengths, doc_tokens_padded, doc_tokens_unfolded, old_bs, chunks_num


class SkynetPooler(Pooler):
    """
    Token Pooler.

    Args:
        args (configargparse.Namespace): parsed command-line arguments

    """

    def __init__(self, args):
        super().__init__()
        self.args = args
        self._prepare_pooler()

    def _prepare_pooler(self):
        if self.args.encoder_pooling != 'lambda':
            self._set_scorer_architecture()

            self._set_softselector_method()

            self._set_additional_positional_embeddings()
        else:
            self.scorer = None
            self.bias = None

    def _set_additional_positional_embeddings(self):
        if hasattr(self.args, 'pool_use_bias'):
            self.bias = nn.Parameter(torch.ones((1, self.args.max_source_positions, 1)))
            print(f'Bias is not supported anymore (thank god!)')
        else:
            self.bias = None

    def _set_softselector_method(self):
        if self.args.encoder_pooling == 'regression':
            self.regressor = nn.Linear(self.args.max_source_positions, self.args.pooled_length)
        elif self.args.encoder_pooling == 'nonet':
            from fairseq.models.no_net import NoNet, PoolerConfig
            self.selector = NoNet()
            self.pooler_config = PoolerConfig(input_len=self.args.max_source_positions,
                                              pooled_len=self.args.pooled_length,
                                  chances_in_first_round=1, softmax_n=self.args.softmax_n,
                                  speed_up_const=self.args.speed_up_const,
                                  flip_right=self.args.flip_right,
                                  pair_idx_fn='get_topk_pair_idx',
                                  depth=self.args.apply_hierarchical_pooling,
                                  iterative=self.args.pooler_iterative_topk,
                                  base=20,
                                  hard_topk_inference=self.args.hard_topk_inference,
                                              )
            # If there is one pooling, I set it here. For hierarchical, it will be set layer-wise.
            if not self.args.apply_hierarchical_pooling:
                self.selector.set_config(self.pooler_config)

    def _set_scorer_architecture(self):
        if self.args.encoder_pooling_arch == 'linear':
            if self.args.pool_reuse_scorer:
                self.scorer = nn.Linear(self.args.encoder_embed_dim, 1)
            else:
                self.scorer = nn.ModuleList([nn.Linear(self.args.encoder_embed_dim, 1)
                               for el in range(0, self.args.encoder_layers)])
        elif self.args.encoder_pooling_arch == 'ffn':
            self.scorer = nn.ModuleList([FFN(self.args.encoder_embed_dim)
                                         for el in range(0, self.args.encoder_layers)])
        elif self.args.encoder_pooling_arch == 'attn':
            self.scorer = 'attn'
        elif self.args.encoder_pooling_arch == 'kth':
            self.scorer = 'kth'
        elif self.args.encoder_pooling_arch == 'random':
            self.scorer = 'random'
        elif self.args.encoder_pooling_arch == 'zeroth_dim':
            self.scorer = 'zeroth_dim'
        elif self.args.encoder_pooling_arch == 'power_bert':
            self.scorer = 'power_bert'
        elif self.args.encoder_pooling_arch == 'mean_pool':
            self.scorer = 'mean_pool'
        elif self.args.encoder_pooling_arch == 'max_pool':
            self.scorer = 'max_pool'
        else:
            self.scorer = None

    def forward(self, encoded_out, src_lengths=None, cut_embedding=None, attn_scores=None,
                pooled_length=None, layer_i=-1, **kwargs):
        """
        Args:
            encoded_tokens (FloatTensor): encoded tokens in the source language of shape
                `(batch, src_len, emb_dim)`
            src_lengths (LongTensor): lengths of each source sentence of shape
                `(batch)`
        """

        if self.args.encoder_pooling == 'lambda':
            return encoded_out, None
        else:

            encoded_tokens = encoded_out.encoder_out.permute(1, 0, 2)
            bs, s, e = encoded_tokens.shape
            m = self.args.max_source_positions
            if pooled_length is None:
                pooled_length = self.args.pooled_length
            if pooled_length == s:
                return encoded_out, None

            regression = self.args.encoder_pooling == 'regression'

            if self.scorer == 'attn' and attn_scores is not None:
                # Previously I took column-wise sum over one attention head
                token_logits = attn_scores.unsqueeze(2)   # Batch x Pre-pooled len x 1

            elif self.scorer == 'power_bert':
                """prune here based on attention"""
                assert len(encoded_out.encoder_last_attn_matrix.shape) == 2
                token_logits = encoded_out.encoder_last_attn_matrix #.sum(dim=[1, 2])

                pad_size = list(token_logits.shape)
                pooler_input_len = self.pooler_config.input_len
                pad_size[-1] = pooler_input_len - token_logits.shape[-1]

                token_logits = torch.cat(
                    (token_logits, -100 * torch.ones(pad_size).to(encoded_tokens.device)),
                    dim=-1).to(encoded_tokens.device)

                topks = token_logits.topk(pooled_length)
                encoded_tokens = torch.cat(
                    (encoded_tokens, torch.zeros(*pad_size, e).to(encoded_out.encoder_out.device)),
                    dim=1)

                pooled_output = torch.stack(
                    [encoded_tokens[i][topks.indices[i].to(encoded_tokens.device)] for i in
                     range(bs)]).to(encoded_tokens.device)

                assert pooled_output.shape[0] == bs
                assert not torch.isnan(pooled_output).any()
                return (EncoderOut(encoder_out=pooled_output.permute(1, 0, 2),
                                   encoder_padding_mask=None,
                                   # So-called "experimental code"
                                   encoder_embedding=None,
                                   encoder_states=None,
                                   encoder_last_attn_matrix=None,), None)

            elif self.scorer == 'kth':
                # take every kth embedding
                k_step_size = s / pooled_length
                didx = torch.arange(0, s, k_step_size, device=encoded_tokens.device)

            elif self.scorer == 'random':
                # randomly generate scores
                # didx = (torch.rand(pooled_length, device=encoded_tokens.device,
                #                    dtype=encoded_tokens.dtype) * s).sort().values
                didx = torch.randperm(s, device=encoded_tokens.device)[:pooled_length]

            elif self.scorer == 'zeroth_dim':
                # get values from a specified hardcoded dim (0-th caused problems actually)
                encoded_tokens[encoded_tokens != encoded_tokens] = 0
                token_logits = encoded_tokens[:, :, 200].unsqueeze(2)

            elif self.scorer == 'mean_pool' or self.scorer == 'max_pool':
                """Implements typical pooling operations"""
                pooler_input_len = self.pooler_config.input_len
                compress_factor = pooler_input_len//pooled_length
                shp = encoded_out.encoder_out.shape
                pool_result = encoded_out.encoder_out.reshape(
                    [shp[0] // compress_factor, compress_factor, shp[1], shp[2]])
                if self.scorer == 'mean_pool':
                    pool_result = pool_result.mean(dim=1)
                else:
                    pool_result = pool_result.max(dim=1).values
                assert len(pool_result.shape) == 3
                return (EncoderOut(encoder_out=pool_result,
                                   encoder_padding_mask=None,
                                   # So-called "experimental code"
                                   encoder_embedding=None,
                                   encoder_states=None,
                                   encoder_last_attn_matrix=None,), None)

            else:
                if self.args.pool_reuse_scorer:
                    token_logits = self.scorer(encoded_tokens)
                else:
                    assert layer_i >= 0 and isinstance(self.scorer, nn.ModuleList)
                    token_logits = self.scorer[layer_i](encoded_tokens)

            # I can finish faster for index-based poolers, without calling nonet
            # (nonet provides derivatives for score-based methods)
            if self.scorer == 'kth' or self.scorer == 'random':
                if pooled_length < encoded_tokens.shape[1]:
                    pooled_output = encoded_tokens.index_select(1, didx.to(torch.long))
                else:
                    pooled_output = encoded_tokens
                assert pooled_output.shape[0] == bs
                assert not torch.isnan(pooled_output).any()
                # if not pooled_output.shape[1] == self.pooler_config.pooled_len:
                #     assert pooled_output.shape[1] == self.pooler_config.pooled_len
                return (EncoderOut(encoder_out=pooled_output.permute(1, 0, 2),
                                   encoder_padding_mask=None,
                                   # So-called "experimental code"
                                   encoder_embedding=None,
                                   encoder_states=None,
                                   encoder_last_attn_matrix=None,), None)

            assert not torch.isnan(token_logits).any()
            assert token_logits.shape[0] == src_lengths.shape[0]
            assert len(token_logits.shape) == 3
            assert token_logits.shape[-1] == 1

            if self.bias is not None:
                pass#token_logits += self.bias[:, :s]

            for sent, slen in zip(token_logits, src_lengths):
                sent[slen:] = -10000

            pooled_length = min(pooled_length, s)

            if self.args.encoder_pooling == 'nonet':
                pooled_output = self.selector(encoded_tokens, torch.sigmoid(token_logits) + 0.00001)
                assert pooled_output.shape[0] == bs
                assert not torch.isnan(pooled_output).any()
                assert pooled_output.shape[1] == self.pooler_config.pooled_len
                return (EncoderOut(encoder_out=pooled_output.permute(1, 0, 2),
                                   encoder_padding_mask=None,
                                   # So-called "experimental code"
                                   encoder_embedding=None,
                                   encoder_states=None,
                                   encoder_last_attn_matrix=None,), None)

            if self.args.encoder_pooling in ('topk', 'threshold'):      # hard topk, legacy
                # Calculate indexes with top values
                selected_indices = torch.topk(token_logits, k=pooled_length, dim=1)[1]

            elif regression:
                padded_token_logits = torch.zeros((bs, m, 1), device=token_logits.device,
                                                  dtype=token_logits.dtype)
                padded_token_logits[:, :s] = token_logits
                selected_indices = self.regressor(padded_token_logits.permute(0, 2, 1)).permute(0,
                                                                                                2,
                                                                                                1)

            else:
                return None, None

            if self.args.encoder_pooling == 'threshold':
                candidate = (torch.nn.functional.sigmoid(token_logits) > 0.5).sum(dim=1).flatten()
                src_lengths = torch.min(src_lengths, candidate)
                m = token_logits.shape[1] - 1
                for bidx in range(bs):
                    selected_indices[bidx, candidate[bidx]:] = m

            if not regression:
                # Sort selected indices (to preserve order)
                selected_indices = selected_indices.sort(1)[0]

                # fast implementation of grid_sample trick
                norm_sel_ind = (2 * selected_indices.unsqueeze(0).to(
                    encoded_tokens.dtype) + 1) / s - 1
            else:
                norm_sel_ind = torch.tanh(selected_indices).unsqueeze(0)

            dummy_idx = torch.zeros_like(norm_sel_ind)

            indices = torch.cat((norm_sel_ind, dummy_idx), dim=3)
            # normalize to (-1 : 1)
            reshaped_enc_toks = encoded_tokens.unsqueeze(0).permute(0, 3, 2, 1)
            pooled_output = F.grid_sample(reshaped_enc_toks, indices, align_corners=True,
                                          mode='bilinear',
                                          padding_mode='border')

            pooled_output = pooled_output.squeeze(0).permute(2, 1, 0)
            encoder_padding_mask = torch.zeros((bs, pooled_length),
                                               device=encoded_tokens.device,
                                               dtype=torch.bool)

            for bidx in range(bs):  # This is ugly and I am sad
                # pooled_output[bidx] = encoded_tokens[bidx].index_select(0, selected_indices[
                #     bidx].squeeze(1))

                encoder_padding_mask[bidx][:selected_indices.shape[1]] = (
                        selected_indices[bidx] >= src_lengths[bidx]).squeeze()

            if cut_embedding is not None:
                cut_mask = (selected_indices - selected_indices.roll(1, dims=1)) != 1
                pooled_output += (cut_mask * cut_embedding.unsqueeze(0).unsqueeze(0))

            assert pooled_output.shape[1] == bs

            return (EncoderOut(encoder_out=pooled_output,
                               encoder_padding_mask=encoder_padding_mask,
                               # So-called "experimental code"
                               encoder_embedding=None,
                               encoder_states=None,
                               encoder_last_attn_matrix=None,), selected_indices)


EncoderOut = NamedTuple(
    "EncoderOut",
    [
        ("encoder_out", Tensor),  # T x B x C
        ("encoder_padding_mask", Tensor),  # B x T
        ("encoder_embedding", Tensor),  # B x T x C
        ("encoder_states", Optional[List[Tensor]]),  # List[T x B x C]
        ("encoder_last_attn_matrix", Optional[Tensor]),  # B x T x T
    ],
)


class FFN(nn.Module):
    def __init__(self, embed_dim, out_dim=1):
        super().__init__()
        self.dense = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.1)
        self.out_proj = nn.Linear(embed_dim, out_dim)

    def forward(self, x, **kwargs):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
