import torch
import torch.nn as nn

from ProtoPNet.settings import coefs

import util
from util import *


def one_epoch(state, epoch, device, train_loader, 
              sender_interpret, receiver_interpret, 
              sender_wrapper, game, optim, experiment, log=print):
    epoch_log = EpochHistory(epoch)
    nb = train_loader.n_batches_per_epoch
    
    # sender_wrapper.model_multi.train()
    game.sender.train()
    game.receiver.train()
    
    if experiment == 'cw' and len(state['semiotic_sgd_epochs']):
        import torchvision.transforms as transforms
        import torchvision.datasets as datasets

        conceptdir_train = state['concept_train_dir']
        conceptdir_test = state['concept_test_dir']
        concept_loaders = [
        torch.utils.data.DataLoader(
            datasets.ImageFolder(os.path.join(conceptdir_train, concept), transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=state['sender_mean'], std=state['sender_std']),
            ])),
            batch_size=state['train_batch_size'], shuffle=True,
            num_workers=4, pin_memory=False
        )
        for concept in state['concepts']
    ]

    
    with tqdm(total=nb) as pb:
        epoch_rolling_acc = []
        epoch_rolling_concept_acc = []
        epoch_rolling_loss = []
        
        for i, (sender_tuple, recv_targets, recv_tuple, sender_labels) in enumerate(train_loader):
            sender_objs, _ = sender_tuple
            sender_objs = sender_objs.to(device)

            concept_logits, (sender_feats, sender_structs), min_dists = sender_interpret(sender_objs)
            
            if sender_structs is None:
                # using cached passthrough version
                _, sender_structs = sender_tuple
                sender_structs = sender_structs.to(device)
                
            recv_targets = recv_targets.to(device)
            recv_objs, _ = recv_tuple
            recv_objs = recv_objs.to(device)
                
            sender_labels = sender_labels.to(device)
            
            with torch.no_grad():  # only message signal is allowed
                _, (recv_feats, recv_structs), _ = receiver_interpret(recv_objs) 
                
            if recv_structs is None:
                # using cached passthrough version
                _, recv_structs = recv_tuple
                recv_structs = recv_structs.to(device)
            
            sender_repr = (sender_feats, sender_structs)
            recv_repr = (recv_feats, recv_structs)
            social_loss, sender_message, receiver_output, aux_dict_i = game(sender_repr, 
                                                                            recv_targets, 
                                                                            receiver_input=recv_repr)
            optim.zero_grad()
            if concept_logits is not None:
                if experiment == 'proto':
                    # deal with proto losses
                    cross_ent, cluster_cost, separation_cost, l1 = proto_losses(
                        sender_wrapper, 
                        state['class_specific'], 
                        concept_logits, 
                        sender_labels,
                        min_dists
                    )
                    aux_dict_i['sign_xe'] = cross_ent.item()
                    aux_dict_i['sign_cluster'] = cluster_cost.item()
                    aux_dict_i['sign_sep'] = separation_cost.item()
                    aux_dict_i['sign_l1'] = l1.item()

                    proto_loss = proto_weighting(
                        state['class_specific'],
                        cross_ent,
                        cluster_cost,
                        separation_cost,
                        l1,
                        coefs
                    )
                    semiotic_loss = state['sign_coef'] * proto_loss + state['social_coef'] * social_loss
                    loss_i = semiotic_loss
                elif experiment == 'cw':
                    # update rotation matrix on certain time step
                    if (i + 1) % 30 == 0: 
                        sender_wrapper.model_multi.eval()
                        # access CW wrapper around base model weights
                        cw_model = sender_wrapper.model_multi.module.cw_model
                        with torch.no_grad():
                            # update the gradient matrix G
                            for concept_index, concept_loader in enumerate(concept_loaders):
                                cw_model.change_mode(concept_index)
                                for j, (X, _) in enumerate(concept_loader):
                                    X_var = torch.autograd.Variable(X).cuda()
                                    _ = sender_wrapper.model_multi(X_var)
                                    break
                            cw_model.update_rotation_matrix()
                            # change to ordinary mode
                            cw_model.change_mode(-1)
                        sender_wrapper.model_multi.train()
                    
                    cross_ent = torch.nn.functional.cross_entropy(concept_logits, sender_labels)
                    aux_dict_i['sign_xe'] = cross_ent.item()
                    endtoend_loss = state['sign_coef'] * cross_ent + state['social_coef'] * social_loss
                    loss_i = endtoend_loss
                else:
                    # end-to-end CNN baseline
                    cross_ent = torch.nn.functional.cross_entropy(concept_logits, sender_labels)
                    aux_dict_i['sign_xe'] = cross_ent.item()
                    endtoend_loss = state['sign_coef'] * cross_ent + state['social_coef'] * social_loss
                    loss_i = endtoend_loss
            else:
                # static case, concept logits = None
                loss_i = social_loss
             
            loss_i.backward()
            optim.step()
            
            trunc_messages, last_recv = process_exchange(sender_message, receiver_output)

            acc_i = last_recv.argmax(dim=1) == recv_targets
            acc_i = torch.mean(acc_i.float()).item()
            epoch_rolling_acc.append(acc_i)
            
            if concept_logits is not None:
                _, c_preds = torch.max(concept_logits, dim=1)
                c_acc_i = c_preds == sender_labels
                c_acc_i = torch.mean(c_acc_i.float()).item()
                epoch_rolling_concept_acc.append(c_acc_i)

            epoch_rolling_loss.append(loss_i.cpu().item())

            str_loss = f"{np.mean(epoch_rolling_loss):.3f}"
            str_acc = f"{np.mean(epoch_rolling_acc):.3f}"
            str_c_acc = f"{np.mean(epoch_rolling_concept_acc) if len(epoch_rolling_concept_acc) else 0.0:.3f}"

            epoch_log.update_aux_info(aux_dict_i)
            epoch_log.update_main_loss(loss_i.cpu().item())

            for key, value in aux_dict_i.items():
                aux_dict_i[key] = f"{aux_dict_i.get(key, 0.0):.3f}"

            pb.update(1)
            pb.set_postfix(loss=str_loss, acc=str_acc, c_acc=str_c_acc, **aux_dict_i)
            # print(sender_message.argmax(dim=2)[0]) 
            # break
    # log(f'Epoch {epoch}: train acc={str_acc}, loss={str_loss}\n')
        
    # reset dali iterator
    train_loader.reset()
        
    return epoch_log


