import pickle
import string
import random
import pickle
import numpy as np
import torch
import fnmatch
import os
import json
import re
import pandas as pd
import yaml

from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
from collections import defaultdict
from torchvision import datasets, transforms
from copy import deepcopy


def get_time_stamp():
    date_object = datetime.now()
    return date_object.strftime('%m%d%y-%H%M%S')


def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
    # https://gist.github.com/zachguo/10296432  @rola93
    """pretty print for confusion matrixes"""
    columnwidth = max([len(x) for x in labels] + [5])  # 5 is value length
    empty_cell = " " * columnwidth

    # Begin CHANGES
    fst_empty_cell = (columnwidth-3)//2 * " " + "t/p" + (columnwidth-3)//2 * " "

    if len(fst_empty_cell) < len(empty_cell):
        fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell
    # Print header
    print("    " + fst_empty_cell, end=" ")
    # End CHANGES

    for label in labels:
        print("%{0}s".format(columnwidth) % label, end=" ")

    print()
    # Print rows
    for i, label1 in enumerate(labels):
        print("    %{0}s".format(columnwidth) % label1, end=" ")
        for j in range(len(labels)):
            cell = "%{0}.1f".format(columnwidth) % cm[i, j]
            if hide_zeroes:
                cell = cell if float(cm[i, j]) != 0 else empty_cell
            if hide_diagonal:
                cell = cell if i != j else empty_cell
            if hide_threshold:
                cell = cell if cm[i, j] > hide_threshold else empty_cell
            print(cell, end=" ")
        print()


def get_latest_checkpoint(folder, prefix):
    items = os.listdir(folder)
    matches = fnmatch.filter(items, f"{prefix}*")
    matches = sorted(matches)
    return os.path.join(folder, matches[-1])


def any_matching_prefix(folder, prefix):
    items = os.listdir(folder)
    matches = fnmatch.filter(items, f"{prefix}*")
    return len(matches) != 0


debugging = False

def debug(*args):
    if debugging: print(*args) 
        
    
def process_exchange(message, recvr_output):
    trunc_messages = []
    receiver_outputs = []
    message = message.argmax(dim=-1)
    
    for i in range(message.size(0)):
        eos_positions = (message[i, :] == 0).nonzero()
        message_end = eos_positions[0].item() if eos_positions.size(0) > 0 else -1
        
        assert message_end == -1 or message[i, message_end] == 0
        if message_end < 0:
            trunc_messages.append(message[i, :])
        else:
            trunc_messages.append(message[i, :message_end + 1])
        
        # take last step of the receiver
        receiver_outputs.append(recvr_output[i, message_end, ...])
    
    return trunc_messages, torch.stack(receiver_outputs)

        
def check_model(sender_wrapper, game,
                sender_interpret, receiver_interpret, 
                x_loader, device, return_metrics=False):
    game.sender.eval()
    game.receiver.eval()
    sender_wrapper.model_multi.eval()
    
    concept_actuals = []
    concept_preds = []
    
    actuals = []
    preds = []
    
    nb = x_loader.n_batches_per_epoch
            
    with tqdm(total=nb) as pb:
        with torch.no_grad():
            for i, (sender_tuple, recv_targets, recv_tuple, sender_labels) in enumerate(x_loader):
                sender_objs, _ = sender_tuple
                sender_objs = sender_objs.to(device)

                concept_logits, (sender_feats, sender_structs), min_dists = sender_interpret(sender_objs)

                if sender_structs is None:
                    # using cached passthrough version
                    _, sender_structs = sender_tuple
                    sender_structs = sender_structs.to(device)
                
                recv_targets = recv_targets.to(device)
                recv_objs, _ = recv_tuple
                recv_objs = recv_objs.to(device)

                sender_labels = sender_labels.to(device)

                with torch.no_grad():  # only message signal is allowed
                    _, (recv_feats, recv_structs), _ = receiver_interpret(recv_objs) 

                if recv_structs is None:
                    # using cached passthrough version
                    _, recv_structs = recv_tuple
                    recv_structs = recv_structs.to(device)

                sender_repr = (sender_feats, sender_structs)
                recv_repr = (recv_feats, recv_structs)
                _, sender_message, receiver_output, _ = game(sender_repr, 
                                                             recv_targets, 
                                                             receiver_input=recv_repr)

                trunc_messages, last_recv = process_exchange(sender_message, receiver_output)
                preds_i = last_recv.argmax(dim=1) 
                preds_i = np.asarray(preds_i.cpu())

                preds.extend(list(preds_i))
                actuals.extend(list(recv_targets.detach().cpu().numpy()))
                
                if concept_logits is not None:
                    _, cpreds = torch.max(concept_logits, dim=1)
                    cactuals = sender_labels
                    
                    concept_preds.extend(list(cpreds.detach().cpu().numpy()))
                    concept_actuals.extend(list(cactuals.detach().cpu().numpy()))

                pb.update(1)
                # break
                
            x_loader.reset()
    
    metrics = {}
    
    acc = accuracy_score(actuals, preds)
    metrics['accuracy'] = acc
    
    if len(concept_actuals) > 0:
        concept_acc = accuracy_score(concept_actuals, concept_preds)
        metrics['concept_accuracy'] = concept_acc
        
    # log(f"{'=' * 20}\tAccuracy:\t{acc * 100.:.1f}%")
    
