# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import json
import copy
import pathlib
import sys
import os
import argparse
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Union
import time
from collections import defaultdict

import torch
import wandb
import numpy as np
from torch.utils.data import (
    DataLoader,
    Subset,
)
from scipy import spatial
from scipy.stats import spearmanr
from tqdm import tqdm

import egg.core as core
from egg.core.callbacks import (
    CheckpointSaver,
    Checkpoint,
    Callback,
    WandbLogger,
)
from egg.core.batch import Batch
from egg.zoo.language_bottleneck.intervention import entropy, mutual_info
from egg.core.interaction import Interaction

try:
    import editdistance  # package to install https://pypi.org/project/editdistance/0.3.1/
except ImportError:
    print(
        "Please install editdistance package: `pip install editdistance`. "
        "It is used for calculating topographic similarity."
    )
from egg.zoo.compo_vs_generalization.compo_measures.metrics.conflict_count import ConflictCount
from egg.zoo.compo_vs_generalization.compo_measures.metrics.context_independence import ContextIndependence
from egg.zoo.compo_vs_generalization.compo_measures.metrics.tre import TreeReconstructionError
from egg.zoo.compo_vs_generalization.compo_measures.metrics.tre import LinearComposition
    


def ask_sender(n_attributes, n_values, dataset, sender, device, batch_size=1):
    attributes = []
    strings = []
    meanings = []
    loader = DataLoader(
        dataset=dataset, 
        shuffle=False, 
        batch_size=batch_size,
        drop_last=False,
    )
    prev_state = sender.training
    sender.eval()

    for item in tqdm(loader, desc='Asking Sender'):
        if isinstance(item, tuple) or isinstance(item, list):
            sender_input, meaning = item
        else:
            sender_input = item
            meaning = item

        if n_attributes is not None:
            attribute = meaning.view(-1, n_attributes, n_values).argmax(dim=-1)
        else:
            attribute = torch.zeros_like(meaning)

        attributes.append(attribute)
        meaning = torch.flatten(meaning, 1, -1)
        meanings.append(meaning.to(device))

        with torch.no_grad():
            string, *other = sender(sender_input.to(device))
        strings.append(string)

    attributes = torch.cat(attributes, dim=0)
    strings = torch.cat(strings, dim=0)
    meanings = torch.cat(meanings, dim=0)

    sender.train(prev_state)

    return attributes, strings, meanings


def ask_receiver(messages, dataset, receiver, device, batch_size, collate_fn):
    messages = torch.split(messages, batch_size, dim=0)

    preve_state = receiver.training
    receiver.eval()

    loader = DataLoader(
        dataset=dataset, 
        shuffle=False, 
        batch_size=batch_size,
        drop_last=False,
        collate_fn=collate_fn,
    )
    assert len(messages) == len(loader)

    outputs, log_probs = [], []
    for rec_input, message in tqdm(zip(loader, messages), desc='Asking receiver'):
        if isinstance(rec_input, tuple) or isinstance(rec_input, list):
            rest = []
            if len(rec_input) >= 2:
                rest.append(rec_input[1])
            else:
                rest.append(rec_input[0])

            if len(rec_input) >= 3:
                rest += rec_input[2:]

        else:
            rest = [rec_input]

        rest = Batch(*rest)
        rest = rest.to(device)
            
        with torch.no_grad():
            output, log_prob, *_ = receiver(message, *rest)
        outputs.append(output)
        log_probs.append(log_prob)

    outputs = torch.cat(outputs, dim=0)
    log_probs = torch.cat(log_probs, dim=0)
    # messages = torch.cat(messages, dim=0)

    receiver.train(preve_state)

    return outputs, log_probs


def information_gap_representation(meanings, representations):
    gaps = torch.zeros(representations.size(1))
    non_constant_positions = 0.0

    for j in range(representations.size(1)):
        symbol_mi = []
        h_j = None
        for i in range(meanings.size(1)):
            x, y = meanings[:, i], representations[:, j]
            info = mutual_info(x, y)
            symbol_mi.append(info)

            if h_j is None:
                h_j = entropy(y)

        symbol_mi.sort(reverse=True)

        if h_j > 0.0:
            gaps[j] = (symbol_mi[0] - symbol_mi[1]) / h_j
            non_constant_positions += 1

    score = gaps.sum() / non_constant_positions
    return score.item()


