import torch
import torch.nn as nn
import random
import numpy as np
import argparse
import os
import re
import json
import shutil
import logging
import sys
sys.path.append('/<anonymized>/disentangling_categorization')
sys.path.append('/<anonymized>/disentangling_categorization/ProtoPNet')

from tqdm import tqdm
from torchvision import transforms
from torch.nn import functional as F
from collections import defaultdict

import torchvision.datasets as datasets

import util
from util import *
from ConceptWhitening.features import construct_CW_features
from ConvNet.features import construct_Cnn_features

import models
from models import CnnWrapper, ProtoWrapper, CwWrapper, CnnBWrapper

import agents2
from agents2 import RnnSenderGS, RnnReceiverGS

# map of which agent is compatible with which perceptual module. 
agent_compat_lists = {
    'RnnSenderGS': ["ProtoWrapper", "ProtoBWrapper", 
                    "CwWrapper", 
                    "CnnWrapper", "CnnBWrapper"],
    'FLRnnSenderGS': ["ProtoWrapper", "ProtoBWrapper", 
                      "CwWrapper", 
                      "CnnWrapper", "CnnBWrapper"],
    'OLRnnSenderGS': ["ProtoWrapper", "CwWrapper"],
    'MultiHeadRnnSenderGS': ["ProtoWrapper", "CwWrapper"],
    'MultiHeadRnnSenderGS2': ["ProtoWrapper", "CwWrapper"],
    'RnnReceiverGS': ["ProtoWrapper", "ProtoBWrapper", "MultiHeadProtoWrapper",
                      "CwWrapper", 
                      "CnnWrapper", "CnnBWrapper"],
    'FLRnnReceiverGS': ["ProtoWrapper", "ProtoBWrapper", "MultiHeadProtoWrapper",
                        "CwWrapper", 
                        "CnnWrapper", "CnnBWrapper"],
}

percept_compat_lists = defaultdict(list)
percepts = []
for agent in list(agent_compat_lists.keys()):
    for percept in agent_compat_lists[agent]:
        percept_compat_lists[percept].append(agent)
        percepts.append(percept)
        
percepts = list(set(percepts))
for p in percepts:
    percept_compat_lists[p] = list(set(percept_compat_lists[p]))


# arch_key - class name from models.py
# ckpt_key - string index into state for architecture checkpoint
# base_cnn_key - string index into state for torchvision.models class name for base CNN model
def build_perceptual_wrapper(state, arch_key, ckpt_key, base_cnn_key, mean_key, std_key):
    # Normalization parameters
    mean = state[mean_key]
    std = state[std_key]
    
    # ===== init perceptual module. 
    # Some wrappers use the same base model so initialize the base model first. 
    if "Proto" in state[arch_key]:
        from ProtoPNet.settings import img_size, prototype_activation_function, add_on_layers_type, num_data_workers
        if 'send' in arch_key:
            ppc = state['sender_prototypes_per_class']
        else:
            ppc = state['recv_prototypes_per_class']
        
        # shouldn't matter in either case since we load ProtoPNet state dict
        pretrained = 'ProtoB' not in state[arch_key]
        
        model_details = {
                'img_size': img_size,
                'num_classes': state['num_classes'],
                'prototypes_per_class': ppc,
                'prototype_activation_function': prototype_activation_function,
                'add_on_layers_type': add_on_layers_type,
                'base_architecture': state[base_cnn_key],
                'pretrained': pretrained
        }
        
        enc = construct_prototype_model(model_details)
        
        if state[ckpt_key] != '':
            logging.info(f"Load {arch_key} checkpoint from {state[ckpt_key]}")
            # model_state = torch.load(state[ckpt_key], map_location=state['device'])
            model_state = torch.load(state[ckpt_key])
            enc.load_state_dict(model_state['state_dict'])
    elif state[arch_key] == "CwWrapper":
        logging.info(f"Load {arch_key} checkpoint from {state[ckpt_key]}")
        model_state = torch.load(state[ckpt_key])
        model_state["num_classes"] = state["num_classes"]
        
        enc = construct_CW_features(model_state)
        
    elif "CnnB" in state[arch_key]:
        logging.info(f"Load {arch_key} checkpoint from {state[ckpt_key]}")
        model_state = torch.load(state[ckpt_key])
        model_state['num_classes'] = state['num_classes']
        model_state['cnn_pretrained'] = False
        
        enc = construct_Cnn_features(model_state)
    else:
        # no longer support TODO: remove
        if state['cnn_pretrained']: 
            logging.info(f"Load {arch_key} checkpoint from torchvision.models")
        else:
            logging.info(f"Pretrained off.")
            
        model_state = {
            'architecture': state[base_cnn_key],
            'num_classes': 1000, # use ImageNet weights in this configuration
            'cnn_pretrained': state['cnn_pretrained'],
            'state_dict': None,
        }
        enc = construct_Cnn_features(model_state)
    
            
    enc = enc.to(state['device'])
    enc_multi = torch.nn.DataParallel(enc)
    
    # ===== wrapper
    wrapper_arch = models.__dict__[state[arch_key]]
    
    # if state[arch_key] == "ProtoWrapper":
    # elif state[arch_key] == "CwWrapper":
    # elif state[arch_key] == "CnnBWrapper":
    # elif state[arch_key] == "CnnWrapper":
    if state[arch_key] == "ProtoWrapper2":
        wrapper = wrapper_arch(enc, enc_multi, topk=state['topk'], mean=mean, std=std).to(state['device'])
    # elif state[arch_key] == "MultiHeadProtoWrapper":
    #     wrapper = wrapper_arch(enc, enc_multi, mean=mean, std=std, h=state['max_len']).to(state['device'])
    else:
        wrapper = wrapper_arch(enc, enc_multi, mean=mean, std=std).to(state['device'])
    # else:
    #     raise NotImplemented(f"The wrapper architecture {state[arch_key]} is not implemented.")
        
    return wrapper
        