#     if x_loader.num_distractors == 1:
#         prec = precision_score(actuals, preds)
#         rec = recall_score(actuals, preds)
#         cm = confusion_matrix(actuals, preds)

#         print_cm(cm, [f'{i}' for i in range(num_classes)])
#         # log(f"{'=' * 20}\tPrecision:\t{prec}")
#         # log(f"{'=' * 20}\tRecall:\t{rec}")

#         metrics['precision'] = prec
#         metrics['recall'] = rec
#         metrics['cm'] = cm
        
    if return_metrics:
        return metrics

    

def safe_load(o_path):
    import numpy as np
    
    if not os.path.exists(o_path):
        print(f"Warn: {o_path} was not found!")
        return None
    else:
        o = np.load(o_path, allow_pickle=True)
        return o


def report(fname, s, log=print):
    log(s)
    with open(fname, 'a') as f:
        f.write(f"{s}\n")

        
def list_directories(directory: str):
    res = np.random.permutation(sorted(os.listdir(directory)))
    l = [di for di in res if os.path.isdir(os.path.join(directory, di))]
    return l


def load_json(config_filepath):
    with open(config_filepath) as config_file:
        state = json.load(config_file)
    return state


def save_json(state, f_path, dry_run=False):
    with open(f_path, 'w') as config_file:
        json.dump(state, config_file)
        
        
def load_yaml(filepath):
    with open(filepath, 'r') as f:
        state = yaml.safe_load(f)
    return state

        
def save_yaml(state, f_path):
    with open(f_path, 'w') as f:
        yaml.dump(state, f)

        
def pickle_write(fpath, obj):
    with open(fpath, 'wb') as f:
        pickle.dump(obj, f)


def pickle_load(fpath):
    with open(fpath, 'rb') as f:
        obj = pickle.load(f)

    return obj


def get_semiotic_epoch_pairs(save_dir, by_epoch=np.inf) -> dict:
    # by_epoch: sender/receiver epoch that is the latest we search
    push_pattern = re.compile('([0-9]+)(push)')
    nopush_pattern = re.compile('([0-9]+)(nopush)')
    ext_pattern = re.compile('.pth')
    ext = ".pth"
    
    files = os.listdir(save_dir)

    found_push_epochs = []
    found_nopush_epochs = []
    max_push = 0
    max_nopush = 0

    for file in files:
        nop = nopush_pattern.match(file)
        p = push_pattern.match(file)
        if nop:
            epoch, project = nop.groups()
            epoch = int(epoch)
            if epoch <= by_epoch:
                found_nopush_epochs.append(epoch)
                if epoch > max_nopush:
                    max_nopush = epoch
            
            # old version of code didn't add .pth 
            if ext_pattern.match(file):
                ext = ".pth"
            
        elif p:
            epoch, project = p.groups()
            epoch = int(epoch)
            if epoch <= by_epoch:
                found_push_epochs.append(epoch)
                if epoch > max_push:
                    max_push = epoch
                    
            if ext_pattern.match(file):
                ext = ".pth"
    
    return {'push': [f"{i}push{ext}" for i in np.sort(found_push_epochs)],
            'nopush': [f"{i}nopush{ext}" for i in np.sort(found_nopush_epochs)],
            'max_push': max_push,
            'max_nopush': max_nopush}


