import math
import torch
from torch import nn
import torch.nn.functional as F
import torchtext


class StyleFlow(nn.Module):

    def __init__(self, config, vocab):
        super(StyleFlow, self).__init__()
        num_styles, num_layers = config.num_styles, config.num_layers
        d_model, max_length = config.d_model, config.max_length

        self.max_length = config.max_length
        self.eos_idx = vocab.stoi['<eos>']
        self.pad_idx = vocab.stoi['<pad>']

        vectors = torchtext.vocab.GloVe('6B', dim=config.embed_size, cache=config.pretrained_embed_path)
        vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim)
        self.word_embeddings = nn.Embedding.from_pretrained(vocab.vectors).to(config.device)

        self.style_embeddings = Embedding(num_styles, d_model)
        self.AaCL = AttentionAwareCouplingLayer(config, vocab)
        self.norm = LayerNorm(config.d_model)

    def embedding(self, x):
        return self.word_embeddings(x)


    def forward(self, input, gold_tokens, inp_lengths, style, reverse=False, source_content=None, temperature=1.0):
        batch_size, max_enc_len = input.size(0), input.size(1)

        pos_idx = torch.arange(self.max_length).unsqueeze(0).expand((batch_size, -1)).to(inp_lengths.device)

        src_mask = (pos_idx[:, :max_enc_len] >= inp_lengths.unsqueeze(-1)).view(batch_size, 1, 1, max_enc_len)
        tgt_length = self.max_length
        tgt_mask = torch.ones((tgt_length, tgt_length)).to(src_mask.device)
        tgt_mask = (tgt_mask.tril() == 0).view(1, 1, tgt_length, tgt_length)

        if not reverse:
            source_content, source_style = self.AaCL(input, gold_tokens, inp_lengths, tgt_mask, reverse=False)
            return source_content, source_style
        else:
            target_content = source_content
            target_style = self.style_embed(style)
            target_text = self.AaCL(target_content, target_style, reverse=True)
            return target_text


class AttentionAwareCouplingLayer(nn.Module):
    def __init__(self, config, vocab):
        super(AttentionAwareCouplingLayer, self).__init__()
        self.Aadl = AttentionAwareDisentanglementLayer(config, vocab)
        self.CL = CouplingLayer(config, vocab)

    def forward(self, input, gold_tokens, inp_lengths, tgt_mask, target_content, target_style, reverse=False):
        if not reverse:
            z_c, z_s = self.Aadl(input, gold_tokens, inp_lengths, tgt_mask)
            source_content, source_style = self.CL(z_c, z_s)
            return source_content, source_style
        if reverse:
            t_c, t_s = self.CL(target_content, target_style)
            return t_c, t_s


class AttentionAwareDisentanglementLayer(nn.Module):
    def __init__(self, config, vocab):
        super(AttentionAwareDisentanglementLayer, self).__init__()
        self.style_attention = MLPAttention(hidden_dim=config.embed_size, att_dim=config.attn_size)

    def forward(self, x, src_mask, gold_tokens, style):
        # ***************sampling   select******************************
        # to be continued
        x_id, x_change = AttentionAwareSplitSelectDim3(x, style)  # 先对半分
        src_mask_id, src_mask_change = AttentionAwareSplitSelectDim4(src_mask, style)
        inp_tokens_id, inp_tokens_change = AttentionAwareSplitSelectDim2(gold_tokens, style)
        return x_id, x_change, src_mask_id, src_mask_change, inp_tokens_id, inp_tokens_change


class CouplingLayer(nn.Module):
    def __init__(self, config, vocab):
        super(CouplingLayer, self).__init__()
        self.dim = config.d_model

        self.block1 = TransformerBlock(config, vocab)
        self.block1 = TransformerBlock(config, vocab)
        self.norm = LayerNorm(embedding_dim=self.dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x1, x2, reverse=False):
        if not reverse:
            z1 = x1
            beta = self.block1(x1)
            gamma = self.block1(x1)
            z2 = torch.mm(beta, x2) + gamma
            output = self.relu(self.norm(torch.cat([z1, z2], dim=1)))
            return output, 0.0
        if reverse:
            z_1 = x1
            beta = self.block1(x1)
            gamma = self.block1(x1)
            z_2 = torch.div((x2 - gamma), beta)
            return z_1, z_2


