# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import copy
import json
import itertools
import random
from typing import (
    List,
    Union,
    Optional,
)

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import wandb
from jsonargparse import ArgumentParser
from torch.utils.data import (
    Subset,
    ConcatDataset,
)

import egg.core as core
from egg.core import EarlyStopperAccuracy
from egg.core.callbacks import (
    WandbLogger,
    Callback,
)
from egg.zoo.compo_vs_generalization.archs import (
    Freezer,
    NonLinearReceiver,
    PlusOneWrapper,
    Receiver,
    Sender,
    MultiRnnReceiverDeterministic,
    SenderExpansionWrapper,
    MultiRnnReceiverReinforce,
    DiscriminationReceiver,
    MultiDiscriminationRnnReceiverReinforce,
    MultiDiscriminationRnnReceiverDeterministic,
)
from egg.zoo.compo_vs_generalization.data import (
    ScaledDataset,
    enumerate_attribute_value,
    one_hotify,
    select_subset_V1,
    select_subset_V2,
    split_holdout,
    split_train_test,
    split_train_val_test,
    ImageDiscrimiationDataset,
    ImageDiscriminationDatasetLogitWrapper,
    ImageReconstructionDataset,
)
from egg.zoo.compo_vs_generalization.intervention import (
    Evaluator, 
    Metrics,
    ValBestCheckpoint,
    RandomResetter,
    UniformResetter,
    DetailedWandbLogger,
    SimultaneousResetter,
    EpochLogEvaluator,
)
from egg.zoo.compo_vs_generalization.losses import (
    DiscriminationLoss,
    MultiDiscriminationLoss,
    ContinuousDiffLoss,
    MultiContinuousDiffLoss,

)


GLOBAL_SEED = 7


def get_params(params):
    parser = ArgumentParser()
    parser.add_argument("--n_attributes", type=int, default=4, help="")
    parser.add_argument("--n_values", type=int, default=4, help="")
    parser.add_argument("--data_scaler", type=int, default=100)
    parser.add_argument("--stats_freq", type=int, default=0)
    parser.add_argument("--val_eval_freq", type=int, default=1)
    parser.add_argument("--test_eval_freq", type=int, default=0)
    parser.add_argument(
        "--baseline", type=str, choices=["no", "mean", "builtin"], default="mean"
    )
    parser.add_argument(
        "--density_data", type=int, default=0, help="no sampling if equal 0"
    )

    parser.add_argument(
        "--sender_hidden",
        type=int,
        default=50,
        help="Size of the hidden layer of Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_hidden",
        type=int,
        default=50,
        help="Size of the hidden layer of Receiver (default: 10)",
    )

    parser.add_argument(
        "--sender_entropy_coeff",
        type=float,
        default=1e-2,
        help="Entropy regularisation coeff for Sender (default: 1e-2)",
    )

    parser.add_argument("--sender_cell", type=str, default="rnn")
    parser.add_argument("--receiver_cell", type=str, default="rnn")
    parser.add_argument(
        "--sender_emb",
        type=int,
        default=10,
        help="Size of the embeddings of Sender (default: 10)",
    )
    parser.add_argument(
        "--receiver_emb",
        type=int,
        default=10,
        help="Size of the embeddings of Receiver (default: 10)",
    )
    parser.add_argument(
        "--early_stopping_thr",
        type=float,
        default=0.999,
        help="Early stopping threshold on accuracy (defautl: 0.99999)",
    )
    parser.add_argument(
        "--load_val_best",
        type=bool,
        default=True,
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default='div_int',
    )
    parser.add_argument(
        "--wandb_name",
        type=str,
        default='div_int',
    )
    parser.add_argument(
        "--wandb_tags",
        type=List[str],
        default='div_int',
    )
    parser.add_argument(
        "--n_att_n_comb_n_dup",
        type=str,
        default='None',
    )
    parser.add_argument(
        "--validation_ratio",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--test_ratio",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--ckpt_path",
        type=str,
        default='checkpoints',
    )
    parser.add_argument(
        "--loss_type",
        type=str,
        default='cross_entropy',
    )
    parser.add_argument(
        "--variable_len",
        type=bool,
        default=False,
    )
    parser.add_argument(
        "--group_size",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--sender_reset_period",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--receiver_reset_period",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--reset_type",
        type=str,
        default='uniform',
    )
    parser.add_argument(
        "--reset_on_epoch",
        type=bool,
        default=True,
    )
    parser.add_argument(
        "--eol_n_epochs",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--eol_n_attributes",
        type=List[int],
        default=[],
    )
    parser.add_argument(
        "--topsim_max_samples",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--cross_entropy_weight",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
    )
    parser.add_argument(
        "--train_n_samples",
        type=Optional[int],
    )
    parser.add_argument(
        "--eval_n_samples",
        type=Optional[int],
    )
    parser.add_argument(
        "--add_layer_norm",
        type=bool,
        default=False,
    )
    parser.add_argument(
        "--n_workers",
        type=int,
    )
    parser.add_argument(
        "--val_batch_size",
        type=int,
    )
    parser.add_argument(
        "--test_batch_size",
        type=int,
    )
    parser.add_argument(
        "--train_as",
        type=str,
        default=None,
    )

    args = core.init(arg_parser=parser, params=params)
    return args


