import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from emb2emb.utils import Namespace
import random
from autoencoder import Encoder
import os

ENCODER_URLS = {
    "use": "https://tfhub.dev/google/universal-sentence-encoder/4",
    "use-large": "https://tfhub.dev/google/universal-sentence-encoder-large/5",
    "muse": "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3",
    "muse-large": "https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3"
}

class PretrainedEncoder(Encoder):
    def __init__(self, config):
        super(PretrainedEncoder, self).__init__(config)

        self.gaussian_noise_std = config.gaussian_noise_std
        self.unit_sphere = config.unit_sphere

        self.config = config
        self.device = config.device

        self.hidden_size = config.hidden_size
        if self.hidden_size != 512:
            print("ERROR: Pretrained encoders require hidden sizes of 512. Exiting")
            exit()

        self.variational = config.variational

        self.max_sequence_len = config.max_sequence_len
        self.input_size = config.hidden_size
        self.tokenizer = config.tokenizer

        # Get rid of random tensorflow warnings.
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
        import tensorflow as tf
        import tensorflow_hub as hub
        import tensorflow_text
        self.encoder = hub.load(ENCODER_URLS[config.pretrained_encoder_name])

        if config.variational:
            self.hidden2mean = nn.Linear(self.hidden_size, self.hidden_size)
            self.hidden2logv = nn.Linear(self.hidden_size, self.hidden_size)

    def encode(self, x, lengths, train=False, reparameterize = True):
        # Extract the sentences back to strings, and remove special tokens
        sentences = self.tokenizer.decode_batch(
            [s[1:l-1].tolist() for s, l in zip(x.cpu().detach(), lengths.cpu().detach())])

        # Encode and turn into tensors.
        h = torch.tensor(self.encoder(sentences).numpy(), device=self.device)

        if self.variational:
            mean = self.hidden2mean(h)
            logv = self.hidden2logv(h)
            std = torch.exp(0.5 * logv)
            if reparameterize:
                h = torch.randn(x.shape[0], self.hidden_size, device=self.device) * std + mean
            else:
                h = mean

        if self.unit_sphere:
            h = h / h.norm(p = None, dim = -1, keepdim=True)

        # (batch, hidden_size)
        if train and self.variational:
            return h, mean, logv
        else:
            return h