import torch
from autoencoders.base_encoder import BaseEncoder
import numpy as np
import math

class CMOWEncoder(BaseEncoder):
    def __init__(self, config):
        super(CMOWEncoder, self).__init__(config)

        if not int((math.sqrt(self.input_size)) ** 2) == self.input_size:
            raise ValueError("Dimension of the embeddings in CMOW must be a perfect square!")
        self.sqrt_size = int(math.sqrt(self.input_size))

        # add identity matrix to embedding weights
        #identity_matrix = np.zeros((self.vocab_size, config.input_size))
        #for i in range(self.vocab_size):
        #    identity_matrix[i,:] = np.reshape(np.eye(int(math.sqrt(config.input_size))), (-1))
        #self.embedding.weight = torch.nn.Parameter(self.embedding.weight.data + torch.tensor(identity_matrix).float())

        init_weights = _init_random_identity(self.vocab_size, config.input_size)
        neutral_elem = np.reshape(np.eye(int(math.sqrt(config.input_size))), (1, -1))
        init_weights = np.vstack([neutral_elem, init_weights])
        self.embedding.weight = torch.nn.Parameter(torch.tensor(init_weights).float())
            
    def _to_hidden_representation(self, embedded, lengths):
        return self._step_by_step(embedded, lengths)
        #return self._chain_matrix(embedded,lengths)

    def _step_by_step(self, embedded, lengths):
        print(embedded.size())
        embedded = embedded.view(embedded.size(0), embedded.size(1), self.sqrt_size, self.sqrt_size)
        result = embedded[:, 0, :, :]
        for i in range(1, embedded.size(1)):
            result = torch.bmm(result, embedded[:, i, :, :])
        return result.view(-1, self.input_size)

    def _chain_matrix(self, embedded, lengths):
        embedded = embedded.view(embedded.size(0), embedded.size(1), self.sqrt_size, self.sqrt_size)
        result = torch.zeros((embedded.size(0), self.sqrt_size * self.sqrt_size), device = embedded.device)
        for i in range(embedded.size(0)):
            e = embedded[i]
            res = torch.chain_matmul(*[e[i] for i in range(e.size(0))])
            result[i] = res.view(-1)
        return result

def _init_random_identity(vocab_size, embedding_size):
    """Random normal initialization around 0., but add 1. at the diagonal"""
    init_weights = np.random.normal(size = (vocab_size - 1, embedding_size),
                                                     loc = 0.,
                                                     scale = 0.1
                                                     ).astype(np.float32)
    for i in range(vocab_size - 1):
        init_weights[i, :] += np.reshape(np.eye(int(np.sqrt(embedding_size)), dtype=np.float32), (-1,))
    init_weights = torch.from_numpy(init_weights)
    return init_weights