def information_gap_position(
    n_attributes=None, 
    n_values=None, 
    dataset=None, 
    sender=None, 
    device=None, 
    attributes=None, 
    strings=None
):
    if attributes is None or strings is None:
        attributes, strings, _meanings = ask_sender(
            n_attributes, n_values, dataset, sender, device
        )
    return information_gap_representation(attributes, strings)


def histogram(strings, vocab_size):
    batch_size = strings.size(0)

    histogram = torch.zeros(batch_size, vocab_size, device=strings.device)

    for v in range(vocab_size):
        histogram[:, v] = strings.eq(v).sum(dim=-1)

    return histogram


def information_gap_vocab(
    vocab_size,
    n_attributes=None, 
    n_values=None, 
    dataset=None, 
    sender=None, 
    device=None, 
    attributes=None, 
    strings=None
):
    if attributes is None or strings is None:
        attributes, strings, _meanings = ask_sender(
            n_attributes, n_values, dataset, sender, device
        )

    histograms = histogram(strings, vocab_size)
    return information_gap_representation(attributes, histograms[:, 1:])


def edit_dist(_list):
    distances = []
    count = 0
    for i, el1 in enumerate(_list[:-1]):
        for j, el2 in enumerate(_list[i + 1 :]):
            count += 1
            # Normalized edit distance (same in our case as length is fixed)
            distances.append(editdistance.eval(el1, el2) / len(el1))
    return distances


def cosine_dist(_list):
    distances = []
    for i, el1 in enumerate(_list[:-1]):
        for j, el2 in enumerate(_list[i + 1 :]):
            distances.append(spatial.distance.cosine(el1, el2))
    return distances


def topographic_similarity(
    n_attributes=None, 
    n_values=None, 
    dataset=None, 
    sender=None, 
    device=None, 
    attributes=None, 
    strings=None,
    meanings=None,
):
    if strings is None or meanings is None:
        attributes, strings, meanings = ask_sender(
            n_attributes, n_values, dataset, sender, device
        )
    list_string = []
    for s in strings:
        list_string.append([x.item() for x in s])
    distance_messages = edit_dist(list_string)
    distance_inputs = cosine_dist(meanings.cpu().numpy())

    corr = spearmanr(distance_messages, distance_inputs).correlation
    return corr


