import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence

from six import iteritems


def clampGrad(grad, limit=5.0):
    '''
    Gradient clip by value
    '''
    grad.data.clamp_(min=-limit, max=limit)
    return grad


def loadModel_Fashion(params, agent='abot', overwrite=False, overlap_fashion=None, overlap_visdial=None):
    if overwrite is False:
        params = params.copy()
    loadedParams = {}
    # should be everything used in encoderParam, decoderParam below
    encoderOptions = [
        'encoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers',
        'useHistory', 'useIm', 'imgEmbedSize', 'imgFeatureSize', 'numRounds',
        'dropout'
    ]
    decoderOptions = [
        'decoder', 'vocabSize', 'embedSize', 'rnnHiddenSize', 'numLayers',
        'dropout'
    ]
    modelOptions = encoderOptions + decoderOptions

    mdict = None
    gpuFlag = params['useGPU']
    continueFlag = params['continue']
    numEpochs = params['numEpochs']
    # startArg: the path to load the model parameters
    startArg = 'startFrom' if agent == 'abot' else 'qstartFrom'
    if continueFlag:
        assert params[startArg], "Can't continue training without a \
                                    checkpoint"

    # load a model from disk if it is given
    if params[startArg]:
        print('Loading model (weights and config) from {}'.format(
            params[startArg]))

        if gpuFlag:
            mdict = torch.load(params[startArg])
        else:
            mdict = torch.load(params[startArg],
                               map_location=lambda storage, location: storage)

        # Model options is a union of standard model options defined
        # above and parameters loaded from checkpoint
        modelOptions = list(set(modelOptions).union(set(mdict['params'])))
        for opt in modelOptions:
            if opt not in params:
                # Loading options from a checkpoint which are
                # necessary for continuing training, but are
                # not present in original parameter list.
                if continueFlag:
                    print("Loaded option '%s' from checkpoint" % opt)
                    params[opt] = mdict['params'][opt]
                    loadedParams[opt] = mdict['params'][opt]

            elif params[opt] != mdict['params'][opt]:
                # When continuing training from a checkpoint, overwriting
                # parameters loaded from checkpoint is okay.
                if continueFlag:
                    print("Overwriting param '%s'" % str(opt))
                    params[opt] = mdict['params'][opt]

        params['continue'] = continueFlag
        params['numEpochs'] = numEpochs
        params['useGPU'] = gpuFlag

        if params['continue']:
            assert 'ckpt_lRate' in params, "Checkpoint does not have\
                info for restoring learning rate and optimizer."

    # assert False, "STOP right there, criminal scum!"

    # Initialize model class
    encoderParam = {k: params[k] for k in encoderOptions}
    decoderParam = {k: params[k] for k in decoderOptions}

    encoderParam['startToken'] = encoderParam['vocabSize'] - 2
    encoderParam['endToken'] = encoderParam['vocabSize'] - 1
    decoderParam['startToken'] = decoderParam['vocabSize'] - 2
    decoderParam['endToken'] = decoderParam['vocabSize'] - 1

    # Initialize the models here
    if agent == 'abot':
        encoderParam['type'] = params['encoder']
        decoderParam['type'] = params['decoder']
        encoderParam['isAnswerer'] = True
        from visdial.models.answerer import Answerer
        model = Answerer(encoderParam, decoderParam)

    elif agent == 'qbot':
        encoderParam['type'] = params['qencoder']
        decoderParam['type'] = params['qdecoder']
        encoderParam['isAnswerer'] = False
        encoderParam['useIm'] = False
        encoderParam['imgEncodingMode'] = params[
            'imgEncodingMode']  # Added by Mingyang Zhou
        encoderParam['fuseType'] = params['fuseType']
        if params['fuseType'] == 4 and params['imgEncodingMode'] == "dual-view":
            decoderParam['numLayers'] = params['numLayers'] + 1
        from visdial.models.questioner import Questioner
        model = Questioner(
            encoderParam,
            decoderParam,
            imgFeatureSize=encoderParam['imgFeatureSize'])

    if params['useGPU']:
        model.cuda()

    for p in model.encoder.parameters():
        p.register_hook(clampGrad)
    for p in model.decoder.parameters():
        p.register_hook(clampGrad)
    # NOTE: model.parameters() should be used here, otherwise immediate
    # child modules in model will not have gradient clamping

    # copy parameters if specified
    if mdict:
        # Change the strict to False
        # model.load_state_dict(mdict['model'], strict=False)
        # Do not load policy_agent right now, hacked
        exclude_list = ['encoder.wordEmbed',
                        'decoder.wordEmbed', 'decoder.outNet']
        filtered_model = {}
        for k, v in mdict['model'].items():

            if 'encoder.wordEmbed' in k:
                print(v.size())
                new_v = model.encoder.wordEmbed.weight.data
                print(new_v.size())
                for x, y in zip(overlap_fashion, overlap_visdial):
                    new_v[x] = v[y]
                filtered_model[k] = new_v
            elif 'decoder.wordEmbed' in k:
                print(v.size())
                new_v = model.decoder.wordEmbed.weight.data
                print(new_v.size())
                for x, y in zip(overlap_fashion, overlap_visdial):
                    new_v[x] = v[y]
                filtered_model[k] = new_v
            elif 'decoder.outNet.weight' in k:
                print(v.size())
                new_v = model.decoder.outNet.weight.data
                for x, y in zip(overlap_fashion, overlap_visdial):
                    new_v[x] = v[y]
                filtered_model[k] = new_v
            elif 'decoder.outNet.bias' in k:
                print(v.size())
                new_v = model.decoder.outNet.bias.data
                for x, y in zip(overlap_fashion, overlap_visdial):
                    new_v[x] = v[y]
                filtered_model[k] = new_v
            else:
                filtered_model[k] = v

        model.load_state_dict(filtered_model, strict=False)
        optim_state = mdict['optimizer']
    else:
        optim_state = None
    return model, loadedParams, optim_state
