# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
BiEncoder component + loss function for 'all-in-batch' training
"""

import collections
import logging
import random
import math
from typing import Tuple, List

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor as T
from torch import nn

from dpr.utils.data_utils import Tensorizer
from dpr.utils.data_utils import normalize_question

logger = logging.getLogger(__name__)

BiEncoderBatch = collections.namedtuple('BiENcoderInput',
                                        ['question_ids', 'context_ids',
                                         'is_positive', 'hard_negatives'])



def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T:
    """
    calculates q->ctx scores for every row in ctx_vector
    :param q_vector:
    :param ctx_vector:
    :return:
    """
    # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
    r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
    return r

class BiEncoder(nn.Module):
    """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
    """

    def __init__(self, args, question_model: nn.Module, ctx_model: nn.Module, fix_q_encoder: bool = False,
                 fix_ctx_encoder: bool = False):
        super(BiEncoder, self).__init__()
        self.question_model = question_model
        self.ctx_model = ctx_model
        self.fix_q_encoder = fix_q_encoder
        self.fix_ctx_encoder = fix_ctx_encoder
        self.d_model = 768
        self.args = args


    @staticmethod
    def get_representation(sub_model: nn.Module, ids: T, segments: T, attn_mask: T, fix_encoder: bool = False) -> (
    T, T, T):
        sequence_output = None
        pooled_output = None
        hidden_states = None
        inputs = {"input_ids": ids, "token_type_ids": segments, "attention_mask": attn_mask}
        if ids is not None:
            if fix_encoder:
                with torch.no_grad():
                    sequence_output, pooled_output, hidden_states = sub_model(**inputs)
                if sub_model.training:
                    sequence_output.requires_grad_(requires_grad=True)
                    pooled_output.requires_grad_(requires_grad=True)
            else:
                sequence_output, pooled_output, hidden_states = sub_model(**inputs)
        return sequence_output, pooled_output, hidden_states

    def forward(self, question_ids: T, context_ids: T) -> Tuple[T, T]:
        q_mask = question_ids != 0
        ctx_mask = context_ids != 0
        q_seg = torch.zeros_like(question_ids) if question_ids is not None else None
        ctx_seg = torch.zeros_like(context_ids) if context_ids is not None else None

        _q_seq, q_pooled_out, _q_hidden = self.get_representation(self.question_model, question_ids, q_seg,
                                                                  q_mask, self.fix_q_encoder)
        _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(self.ctx_model, context_ids, ctx_seg,
                                                                        ctx_mask, self.fix_ctx_encoder)

        q_rep = _q_seq[:, 0, :] if _q_seq is not None else None
        ctx_rep = _ctx_seq[:, 0, :] if _ctx_seq is not None else None

        return q_rep, ctx_rep


    @classmethod
    def create_biencoder_input(cls,
                               args,
                               samples: List,
                               tensorizer: Tensorizer,
                               insert_title: bool,
                               num_hard_negatives: int = 0,
                               num_other_negatives: int = 0,
                               shuffle: bool = True,
                               shuffle_positives: bool = False,
                               ) -> BiEncoderBatch:
        """
        Creates a batch of the biencoder training tuple.
        :param samples: list of data items (from json) to create the batch for
        :param tensorizer: components to create model input tensors from a text sequence
        :param insert_title: enables title insertion at the beginning of the context sequences
        :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
        :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
        :param shuffle: shuffles negative passages pools
        :param shuffle_positives: shuffles positive passages pools
        :return: BiEncoderBatch tuple
        """
        question_tensors = []
        ctx_tensors = []
        positive_ctx_indices = []
        hard_neg_ctx_indices = []

        q_max_len = args.max_q_len
        p_max_len = args.max_p_len

        for sample_i, sample in enumerate(samples):
            if shuffle and shuffle_positives:
                positive_ctxs = sample['positive_ctxs']
                positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
            else:
                positive_ctx = sample['positive_ctxs'][0]
            neg_ctxs = sample['negative_ctxs'] if 'negative_ctxs' in sample else []
            hard_neg_ctxs = sample['hard_negative_ctxs'] if 'hard_negative_ctxs' in sample else []
            if shuffle:
                random.shuffle(hard_neg_ctxs)
                random.shuffle(neg_ctxs)

            hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
            neg_ctxs = neg_ctxs[0:num_other_negatives]

            all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
            hard_negatives_start_idx = 1
            hard_negatives_end_idx = 1 + len(hard_neg_ctxs)
            current_ctxs_len = len(ctx_tensors)

            # input_ids

            sample_ctxs_tensors = []
            for ctx_i, ctx in enumerate(all_ctxs):
                ids = ctx["text_ids"][:p_max_len]
                ids = torch.tensor(ids[:p_max_len]).long()
                to_cat = torch.zeros([p_max_len - len(ids)]).long()
                ids = torch.cat([ids, to_cat], dim=0)
                sample_ctxs_tensors.append(ids)

            ctx_tensors.extend(sample_ctxs_tensors)

            positive_ctx_indices.append(current_ctxs_len)
            hard_neg_ctx_indices.append(
                [i for i in
                 range(current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx)])

        ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)

        q_ids = []
        for sample_ in samples:
            question = sample_["question"]
            ids = question["text_ids"]
            ids = torch.tensor(ids[:q_max_len]).long()
            to_cat = torch.zeros([q_max_len-len(ids)]).long()
            ids = torch.cat([ids, to_cat], dim = 0)
            q_ids.append(ids)
        questions_tensor = torch.cat([q.view(1, -1) for q in q_ids], dim=0)

        return BiEncoderBatch(questions_tensor, ctxs_tensor,
                              positive_ctx_indices, hard_neg_ctx_indices)


class BiEncoderNllLoss(object):
    def calc(self, q_vectors: T, ctx_vectors: T, positive_idx_per_question: list,
             hard_negative_idx_per_question: list = None) -> Tuple[T, int]:
        """
        Computes nll loss for the given lists of question and ctx vectors.
        Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
        loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
        :return: a tuple of loss value and amount of correct predictions per batch
        """

        scores = self.get_scores(q_vectors, ctx_vectors)
        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            scores = scores.view(q_num, -1)
        softmax_scores = F.log_softmax(scores, dim=1)
        loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device),
                          reduction='mean')
        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
        return loss, correct_predictions_count


    @staticmethod
    def get_scores(q_vector: T, ctx_vectors: T) -> T:
        f = BiEncoderNllLoss.get_similarity_function()
        return f(q_vector, ctx_vectors)

    @staticmethod
    def get_similarity_function():
        # return dot_product_scores
        return dot_product_scores



def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T:
    """
    calculates q->ctx scores for every row in ctx_vector
    :param q_vector:
    :param ctx_vector:
    :return:
    """
    r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
    return r