def classifier_epoch(state, epoch, device, train_loader, wrapper, optim, experiment, log=print):
    epoch_log = EpochHistory(epoch)
    nb = train_loader.n_batches_per_epoch
    # proto_wrapper.model_multi.train()
    
    with tqdm(total=nb) as pb:
        epoch_rolling_concept_acc = []
        epoch_rolling_loss = []
        
        for i, (image, label) in enumerate(train_loader):
            images = image.cuda()
            target = label.cuda()
            
            concept_logits, _, min_dists = wrapper(images)   
            aux_dict_i = {}
            
            optim.zero_grad()
            
            if experiment == 'proto':
                # semiotic case, deal with proto losses
                cross_ent, cluster_cost, separation_cost, l1 = proto_losses(
                    wrapper, 
                    state['class_specific'], 
                    concept_logits, 
                    target,
                    min_dists
                )
                aux_dict_i['sign_xe'] = cross_ent.item()
                aux_dict_i['sign_cluster'] = cluster_cost.item()
                aux_dict_i['sign_sep'] = separation_cost.item()
                aux_dict_i['sign_l1'] = l1.item()

                proto_loss = proto_weighting(
                    state['class_specific'],
                    cross_ent,
                    cluster_cost,
                    separation_cost,
                    l1,
                    coefs
                )
                proto_loss.backward()
                optim.step()

                loss_i = proto_loss
            else:
                # end-to-end CNN baseline
                cross_ent = torch.nn.functional.cross_entropy(concept_logits, target)
                aux_dict_i['sign_xe'] = cross_ent.item()
                cross_ent.backward()
                optim.step()
                
                loss_i = cross_ent
             
            _, c_preds = torch.max(concept_logits, dim=1)
            c_acc_i = c_preds == target
            c_acc_i = torch.mean(c_acc_i.float()).item()
            epoch_rolling_concept_acc.append(c_acc_i)

            epoch_rolling_loss.append(loss_i.cpu().item())

            str_loss = f"{np.mean(epoch_rolling_loss):.3f}"
            str_c_acc = f"{np.mean(epoch_rolling_concept_acc) if len(epoch_rolling_concept_acc) else 0.0:.3f}"

            epoch_log.update_aux_info(aux_dict_i)
            epoch_log.update_main_loss(loss_i.cpu().item())

            for key, value in aux_dict_i.items():
                aux_dict_i[key] = f"{aux_dict_i.get(key, 0.0):.3f}"

            pb.update(1)
            pb.set_postfix(loss=str_loss, c_acc=str_c_acc, **aux_dict_i)
            # break
        
    # reset dali iterator
    train_loader.reset()
        
    return epoch_log
    