from models.encoder import Encoder
from models.decoder import Decoder
from torch.nn.init import xavier_uniform_
import torch
from torch import nn
from models.model import NMTModel, SalienceModel
from models.citation import CitationModel
from utils.sparse_activations import LogSparsemax
from models.copy_generator import CopyGenerator
from models.GAT import GAT
from models.salience import Salience

class Cast(nn.Module):
    """
    Basic layer that casts its input to a specific data type. The same tensor
    is returned if the data type is already correct.
    """

    def __init__(self, dtype):
        super(Cast, self).__init__()
        self._dtype = dtype

    def forward(self, x):
        return x.to(self._dtype)

def build_model(model_opt, vocab_size, gpu, word_embeddings=None, checkpoint=None, param_init=0.1, param_init_glorot=True):
    """Build a model from opts.
    Args:
        model_opt: the option loaded from checkpoint. It's important that
            the opts have been updated and validated. See
            :class:`onmt.utils.parse.ArgumentParser`.
        vocab_size: the size of the vocabulary
        gpu (bool): whether to use gpu.
        word_embeddings: pretrained word embeddings
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
        gpu_id (int or NoneType): Which GPU to use.
    Returns:
        the NMTModel.
    """

    if gpu:
        device = torch.device("cuda")
    elif not gpu:
        device = torch.device("cpu")

    # Build encoder.
    if model_opt["train_mode"] == "2":
        model = CitationModel(model_opt, gpu)
        if checkpoint and "citation_checkpoint" in model_opt:
            print("Loading citation model from " + model_opt['citation_checkpoint'])
            if gpu:
                citation_checkpoint = torch.load(model_opt["citation_checkpoint"])
            else:
                citation_checkpoint = torch.load(model_opt["citation_checkpoint"], map_location=torch.device('cpu'))
            model.load_state_dict(citation_checkpoint)
            model.to(device)
            return model
    else:

        encoder = Encoder(model_opt, vocab_size)
        salience = Salience(model_opt)
        if model_opt["train_mode"] == "1":
            encoder.text_encoder.fine_tune_embeddings(True)
            model = SalienceModel(encoder, salience, gpu)
            if checkpoint and "salience_checkpoint" in model_opt:
                print("Loading salience model from " + model_opt['salience_checkpoint'])
                model.load_state_dict(torch.load(model_opt["salience_checkpoint"]))
                model.to(device)
                return model
        else:
            citation = CitationModel(model_opt, gpu)
            decoder = Decoder(model_opt, vocab_size)
            if model_opt["citation_function"] == "True":
                citation_func = True
            else:
                citation_func = False
            model = NMTModel(encoder, decoder, salience, citation, gpu, is_citation_func=citation_func)

            # Build Generator.
            if model_opt["copy_attn"] == "False":
                if model_opt.generator_function == "sparsemax":
                    gen_func = LogSparsemax(dim=-1)
                else:
                    gen_func = nn.LogSoftmax(dim=-1)
                generator = nn.Sequential(
                    nn.Linear(model_opt["dec_rnn_size"],
                              vocab_size),
                    Cast(torch.float32),
                    gen_func
                )
                if model_opt.share_decoder_embeddings:
                    generator[0].weight = decoder.embeddings.word_lut.weight
            else:
                pad_idx = 1
                generator = CopyGenerator(int(model_opt["dec_rnn_size"]), vocab_size, pad_idx, gpu)

    # Load the model states from checkpoint or initialize them.
    if checkpoint:
        model.generator = generator
        print("Loading model from " + model_opt['base_checkpoint'])
        if gpu:
            model_checkpoint = torch.load(model_opt['base_checkpoint'])
        else:
            model_checkpoint = torch.load(model_opt['base_checkpoint'], map_location=torch.device('cpu'))
        model.load_state_dict(model_checkpoint)
        # generator.load_state_dict(checkpoint['generator'], strict=False)
    else:
        if param_init != 0.0:
            for p in model.parameters():
                p.data.uniform_(-param_init, param_init)
            if model_opt['train_mode'] == "3":
                for p in generator.parameters():
                    p.data.uniform_(-param_init, param_init)
        if param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            if model_opt['train_mode'] == "3":
                for p in generator.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        # if hasattr(model.encoder.context_encoder, 'word_embedding') and word_embeddings is not None:
        #     model.encoder.context_encoder.load_pretrained_embeddings(word_embeddings)
        # if hasattr(model.encoder.abstract_encoder, 'word_embeddings') and word_embeddings is not None:
        #     model.encoder.abstract_encoder.load_pretrained_embeddings(word_embeddings)

        if model_opt["train_mode"] != "2":
            if hasattr(model.encoder.text_encoder, 'word_embeddings') and model_opt["train_from_salience"] == "False" \
                    and word_embeddings is not None:
                model.encoder.text_encoder.load_pretrained_embeddings(word_embeddings)
            if model_opt["train_mode"] != "1":
                if hasattr(model.decoder, 'embeddings') and word_embeddings is not None:
                    model.decoder.load_pretrained_embeddings(word_embeddings)

                if model_opt["train_from_salience"] == "True":
                    salience_model = SalienceModel(encoder, salience, gpu)
                    if "salience_checkpoint" in model_opt:
                        print("Loading salience model from " + model_opt['salience_checkpoint'])
                        if gpu:
                            salience_checkpoint = torch.load(model_opt["salience_checkpoint"])
                        else:
                            salience_checkpoint = torch.load(model_opt["salience_checkpoint"],
                                                             map_location=torch.device('cpu'))
                        salience_model.load_state_dict(salience_checkpoint)
                        model.encoder.load_state_dict(salience_model.encoder.state_dict())
                        model.salience.load_state_dict(salience_model.salience.state_dict())


                if "citation_checkpoint" in model_opt:
                    print("Loading citation model from " + model_opt['citation_checkpoint'])
                    if gpu:
                        citation_checkpoint = torch.load(model_opt["citation_checkpoint"])
                    else:
                        citation_checkpoint = torch.load(model_opt["citation_checkpoint"],
                                                         map_location=torch.device('cpu'))
                    model.citation.load_state_dict(citation_checkpoint)
                model.encoder.text_encoder.fine_tune_embeddings(True)
                model.decoder.fine_tune_embeddings(True)
                model.generator = generator
    model.to(device)

    return model