class Metrics(core.Callback):
    def __init__(
        self, 
        dataset, 
        device, 
        n_attributes, 
        n_values, 
        vocab_size, 
        name='', 
        freq=1,
        max_len = None,
        topsim_max_samples = None,
        seed = None,
        batch_size=1,
        receiver = None,
        receiver_dataset = None,
        save_protocol: bool = False,
        final_topsim_max_samples = None,
        max_samples: Optional[int] = None,
        exclude_eos: bool = False,
        receiver_batch_size = None,
        receiver_collate_fn = None,
        abstractness_dataset = None,
        abstractness_max_samples = None,
    ):
        self.dataset = dataset
        self.device = device
        self.n_attributes = n_attributes
        self.n_values = n_values
        self.epoch = 0
        self.vocab_size = vocab_size
        self.freq = freq
        self.name = name
        self.max_len = max_len
        self.topsim_max_samples = topsim_max_samples
        self.seed = seed
        self.batch_size = batch_size
        self.receiver = receiver
        self.receiver_dataset = receiver_dataset
        self.save_protocol = save_protocol
        self.final_topsim_max_samples = final_topsim_max_samples
        self.max_samples = max_samples
        self.exclude_eos = exclude_eos
        self.abstractness_dataset = abstractness_dataset

        if receiver_batch_size is not None:
            self.receiver_batch_size = receiver_batch_size
        else:
            self.receiver_batch_size = self.batch_size
        self.recever_collate_fn = receiver_collate_fn
            

        if self.max_samples is not None:
            assert len(self.receiver_dataset) == len(self.dataset)
            if self.abstractness_dataset is not None:
                assert len(abstractness_dataset) == len(self.receiver_dataset)
            n_samples = min(self.max_samples, len(self.dataset))
            rng = np.random.default_rng(self.seed)
            indices = rng.choice(len(self.dataset), n_samples, replace=False)
            self.dataset = Subset(self.dataset, indices=indices)
            self.receiver_dataset = Subset(self.receiver_dataset, indices=indices)
            if self.abstractness_dataset is not None:
                if abstractness_max_samples is not None:
                    n_samples = abstractness_max_samples / len(self.abstractness_dataset.att_indices)
                    n_samples = int(np.ceil(n_samples))
                indices = rng.choice(
                    len(self.abstractness_dataset), 
                    n_samples, 
                    replace=False
                )
                self.abstractness_dataset = Subset(self.abstractness_dataset, indices=indices)

        if self.max_len == self.n_attributes:
            self.conflict_count = ConflictCount(self.max_len)
        else:
            self.conflict_count = None
        
        if self.n_attributes is not None:
            self.context_independence = ContextIndependence(num_concepts=n_attributes * n_values)
        else:
            self.context_independence = None

        if self.n_attributes == 2:
            self.tre = TreeReconstructionError(
                num_concepts=n_attributes * n_values, 
                message_length=self.max_len, 
                composition_fn=LinearComposition,
            )
        else:
            self.tre = None

    def dump_stats(self, prefix: str='', topsim_max_samples: int = None):
        start_time = time.time() 
        game = self.trainer.game
        prev_state = game.training
        game.eval()

        start_time_each = time.time() 
        attributes, strings, meanings = ask_sender(
            self.n_attributes, 
            self.n_values, 
            self.dataset, 
            game.sender, 
            self.device,
            batch_size=self.batch_size,
        )
        if self.exclude_eos:
            msg_only_strings = strings[:, :-1]
            eos = strings[:, -1:]
        else:
            msg_only_strings = strings
            eos = strings[:, :-0]

        print(f'AskSdr. {self.name} proc. time: {time.time() - start_time_each:.1}')

        protocol = {}
        # Remove EOS in messages.
        for attribute, string in zip(attributes, strings[:, :self.max_len]):
            attribute = attribute.tolist() 
            # (n_att)
            attribute = tuple(f'{att}_{val}' for att, val in enumerate(attribute))
            string = string.tolist()
            string = [chr(s) for s in string]
            string = ''.join(string)
            protocol[attribute] = string

        output = {'epoch': self.epoch}


        if self.receiver is not None:
            start_time_each = time.time() 
            rec_outputs, log_probs = ask_receiver(
                messages=strings,
                dataset=self.receiver_dataset,
                receiver=self.receiver,
                device=self.device,
                batch_size=self.receiver_batch_size,
                collate_fn=self.recever_collate_fn,
            )

            permuted_msgs = []
            for message in msg_only_strings:
                idx = torch.randperm(len(message))
                permuted_msgs.append(message[idx])
            permuted_msgs = torch.stack(permuted_msgs, dim=0)
            permuted_msgs = torch.cat([permuted_msgs, eos], dim=1)

            perm_rec_outputs, perm_log_probs = ask_receiver(
                messages=permuted_msgs,
                dataset=self.receiver_dataset,
                receiver=self.receiver,
                device=self.device,
                batch_size=self.receiver_batch_size,
                collate_fn=self.recever_collate_fn,
            )
            scrambling_resistance = 0.0
            n_samples = 0
            with torch.no_grad():
                for p_log_probs, p_perm_log_probs in zip(
                    log_probs.split(self.batch_size),
                    perm_log_probs.split(self.batch_size),
                ):
                    p_min_log_probs = torch.min(torch.stack([p_log_probs, p_perm_log_probs], dim=1), dim=1)[0]
                    scrambling_resistance += torch.exp(p_min_log_probs - p_log_probs).sum()
                    n_samples += p_min_log_probs.size(0)
                output['scrambling_resistance'] = (scrambling_resistance / n_samples).item()
            print(f'ScrRes. {self.name} proc. time: {time.time() - start_time_each:.1}')


        if self.n_attributes is not None:
            start_time_each = time.time() 
            positional_disent = information_gap_position(
                attributes=attributes,
                strings=msg_only_strings,
            )
            output['positional_disent'] = positional_disent
            print(f'PosDis. {self.name} proc. time: {time.time() - start_time_each:.1}')

        if self.n_attributes is not None:
            start_time_each = time.time() 
            bos_disent = information_gap_vocab(
                vocab_size=self.vocab_size,
                attributes=attributes,
                strings=msg_only_strings,
            )
            output['bag_of_symbol_disent'] = bos_disent
            print(f'BosDis. {self.name} proc. time: {time.time() - start_time_each:.1}')

        start_time_each = time.time() 
        if topsim_max_samples is not None:
            topsim_max_samples = min(topsim_max_samples, len(attributes))
            rng = np.random.default_rng(self.seed)
            idx = rng.choice(len(attributes), topsim_max_samples, replace=False)
        else:
            idx = np.arange(len(attributes))
        topo_sim = topographic_similarity(
            attributes=attributes[idx],
            strings=msg_only_strings[idx],
            meanings=meanings[idx],
        )
        output['topographic_sim'] = topo_sim
        print(f'TopSim. {self.name} proc. time: {time.time() - start_time_each:.1}')

        if self.conflict_count is not None:
            start_time_each = time.time() 
            conf_count = self.conflict_count.measure(protocol)
            output['conf_count'] = conf_count
            print(f'CnfCnt. {self.name} proc. time: {time.time() - start_time_each:.1}')

        if self.context_independence is not None:
            start_time_each = time.time() 
            context_independence = self.context_independence.measure(protocol)
            output['context_independence'] = context_independence
            print(f'CtxIdp. {self.name} proc. time: {time.time() - start_time_each:.1}')

        if self.tre is not None:
            start_time_each = time.time() 
            tre = self.tre.measure(protocol)
            output['tre'] = tre
            print(f'TRE. {self.name} proc. time: {time.time() - start_time_each:.1}')

        if self.abstractness_dataset is not None:
            start_time_each = time.time() 

            with torch.no_grad():
                abstractness_list = []
                for item in tqdm(self.abstractness_dataset, desc='Abstractness...'):
                    prev_state_sender = game.sender.training
                    prev_state_receiver = self.receiver.training
                    game.sender.eval()
                    self.receiver.eval()

                    # (T, 2, H)
                    rec_input = item[2].to(self.device)
                    # (T, H), (T, H)
                    targets, dists = rec_input[:, 0, :], rec_input[:, 1, :]
                    messages = game.sender(targets)[0]
                    for msg, target, dist in zip(messages, targets, dists):
                        # (1, H), (1, 1, H)
                        msg, target, dist = msg[None, :], target[None, :], dist[None, None, :]
                        label = torch.zeros(1, dtype=int).to(self.device)
                        aux = {'sender_input': target}
                        # (1, H), (1, H) -> (1, 2)
                        logits, _ = self.receiver(msg, label.unsqueeze(-1), dist, aux)
                        # (1, 2)
                        probs = torch.softmax(logits, dim=-1)
                        dist_prob = probs[0, -1]
                        abstractness = 2 * dist_prob
                        abstractness_list.append(abstractness)

                    game.sender.train(prev_state_sender)
                    self.receiver.train(prev_state_receiver)
                abstractness = torch.stack(abstractness_list)
                abstractness = abstractness.mean().item()

            output['abstractness'] = abstractness
            print(f'Abstns. {self.name} proc. time: {time.time() - start_time_each:.1}')

        output_json = json.dumps(output)
        print(output_json, flush=True)

        for key in list(output.keys()):
            name = f'{self.name}_' if len(self.name) > 0 else ''
            output[f'{prefix}{name}{key}'] = output.pop(key)
        output[f'{prefix}epoch'] = self.epoch
        wandb.log(output)

        game.train(prev_state)
        print(f'Compo. {self.name} proc. time: {time.time() - start_time:.1}')

    def on_train_end(self):
        self.dump_stats('final/', topsim_max_samples=self.final_topsim_max_samples)
        
        if self.save_protocol:
            self._save_protocol()
    
    def _save_protocol(self):
        start_time_each = time.time() 

        cols = ['message', 'attribute_values']
        game = self.trainer.game
        prev_state = game.training
        game.eval()

        attributes, strings, meanings = ask_sender(
            self.n_attributes, 
            self.n_values, 
            self.dataset, 
            game.sender, 
            self.device,
            batch_size=self.batch_size,
        )

        pairs = [] 
        for string, attribute in zip(strings, attributes):
            string = ' '.join([str(s) for s in string.tolist()])
            attribute = ' '.join([str(att.item()) for att in attribute])
                
            pairs.append([string, attribute])

        table = wandb.Table(columns=cols, data=pairs)
        wandb.log({f'protocol_{self.name}': table})

        print(f'SavMsg. {self.name} proc. time: {time.time() - start_time_each:.1}')
        game.train(prev_state)

    def on_epoch_end(self, *stuff):
        self.epoch += 1

        if self.freq <= 0 or self.epoch % self.freq != 0:
            return

        self.dump_stats(topsim_max_samples=self.topsim_max_samples)

    def on_epoch_begin(self, n_epoch):
        if self.freq != 0 and self.epoch == 0:
            self.dump_stats(topsim_max_samples=self.topsim_max_samples)


