# -*- coding: utf-8 -*-

import argparse
import os
import pickle
import random

import torch

import pyshrg
from framework.common.logger import LOGGER, open_file
from framework.data.vocab import VocabularySet
from nn_generator.feature_based.model import SHRGGenerator, build_grammar_nonterminals

KEY_HRG = 'hrg.hrg_embeddings.weight'
KEY_CFG = 'cfg.cfg_embeddings.weight'


def get_embeddings(embeddings, indices, method):
    if method == 'average':
        t = embeddings.index_select(0, indices).mean(dim=0)
    elif method == 'random':
        t = embeddings[random.randrange(0, indices.size(0))]
    elif callable(method):  # count-based
        max_index = -1
        max_value = 0
        for index in indices:
            value = method(index)
            if max_value < value:
                max_value = value
                max_index = index
        t = embeddings[max_index]
    return t


def compute_hrg_embeddings(shrg_relations, hrg_weights, cfg_weights, method):
    manager = pyshrg.get_manager()
    num_old_cfg_size = cfg_weights.size(0)
    num_old_hrg_size = hrg_weights.size(0)
    extra_hrg_weights = torch.zeros(manager.hrg_size - num_old_hrg_size,
                                    hrg_weights.size(1),
                                    dtype=hrg_weights.dtype)

    if method == 'count-based':
        def method(x):
            return manager.get_hrg(x).num_occurences

    for shrg_index in range(num_old_cfg_size, manager.shrg_size):
        hrg_index = manager.index_of_hrg(manager.get_shrg(shrg_index))
        assert hrg_index >= num_old_hrg_size
        indices = torch.tensor([
            manager.index_of_hrg(manager.get_shrg(child_shrg_index))
            for child_shrg_index in shrg_relations[shrg_index]
        ], device=hrg_weights.device)

        offset = hrg_index - num_old_hrg_size
        extra_hrg_weights[offset] = get_embeddings(hrg_weights, indices, method=method)
    return extra_hrg_weights


def compute_cfg_embeddings(shrg_relations, cfg_weights, method):
    manager = pyshrg.get_manager()
    num_old_cfg_size = cfg_weights.size(0)
    extra_cfg_weights = torch.zeros(manager.shrg_size - num_old_cfg_size,
                                    cfg_weights.size(1),
                                    dtype=cfg_weights.dtype)

    if method == 'count-based':
        def method(x):
            return manager.get_shrg(x).num_occurences

    for shrg_index in range(num_old_cfg_size, manager.shrg_size):
        indices = torch.tensor(list(shrg_relations[shrg_index]))

        offset = shrg_index - num_old_cfg_size
        extra_cfg_weights[offset] = get_embeddings(cfg_weights, indices, method=method)

    return extra_cfg_weights


def main(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--method', '-m', default='average',
                        choices=['average', 'count-based', 'random'])
    parser.add_argument('model_path')
    parser.add_argument('grammar_dir')

    cmd_options = parser.parse_args(argv)
    method = cmd_options.method

    grammar_path = os.path.join(cmd_options.grammar_dir, 'train.mapping.txt.merged')
    relations_path = os.path.join(cmd_options.grammar_dir, 'train.merged.relations.p')

    saved_state = torch.load(cmd_options.model_path, torch.device('cpu'))
    saved_state['object'] = SHRGGenerator.make_release(saved_state['object'])
    del saved_state['object']['evaluator']

    pyshrg.initialize()
    manager = pyshrg.get_manager()
    assert manager.load_grammars(grammar_path)

    # compute new weights
    shrg_relations = pickle.load(open_file(relations_path, 'rb'))
    hrg_weights = saved_state['object']['network'][KEY_HRG]
    cfg_weights = saved_state['object']['network'][KEY_CFG]

    extra_hrg_weights = compute_hrg_embeddings(shrg_relations, hrg_weights, cfg_weights, method)
    extra_cfg_weights = compute_cfg_embeddings(shrg_relations, cfg_weights, method)

    saved_state['object']['network'][KEY_HRG] = torch.cat([hrg_weights, extra_hrg_weights], 0)
    saved_state['object']['network'][KEY_CFG] = torch.cat([cfg_weights, extra_cfg_weights], 0)

    # compute new `grammar_nonterminals`
    vocabs = VocabularySet().load_state_dict(saved_state['object']['statistics'])
    nonterminals = build_grammar_nonterminals(manager, vocabs.get('nonterminal'))
    saved_state['object']['nonterminals'] = nonterminals

    # set new grammars
    saved_state['object']['grammar'] = open(grammar_path, 'rb').read()

    output_path, ext = os.path.splitext(cmd_options.model_path)
    output_path = output_path + '.anonym' + ext
    LOGGER.info('> %s', output_path)
    torch.save(saved_state, output_path)


if __name__ == '__main__':
    main()
