import torch.nn as nn
import torch.nn.functional as F
import torch
from functools import partial
#from dataclasses import dataclass


#@dataclass
class PoolerConfig:
    input_len: int = -1
    pooled_len: int = -1
    depth: int = 0
    chances_in_first_round: int = 1
    softmax_n: int = 2
    speed_up_const: int = 1.0
    flip_right: bool = True
    sort_back: bool = False
    pair_idx_fn: str = 'get_halving_pair_idx'
    iterative: int = 0
    base: int = 20
    hard_topk_inference = False

    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


class NoNet(nn.Module):
    """[mp] funny no-net implementation xD"""

    def __init__(self):
        super(NoNet, self).__init__()
        self.iterations_performed = 0

    def set_config(self, pooler_config):
        self.input_len = pooler_config.input_len
        self.pooled_len = pooler_config.pooled_len
        self.depth = pooler_config.depth if pooler_config.depth > 0 \
            else int(
            torch.log2(torch.tensor(self.input_len / self.pooled_len)
                       * pooler_config.chances_in_first_round))
        self.chances_in_first_round = pooler_config.chances_in_first_round
        self.softmax_n = pooler_config.softmax_n
        self.speed_up_const = pooler_config.speed_up_const
        self.flip_right = pooler_config.flip_right
        self.sort_back = pooler_config.sort_back

        if pooler_config.pair_idx_fn == 'get_halving_pair_idx':
            self.get_pair_idx = self.get_halving_pair_idx
        elif pooler_config.pair_idx_fn == 'get_topk_pair_idx':
            self.get_pair_idx = self.get_topk_pair_idx
        self.pair_idx_fn = pooler_config.pair_idx_fn
        self.iterative = pooler_config.iterative
        self.hard_topk_inference = pooler_config.hard_topk_inference

        self.base = 20          # HARDCODED BUT IM FINE WITH IT

    def forward(self, embs, scores):
        """
        embs: batch x input_len x emb_depth
        scores: batch x input_len x 1
        """
        new_embs = []
        new_scores = []

        # pad embeddings with zeros
        embs, scores = self.pad_to_input_len(embs, scores)
        if not self.training and self.hard_topk_inference:
            bs, s, e = embs.shape
            topks = torch.topk(scores, k=self.pooled_len)
            new_embs = torch.stack(
                [embs[i][torch.sort(torch.sort(topks.indices[0]).values).indices.to(embs.device)] for i in
                 range(bs)]).to(embs.device)
            return new_embs

        if self.iterative == 2:
            new_embs, new_scores = self.vectorized_iterative_topk(embs, scores)
            return new_embs  # , new_scores

        for batch_i in range(embs.shape[0]):
            embs_tmp, scores_tmp = self.forward_internal(embs[batch_i].unsqueeze(0),
                                                         scores[batch_i].unsqueeze(0))
            new_embs.append(embs_tmp)
            new_scores.append(scores_tmp)
        new_embs = torch.cat(new_embs, dim=0)
        new_scores = torch.cat(new_scores, dim=0)

        return new_embs  # , new_scores

    def pad_to_input_len(self, embs, scores):
        sh = list(embs.shape)
        sh[1] = self.input_len - sh[1]
        assert sh[1] >= 0, f"sh = {sh}, embs.shape={embs.shape}"
        emb_pad = torch.zeros(sh, dtype=embs.dtype, device=embs.device)
        embs = torch.cat((embs, emb_pad), dim=1)
        # pad scores with negative big score
        sh = list(scores.shape)
        sh[1] = self.input_len - sh[1]
        score_pad = torch.zeros(sh, dtype=scores.dtype, device=scores.device) + 0.00001
        scores = torch.cat((scores, score_pad), dim=1).squeeze(2)
        return embs, scores

    def forward_internal(self, embs, scores):
        if self.iterative:
            embs, scores = self.iterative_topk(embs, scores)
        else:
            embs, scores = self.our_topk(embs, scores)
        assert len(embs.shape) == 3 and embs.shape[0] == 1
        assert len(scores.shape) == 2 and scores.shape[0] == 1
        return embs, scores

    def iterative_topk(self, embs, scores):
        """Iterative approach to test as a baseline"""
        new_scores = []
        new_embs = []
        max_weights = []  # debug, and proving that this is not sharply defined
        alpha = self.iterative
        for i in range(self.pooled_len):
            miv = scores.max(dim=1)
            m = miv.values
            sqaured_dist = -(scores - m) ** 2
            weights = F.softmax(sqaured_dist * alpha, dim=1)
            ith_vec = (weights.permute(1, 0) * embs.squeeze(0)).sum(0)
            weighted_scores = (weights * scores.squeeze(0))
            ith_score = weighted_scores.sum()
            max_ith_weight = weights.max()
            new_embs.append(ith_vec)
            new_scores.append(ith_score)
            max_weights.append(max_ith_weight)
            scores[:, miv.indices] = -10000
        stacked_max_ith_weights = torch.stack(max_weights)         # look here to check how bad is this algo
        stacked_embs = torch.stack(new_embs).unsqueeze(0)
        stacked_scores = torch.stack(new_scores).unsqueeze(0)
        self.iterations_performed += 1
        if self.iterations_performed % 1000 == 0:
            print(f'Iterative topk\n: Cosine similarity is : '
                  f'{torch.cosine_similarity(stacked_embs[0, 0], stacked_embs[0, -1], dim=0)}')
            print(f'Maximal weight of a single vector is : {stacked_max_ith_weights.max()}\n')
        assert stacked_embs.shape[2] == embs.shape[2]
        assert stacked_embs.shape[1] == self.pooled_len
        return stacked_embs, stacked_scores

    def vectorized_iterative_topk(self, embs, scores):
        """Iterative approach to test as a baseline"""
        new_scores = []
        new_embs = []
        max_weights = []  # debug, and proving that this is not sharply defined
        alpha = self.iterative
        bs, tlen, hdim = embs.shape
        for i in range(self.pooled_len):
            miv = scores.max(dim=1)
            m = miv.values
            sqaured_dist = -(scores - m.unsqueeze(1)) ** 2
            weights = F.softmax(sqaured_dist * alpha, dim=1)
            ith_vec = (weights.unsqueeze(2) * embs).sum(1)
            weighted_scores = weights * scores
            ith_score = weighted_scores.sum(1)
            max_ith_weight = weights.max(1)
            new_embs.append(ith_vec)
            new_scores.append(ith_score)
            max_weights.append(max_ith_weight.values)
            scores[:, miv.indices] *= torch.ones(bs, bs, device=scores.device,
                                                 dtype=scores.dtype).fill_diagonal_(-10000)     # mask out highest values to minus big number

        stacked_max_ith_weights = torch.stack(max_weights)         # look here to check how bad is this algo
        stacked_embs = torch.stack(new_embs).permute(1, 0, 2)
        stacked_scores = torch.stack(new_scores).permute(1, 0)

        self.iterations_performed += 1
        if self.iterations_performed % 1000 == 0:
            print(f'Iterative topk: \n \t Cosine similarity is : '
                  f'{torch.cosine_similarity(stacked_embs[0, 0], stacked_embs[0, -1], dim=0)}')
            print(f'\t Maximal weight of a single vector is : {stacked_max_ith_weights.max()}\n')
        assert stacked_embs.shape[2] == embs.shape[2]
        assert stacked_embs.shape[1] == self.pooled_len
        return stacked_embs, stacked_scores

    def our_topk(self, embs, scores):
        """This is an implementation of our topk function"""
        e = embs.shape[2]
        s = partial(F.softmax, dim=1)
        l = torch.log2
        # debug only
        old_embs = embs.clone()
        old_scores = scores.clone()
        current_size = self.input_len
        target_size = (self.input_len // 2) * self.chances_in_first_round
        for layer in range(self.depth):

            if self.pair_idx_fn == 'get_topk_pair_idx':
                kwargs = {'scores': scores}
            pairs_idx = self.get_pair_idx(**kwargs)
            scores_before = scores.clone()
            scores_converged = scores[:, pairs_idx]
            if self.base > 0:
                aa = torch.pow(self.base, scores_converged)
                scores_converged = s(aa)
            else:
                if self.speed_up_const != 1.0:
                    for soft_i in range(self.softmax_n):
                        scores_converged = s(l(scores_converged * self.speed_up_const))
                else:
                    for soft_i in range(self.softmax_n):
                        scores_converged = s(l(scores_converged))

            # selected_scores = NoNet.batched_index_select(scores_before, 1, pairs_idx)
            scores = (scores_before[:, pairs_idx] * scores_converged).sum(dim=1)
            # TODO: the interpolation here, on embs, can be performed using radial coordinates
            embs = (embs[:, pairs_idx] * scores_converged.unsqueeze(3).expand(
                (1, 2, target_size, e))).sum(dim=1)
            if torch.isinf(scores).max() or torch.isnan(scores).max():
                print('SCORES ZBYT POTEZNE!')
            if torch.isinf(embs).max() or torch.isnan(embs).max():
                print('EMBEDDINGI ZBYT POTEZNE!')

            # De-sort back into chunk-positions
            if self.sort_back:
                scores = scores[:, pairs_idx[0].sort().indices]
                embs = embs[:, pairs_idx[0].sort().indices]

            current_size = target_size
            target_size = embs.shape[1] // 2

            if current_size < self.pooled_len:
                break
        return embs, scores

    def get_ordered_pairs_idx(self, current_size, target_size, **kwargs):
        """
        DEPRECATED
        return:
         pairs_idx: 2 x target_size"""
        chances = target_size // (current_size // 2)
        ordered_idx = torch.arange(current_size).repeat_interleave(chances)
        random_idx = torch.randperm(current_size * chances)
        pairs_idx = torch.stack((ordered_idx, random_idx), dim=0)
        return pairs_idx

    def get_halving_pair_idx(self, current_size, target_size, **kwargs):
        """
        DEPRECATED
        This is halving the number of inputs in each step. Randomly permuting the vector."""
        chances = target_size // (current_size // 2)
        random_idx = torch.randperm(current_size * chances) // chances
        l_half = len(random_idx) // 2
        pairs_idx = torch.stack((random_idx[:l_half],
                                 random_idx[l_half:]),
                                dim=0)
        return pairs_idx

    def get_topk_pair_idx(self, scores):
        """This is halving the number of inputs in each step.
        This keeps topk token in different sampling 'pool'"""
        sort_idx = scores.sort(descending=True).indices

        l_half = sort_idx.shape[-1] // 2
        left = sort_idx[:, :l_half]
        right = sort_idx[:, l_half:]
        if self.flip_right:
            right = torch.flip(right, dims=(1, 0))  # [MP] why dims doesn't matter?
        pairs_idx = torch.cat((left, right),
                              dim=0)
        return pairs_idx