class Evaluator(core.Callback):
    def __init__(
        self, 
        loaders_metrics, 
        device, 
        freq=1, 
        global_prefix: str = '', 
    ):
        self.loaders_metrics = loaders_metrics
        self.device = device
        self.epoch = 0
        self.freq = freq
        self.global_prefix = global_prefix
        self.final_prefix = 'final/'
        self.results = defaultdict(list)
    
    def on_train_begin(self, trainer_instance: "Trainer"):
        super().on_train_begin(trainer_instance)

        for loader_name, _,  _ in self.loaders_metrics:
            wandb.define_metric(f'{self.global_prefix}/{loader_name}_epoch')
            wandb.define_metric(
                f'{self.global_prefix}{loader_name}*',
                f'{self.global_prefix}{loader_name}_epoch'
            )
            wandb.define_metric(
                f'{self.global_prefix}{self.final_prefix}{loader_name}*',
                f'{self.global_prefix}{loader_name}_epoch'
            )

    @torch.no_grad()
    def evaluate(self, prefix: str=''):
        game = self.trainer.game
        prev_state = game.training
        game.eval()
        old_loss = game.loss

        for loader_name, loader, metric in self.loaders_metrics:
            n_batches = 0
            game.loss = metric
            results = defaultdict(list)

            for batch in tqdm(loader, desc=f'Evaluating {loader_name}'):
                n_batches += 1
                if not isinstance(batch, Batch):
                    batch = Batch(*batch)
                batch = batch.to(self.device)
                with torch.no_grad():
                    _, interaction = game(*batch)

                for key, value in interaction.aux.items():
                    results[f'{loader_name}_{key}'].append(value)

            for key in list(results.keys()):
                self.results[key] = torch.cat(results[key]).mean().item()

        self.results[f"epoch"] = self.epoch
        output_json = json.dumps(self.results)
        print(output_json, flush=True)

        results = copy.deepcopy(self.results)
        for key in list(results.keys()):
            results[f'{self.global_prefix}{prefix}{key}'] = results.pop(key)
        wandb.log(results)

        game.loss = old_loss
        game.train(prev_state)

    def on_train_end(self):
        self.evaluate(self.final_prefix)

    def on_epoch_end(self, *stuff):
        self.epoch += 1

        if self.freq <= 0 or self.epoch % self.freq != 0:
            return
        self.evaluate()

    def on_epoch_begin(self, n_epoch):
        if self.freq != 0 and self.epoch == 0:
            self.evaluate()


