# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from my_transformers.modeling_bert import (
    BertPreTrainedModel,
    BertConfig,
    BertModel,
)

from pytorch_transformers.tokenization_bert import BertTokenizer

from optimizer import get_bert_optimizer
from torch.autograd import Variable


def get_model_obj(model):
    model = model.module if hasattr(model, "module") else model
    return model

class BiEncoderModule(torch.nn.Module):
    def __init__(self, params):
        super(BiEncoderModule, self).__init__()
        conf = BertConfig.from_pretrained(params['bert_model'])
        conf.epsilon = params['epsilon']
        conf.layer_id = params['layer_id']

        self.context_bert = BertModel.from_pretrained(params['bert_model'], config=conf)
        self.cand_bert    = BertModel.from_pretrained(params['bert_model'], config=conf)

        self.config = self.context_bert.config

    def forward(
        self,
        token_idx_ctxt,
        segment_idx_ctxt,
        mask_ctxt,
        token_idx_cands,
        segment_idx_cands,
        mask_cands,
        if_turbulence=False,
    ):
        encoded_ctxt, encoded_cand = None, None
        if token_idx_ctxt is not None:
            encoded_ctxt, _ = self.context_bert(token_idx_ctxt, segment_idx_ctxt, mask_ctxt, if_turbulence=if_turbulence)
            encoded_ctxt = encoded_ctxt[:,0,:]

        if token_idx_cands is not None:
            encoded_cand, _ = self.cand_bert(token_idx_cands, segment_idx_cands, mask_cands, if_turbulence=if_turbulence)
            encoded_cand = encoded_cand[:,0,:]

        return encoded_ctxt, encoded_cand

class BiEncoderRanker(torch.nn.Module):
    def __init__(self, params, shared=None):
        super(BiEncoderRanker, self).__init__()
        self.params = params
        self.device = torch.device("cuda")
        self.n_gpu = torch.cuda.device_count()

        # init tokenizer
        self.NULL_IDX = 0

        # init model
        self.build_model()

        # to cuda, dataparallel
        self.model = self.model.to(self.device)

    def load_model(self, fname):
        state_dict = torch.load(fname)
        self.model.load_state_dict(state_dict)

    def build_model(self):
        self.model = BiEncoderModule(self.params)

    def encode_context(self, cands):
        token_idx_cands, segment_idx_cands, mask_cands = to_bert_input(
            cands, self.NULL_IDX
        )
        embedding_context, _ = self.model(
            token_idx_cands, segment_idx_cands, mask_cands, None, None, None
        )
        return embedding_context.cpu().detach()

    def encode_candidate(self, cands):
        token_idx_cands, segment_idx_cands, mask_cands = to_bert_input(
            cands, self.NULL_IDX
        )
        _, embedding_cands = self.model(
            None, None, None, token_idx_cands, segment_idx_cands, mask_cands
        )
        return embedding_cands.cpu().detach()

    # Score candidates given context input and label input
    # If cand_encs is provided (pre-computed), cand_ves is ignored
    def score_candidate(
        self,
        text_vecs,
        cand_vecs,
        cand_encs=None,  # pre-computed candidate encoding.
        if_turbulence=False,
    ):
        # Encode contexts first
        token_idx_ctxt, segment_idx_ctxt, mask_ctxt = to_bert_input(
            text_vecs, self.NULL_IDX
        )
        embedding_ctxt, _ = self.model(
            token_idx_ctxt, segment_idx_ctxt, mask_ctxt, None, None, None, if_turbulence 
        )

        # Candidate encoding is given, do not need to re-compute
        # Directly return the score of context encoding and candidate encoding
        if cand_encs is not None:
            return embedding_ctxt.mm(cand_encs.t())

        flag = False
        if len(cand_vecs.size()) == 3:
            bs, cs, ws = cand_vecs.size()
            cand_vecs = cand_vecs.reshape(-1, ws)
            flag = True

        # Train time. We compare with all elements of the batch
        token_idx_cands, segment_idx_cands, mask_cands = to_bert_input(
            cand_vecs, self.NULL_IDX
        )
        _, embedding_cands = self.model(
            None, None, None, token_idx_cands, segment_idx_cands, mask_cands, if_turbulence
        )
        if flag == True:
            embedding_cands = embedding_cands.reshape(bs, cs, -1)
            embedding_ctxt = embedding_ctxt.unsqueeze(1)
            return torch.bmm(embedding_ctxt, embedding_cands.transpose(1,2)).reshape(bs, cs)
        else:
            return embedding_ctxt.mm(embedding_cands.t())

    def forward(self, context_input, cand_input, if_turbulence=False):
        # return loss and score 
        scores = self.score_candidate(context_input, cand_input, if_turbulence=if_turbulence)
        bs = scores.size(0)
        target = torch.LongTensor([i for i in range(bs)])
        target = target.to(self.device)
        loss = F.cross_entropy(scores, target, reduction="mean")

        return loss, scores

def to_bert_input(token_idx, null_idx):
    """ token_idx is a 2D tensor int.
        return token_idx, segment_idx and mask
    """
    segment_idx = token_idx * 0
    mask = token_idx != null_idx
    # nullify elements in case self.NULL_IDX was not 0
    token_idx = token_idx * mask.long()
    return token_idx, segment_idx, mask
