# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import copy
import json
import itertools
import random
from typing import (
    List,
    Union,
)

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import wandb
from jsonargparse import ArgumentParser

import egg.core as core
from egg.core import EarlyStopperAccuracy
from egg.core.callbacks import (
    WandbLogger,
    Callback,
)
from egg.zoo.compo_vs_generalization.archs import (
    Freezer,
    NonLinearReceiver,
    PlusOneWrapper,
    Receiver,
    Sender,
    MultiRnnReceiverDeterministic,
    SenderExpansionWrapper,
    MultiRnnReceiverReinforce,
    ReceiverLogProbWrapper,
)
from egg.zoo.compo_vs_generalization.data import (
    ScaledDataset,
    enumerate_attribute_value,
    one_hotify,
    select_subset_V1,
    select_subset_V2,
    split_holdout,
    split_train_test,
    split_train_val_test,
)
from egg.zoo.compo_vs_generalization.intervention import (
    Evaluator, 
    Metrics,
    ValBestCheckpoint,
    RandomResetter,
    UniformResetter,
    DetailedWandbLogger,
    SimultaneousResetter,
)
from egg.zoo.compo_vs_generalization.losses import (
    DiffLoss,
    MultiDiffLoss,
)


def _set_seed(seed):
    import random

    import numpy as np

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_params(params):
    parser = ArgumentParser()
    parser.add_argument("--n_attributes", type=int, default=4, help="")
    parser.add_argument("--n_values", type=int, default=4, help="")
    parser.add_argument("--data_scaler", type=int, default=100)
    parser.add_argument("--stats_freq", type=int, default=0)
    parser.add_argument("--val_eval_freq", type=int, default=1)
    parser.add_argument("--test_eval_freq", type=int, default=0)
    parser.add_argument(
        "--baseline", type=str, choices=["no", "mean", "builtin"], default="mean"
    )
    parser.add_argument(
        "--density_data", type=int, default=0, help="no sampling if equal 0"
    )

    parser.add_argument(
        "--sender_hidden",
        type=int,
        default=50,
        help="Size of the hidden layer of Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_hidden",
        type=int,
        default=50,
        help="Size of the hidden layer of Receiver (default: 10)",
    )

    parser.add_argument(
        "--sender_entropy_coeff",
        type=float,
        default=1e-2,
        help="Entropy regularisation coeff for Sender (default: 1e-2)",
    )

    parser.add_argument("--sender_cell", type=str, default="rnn")
    parser.add_argument("--receiver_cell", type=str, default="rnn")
    parser.add_argument(
        "--sender_emb",
        type=int,
        default=10,
        help="Size of the embeddings of Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_emb",
        type=int,
        default=10,
        help="Size of the embeddings of Receiver (default: 10)",
    )
    parser.add_argument(
        "--early_stopping_thr",
        type=float,
        default=0.999,
        help="Early stopping threshold on accuracy (defautl: 0.99999)",
    )
    parser.add_argument(
        "--eol_early_stopping_thr",
        type=float,
        default=0.999,
        help="Early stopping threshold on accuracy (defautl: 0.99999)",
    )
    parser.add_argument(
        "--load_val_best",
        type=bool,
        default=True,
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default='div_int',
    )
    parser.add_argument(
        "--wandb_name",
        type=str,
        default='div_int',
    )
    parser.add_argument(
        "--wandb_tags",
        type=List[str],
        default='div_int',
    )
    parser.add_argument(
        "--n_att_n_comb_n_dup",
        type=str,
        default='None',
    )
    parser.add_argument(
        "--validation_ratio",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--test_ratio",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--ckpt_path",
        type=str,
        default='checkpoints',
    )
    parser.add_argument(
        "--loss_type",
        type=str,
        default='cross_entropy',
    )
    parser.add_argument(
        "--variable_len",
        type=bool,
        default=False,
    )
    parser.add_argument(
        "--group_size",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--sender_reset_period",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--receiver_reset_period",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--reset_type",
        type=str,
        default='uniform',
    )
    parser.add_argument(
        "--reset_on_epoch",
        type=bool,
        default=True,
    )
    parser.add_argument(
        "--eol_n_epochs",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--topsim_max_samples",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--cross_entropy_weight",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--preserve_eos",
        type=bool,
        default=False,
    )
    parser.add_argument(
        "--metric_exclude_eos",
        type=bool,
        default=False,
    )

    args = core.init(arg_parser=parser, params=params)
    return args