class EpochLogEvaluator(core.Callback):
    def __init__(
        self, 
        loaders_metrics, 
        device, 
        freq=1, 
        global_prefix: str = '', 
    ):
        self.loaders_metrics = loaders_metrics
        self.device = device
        self.epoch = 0
        self.freq = freq
        self.global_prefix = global_prefix
        self.final_prefix = 'final/'
        self.results = defaultdict(list)
    
    def on_train_begin(self, trainer_instance: "Trainer"):
        super().on_train_begin(trainer_instance)

        for loader_name, _,  _ in self.loaders_metrics:
            wandb.define_metric(f'{self.global_prefix}/{loader_name}_epoch')
            wandb.define_metric(
                f'{self.global_prefix}{loader_name}*',
                f'{self.global_prefix}{loader_name}_epoch'
            )
            wandb.define_metric(
                f'{self.global_prefix}{self.final_prefix}{loader_name}*',
                f'{self.global_prefix}{loader_name}_epoch'
            )
    
    def evaluate(self, prefix: str=''):
        game = self.trainer.game
        prev_state = game.training
        game.eval()
        old_loss = game.loss

        for loader_name, loader, metric in self.loaders_metrics:
            n_batches = 0
            game.loss = metric
            results = defaultdict(list)

            for batch in tqdm(loader, desc=f'Evaluating {loader_name}'):
                n_batches += 1
                if not isinstance(batch, Batch):
                    batch = Batch(*batch)
                batch = batch.to(self.device)
                with torch.no_grad():
                    _, interaction = game(*batch)

                for key, value in interaction.aux.items():
                    results[f'{loader_name}_{key}'].append(value)

            for key in list(results.keys()):
                self.results[key] = torch.cat(results[key]).mean().item()

        self.results[f"epoch"] = self.epoch
        output_json = json.dumps(self.results)
        print(output_json, flush=True)

        results = copy.deepcopy(self.results)
        for key in list(results.keys()):
            results[f'{self.global_prefix}{prefix}{key}'] = results.pop(key)
        wandb.log(results)

        game.loss = old_loss
        game.train(prev_state)

    def on_train_end(self):
        self.evaluate(self.final_prefix)

    def on_epoch_begin(self, n_epoch):
        if self.epoch == 0:
            self.evaluate()
    
    def on_epoch_end(self, loss: float, logs: Interaction, epoch: int):
        self.epoch += 1
        self._log_epoch_results(
            loss,
            logs,
            epoch,
            'train',
        )

        # if self.freq <= 0 or self.epoch % self.freq != 0:
        #     return
        # self.evaluate()

    def on_validation_end(self, loss: float, logs: Interaction, epoch: int):
        self._log_epoch_results(
            loss,
            logs,
            epoch,
            'val',
        )

    @torch.no_grad()
    def _log_epoch_results(self, loss: float, logs: Interaction, epoch: int, name: str = ''):
        results = {}
        for key, value in logs.aux.items():
            value = value.mean().item()
            results[f'{self.global_prefix}{name}_{key}'] = value

        results[f'{self.global_prefix}epoch'] = self.epoch
        results[f'{self.global_prefix}{name}_loss'] = loss

        output_json = json.dumps(results)
        print(output_json, flush=True)

        wandb.log(results)