def get_last_semiotic_model_file_pair(save_dir: str, by_epoch=np.inf) -> dict:
    # by_epoch: sender/receiver epoch that is the latest we search
    epoch_pairs = get_semiotic_epoch_pairs(save_dir, by_epoch)
    res = {'push': None, 'nopush': None}
    for project in ['push', 'nopush']:
        if len(epoch_pairs[project]):
            res[project] = epoch_pairs[project][-1]
            
    return res


def get_last_semiotic_model_file(save_dir: str, by_epoch: int) -> str:
    # return the latest semiotic model file
    epoch_pairs = get_semiotic_epoch_pairs(save_dir, by_epoch)
    if epoch_pairs['max_push'] == 0 and epoch_pairs['max_nopush'] == 0:
        return None
    
    # either nopush or push are avaialble, choose latest
    if epoch_pairs['max_push'] >= epoch_pairs['max_nopush']:
        return epoch_pairs['push'][-1]
    else:
        return epoch_pairs['nopush'][-1]
    

class EpochHistory(object):
    def __init__(self, epoch):
        self.epoch = epoch
        self.msg_lengths = defaultdict(list)
        self.hist_main_loss = []
        self.hist_aux_info = []
        
        # Dict of form {'original': tensor, 'target': tensor, 'output': tensor}
        self.sender_sample_pairs = {}
        self.accuracy = 0.0
        self.concept_accuracy_map = {}
        self.concept_costs_map = {}
        
    def update_main_loss(self, main_loss):
        self.hist_main_loss.append(main_loss)
        
    def update_aux_info(self, aux_info_i):
        self.hist_aux_info.append(aux_info_i)
        
    def log_accuracy(self, accuracy):
        self.accuracy = accuracy
    
    def log_concept_accuracies(self, concept_accuracy: dict):
        for key, val in concept_accuracy.items():
            self.concept_accuracy_map[key] = val
            
    def log_concept_costs(self, concept_costs: dict):
        # This is for test time costs
        # Train time costs are kept track of using update_aux_info inside train.py functions
        for key, pack in concept_costs.items():
            xent, cluster_cost, accu, l1, p_avg_pair_dist, separation_cost, avg_separation_cost = pack
            self.concept_costs_map[key] = {
                'xent': xent,
                'cluster_cost': cluster_cost,
                'l1': l1,
                'p_avg_pair_dist': p_avg_pair_dist,
                'separation_cost': separation_cost,
                'avg_separation_cost': avg_separation_cost,
            }
        
    def set_sample_dict(self, d):
        self.sender_sample_pairs = d
        

