import argparse
import torchtext, random, torch

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
from tqdm import tqdm
import wandb

from ..box.box_wrapper import DeltaBoxTensor
from ..box.modules import BoxEmbedding
from ..box.pooling import LearntPooling
from ..box.pooling_utils import PerEntityAlpha
from .BaseModule import BaseModule

global use_cuda
use_cuda = torch.cuda.is_available()
device = 0 if use_cuda else -1


class BoxAffineTransform(BaseModule):
    box_types = {
        "DeltaBoxTensor": DeltaBoxTensor,
    }

    def __init__(
        self,
        TEXT=None,
        embedding_dim=50,
        batch_size=10,
        n_gram=4,
        volume_temp=1.0,
        intersection_temp=1.0,
        box_type="BoxTensorLearntTemp",
        pooling="avg_pool",
    ):
        super(BoxAffineTransform, self).__init__()

        # Basic model descriptions
        self.batch_size = batch_size
        self.n_gram = n_gram
        self.vocab_size = len(TEXT.vocab.itos)
        self.embedding_dim = embedding_dim

        # Box features
        self.volume_temp = volume_temp
        self.intersection_temp = intersection_temp
        self.box_type = box_type
        self.pooling = pooling
        if self.pooling == "learnt_pool":
            self.alpha_layer = PerEntityAlpha(
                self.vocab_size, embedding_dim, embedding_dim
            )
            self.pooling_layer = LearntPooling(self.alpha_layer)

        # Creat embeddings
        self.embeddings_word = BoxEmbedding(
            self.vocab_size, self.embedding_dim, box_type=box_type
        )
        self.embedding_bias = nn.Embedding(self.vocab_size, 1)
        self.embedding_bias.weight.data = torch.zeros(self.vocab_size, 1)

        self.position_delta_weight = nn.Embedding(
            num_embeddings=n_gram, embedding_dim=embedding_dim, sparse=False
        )
        self.position_delta_bias = nn.Embedding(
            num_embeddings=n_gram, embedding_dim=embedding_dim, sparse=False
        )
        self.position_min_weight = nn.Embedding(
            num_embeddings=n_gram, embedding_dim=embedding_dim, sparse=False
        )
        self.position_min_bias = nn.Embedding(
            num_embeddings=n_gram, embedding_dim=embedding_dim, sparse=False
        )

    def position_transformation(self, box, position):
        """
         Affine transform for boxes (Designed for Delta Boxes)
         Args:
            Box: The input boxes
            Position: The positions of the boxes for position specific transformation.
         Output:
            box: Transformed box
        """
        box_type = box.__class__.__name__
        weight_delta = self.position_delta_weight(position)
        weight_min = self.position_min_weight(position)
        bias_delta = self.position_delta_bias(position)
        bias_min = self.position_min_bias(position)

        # Affine transformation on the min.
        box.data[:, :, 0, :] = box.data[:, :, 0, :].clone() * weight_min + bias_min

        if box_type == "DeltaBoxTensor":
            # Affine transformation on the delta.
            box.data[:, :, 1, :] = nn.functional.softplus(
                box.data[:, :, 1, :].clone() * weight_delta + bias_delta
            )
        else:
            box.data[:, :, 1, :] = (
                box.data[:, :, 1, :].clone() * weight_delta + bias_delta
            )
        return box

    def forward(self, x, train=True):
        # get the context words from the ids (batch_size * n_grams)
        context_word_boxes = self.embeddings_word(x)
        all_gram_idx = (
            torch.arange(self.n_gram).cuda() if use_cuda else torch.arange(self.n_gram)
        )
        all_vocab_idx = (
            torch.arange(self.vocab_size).cuda()
            if use_cuda
            else torch.arange(self.vocab_size)
        )

        # Transformation and aggregation of the on the context words
        transformed_boxes = self.position_transformation(
            context_word_boxes, all_gram_idx
        )
        # transformed_boxes.data = torch.mean(transformed_boxes.data, dim=1)
        ### Pooling

        if self.pooling == "avg_pool":
            transformed_boxes.data = torch.mean(
                transformed_boxes.data, dim=1
            ).unsqueeze_(
                1
            )  # adding dim to broadcast

        elif self.pooling == "learnt_pool":
            output_box = self.pooling_layer(x, transformed_boxes, dim=1)
            output_box.data.unsqueeze_(1)

        all_word = self.embeddings_word(all_vocab_idx)
        all_word.data.unsqueeze_(0)  # adding dim to broadcast

        # Intersection volume of context word with all vocab.
        # NCE could be placed here.
        if self.intersection_temp == 0.0:
            dec = all_word.intersection_log_soft_volume(
                output_box, temp=self.volume_temp
            )
        else:
            dec = all_word.gumbel_intersection_log_volume(
                output_box,
                volume_temp=self.volume_temp,
                intersection_temp=self.intersection_temp,
            )

        # Adding embedding bias to account for the frequency (Minh et.al. (2011))
        decoded = dec + self.embedding_bias(all_vocab_idx).view(-1)
        logits = F.log_softmax(decoded, dim=1)
        return logits