class BestCheckpoint(core.early_stopping.EarlyStopper):
    def __init__(
        self, field_name: str = "acc", validation: bool = True
    ) -> None:
        """
        :param threshold: early stopping threshold for the validation set accuracy
            (assumes that the loss function returns the accuracy under name `field_name`)
        :param field_name: the name of the metric return by loss function which should be evaluated against stopping
            criterion (default: "acc")
        :param validation: whether the statistics on the validation (or training, if False) data should be checked
        """
        super(BestCheckpoint, self).__init__(validation)
        self.field_name = field_name
        self.best_metric = float('-inf')
        self.metrics  = []

    def should_stop(self) -> bool:
        if self.validation:
            assert (
                self.validation_stats
            ), "Validation data must be provided for early stooping to work"
            loss, last_epoch_interactions = self.validation_stats[-1]
        else:
            assert (
                self.train_stats
            ), "Training data must be provided for early stooping to work"
            loss, last_epoch_interactions = self.train_stats[-1]

        metric_mean = last_epoch_interactions.aux[self.field_name].mean()
        stats = self.validation_stats if self.validation else self.train_stats

        if len(stats) > len(self.metrics):
            self.metrics.append(metric_mean)
        if len(self.metrics) < 2:
            return False
        else:
            print(self.metrics[-2], self.metrics[-1])
            return self.metrics[-2] >= self.metrics[-1]