class ParamSet(object):
    def __init__(self, series_obj):
        self.state = series_obj
        try:
            self.id = int(self.state['run_id'])
        except:
            print(series_obj)
            
        self.epoch_histories = self.init_histories()
        self.printable_label = self.init_label()
        self.color_a = None
        self.color_b = None
        self.linestyle = None
        
    @staticmethod
    def proc_concept_accuracy_map(acc_map: dict):
        res = {}
        for k, val in acc_map:
            res[k] = val * 100
        return res
    
    def init_histories(self):
        state = self.state
        aux_losses = state['aux_losses']
        aux_weights = state['aux_weights']
        
        res = {}
        history_path = os.path.join(self.state['save_dir'], str(self.id), 'history.pkl')
        eh = pickle_load(history_path)
        n = len(eh)
        
        res['epochs'] = np.asarray([eh[j].epoch for j in range(n)])
        
        res['receiver_accuracies'] = np.asarray([eh[j].accuracy * 100 for j in range(n)])
        
        concept_accuracies_map = defaultdict(list)
        for i in range(n):
            for key, val in eh[i].concept_accuracy_map.items():
                concept_accuracies_map[key].append([eh[i].epoch, val])
        # convert to 2d numpy
        np_concept_accuracies_map = {}
        for key, val in concept_accuracies_map.items():                
            np_concept_accuracies_map[key] = np.asarray(val)

        res['concept_accuracies_map'] = np_concept_accuracies_map
        
        concept_costs_df = pd.DataFrame()
        for i in range(n):
            for key, val in eh[i].concept_costs_map.items():
                concept_costs_df = concept_costs_df.append(pd.Series(data={
                    'push_type': key,
                    'epoch': i,
                    **val
                }, name=f"{i}-{key}"))
        
        res['concept_costs_df'] = concept_costs_df
        
        epochs = res['epochs']
        push_idxes = state['semiotic_push_epochs']
        sgd_idxes = state['semiotic_sgd_epochs']
        
        if len(push_idxes):
            # grab from epochs that were static after a push or on a push
            valid_epochs = list(deepcopy(push_idxes))
            valid_with_sentinal = list(valid_epochs) + [int(epochs[-1])]
            # print(valid_with_sentinal)
            for i in range(len(valid_with_sentinal) - 1):
                start = int(valid_with_sentinal[i])
                end = int(valid_with_sentinal[i+1])

                between = list(range(start, end, 1))
                for k in between:
                    if k not in sgd_idxes:
                        valid_epochs.append(k)
                        
        elif len(sgd_idxes):
            # select after first sgd epoch
            valid_epochs = list(range(sgd_idxes[0], int(epochs[-1])))
            
        else:
            valid_epochs = deepcopy(epochs)
        
        res['human_interp_epochs'] = valid_epochs
        
        res['main_loss'] = np.concatenate([eh[j].hist_main_loss for j in range(n)])
        
        res['expected_length'] = []
        res['main_loss'] = []
        res['least_effort'] = []
        
        for epoch_obj in eh:
            aux_dicts = epoch_obj.hist_aux_info
            for aux_dict in aux_dicts:
                res['expected_length'].append(float(aux_dict['expected_length']))
                
            res['main_loss'].extend(epoch_obj.hist_main_loss)
            
        # one update per dict (minibatch)
        res['expected_length_frequency'] = len(eh[1].hist_aux_info)
        res['main_loss_frequency'] = len(eh[1].hist_main_loss)
        
        return res
    
    def init_label(self):
        state = self.state

#         hidden_dim = state['hidden_dim']
#         embed_dim = state['embed_dim']
#         vocab_size = state['vocab_size']
#         sender_arch = state['sender_arch']
        aux_losses = state['aux_losses']
        aux_weights = state['aux_weights']
        sse = state['semiotic_sgd_epochs']
        spe = state['semiotic_push_epochs']