def build_complete_sender(state):
    assert state['sender_percept_arch'] in agent_compat_lists[state['sender_arch']], \
        f"The perceptual architecture {state['sender_percept_arch']} is not compatible with {state['sender_arch']}"
    
    percept_wrapper = build_perceptual_wrapper(state, 
                                               'sender_percept_arch', 
                                               'sender_percept_ckpt', 
                                               'sender_base_cnn',
                                               'sender_mean',
                                               'sender_std')

    sender_arch = agents2.__dict__[state['sender_arch']]
    if "MultiHeadRnnSenderGS" in state['sender_arch']:
        sender = sender_arch(input_size=state['sender_input_dim'], 
                             structure_size=state['sender_structure_dim'],
                             heads=state['max_len'],
                             vocab_size=state['vocab_size'], 
                             hidden_size=state['hidden_dim'], 
                             max_len=state['max_len'], 
                             embed_dim=state['embed_dim'], 
                             straight_through=state['gs_st'], 
                             cell=state['sender_cell'], 
                             trainable_temperature=state['learnable_temperature']).to(state['device'])
    else:
        sender = sender_arch(input_size=state['sender_input_dim'], 
                             vocab_size=state['vocab_size'], 
                             hidden_size=state['hidden_dim'], 
                             max_len=state['max_len'], 
                             embed_dim=state['embed_dim'], 
                             straight_through=state['gs_st'], 
                             cell=state['sender_cell'], 
                             trainable_temperature=state['learnable_temperature']).to(state['device'])
    
    if state['sender_ckpt'] != '':
        logging.info(f"Loading sender agent checkpoint.")
        sender.load_state_dict(torch.load(state['sender_ckpt']))
        
    return percept_wrapper, sender
    
    
def build_complete_receiver(state):
    assert state['recv_percept_arch'] in agent_compat_lists[state['recv_arch']], \
        f"The perceptual architecture {state['recv_percept_arch']} is not compatible with {state['recv_arch']}"
    
    percept_wrapper = build_perceptual_wrapper(state, 
                                               'recv_percept_arch', 
                                               'recv_percept_ckpt', 
                                               'recv_base_cnn',
                                               'recv_mean',
                                               'recv_std')

    recv_arch = agents2.__dict__[state['recv_arch']]
    receiver = recv_arch(vocab_size=state['vocab_size'], 
                          num_distractors=state['distractors'], 
                          embed_dim=state['embed_dim'], 
                          hidden_size=state['hidden_dim'], 
                          aux_size=state['recv_input_dim'], 
                          cell=state['receiver_cell']).to(state['device']) 
    
    if state['recv_ckpt'] != '':
        logging.info(f"Loading receiver agent checkpoint.")
        receiver.load_state_dict(torch.load(state['recv_ckpt']))
        
    return percept_wrapper, receiver


def full_system_from_row(row, log=print, test=True, from_epoch=None):
    state = row
    
    if from_epoch is None:
        from_epoch = int(state['best_epoch'])

    save_dir = os.path.join(state['save_dir'], str(state['run_id']))
    latest_ckpt_file = get_last_semiotic_model_file(save_dir, by_epoch=from_epoch)    
    
    if latest_ckpt_file:
        # take push model instead if it is available
        state['sender_percept_ckpt'] = os.path.join(save_dir, latest_ckpt_file)
        # print(f"Loading {sender_encoder_path} based on best epoch {curr_epoch}")
        
    basename = os.path.join(state['save_dir'], str(state['run_id']), f"sender_e{from_epoch}")
    if os.path.exists(basename + '.pt'):
        state['sender_ckpt'] = basename + '.pt'
    else:
        state['sender_ckpt'] = basename + '.pth'
    
    basename = os.path.join(state['save_dir'], str(state['run_id']), f"receiver_e{from_epoch}")
    if os.path.exists(basename + '.pt'):
        state['recv_ckpt'] = basename + '.pt'
    else:
        state['recv_ckpt'] = basename + '.pth'
    
    sender_wrapper, sender = build_complete_sender(state)
    recv_wrapper, receiver = build_complete_receiver(state)
    
    for model in [sender_wrapper.model, sender_wrapper.model_multi, 
                      recv_wrapper.model, recv_wrapper.model_multi, 
                      sender, receiver]:
        model = model.to(state['device'])
    
    if test:
        for model in [sender_wrapper.model, sender_wrapper.model_multi, 
                      recv_wrapper.model, recv_wrapper.model_multi, 
                      sender, receiver]:
            model.eval()
    
    return sender_wrapper, sender, recv_wrapper, receiver