class ValBestCheckpoint(Callback):
    def __init__(
        self,
        checkpoint_path: Union[str, pathlib.Path],
        checkpoint_freq: int = 1,
        metric_name: str = 'acc',
        prefix: str = "",
        higher_is_better: bool = True,
    ):
        """Saves a checkpoint file for training.
        :param checkpoint_path:  path to checkpoint directory, will be created if not present
        :param checkpoint_freq:  Number of epochs for checkpoint saving
        :param prefix: Name of checkpoint file, will be {prefix}{current_epoch}.tar
        :param max_checkpoints: Max number of concurrent checkpoint files in the directory.
        """
        super().__init__()
        self.prefix = prefix
        self.checkpoint_freq = checkpoint_freq
        self.checkpoint_path = pathlib.Path(checkpoint_path)
        self.best_ckpt_path = None
        self.metric_name = metric_name
        self.best_epoch = None
        self.higher_is_better = higher_is_better
        if self.higher_is_better:
            self.best_metric = float('-inf')
        else:
            self.best_metric = float('inf')
    
    def on_validation_end(self, loss: float, logs: Interaction, epoch: int):
        if epoch == 1 or epoch % self.checkpoint_freq == 0:
            metric_mean = logs.aux[self.metric_name].mean().item()

            should_update = False
            if self.best_metric < metric_mean and self.higher_is_better:
                should_update = True
            if self.best_metric > metric_mean and not self.higher_is_better:
                should_update = True
            
            if should_update:
                print(f'Saving checkpoint for epoch {epoch}...')
                self.best_metric = metric_mean
                filename = f"{self.prefix}_val_best" if self.prefix else f'val_{epoch}'

                self.checkpoint_path.mkdir(exist_ok=True, parents=True)
                path = self.checkpoint_path / f'{filename}.pt'
                torch.save(self.trainer.game.state_dict(), path)
                self.best_ckpt_path = path
                self.best_epoch = epoch

    def on_train_end(self):
        print(f'Loading best checkpoint (epoch: {self.best_epoch}).')
        state_dict = torch.load(self.best_ckpt_path)
        self.trainer.game.load_state_dict(state_dict)


class RandomResetter(core.Callback):
    def __init__(
        self, 
        receiver_period: int, 
        sender_period: int,
        on_epoch: bool,
    ):
        self.sender_period = sender_period
        self.receiver_period = receiver_period
        self.on_epoch = on_epoch
        self.step = 0

    def on_epoch_begin(self, epoch: int):
        if not self.on_epoch:
            return

        if self.sender_period > 0:
            if torch.bernoulli(torch.tensor(1 / self.sender_period)).bool():
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. epoch: {epoch}.')

        if self.receiver_period > 0:
            for i, receiver in enumerate(self.trainer.game.receiver.receivers):
                if torch.bernoulli(torch.tensor(1 / self.receiver_period)).bool():
                    receiver.reset_parameters()
                    print(f'Resetting receiver_{i}\'s paramerters. epoch: {epoch}.')
    
    def on_batch_end(self, logs: Interaction, loss: float, batch_id: int, is_training: bool = True):
        if not is_training or self.on_epoch:
            return
        self.step += 1
        
        if self.sender_period > 0:
            if torch.bernoulli(torch.tensor(1 / self.sender_period)).bool():
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. step: {self.step}.')

        if self.receiver_period > 0:
            for i, receiver in enumerate(self.trainer.game.receiver.receivers):
                if torch.bernoulli(torch.tensor(1 / self.receiver_period)).bool():
                    receiver.reset_parameters()
                    print(f'Resetting receiver_{i}\'s paramerters. step: {self.step}.')