#         max_len = state['max_len']
        pretty_loss = {
            'least_effort': 'LEP'
        }
    
        def prettify(attr):
            pretty_attr = {
                "social_coef": "$\\beta$=",
                "sign_coef": "$\\alpha$=",
                "prototype_vectors_lr": "$\\eta_{P}$=",
                "add_on_layers_lr": "$\\eta_{\\theta^+}$=",
                "last_layer_lr": "$\\eta_{C}$=",
                "features_lr": "$\\eta_{\\theta}$=",
                "semiotic_sgd_epochs": f"SSGD-{len(sse)}",
                "semiotic_push_epochs": f"SP-{len(spe)}",
                "sender_arch": "$S$ Arch.=",
                "learnable_temperature": "$\\tau$-Opt.=",
                "vocab_size": "|A|=",
                "approach": "",
                "sender_percept_arch": "$S_f=$ ",
                "recv_percept_arch": "$R_f=$ ",
                "sender_prototypes_per_class": "$S_k=$ ",
                "recv_prototypes_per_class": "$R_k=$ ",
                "seed": "",
            }
            pretty = pretty_attr.get(attr, None)
            
            if pretty is not None:
                return pretty
            else:
                return attr.replace('_', ' ').capitalize()
    
        # s = f"H{hidden_dim} E{embed_dim} |V|={vocab_size} S={sender_arch} L={max_len} |S-SGD|={len(sse)} |S-Push|={len(spe)}"
        # s = f"|S-SGD|={len(sse)} |S-Push|={len(spe)}"
        s = []
        for attrb in state['experiments_variables']:
            if attrb == 'aux_weights' or attrb == 'aux_losses' or attrb == 'seed':
                continue  # handle below
            try:
                val = state[attrb]
            except KeyError:
                continue
            
            if attrb == 'approach':
                val = {'proto': 'Semiotic', 'feats': 'End2End'}[val]
                
            if type(val) is list or type(val) is tuple:
                if len(val) > 10:
                    val = ""
                else:
                    val = f"={val}"
            
            # fix architecture string
            if type(val) is str and "Wrapper" in val:
                lookup = {
                    "ProtoWrapper": "ProtoPNet", 
                    # "ProtoBWrapper":, 
                    "CwWrapper": "CW", 
                    # "CnnWrapper": "ConvNet", 
                    "CnnBWrapper": "ConvNet",
                }
                val = lookup[val] # val.replace("Wrapper", "")
                
                        # fix architecture string
            if type(val) is str and ("Sender" in val or "Receiver" in val):
                lookup = {
                    "RnnSenderGS": "Vanilla RNN",
                    "FLRnnSenderGS": "Vanilla RNN",
                    "OLRnnSenderGS": "1-Length",
                    "MultiHeadRnnSenderGS": "Self-attention RNN",
                    "MultiHeadRnnSenderGS2": "Self-attention RNN",
                    "ProtoSenderGS": "ProtoRNN",
                    "ProtoSender2GS": "ProtoRNN",
                    "ProtoSender3GS": "ProtoRNN",
                    "RnnReceiverGS": "Vanilla RNN",
                    "FLRnnReceiverGS": "Vanilla RNN",
                    "ProtoReceiver2GS": "ProtoRNN",
                }
                val = lookup[val] # val.replace("Wrapper", "")
                
            s.append(f"{prettify(attrb)}{val}")
            
        for loss, weight in zip(aux_losses, aux_weights):
            if type(weight) is float:
                s.append(f"{pretty_loss[loss]} {weight:.2f}")
            else:
                s.append(f"{weight[0]}-{pretty_loss[loss]} {weight[1]:.2f}")
        
        return ", ".join(s)
        
    def set_color(self, color_tuple):
        self.color_a, self.color_b = color_tuple
    
    def set_linestyle(self, ls):
        self.linestyle = ls


def agents_only(sender_percept, sender, receiver, log=print):
    sender_percept.model_multi.eval()
    sender.train()
    receiver.train()
    
    sender_percept.choose_grad('off', log=log)
    sender.choose_grad('on')
    receiver.choose_grad('on')
    
    log('\tSystem configuration: agents only')
    

def semiosis_joint(sender_percept, sender, receiver, log=print):
    sender_percept.model_multi.train()
    sender.train()
    receiver.train()
    
    sender_percept.choose_grad('joint', log=log)
    sender.choose_grad('on')
    receiver.choose_grad('on')
    
    log('\tSystem configuration: semiosis')
    
    
def semiosis_classifier(sender_percept, sender, receiver, log=print):
    sender_percept.model_multi.train() # only optimize fc
    sender.eval()
    receiver.eval()
    
    sender_percept.choose_grad('last_only', log=log)
    sender.choose_grad('off')
    receiver.choose_grad('off')
    
    log('\tSystem configuration: last layer optim')
    
    
def proto_losses(wrapper, class_specific, output, labels, min_distances, use_l1_mask=False):
    labels_int = labels.cpu().detach().numpy().astype(int)
    
    cross_entropy = torch.nn.functional.cross_entropy(output, labels)

    model = wrapper.model_multi
    if class_specific:
        max_dist = (model.module.prototype_shape[1]
                    * model.module.prototype_shape[2]
                    * model.module.prototype_shape[3])

        # prototypes_of_correct_class is a tensor of shape batch_size * num_prototypes
        # calculate cluster cost
        prototypes_of_correct_class = torch.t(model.module.prototype_class_identity[:,labels_int]).cuda()
        inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1)
        cluster_cost = torch.mean(max_dist - inverted_distances)

        # calculate separation cost
        prototypes_of_wrong_class = 1 - prototypes_of_correct_class
        inverted_distances_to_nontarget_prototypes, _ = \
            torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1)
        separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes)

        # calculate avg cluster cost
        avg_separation_cost = \
            torch.sum(min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1)
        avg_separation_cost = torch.mean(avg_separation_cost)

        if use_l1_mask:
            l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda()
            l1 = (model.module.last_layer.weight * l1_mask).norm(p=1)
        else:
            l1 = model.module.last_layer.weight.norm(p=1) 

    else:
        min_distance, _ = torch.min(min_distances, dim=1)
        cluster_cost = torch.mean(min_distance)
        l1 = model.module.last_layer.weight.norm(p=1)
        separation_cost = 0.0
        
    return cross_entropy, cluster_cost, separation_cost, l1