def main(params):
    import copy

    opts = get_params(params)
    device = opts.device
    assert not opts.variable_len
    _set_seed(opts.random_seed)

    full_data = enumerate_attribute_value(opts.n_attributes, opts.n_values)
    if opts.density_data > 0:
        sampled_data = select_subset_V2(
            full_data, opts.density_data, opts.n_attributes, opts.n_values
        )
        full_data = copy.deepcopy(sampled_data)

    train, validation, test = split_train_val_test(
        full_data, 
        validation_ratio=opts.validation_ratio,
        test_ratio=opts.test_ratio,
    )

    # generalization_holdout, train, uniform_holdout, full_data = [
    train, validation, test, full_data = [
        one_hotify(x, opts.n_attributes, opts.n_values)
    #     for x in [generalization_holdout, train, uniform_holdout, full_data]
        for x in [train, validation, test, full_data]
    ]

    if len(validation) > 0:
        validation = ScaledDataset(validation, 1)
    else:
        validation = ScaledDataset(train, 1)
    train = ScaledDataset(train, opts.data_scaler)
    test = ScaledDataset(test, 1)
    full_data = ScaledDataset(full_data)

    train_loader, test_loader, full_data_loader = [
        DataLoader(x, batch_size=opts.batch_size)
        for x in [train, test, full_data]
    ]
    validation_loader = DataLoader(validation, batch_size=len(validation))

    n_dim = opts.n_attributes * opts.n_values

    opts.n_att_n_comb_n_dup = eval(opts.n_att_n_comb_n_dup)
    if opts.n_att_n_comb_n_dup is None:
        opts.n_att_n_comb_n_dup = [(opts.n_attributes, 1, 1)]

    att_indices = []
    rng = np.random.default_rng(opts.random_seed)
    for n_att, n_comb, n_dup in opts.n_att_n_comb_n_dup:
        combs = list(itertools.combinations(range(opts.n_attributes), n_att))
        att_idx = rng.choice(len(combs), replace=False, size=n_comb)
        att_idx = np.tile(att_idx, n_dup)
        att_indices += [combs[i] for i in att_idx]

    receivers = []
    for n_att, n_comb, n_dup in opts.n_att_n_comb_n_dup:
        for _ in range(n_comb):
            for _ in range(n_dup):
                if opts.receiver_cell in ["lstm", "rnn", "gru"]:
                    receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
                    receiver = core.RnnReceiverDeterministic(
                        receiver,
                        opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
                        opts.receiver_emb,
                        opts.receiver_hidden,
                        cell=opts.receiver_cell,
                    )
                else:
                    raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")
                receivers.append(receiver)
    if opts.loss_type == 'cross_entropy':
        receiver = MultiRnnReceiverDeterministic(receivers)
    elif opts.loss_type in ['task_success', 'mixed']:
        receiver = MultiRnnReceiverReinforce(
            receivers, 
            att_indices=att_indices, 
            n_attribuits=opts.n_attributes,
        )
    else:
        raise ValueError()

    if opts.sender_cell in ["lstm", "rnn", "gru"]:
        sender = Sender(n_inputs=n_dim, n_hidden=opts.sender_hidden)
        # sender = SenderExpansionWrapper(sender, len(att_indices))
        sender = core.RnnSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_emb,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            cell=opts.sender_cell,
        )
    else:
        raise ValueError(f"Unknown sender cell, {opts.sender_cell}")
    
    if not opts.variable_len:
        sender = PlusOneWrapper(sender, preserve_eos=opts.preserve_eos)
    loss = MultiDiffLoss(
        att_indices=att_indices, 
        n_values=opts.n_values, 
        loss_type=opts.loss_type, 
        group_size=opts.group_size,
        cross_entropy_weight=opts.cross_entropy_weight,
    )

    baseline = {
        "no": core.baselines.NoBaseline,
        "mean": core.baselines.MeanBaseline,
        "builtin": core.baselines.BuiltInBaseline,
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline,
    )
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    max_idx = np.argmax([len(att) for att in att_indices])
    metrics_evaluators = [
        Metrics(
            dataset.examples,
            opts.device,
            opts.n_attributes,
            opts.n_values,
            opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
            freq=opts.stats_freq,
            name=name,
            max_len=opts.max_len, # Not counting EOS.
            topsim_max_samples=opts.topsim_max_samples,
            seed=opts.random_seed,
            batch_size=opts.batch_size,
            receiver=ReceiverLogProbWrapper(
                receivers[max_idx], 
                n_attributes=opts.n_attributes,
                n_values=opts.n_values,
                att_idx=att_indices[max_idx],
            ),
            receiver_dataset=ScaledDataset(dataset.examples),
            save_protocol=True if name == 'full' else False,
            exclude_eos=opts.metric_exclude_eos,
        ) for dataset, name in zip(
            [train, validation, test, full_data],
            ['train', 'val', 'test', 'full']
        )
    ]

    train_val_loaders = [
        (
            "train",
            train_loader,
            MultiDiffLoss(
                att_indices=att_indices, 
                n_values=opts.n_values, 
                loss_type=opts.loss_type,
                group_size=opts.group_size,
                cross_entropy_weight=opts.cross_entropy_weight,
            ),
        ),
        (
            "val",
            validation_loader,
            MultiDiffLoss(
                att_indices=att_indices, 
                n_values=opts.n_values, 
                loss_type=opts.loss_type,
                group_size=opts.group_size,
                cross_entropy_weight=opts.cross_entropy_weight,
            ),
        ),
    ]

    test_full_loaders = [
        (
            "test",
            test_loader,
            MultiDiffLoss(
                att_indices=att_indices, 
                n_values=opts.n_values, 
                loss_type=opts.loss_type,
                group_size=opts.group_size,
                cross_entropy_weight=opts.cross_entropy_weight,
            ),
        ),
        (
            "full",
            full_data_loader,
            MultiDiffLoss(
                att_indices=att_indices, 
                n_values=opts.n_values, 
                loss_type=opts.loss_type,
                group_size=opts.group_size,
                cross_entropy_weight=opts.cross_entropy_weight,
            ),
        ),
    ]
    train_val_evaluator = Evaluator(train_val_loaders, opts.device, freq=opts.val_eval_freq)
    test_full_evaluator = Evaluator(test_full_loaders, opts.device, freq=opts.test_eval_freq)
    early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)

    wandb_logger = DetailedWandbLogger(
                opts=opts, 
                project=opts.wandb_project, 
                name=opts.wandb_name, 
                tags=opts.wandb_tags,
    )
    
    if opts.load_val_best: 
        val_best_saver = ValBestCheckpoint(
            checkpoint_path=f'{opts.ckpt_path}/{wandb.run.project}/{wandb.run.id}',
            checkpoint_freq=opts.val_eval_freq,
            prefix='',
        )
    else:
        val_best_saver = Callback()
    
    if opts.reset_type == 'random':
        resetter = RandomResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    elif opts.reset_type == 'uniform':
        resetter = UniformResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    elif opts.reset_type == 'simultaneous':
        resetter = SimultaneousResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    else:
        raise ValueError()
        
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=[
            val_best_saver,
            core.ConsoleLogger(as_json=True, print_train_loss=False),
            wandb_logger,
            early_stopper,
            train_val_evaluator,
            test_full_evaluator,
            resetter,
            *metrics_evaluators,
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)

    core.get_opts().preemptable = False
    core.get_opts().checkpoint_path = None

    # freeze Sender and probe how fast a simple Receiver will learn the thing
    frozen_sender = Freezer(copy.deepcopy(sender))

    for n_att in range(1, opts.n_attributes + 1):
        wandb.define_metric(f'eol_{n_att}_step')
        wandb.define_metric(f'eol_{n_att}/*', f'eol_{n_att}_step')
        wandb_logger.global_prefix = f'eol_{n_att}/'
        _set_seed(opts.random_seed)

        att_indices = []
        rng = np.random.default_rng(opts.random_seed)
        combs = list(itertools.combinations(range(opts.n_attributes), n_att))
        att_idx = rng.choice(len(combs), replace=False, size=1)
        att_indices += [combs[i] for i in att_idx]

        receivers = []
        if opts.receiver_cell in ["lstm", "rnn", "gru"]:
            receiver = Receiver(n_hidden=opts.receiver_hidden, n_outputs=n_dim)
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
                opts.receiver_emb,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
            )
        else:
            raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")
        receivers.append(receiver)
        if opts.loss_type == 'cross_entropy':
            receiver = MultiRnnReceiverDeterministic(receivers)
        elif opts.loss_type in ['task_success', 'mixed']:
            receiver = MultiRnnReceiverReinforce(
                receivers, 
                att_indices=att_indices, 
                n_attribuits=opts.n_attributes,
            )
        else:
            raise ValueError()

        loss = MultiDiffLoss(
            att_indices=att_indices, 
            n_values=opts.n_values, 
            loss_type=opts.loss_type, 
            group_size=1,
            cross_entropy_weight=opts.cross_entropy_weight,
        )

        game = core.SenderReceiverRnnReinforce(
            frozen_sender,
            receiver,
            loss,
            sender_entropy_coeff=0.0,
            receiver_entropy_coeff=0.0,
        )
        optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)
        early_stopper = EarlyStopperAccuracy(
            opts.eol_early_stopping_thr, validation=False
        )

        train_val_loaders = [
            (
                "train",
                train_loader,
                loss,
            ),
            (
                "val",
                validation_loader,
                loss,
            ),
        ]
        test_full_loaders = [
            (
                "test",
                test_loader,
                loss,
            ),
            (
                "full",
                full_data_loader,
                loss,
            ),
        ]
        train_val_evaluator = Evaluator(
            train_val_loaders, 
            opts.device, 
            freq=opts.val_eval_freq,
            global_prefix=f'eol_{n_att}/'
        )
        test_full_evaluator = Evaluator(
            test_full_loaders, 
            opts.device, 
            freq=opts.test_eval_freq,
            global_prefix=f'eol_{n_att}/'
        )

        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=validation_loader,
            callbacks=[
                wandb_logger,
                early_stopper, 
                train_val_evaluator,
                test_full_evaluator,
            ],
        )
        trainer.train(n_epochs=opts.eol_n_epochs)

    print("---End--")

    core.close()


if __name__ == "__main__":
    import sys

    main(sys.argv[1:])