class TransformerBlock(nn.Module):
    def __init__(self, config, vocab):
        super(TransformerBlock, self).__init__()
        num_styles, num_layers = config.num_styles, config.num_layers
        d_model, max_length = config.d_model, config.max_length
        h, dropout = config.h, config.dropout

        self.input_emb = nn.Embedding(len(vocab), d_model)
        self.max_length = config.max_length
        self.eos_idx = vocab.stoi['<eos>']
        self.pad_idx = vocab.stoi['<pad>']

        self.encoder = Encoder(num_layers, d_model, len(vocab), h, dropout)
        self.decoder = Decoder(num_layers, d_model, len(vocab), h, dropout)

    def forward(self, enc_input, src_mask, tgt_mask, pos_idx, batch_size, gold_tokens, temperature=1.0):
        memory = self.encoder(enc_input, src_mask)
        max_dec_len = gold_tokens.size(1)

        dec_input_emb = self.input_emb(gold_tokens)
        y = self.decoder(
            dec_input_emb, memory,
            src_mask, tgt_mask[:, :, :max_dec_len, :max_dec_len],
            temperature
        )

        return y


class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, vocab_size, h, dropout):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, h, dropout) for _ in range(num_layers)])
        self.norm = LayerNorm(d_model)

    def forward(self, x, mask):
        y = x

        assert y.size(1) == mask.size(-1)

        for layer in self.layers:
            y = layer(y, mask)

        return self.norm(y)


class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, vocab_size, h, dropout):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, h, dropout) for _ in range(num_layers)])
        self.norm = LayerNorm(d_model)
        self.generator = Generator(d_model, vocab_size)

    def forward(self, x, memory, src_mask, tgt_mask, temperature):
        y = x

        assert y.size(1) == tgt_mask.size(-1)

        for layer in self.layers:
            y = layer(y, memory, src_mask, tgt_mask)

        return y

    def incremental_forward(self, x, memory, src_mask, tgt_mask, temperature, prev_states=None):
        y = x

        new_states = []

        for i, layer in enumerate(self.layers):
            y, new_sub_states = layer.incremental_forward(
                y, memory, src_mask, tgt_mask,
                prev_states[i] if prev_states else None
            )

            new_states.append(new_sub_states)

        new_states.append(torch.cat((prev_states[-1], y), 1) if prev_states else y)
        y = self.norm(new_states[-1])[:, -1:]
        return y, new_states
        # return self.generator(y, temperature), new_states


class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x, temperature):
        return F.log_softmax(self.proj(x) / temperature, dim=-1)


class EmbeddingLayer(nn.Module):
    def __init__(self, vocab, d_model, max_length, pad_idx, learned_pos_embed, load_pretrained_embed):
        super(EmbeddingLayer, self).__init__()
        self.token_embed = nn.Embedding(len(vocab), d_model)
        self.pos_embed = nn.Embedding(max_length, d_model)
        self.vocab_size = len(vocab)
        if load_pretrained_embed:
            self.token_embed = nn.Embedding.from_pretrained(vocab.vectors)
            print('embed loaded.')

    def forward(self, x, pos):
        if len(x.size()) == 2:
            y = self.token_embed(x) + self.pos_embed(pos)
        else:
            y = torch.matmul(x, self.token_embed.weight) + self.pos_embed(pos)

        return y


class EncoderLayer(nn.Module):
    def __init__(self, d_model, h, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, h, dropout)
        self.pw_ffn = PositionwiseFeedForward(d_model, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)])

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.pw_ffn)


class DecoderLayer(nn.Module):
    def __init__(self, d_model, h, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, h, dropout)
        self.src_attn = MultiHeadAttention(d_model, h, dropout)
        self.pw_ffn = PositionwiseFeedForward(d_model, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)])

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.pw_ffn)

    def incremental_forward(self, x, memory, src_mask, tgt_mask, prev_states=None):
        new_states = []
        m = memory

        x = torch.cat((prev_states[0], x), 1) if prev_states else x
        new_states.append(x)
        x = self.sublayer[0].incremental_forward(x, lambda x: self.self_attn(x[:, -1:], x, x, tgt_mask))
        x = torch.cat((prev_states[1], x), 1) if prev_states else x
        new_states.append(x)
        x = self.sublayer[1].incremental_forward(x, lambda x: self.src_attn(x[:, -1:], m, m, src_mask))
        x = torch.cat((prev_states[2], x), 1) if prev_states else x
        new_states.append(x)
        x = self.sublayer[2].incremental_forward(x, lambda x: self.pw_ffn(x[:, -1:]))
        return x, new_states


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h, dropout):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.head_projs = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.fc = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask):
        batch_size = query.size(0)

        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for x, l in zip((query, key, value), self.head_projs)]

        attn_feature, _ = ScaledAttention(query, key, value, mask)

        attn_concated = attn_feature.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.fc(attn_concated)