def proto_weighting(class_specific, cross_entropy, cluster_cost, separation_cost, l1, coefs=None):
    if class_specific:
        if coefs is not None:
            proto_loss = (coefs['crs_ent'] * cross_entropy
                        + coefs['clst'] * cluster_cost
                        + coefs['sep'] * separation_cost
                        + coefs['l1'] * l1)
        else:
            proto_loss = cross_entropy + 0.8 * cluster_cost - 0.08 * separation_cost + 1e-4 * l1
    else:
        if coefs is not None:
            proto_loss = (coefs['crs_ent'] * cross_entropy
                        + coefs['clst'] * cluster_cost
                        + coefs['l1'] * l1)
        else:
            proto_loss = cross_entropy + 0.8 * cluster_cost + 1e-4 * l1
            
    return proto_loss


def save_enc_model(model, model_dir, model_name, log=print):
    '''
    model: this is not the multigpu model
    '''
    torch.save(obj=model.state_dict(), f=os.path.join(model_dir, f"{model_name}.pth"))
    log(f"Wrote prototype model to {model_dir}")


def construct_prototype_model(model_details: dict):
    from ProtoPNet import model
    prototype_shape = (model_details['num_classes'] * model_details['prototypes_per_class'], 128, 1, 1)

    ppnet = model.construct_PPNet(base_architecture=model_details['base_architecture'],
                                  pretrained=model_details['pretrained'], 
                                  img_size=model_details['img_size'],
                                  prototype_shape=prototype_shape,
                                  num_classes=model_details['num_classes'],
                                  prototype_activation_function=model_details['prototype_activation_function'],
                                  add_on_layers_type=model_details['add_on_layers_type'])
    
    return ppnet


def merge_accuracies(concept_accuracies_map):
    # replace blank epochs in push matrix with nopush epochs (if they both exist)
    pass


def build_class_to_prototype_files(vocab_size, proto_per_class, epoch_folder, file_prefix):
    class_to_prototype_files = defaultdict(list)
    k = 0
    for ix in range(vocab_size):
        file = f'{file_prefix}{ix}.png'
        class_to_prototype_files[k].append(os.path.join(epoch_folder, file))

        if (ix + 1) % proto_per_class == 0:
            k += 1
            
    return class_to_prototype_files


def calculate_bbrf_areas(bbrf):
    '''
    proto_rf_boxes and proto_bound_boxes column (ProtoPNet/push.py):
    0: image index in the entire dataset
    1: height start index
    2: height end index
    3: width start index
    4: width end index
    5: (optional) class identity
    '''
    vocab_size = bbrf.shape[0]
    heights = bbrf[:, 2] - bbrf[:, 1]
    widths = bbrf[:, 4] - bbrf[:, 3]
    areas = heights * widths
    
    return areas


def test_message_identity(messages, k, vocab_size, prototypes_per_class):
    ands = []
    for i in range(messages.size(0)):
        message = messages[i].argmax(axis=-1)
        identities = torch.zeros(vocab_size - 1, dtype=int)
        start = k[i] * prototypes_per_class
        end = start + prototypes_per_class
        identities[start:end] = 1
        proto_ids = message[message > 0].unique() # - 1  # remove eos
        proto_chosen = identities[proto_ids]
        
        if 0 in proto_chosen:
            ands.append(0)
        else:
            ands.append(1)
    
    return ands
        