def main(params):
    import copy

    opts = get_params(params)
    device = opts.device
    assert not opts.variable_len

    opts.n_attributes = ImageReconstructionDataset.n_attributes(opts.dataset_dir)
    opts.n_att_n_comb_n_dup = eval(opts.n_att_n_comb_n_dup)
    if opts.n_att_n_comb_n_dup is None:
        opts.n_att_n_comb_n_dup = [(opts.n_attributes, 1, 1)]

    att_indices = []
    rng = np.random.default_rng(opts.random_seed)
    for n_att, n_comb, n_dup in opts.n_att_n_comb_n_dup:
        combs = list(range(opts.n_attributes))
        rng.shuffle(combs)
        combs = torch.tensor(combs).split(n_att)
        for comb in combs:
            assert len(comb) == n_att
        att_idx = rng.choice(len(combs), replace=False, size=n_comb)
        att_idx = np.tile(att_idx, n_dup)
        att_indices += [combs[i].tolist() for i in att_idx]

    train = ImageReconstructionDataset(
        dataset_dir=opts.dataset_dir,
        split='train' if opts.train_as is None else opts.train_as,
        n_samples_per_epoch=opts.train_n_samples,
        deterministic=False,
        batch_size=opts.batch_size,
        scale=opts.data_scaler,
        load_at_init=True,
    )
    validation = ImageReconstructionDataset(
        dataset_dir=opts.dataset_dir,
        split='valid' if opts.validation_ratio > 0.0 else 'train',
        n_samples_per_epoch=opts.eval_n_samples,
        deterministic=True,
        seed=GLOBAL_SEED,
        batch_size=opts.val_batch_size,
        load_at_init=True,
    )
    test = ImageReconstructionDataset(
        dataset_dir=opts.dataset_dir,
        split='test',
        n_samples_per_epoch=opts.eval_n_samples,
        deterministic=True,
        seed=GLOBAL_SEED,
        batch_size=opts.test_batch_size,
        load_at_init=True,
    )

    train_loader, validation_loader, test_loader = [
        DataLoader(
            dataset=dataset,
            shuffle=shuffle, 
            num_workers=opts.n_workers, 
            # pin_memory=True, 
            batch_size=1,
            # prefetch_factor=0,
            # persistent_workers=True,
            worker_init_fn=init_fn,
            collate_fn=ImageReconstructionDataset.collate_fn,
        )
        for dataset, shuffle, init_fn in zip(
            [train, validation, test],
            [True, False, False],
            [None, None, None],
        )
    ]
    # full = ConcatDataset([train, validation, test])
    
    receivers = []
    for n_att, n_comb, n_dup in opts.n_att_n_comb_n_dup:
        for _ in range(n_comb):
            for _ in range(n_dup):
                if opts.receiver_cell in ["lstm", "rnn", "gru"]:
                    receiver = Receiver(
                        n_hidden=opts.receiver_hidden,
                        n_outputs=test.visual_dim,
                    )
                    receiver = core.RnnReceiverDeterministic(
                        receiver,
                        opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
                        opts.receiver_emb,
                        opts.receiver_hidden,
                        cell=opts.receiver_cell,
                    )
                else:
                    raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")
                receivers.append(receiver)
    if opts.loss_type == 'mse':
        receiver = MultiRnnReceiverDeterministic(receivers)
    else:
        raise ValueError()

    if opts.sender_cell in ["lstm", "rnn", "gru"]:
        sender = Sender(n_inputs=test.visual_dim, n_hidden=opts.sender_hidden)
        # sender = SenderExpansionWrapper(sender, len(att_indices))
        sender = core.RnnSenderReinforce(
            agent=sender,
            vocab_size=opts.vocab_size,
            embed_dim=opts.sender_emb,
            hidden_size=opts.sender_hidden,
            max_len=opts.max_len,
            cell=opts.sender_cell,
            layer_norm=opts.add_layer_norm,
        )
    else:
        raise ValueError(f"Unknown sender cell, {opts.sender_cell}")
    
    if not opts.variable_len:
        sender = PlusOneWrapper(sender)
    loss = MultiContinuousDiffLoss(
        att_indices=att_indices, 
        loss_type=opts.loss_type, 
        group_size=opts.group_size,
    )

    baseline = {
        "no": core.baselines.NoBaseline,
        "mean": core.baselines.MeanBaseline,
        "builtin": core.baselines.BuiltInBaseline,
    }[opts.baseline]

    game = core.SenderReceiverRnnReinforce(
        sender,
        receiver,
        loss,
        sender_entropy_coeff=opts.sender_entropy_coeff,
        receiver_entropy_coeff=0.0,
        length_cost=0.0,
        baseline_type=baseline,
    )
    optimizer = torch.optim.Adam(game.parameters(), lr=opts.lr)

    metrics_evaluators = [
        Metrics(
            ImageDiscriminationDatasetLogitWrapper(
                ImageReconstructionDataset(
                    dataset_dir=opts.dataset_dir,
                    split=split,
                    deterministic=True,
                    seed=GLOBAL_SEED,
                    n_samples_per_epoch=opts.topsim_max_samples,
                    batch_size=1,
                )
            ),
            opts.device,
            None,
            None,
            opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
            freq=opts.stats_freq,
            name=name,
            max_len=opts.max_len, # Not counting EOS.
            topsim_max_samples=opts.topsim_max_samples,
            seed=opts.random_seed,
            batch_size=opts.batch_size,
        ) for split, name in zip(
            ['test'] + (['valid'] if opts.validation_ratio > 0.0 else []),
            ['test'] + (['valid'] if opts.validation_ratio > 0.0 else [])
        )
    ]

    loaders = [
        (
            "test",
            test_loader,
            MultiContinuousDiffLoss(
                att_indices=att_indices, 
                loss_type=opts.loss_type,
                group_size=opts.group_size,
            ),
        ),
    ]
    if opts.validation_ratio > 0.0:
        val_test_loaders = loaders + [(
            (
                "val",
                validation_loader,
                MultiContinuousDiffLoss(
                    att_indices=att_indices, 
                    loss_type=opts.loss_type,
                    group_size=opts.group_size,
                ),
            )
        )]

    epoch_evaluator = EpochLogEvaluator(
        val_test_loaders, 
        opts.device, 
    )
    evaluator = Evaluator(
        loaders, 
        opts.device, 
        freq=opts.test_eval_freq
    )
    
    # early_stopper = EarlyStopperAccuracy(opts.early_stopping_thr, validation=True)

    wandb_logger = DetailedWandbLogger(
                opts=opts, 
                project=opts.wandb_project, 
                name=opts.wandb_name, 
                tags=opts.wandb_tags,
    )
    
    if opts.load_val_best: 
        val_best_saver = ValBestCheckpoint(
            checkpoint_path=f'{opts.ckpt_path}/{wandb.run.project}/{wandb.run.id}',
            checkpoint_freq=opts.val_eval_freq,
            prefix='',
            metric_name='receiver_loss',
            higher_is_better=False,
        )
    else:
        val_best_saver = Callback()
    
    if opts.reset_type == 'random':
        resetter = RandomResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    elif opts.reset_type == 'uniform':
        resetter = UniformResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    elif opts.reset_type == 'simultaneous':
        resetter = SimultaneousResetter(
            receiver_period=opts.receiver_reset_period,
            sender_period=opts.sender_reset_period,
            on_epoch=opts.reset_on_epoch,
        )
    else:
        raise ValueError()
        

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=validation_loader,
        callbacks=[
            val_best_saver, # This needs to come at first.
            core.ConsoleLogger(as_json=True, print_train_loss=False),
            wandb_logger,
            # early_stopper,
            resetter,
            epoch_evaluator,
            evaluator,
            *metrics_evaluators,
        ],
    )
    trainer.train(n_epochs=opts.n_epochs)


    def _set_seed(seed):
        import random
        import numpy as np

        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    core.get_opts().preemptable = False
    core.get_opts().checkpoint_path = None

    # freeze Sender and probe how fast a simple Receiver will learn the thing
    frozen_sender = Freezer(copy.deepcopy(sender))

    prev_att_indices = att_indices
    prev_n_atts = []
    for prev_att_idx in prev_att_indices:
        if len(prev_att_idx) not in prev_n_atts:
            prev_n_atts.append(len(prev_att_idx))

    n_atts = list(set(prev_n_atts) | set(opts.eol_n_attributes))

    for n_att in n_atts:
        del(train_loader)
        del(train)

        wandb.define_metric(f'eol_{n_att}_step')
        wandb.define_metric(f'eol_{n_att}/*', f'eol_{n_att}_step')
        wandb_logger.global_prefix = f'eol_{n_att}'
        _set_seed(opts.random_seed)

        att_indices = []
        rng = np.random.default_rng(opts.random_seed)
        att_indices = None
        for prev_att_idx in prev_att_indices:
            if len(prev_att_idx) == n_att:
                att_indices = [prev_att_idx]
        if att_indices is None:
            combs = list(range(opts.n_attributes))
            rng.shuffle(combs)
            combs = torch.tensor(combs).split(n_att)
            assert len(combs[0]) == n_att
            att_indices = [combs[0].tolist()]

        receivers = []
        if opts.receiver_cell in ["lstm", "rnn", "gru"]:
            receiver = Receiver(
                n_hidden=opts.receiver_hidden,
                n_outputs=test.visual_dim,
            )
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size if opts.variable_len else opts.vocab_size + 1,
                opts.receiver_emb,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
            )
        else:
            raise ValueError(f"Unknown receiver cell, {opts.receiver_cell}")
        receivers.append(receiver)
        if opts.loss_type == 'mse':
            receiver = MultiRnnReceiverDeterministic(receivers)
        else:
            raise ValueError()

        loss = MultiContinuousDiffLoss(
            att_indices=att_indices, 
            loss_type=opts.loss_type, 
            group_size=1,
        )

        game = core.SenderReceiverRnnReinforce(
            frozen_sender,
            receiver,
            loss,
            sender_entropy_coeff=0.0,
            receiver_entropy_coeff=0.0,
        )
        optimizer = torch.optim.Adam(receiver.parameters(), lr=opts.lr)

        # early_stopper = EarlyStopperAccuracy(
        #     opts.early_stopping_thr, validation=False
        # )

        train = ImageReconstructionDataset(
            dataset_dir=opts.dataset_dir,
            split='train' if opts.train_as is None else opts.train_as,
            n_samples_per_epoch=opts.train_n_samples,
            deterministic=False,
            batch_size=opts.batch_size,
            scale=opts.data_scaler,
            load_at_init=True,
        )
        validation = ImageReconstructionDataset(
            dataset_dir=opts.dataset_dir,
            split='valid' if opts.validation_ratio > 0.0 else 'train',
            n_samples_per_epoch=opts.eval_n_samples,
            deterministic=True,
            seed=GLOBAL_SEED,
            batch_size=opts.val_batch_size,
            load_at_init=True,
        )
        test = ImageReconstructionDataset(
            dataset_dir=opts.dataset_dir,
            split='test',
            n_samples_per_epoch=opts.eval_n_samples,
            deterministic=True,
            seed=GLOBAL_SEED,
            batch_size=opts.test_batch_size,
            load_at_init=True,
        )

        train_loader, validation_loader, test_loader = [
            DataLoader(
                dataset=dataset,
                shuffle=shuffle, 
                num_workers=opts.n_workers, 
                # pin_memory=True, 
                batch_size=1,
                # prefetch_factor=2,
                # persistent_workers=True,
                worker_init_fn=init_fn,
                collate_fn=ImageReconstructionDataset.collate_fn,
            )
            for dataset, shuffle, init_fn in zip(
                [train, validation, test],
                [True, False, False],
                [None, None, None],
            )
        ]

        loaders = [
            (
                "test",
                test_loader,
                MultiContinuousDiffLoss(
                    att_indices=att_indices, 
                    loss_type=opts.loss_type,
                    group_size=opts.group_size,
                ),
            ),
        ]
        if opts.validation_ratio > 0.0:
            val_test_loaders = loaders + [(
                (
                    "val",
                    validation_loader,
                    MultiContinuousDiffLoss(
                        att_indices=att_indices, 
                        loss_type=opts.loss_type,
                        group_size=opts.group_size,
                    ),
                )
            )]

        epoch_evaluator = EpochLogEvaluator(
            val_test_loaders, 
            opts.device, 
            global_prefix=f'eol_{n_att}/'
        )
        evaluator = Evaluator(
            loaders, 
            opts.device, 
            freq=opts.test_eval_freq,
            global_prefix=f'eol_{n_att}/'
        )

        trainer = core.Trainer(
            game=game,
            optimizer=optimizer,
            train_data=train_loader,
            validation_data=validation_loader,
            callbacks=[
                wandb_logger,
                # early_stopper, 
                epoch_evaluator,
                evaluator,
            ],
        )
        trainer.train(n_epochs=opts.eol_n_epochs)

    print("---End--")

    core.close()


if __name__ == "__main__":
    import sys

    main(sys.argv[1:])
