import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from transformers import AutoModel


class PretrainedModel(nn.Module):
    def __init__(self, model_name):
        super(PretrainedModel, self).__init__()
        self.model_name = model_name
        if model_name == 'criss':
            from fairseq import checkpoint_utils, tasks
            from fairseq.sequence_generator import EnsembleModel
            # TODO(freda) should be done by args
            args_path='../../data/criss-translation/criss_checkpoints/fairseq-args.pt'
            model_path = '../../data/criss-translation/criss_checkpoints/criss.3rd.pt'
            args = torch.load(args_path)
            task = tasks.setup_task(args)
            models, _ = checkpoint_utils.load_model_ensemble(model_path.split(':'), arg_overrides=eval('{}'), task=task)
            self.args = args
            self.model = EnsembleModel(models)
            self.o_dim = args.encoder_embed_dim * 2
            self.layer_weight = nn.Parameter(torch.ones(self.args.encoder_layers + 1))
        else:
            self.model = AutoModel.from_pretrained(model_name)
            self.layer_weight = nn.Parameter(torch.ones(self.model.config.num_hidden_layers + 1))
            self.o_dim = self.model.config.hidden_size * 2

    def forward(self, batch, masks):
        bsz = batch.shape[0]
        lengths = masks.sum((1, 2))
        if self.model_name == 'criss':
            subwords = pad_sequence(batch[masks].split(lengths.tolist()), True)
            with torch.no_grad():
                encoder_outs = self.model.single_model.encoder(subwords, lengths, return_all_hiddens=True)
            hidden_states = torch.stack(encoder_outs['encoder_states'], dim=-1)
            subword_feats = torch.einsum('l,...l->...', self.layer_weight.softmax(0), hidden_states).transpose(0, 1)
        else:
            subwords = pad_sequence(batch[masks].split(lengths.tolist()), True)
            attn_masks = pad_sequence(masks[masks].split(lengths.tolist()), True)
            with torch.no_grad():
                outputs = self.model(subwords, attention_mask=attn_masks, return_dict=True, output_hidden_states=True)
            hidden_states = torch.stack(outputs['hidden_states'], dim=-1)
            subword_feats = torch.einsum('l,...l->...', self.layer_weight.softmax(0), hidden_states)

        n_subwords = masks.sum(-1)
        left_pos = torch.cat([n_subwords.new_zeros(bsz, 1), n_subwords.cumsum(-1)[:, :-1]], dim=-1)
        right_pos = n_subwords.cumsum(-1) - 1
        gt_mask = left_pos.gt(right_pos).long()
        left_pos = (1 - gt_mask) * left_pos + gt_mask * right_pos
        left_feats = subword_feats[torch.arange(bsz).unsqueeze(-1), left_pos]
        right_feats = subword_feats[torch.arange(bsz).unsqueeze(-1), right_pos]
        return torch.cat([left_feats, right_feats], dim=-1)
    
    def set_grad_requirement(self, requirement=False):
        for x in self.model.parameters():
            x.requires_grad = requirement


class BiAffine(nn.Module):
    def __init__(self, ix_dim, iy_dim, o_dim=1):
        super(BiAffine, self).__init__()
        self.ix_dim = ix_dim
        self.iy_dim = iy_dim
        self.o_dim = o_dim
        self.weight = nn.Parameter(torch.Tensor(o_dim, ix_dim+1, iy_dim+1))
        self.init_parameters()

    def init_parameters(self, init_range=0.1, xavier=False):
        if xavier:
            init_range = np.sqrt(3/self.i_dim)
        nn.init.uniform_(self.weight, -init_range, init_range)

    # output: [bsz, xlen, ylen, o_dim]
    def forward(self, x, y):
        x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
        y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
        s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
        s = torch.einsum('boxy->bxyo', s).squeeze(-1)
        return s


class MLP(nn.Module):
    def __init__(self, i_dim, h_dims, o_dim, dropout):
        super(MLP, self).__init__()
        self.i_dim = i_dim
        self.h_dims = [h_dims] if isinstance(h_dims, int) else h_dims
        self.o_dim = o_dim
        layer_sequences = list()
        dims = [i_dim] + self.h_dims + [o_dim] if self.h_dims is not None else [i_dim, o_dim]
        for i in range(len(dims) - 1):
            layer_sequences.append(nn.Linear(dims[i], dims[i+1]))
            layer_sequences.append(nn.LeakyReLU(negative_slope=0.1))
            layer_sequences.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layer_sequences)

    def forward(self, batch):
        return self.model(batch)


class LSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers, dropout, bidirectional=False):
        super(LSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.rnn = nn.LSTM(input_size, hidden_size, n_layers, dropout=dropout, bidirectional=bidirectional)

    def forward(self, x, lengths):
        lengths, ids = lengths.sort(descending=True)
        _, rev_ids = ids.sort()
        packed_x = pack_padded_sequence(x.index_select(0, ids), lengths.to('cpu'), batch_first=True)
        outputs, _ = self.rnn(packed_x)
        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        return outputs.index_select(0, rev_ids)


class BiAffineDependencyParser(nn.Module):
    def __init__(
                self, i_dim=1536, proj_h_dim=512, proj_dropout=0.2, 
                lstm_h_dim=512, lstm_n_layers=2, lstm_dropout=0.3,
                head_dim=512, dep_dim=512, feat_dropout=0.2, 
                n_labels=10
            ):
        super(BiAffineDependencyParser, self).__init__()
        self.proj = MLP(i_dim, None, proj_h_dim, proj_dropout)
        self.lstm = LSTMEncoder(proj_h_dim, lstm_h_dim, lstm_n_layers, lstm_dropout, bidirectional=True)
        self.arc_h = MLP(lstm_h_dim*2, None, head_dim, feat_dropout)
        self.arc_d = MLP(lstm_h_dim*2, None, dep_dim, feat_dropout)
        self.label_h = MLP(lstm_h_dim*2, None, head_dim, feat_dropout)
        self.label_d = MLP(lstm_h_dim*2, None, dep_dim, feat_dropout)
        self.arc_scorer = BiAffine(head_dim, dep_dim, 1)
        self.label_scorer = BiAffine(head_dim, dep_dim, n_labels)

    def forward(self, batch, batch_mask):
        batch_feats = self.proj(batch)
        lstm_feats = self.lstm(batch_feats, batch_mask.sum(-1) + 1)
        arc_h, arc_d = self.arc_h(lstm_feats), self.arc_d(lstm_feats)
        label_h, label_d = self.label_h(lstm_feats), self.label_d(lstm_feats)
        arc_scores = self.arc_scorer(arc_h, arc_d).squeeze(-1)
        label_scores = self.label_scorer(label_h, label_d)
        arc_scores = arc_scores.masked_fill(~batch_mask.unsqueeze(-1), -1e10).masked_fill(~batch_mask.unsqueeze(1), -1e10)
        label_scores = label_scores.masked_fill(
            ~batch_mask.unsqueeze(-1).unsqueeze(-1), -1e10).masked_fill(~batch_mask.unsqueeze(1).unsqueeze(-1), -1e10)
        return arc_scores, label_scores


class Parser(object):
    def __init__(self, configs):
        self.configs = configs
        self.pretrained_model = PretrainedModel(configs.pretrain_name).to(configs.device)
        self.pretrained_model.set_grad_requirement(configs.pretrain_grad)
        if not self.configs.pretrain_grad:
            self.pretrained_model.eval()
        self.parser = BiAffineDependencyParser(
            i_dim=self.pretrained_model.o_dim,
            proj_h_dim=configs.hidden_size,
            proj_dropout=configs.dropout,
            lstm_h_dim=configs.lstm_hidden_size,
            lstm_n_layers=configs.lstm_n_layers,
            lstm_dropout=configs.lstm_dropout,
            head_dim=configs.hidden_size,
            dep_dim=configs.hidden_size,
            feat_dropout=configs.dropout,
            n_labels=configs.n_labels
        ).to(configs.device)

    def state_dict(self):
        if self.configs.pretrain_grad:
            return (self.pretrained_model.state_dict(), self.parser.state_dict())
        else:
            return (self.pretrained_model.layer_weight, self.parser.state_dict())

    def load_state_dict(self, state_dict):
        pretrain_state_dict, parser_state_dict = state_dict
        if self.configs.pretrain_grad:
            self.pretrained_model.load_state_dict(pretrain_state_dict)
        else:
            self.pretrained_model.layer_weight = pretrain_state_dict
        self.parser.load_state_dict(parser_state_dict)

    def parameters(self):
        if self.configs.pretrain_grad:
            return list(self.pretrained_model.parameters()) + list(self.parser.parameters())
        else:
            return [self.pretrained_model.layer_weight] + list(self.parser.parameters())

    def train(self):
        if self.configs.pretrain_grad:
            self.pretrained_model.train()
        self.parser.train()

    def eval(self):
        if self.configs.pretrain_grad:
            self.pretrained_model.eval()
        self.parser.eval()

    def forward(self, sentences, masks):
        word_hidden_states = self.pretrained_model(sentences, masks)
        word_masks = masks.sum(-1).bool()
        lengths = word_masks.sum(-1)
        word_masks[torch.arange(masks.shape[0]), lengths-1] = False
        arc_scores, label_scores = self.parser(word_hidden_states, word_masks)
        return arc_scores, label_scores, word_masks

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


if __name__ == '__main__':
    model = PretrainedModel('xlm-roberta-base').to('cuda')
    d_parser = BiAffineDependencyParser().to('cuda')
    from data import CoNLLDatasetCollection
    from torch.utils.data import DataLoader
    datasetc = CoNLLDatasetCollection('data/ud-2.0-acl13/universal_treebanks_v2.0/std/en/en-univiersal-{split}.conll', 'xlm-roberta-base')
    dataloader = DataLoader(datasetc['dev'], batch_size=8, shuffle=False, collate_fn=datasetc.collator)
    for words, masks, arcs, labels, info in dataloader:
        words = words.to('cuda')
        masks = masks.to('cuda')
        arcs = arcs.to('cuda')
        labels = labels.to('cuda')
        word_feats = model.forward(words, masks)
        word_masks = masks.sum(-1).bool()
        word_masks[torch.arange(word_masks.shape[0]), word_masks.sum(-1) - 1] = False
        arc_scores, label_scores = d_parser.forward(word_feats, word_masks)
        from algo import mst
        word_masks[:, 0] = False
        print(mst(arc_scores, word_masks))
        from IPython import embed; embed(using=False)