class UniformResetter(core.Callback):
    def __init__(
        self, 
        receiver_period: int,
        sender_period: int, 
        on_epoch: bool,
    ):
        self.sender_period = sender_period
        self.receiver_period = receiver_period
        self.on_epoch = on_epoch
        self.step = 0
        self.next_idx = 0
    
    def on_epoch_begin(self, epoch: int):
        if not self.on_epoch:
            return

        if self.sender_period > 0:
            if (epoch - 1) % self.sender_period == 0 and epoch != 1:
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. epoch: {epoch}.')

        if self.receiver_period > 0:
            n_receivers = len(self.trainer.game.receiver.receivers)
            if (epoch - 1) % int(np.ceil((self.receiver_period / n_receivers))) == 0 and epoch != 1:
                self.trainer.game.receiver.receivers[self.next_idx].reset_parameters()
                print(f'Resetting receiver_{self.next_idx}\'s paramerters. epoch: {epoch}.')
                self.next_idx = (self.next_idx + 1) % n_receivers
    
    def on_batch_end(self, logs: Interaction, loss: float, batch_id: int, is_training: bool = True):
        if not is_training or self.on_epoch:
            return
        self.step += 1
        
        if self.sender_period > 0:
            if self.step % self.sender_period == 0:
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. step: {self.step}.')

        if self.receiver_period > 0:
            n_receivers = len(self.trainer.game.receiver.receivers)
            if self.step % (self.receiver_period // n_receivers) == 0:
                self.trainer.game.receiver.receivers[self.next_idx].reset_parameters()
                print(f'Resetting receiver_{self.next_idx}\'s paramerters. step: {self.step}.')
                self.next_idx = (self.next_idx + 1) % n_receivers
    

class SimultaneousResetter(core.Callback):
    def __init__(
        self, 
        receiver_period: int,
        sender_period: int, 
        on_epoch: bool,
    ):
        self.sender_period = sender_period
        self.receiver_period = receiver_period
        self.on_epoch = on_epoch
        self.step = 0
    
    def on_epoch_begin(self, epoch: int):
        if not self.on_epoch:
            return

        if self.sender_period > 0:
            if (epoch - 1) % self.sender_period == 0 and epoch != 1:
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. epoch: {epoch}.')

        if self.receiver_period > 0:
            if (epoch - 1) % self.receiver_period == 0 and epoch != 1:
                n_receivers = len(self.trainer.game.receiver.receivers)
                for i in range(n_receivers):
                    self.trainer.game.receiver.receivers[i].reset_parameters()
                    print(f'Resetting receiver_{i}\'s paramerters. epoch: {epoch}.')
    
    def on_batch_end(self, logs: Interaction, loss: float, batch_id: int, is_training: bool = True):
        if not is_training or self.on_epoch:
            return
        self.step += 1
        
        if self.sender_period > 0:
            if self.step % self.sender_period == 0:
                self.trainer.game.sender.reset_parameters()
                print(f'Resetting sender\'s paramerters. step: {self.step}.')

        if self.receiver_period > 0:
            n_receivers = len(self.trainer.game.receiver.receivers)
            if self.step % self.receiver_period == 0:
                for i in range(n_receivers):
                    self.trainer.game.receiver.receivers[i].reset_parameters()
                    print(f'Resetting receiver_{i}\'s paramerters. step: {self.step}.')


class DetailedWandbLogger(Callback):
    def __init__(
        self,
        opts: Union[argparse.ArgumentParser, Dict, str, None] = None,
        project: Optional[str] = None,
        run_id: Optional[str] = None,
        global_prefix: str = '',
        **kwargs,
    ):
        # This callback logs to wandb the interaction as they are stored in the leader process.
        # When interactions are not aggregated in a multigpu run, each process will store
        # its own Interaction object in logs. For now, we leave to the user handling this case by
        # subclassing WandbLogger and implementing a custom logic since we do not know a priori
        # what type of data are to be logged.
        self.opts = opts
        self.global_prefix = global_prefix
        self.step = 0

        wandb.init(project=project, id=run_id, **kwargs)
        wandb.config.update(opts)

    @staticmethod
    def log_to_wandb(metrics: Dict[str, Any], commit: bool = False, **kwargs):
        wandb.log(metrics, commit=commit, **kwargs)

    def on_train_begin(self, trainer_instance: "Trainer"):  # noqa: F821
        self.trainer = trainer_instance
        self.step = 0
        # wandb.watch(self.trainer.game, log="all")

    def on_batch_end(
        self, logs: Interaction, loss: float, batch_id: int, is_training: bool = True
    ):
        if is_training and self.trainer.distributed_context.is_leader:
            batch_log = {}
            for key, value in logs.aux.items():
                batch_log[f'{self.global_prefix}batch_{key}'] = value.mean()

            batch_log.update({
                f'{self.global_prefix}batch_loss': loss,
                f'{self.global_prefix}step': self.step,
            })

            self.log_to_wandb(
                batch_log,
                commit=True
            )
            self.step += 1

    def on_epoch_end(self, loss: float, logs: Interaction, epoch: int):
        if self.trainer.distributed_context.is_leader:
            self.log_to_wandb(
                {
                    f"{self.global_prefix}train_loss": loss, 
                    f"{self.global_prefix}epoch": epoch,
                }, 
            commit=True)

    def on_validation_end(self, loss: float, logs: Interaction, epoch: int):
        if self.trainer.distributed_context.is_leader:
            self.log_to_wandb(
                {
                    f"{self.global_prefix}validation_loss": loss, 
                    f"{self.global_prefix}epoch": epoch
                }, 
            commit=True
        )

    

class MessageSaver(core.Callback):
    def __init__(self, dataset, name, n_attributes):
        self.dataset = dataset
        self.name = name

    def on_train_end(self):
        ask_sender(
            
        )
        self.trainer.game.sender()