def loader_to_message_data(loader, sender_wrapper, agents, max_msgs=1000):
    _sender, _receiver = agents
    in_images = []
    in_vectors = []
    in_structs = []
    messages = []
    actuals = []
    preds = []
    ands = []
    
    # play signal game with preprocessed feats
    # n = len(curr_loader.cache)
    class_to_symbols = defaultdict(list)

    with torch.no_grad():
        with tqdm(total=max_msgs) as pb:
            # loader.start_epoch('semiotic')
            start = 0
            for i, (sender_images, recv_targets, _, sender_labels) in enumerate(loader):
                # sender_repr = sender_images
                sender_repr, sender_structure = sender_wrapper.prelinguistic(sender_images)
                message = _sender((sender_repr, sender_structure))
                end = min(start + sender_repr.size(0), loader.dataset_size)

                for actual, message_am in zip(sender_labels, message.argmax(axis=-1)):
                    actual = actual.detach().cpu().item()
                    message_am = message_am.detach().cpu().numpy()
                    class_to_symbols[actual].extend(list(message_am))

                in_images.extend(sender_images[:end-start].detach().cpu())
                messages.extend(message[:end-start].detach().cpu().numpy())
                in_vectors.extend(sender_repr[:end-start].detach().cpu().numpy())
                in_structs.extend(sender_structure[:end-start].detach().cpu().numpy())
                actuals.extend(sender_labels[:end-start].detach().cpu().numpy())
                
                pb.update(
                    min(
                        max_msgs - len(message[:end-start]), 
                        len(message[:end-start])
                    )
                )
                start = end
                
                if len(messages) >= max_msgs:
                    break
                    
    print(f"Reached max messages count of {max_msgs}")
    loader.reset()
    in_images = in_images[:max_msgs]
    messages = messages[:max_msgs]
    in_vectors = in_vectors[:max_msgs]
    in_structs = in_structs[:max_msgs]
    actuals = actuals[:max_msgs]
    
    if "Proto" in sender_wrapper.__class__.__name__:
        ands = ands[:max_msgs]
    
    messages = torch.as_tensor(messages).argmax(axis=-1)
    in_structs = torch.as_tensor(in_structs)
    in_vectors = torch.as_tensor(in_vectors)
    lengths = []
    uniques = []
    # print('messages', messages.shape)
    # print('in_vectors', in_vectors.shape)
    
    for i in range(messages.size(0)):
        eos_positions = (messages[i, :] == 0).nonzero()
        message_end = eos_positions[0].item() if eos_positions.size(0) > 0 else -1
        messages[i, message_end:] = 0
        lengths.append(eos_positions[0].item())
        nz = messages[i][messages[i] > 0]
        uniques.append(len(nz.unique()))
        
    if len(ands):
        ident = np.mean(ands)
    else:
        ident = None
        
    return {
        'in_images': in_images,
        'in_vectors': in_vectors,
        'in_structs': in_structs,
        'messages': messages,
        'actuals': actuals,
        'preds': preds,
        'class_to_symbols': class_to_symbols,
        'length_average': np.mean(lengths),
        'length_std': np.std(lengths),
        'uniques_average': np.mean(uniques),
        'uniques_std': np.std(uniques),
        'percent_identified': ident
    }


