import copy

import numpy as np
import torch
from torch import nn
from transformers import BertModel

from few_shot_ner.crf import ChainCRF


def init_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def euclidean_dist(x, y):
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    if d != y.size(1):
        raise Exception
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x - y, 2).sum(2)


class ProtoNet(nn.Module):
    def __init__(self, encoder, protonet_encoder, trans_nn, crf_decoder):
        super(ProtoNet, self).__init__()
        self.encoder = encoder
        self.protonet_encoder = protonet_encoder
        self.trans_nn = trans_nn
        self.crf_decoder = crf_decoder

    def encode(self, tokens, lengths, mask, sub_word_ids, sub_word_valid_lengths):
        _, _, layers_output = self.encoder(tokens, attention_mask=mask)
        # use sum of last 4 layers output as the subword representation
        encoded_inputs = layers_output[-4] + layers_output[-3] + layers_output[-2] + layers_output[-1]

        # pick word end representations
        word_out = encoded_inputs[torch.arange(encoded_inputs.shape[0]).unsqueeze(1), sub_word_ids]

        return self.protonet_encoder(word_out)

    def compute_prototypes(self, supports_embed, supports_intents, supports_slots):
        flattened_slots = torch.flatten(supports_slots)
        flattened_embed = supports_embed.reshape(-1, supports_embed.shape[-1])
        groups = torch.unique(flattened_slots)

        # remove pad group (assuming pad id is 0)
        groups, _ = groups.sort()
        groups = groups[1:]

        n_groups = len(groups)

        def supports_idxs(c):
            return flattened_slots.eq(c).nonzero().squeeze(1)

        support_idxs = list(map(supports_idxs, groups))
        prototypes = torch.stack([flattened_embed[idx_list].mean(0) for idx_list in support_idxs])
        prototypes_slots = torch.stack([flattened_slots[idx_list[0]] for idx_list in support_idxs])
        prototypes_sizes = torch.tensor([idx_list.size(0) for idx_list in support_idxs]).to(supports_embed.device)

        return prototypes, prototypes_slots, prototypes_sizes

    def predict_from_prototypes(self, queries, prototypes, prototypes_slots):
        # queries_embed of shape (B, L, H)
        queries_embed = self.encode(*queries)
        batch, length, embed_size = queries_embed.shape
        queries_embed = queries_embed.reshape(-1, embed_size)

        # emission scores of shape (B, L, num_prototypes)
        dists = euclidean_dist(queries_embed, prototypes)
        emissions = -dists
        emissions = emissions.reshape(batch, length, -1)

        # transition scores of shape (num_prototypes, num_prototypes)
        num_prototypes, embed_size = prototypes.shape
        p = prototypes.repeat(num_prototypes, 1)
        r = prototypes.unsqueeze(0).transpose(0, 1).repeat(1, num_prototypes, 1).view(-1, embed_size)
        trans_input = torch.cat((r, p), dim=1)
        transitions = self.trans_nn(trans_input).reshape(num_prototypes, num_prototypes)

        return emissions, transitions, prototypes_slots

    def forward(self, queries, supports, supports_intents, supports_slots):
        supports_embed = self.encode(*supports)
        prototypes, prototypes_slots, _ = self.compute_prototypes(supports_embed, supports_intents, supports_slots)
        return self.predict_from_prototypes(queries, prototypes, prototypes_slots)


class BaselineNet(nn.Module):
    def __init__(self, encoder, emission_net, crf_decoder):
        super(BaselineNet, self).__init__()
        self.encoder = encoder
        self.emission_net = emission_net
        self.crf_decoder = crf_decoder

    def encode(self, tokens, lengths, mask, sub_word_ids, sub_word_valid_lengths):
        _, _, layers_output = self.encoder(tokens, attention_mask=mask)
        # use sum of last 4 layers output as the subword representation
        encoded_inputs = layers_output[-4] + layers_output[-3] + layers_output[-2] + layers_output[-1]

        # pick word end representations
        word_out = encoded_inputs[torch.arange(encoded_inputs.shape[0]).unsqueeze(1), sub_word_ids]

        return self.emission_net(word_out)

    def forward(self, input):
        return self.encode(*input)


class MLPEncoder(nn.Module):
    def __init__(self, layer_sizes, dropout=0):
        super(MLPEncoder, self).__init__()
        self.dropout = dropout
        layers = []
        for i in range(len(layer_sizes)-1):
            layers.append(torch.nn.Dropout(p=dropout))
            layers.append(torch.nn.Linear(layer_sizes[i], layer_sizes[i+1]))
            layers.append(torch.nn.ELU())
        self.net = torch.nn.Sequential(*layers[:-1])

    def forward(self, input):
        return self.net(input)


@torch.no_grad()
def save(model):
    state_dict = copy.deepcopy(model.state_dict())
    return state_dict


@torch.no_grad()
def load(model, state_dict):
    model.load_state_dict(state_dict)


def build_baseline_model(args, device):
    # Load pre-trained BERT model
    encoder = BertModel.from_pretrained(args.bert_model_path).to(device=device)
    # Finetune BERT or not
    if args.encoder_learning_rate > 0:
        for p in encoder.parameters():
            p.requires_grad = True
        # We use only one linear layer
        emission_net = MLPEncoder([encoder.config.hidden_size, args.n + 2], args.decoder_dropout).to(device=device)
    else:
        for p in encoder.parameters():
            p.requires_grad = False
        # We use a larger NN
        emission_net = MLPEncoder(
            [encoder.config.hidden_size, 128, 128, args.n + 2],
            args.decoder_dropout).to(device=device)

    crf_decoder = ChainCRF(num_labels=args.n + 2).to(device=device)  # removing padding slot

    # Parameter-less modules
    base_net = BaselineNet(encoder, emission_net, crf_decoder).to(device=device)
    return base_net


def build_protonet_model(args, device):
    # Load pre-trained BERT model
    encoder = BertModel.from_pretrained(args.bert_model_path)
    # Enables dropout
    encoder.train()
    # Finetune BERT or not
    if args.encoder_learning_rate > 0:
        for p in encoder.parameters():
            p.requires_grad = True
        protonet_encoder = MLPEncoder([encoder.config.hidden_size, 64], args.decoder_dropout)
    else:
        for p in encoder.parameters():
            p.requires_grad = False
        # We use a larger NN
        protonet_encoder = MLPEncoder([encoder.config.hidden_size, 128, 128, 64], args.decoder_dropout)
    # Transition network
    trans_nn = torch.nn.Sequential(torch.nn.Linear(2 * 64, 64), torch.nn.ELU(), torch.nn.Linear(64, 1), torch.nn.ELU())
    crf_decoder = ChainCRF()
    # Parameter-less modules
    proto_net = ProtoNet(encoder, protonet_encoder, trans_nn, crf_decoder).to(device)
    return proto_net