def ScaledAttention(query, key, value, mask):
    d_k = query.size(-1)
    scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(d_k)
    scores.masked_fill_(mask, float('-inf'))
    attn_weight = F.softmax(scores, -1)
    attn_feature = attn_weight.matmul(value)

    return attn_feature, attn_weight


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, dropout):
        super(PositionwiseFeedForward, self).__init__()
        self.mlp = nn.Sequential(
            Linear(d_model, 4 * d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            Linear(4 * d_model, d_model),
        )

    def forward(self, x):
        return self.mlp(x)


class SublayerConnection(nn.Module):
    def __init__(self, d_model, dropout):
        super(SublayerConnection, self).__init__()
        self.layer_norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        y = sublayer(self.layer_norm(x))
        return x + self.dropout(y)

    def incremental_forward(self, x, sublayer):
        y = sublayer(self.layer_norm(x))
        return x[:, -1:] + self.dropout(y)


def Linear(in_features, out_features, bias=True, uniform=True):
    m = nn.Linear(in_features, out_features, bias)
    if uniform:
        nn.init.xavier_uniform_(m.weight)
    else:
        nn.init.xavier_normal_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.)
    return m


def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.xavier_uniform_(m.weight)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


def LayerNorm(embedding_dim, eps=1e-6):
    m = nn.LayerNorm(embedding_dim, eps)
    return m


def AttentionAwareSplitSelectDim4(input, style):
    attention_indices = style.topk(8)[1]
    attention_one_hot = F.one_hot(attention_indices).float()
    id_ = torch.einsum("sxyb,sib->sxyi", input.float(), attention_one_hot)

    reverse_attention_one_hot = 1 - attention_one_hot
    change_ = torch.einsum("sxyb,sib->sxyi", input.float(), reverse_attention_one_hot)
    return id_, change_


def AttentionAwareSplitSelectDim3(input, style):
    attention_indices = style.topk(8)[1]
    attention_one_hot = F.one_hot(attention_indices).float()
    id_ = torch.einsum("sbe,sib->sie", input, attention_one_hot)

    reverse_attention_one_hot = 1 - attention_one_hot
    change_ = torch.einsum("sbe,sib->sie", input, reverse_attention_one_hot)
    return id_, change_


def AttentionAwareSplitSelectDim2(input, style):
    attention_indices = style.topk(8)[1]
    attention_one_hot = F.one_hot(attention_indices).float()
    id_ = torch.einsum("sb,sib->si", input.float(), attention_one_hot)

    reverse_attention_one_hot = 1 - attention_one_hot
    change_ = torch.einsum("sb,sib->si", input.float(), reverse_attention_one_hot)
    return id_, change_


class BERTArchitecture(nn.Module):
    def __init__(self, bert):
        super(BERTArchitecture, self).__init__()
        self.bert = bert
        # dropout layer
        self.dropout = nn.Dropout(0.2)
        # relu activation function
        self.relu = nn.ReLU()
        # dense layer 1
        self.fc1 = nn.Linear(768, 512)
        # dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512, 2)
        # softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    # define the forward pass
    def forward(self, sent_id, mask):
        # pass the inputs to the model
        _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        # output layer
        x = self.fc2(x)
        # apply softmax activation
        x = self.softmax(x)
        return x


class SentimentClassifier(nn.Module):
    def __init__(self, hidden_size, num_class=2, dropout=0.3):
        super(SentimentClassifier, self).__init__()
        self.hidden = hidden_size
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(self.hidden, num_class)

    def forward(self, attn_output):
        return self.linear(attn_output).squeeze()


class MLPAttention(nn.Module):
    def __init__(self, hidden_dim, att_dim, dropout=0.4):
        super(MLPAttention, self).__init__()

        self.W_k = nn.Linear(hidden_dim, att_dim, bias=True)
        self.v = nn.Linear(att_dim, 1, bias=True)
        self.dropout = nn.Dropout(dropout)
        self.tanh = nn.Tanh()
        # self.linear = nn.Linear(hidden_dim, 2)

    def forward(self, encoder_outputs, lengths, temp=1):
        V = encoder_outputs
        max_len = encoder_outputs.size(1)

        pos_idx = torch.arange(max_len).unsqueeze(0)
        pos_idx = pos_idx.to(lengths.device)
        src_mask = pos_idx[:, :max_len] >= (lengths).unsqueeze(-1)
        encoder_outputs = self.W_k(encoder_outputs)
        encoder_outputs = self.dropout(encoder_outputs)
        encoder_outputs = self.tanh(encoder_outputs)
        scores = self.v(encoder_outputs).squeeze(-1)

        ## MASK OUT PAD
        scores = scores.masked_fill(src_mask, float("-inf"))
        scores = torch.softmax(scores / temp, dim=0).unsqueeze(-1)

        attn_hidden = torch.bmm(V.permute(0, 2, 1), scores)

        reverse_scores = 1 - scores.squeeze(2)
        reverse_scores = reverse_scores.masked_fill(src_mask, float(0)).view(-1, max_len, 1)

        return scores.squeeze(2), reverse_scores.squeeze(2), attn_hidden.squeeze(2), src_mask