def loader_to_noise_result(loader, sender_wrapper, recv_wrapper, agents, device, 
                           mean, std,
                           max_msgs=1000, epsilon=0.1):
    _sender, _receiver = agents
    in_images = []
    in_vectors = []
    messages = []
    actuals = []
    preds = []
    ands = []
    
    # unnormalize input on gpu
    def unnormalize(x):
        # B x C x H x W
        x = (x.permute(0, 2, 3, 1) * torch.tensor(std).cuda()) + torch.tensor(mean).cuda()
        # B x H x W x C
        return x

    def normalize(x):
        # B x H x W x C
        x = (x - torch.tensor(mean).cuda()) / torch.tensor(std).cuda()
        x = x.permute(0, 3, 1, 2)
        # B x C x H x W
        return x
    
    # play signal game with preprocessed feats
    # n = len(curr_loader.cache)
    class_to_symbols = defaultdict(list)

    with torch.no_grad():
        with tqdm(total=max_msgs) as pb:
            loader.start_epoch('semiotic')
            start = 0
            for i, (sender_images, recv_targets, recv_images, sender_labels) in enumerate(loader):
                sender_images = sender_images.to(device)
                recv_targets = recv_targets.to(device)
                recv_images = recv_images.to(device)
                sender_labels = sender_labels.to(device)
            
                # sender_repr = sender_images
                # print(sender_images.shape)
                # plot_single(sender_images[2], normalized=True)
                
                # plot_single(unnormalize(sender_images)[3], normalized=False)
                uniform_ = torch.rand_like(unnormalize(sender_images)) * epsilon
                images_noisy = unnormalize(sender_images) + uniform_
                images_noisy = normalize(images_noisy)
                
                # plot_single(images_noisy[2])
                
                sender_repr = sender_wrapper.prelinguistic(images_noisy)
                message = _sender(sender_repr)
                recv_repr = recv_wrapper(receiver_input)
                receiver_output, receiver_hiddens = _receiver(message, recv_repr)
                trunc_messages, last_recv = process_exchange(message, receiver_output)
                
                end = min(start + sender_repr.size(0), loader.dataset_size)

                for actual, message_am in zip(sender_labels, message.argmax(axis=-1)):
                    actual = actual.detach().cpu().item()
                    message_am = message_am.detach().cpu().numpy()
                    class_to_symbols[actual].extend(list(message_am))

                in_images.extend(sender_images[:end-start].detach().cpu())
                messages.extend(message[:end-start].detach().cpu().numpy())
                in_vectors.extend(sender_repr[:end-start].detach().cpu().numpy())
                actuals.extend(recv_targets[:end-start].detach().cpu().numpy())
                # print('actual', recv_targets[:end-start])
                # print('pred', last_recv.argmax(dim=1))
                preds.extend(last_recv.argmax(dim=1)[:end-start].detach().cpu().numpy())
                    
                pb.update(
                    min(
                        len(messages) - len(message[:end-start]), 
                        len(message[:end-start])
                    )
                )
                start = end
                
                if len(messages) >= max_msgs:
                    break
    
    print(f"Reached max messages count of {max_msgs}")
    loader.reset()
    in_images = in_images[:max_msgs]
    messages = messages[:max_msgs]
    in_vectors = in_vectors[:max_msgs]
    actuals = actuals[:max_msgs]
    preds = preds[:max_msgs]
    
    if "Proto" in sender_wrapper.__class__.__name__:
        ands = ands[:max_msgs]
    
    messages = torch.as_tensor(messages).argmax(axis=-1)
    in_vectors = torch.as_tensor(in_vectors)
    lengths = []
    uniques = []
    # print('messages', messages.shape)
    # print('in_vectors', in_vectors.shape)
    
    for i in range(messages.size(0)):
        eos_positions = (messages[i, :] == 0).nonzero()
        message_end = eos_positions[0].item() if eos_positions.size(0) > 0 else -1
        messages[i, message_end:] = 0
        lengths.append(eos_positions[0].item())
        nz = messages[i][messages[i] > 0]
        uniques.append(len(nz.unique()))
        
    if len(ands):
        ident = np.mean(ands)
    else:
        ident = None
        
    return {
        'in_images': in_images,
        'in_vectors': in_vectors,
        'messages': messages,
        'actuals': actuals,
        'preds': preds,
        'class_to_symbols': class_to_symbols,
        'length_average': np.mean(lengths),
        'length_std': np.std(lengths),
        'uniques_average': np.mean(uniques),
        'uniques_std': np.std(uniques),
        'percent_identified': ident
    }


def set_eval(models):
    for m in models:
        m.eval()
        

def set_train(models):
    for m in models:
        m.train()


def disable_parameter_requires_grad(model):
    for param in model.parameters():
        param.requires_grad = False


def log_whole_system_params(sender_percept, recv_percept, _game, log=print):
    log('Optimized parameters:')
    log('sender_percept')
    for name, param in sender_percept.named_parameters():
        if param.requires_grad:
            log(f"\t{name}")

    log('sender')
    for name, param in _game.sender.named_parameters():
        if param.requires_grad:
            log(f"\t{name}")

    log('recv_percept')
    for name, param in recv_percept.named_parameters():
        if param.requires_grad:
            log(f"\t{name}")

    log('receiver')
    for name, param in _game.receiver.named_parameters():
        if param.requires_grad:
            log(f"\t{name}")