import torch
import torch.nn as nn
import random
import numpy as np
import argparse
import os
import re
import json
import shutil
import logging
import sys

from tqdm import tqdm
from torchvision import transforms
from torch.nn import functional as F
import torchvision.datasets as datasets

from ProtoPNet.settings import img_size, base_architecture, \
                               prototype_activation_function, add_on_layers_type, num_data_workers
from ProtoPNet.preprocess import preprocess_input_function
import ProtoPNet.prune as prune
import ProtoPNet.push as push
from ProtoPNet import model
from ConvNet import test_and_train as ct

import util
from util import *
from dataloader_dali import *
import agents2
from agents2 import *
from games import SignalGameGS
import losses
from losses import loss_nll, loss_xent, least_effort
import train
from community import train_and_test as tnt
from model_builder import build_complete_sender, build_complete_receiver
        


parser = argparse.ArgumentParser()
parser.add_argument('--config', required=True, help='path to configuration')
parser.add_argument('--gpuid', nargs='+', type=str, default="0")
args = vars(parser.parse_args())

os.environ['CUDA_VISIBLE_DEVICES'] = args['gpuid'][0]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device={device}:{os.environ['CUDA_VISIBLE_DEVICES']}")

state = util.load_json(args['config'])
state['device'] = device


if "run_id" in list(state.keys()):
    run_id = state["run_id"] if state["run_id"] != "" else util.get_time_stamp()
    run_id = str(run_id)
else:
    run_id = util.get_time_stamp()

state['save_dir'] = os.path.join(state['save_dir'], run_id)
if not os.path.exists(state['save_dir']):
    os.makedirs(state['save_dir'])
        
# ============ start logging
logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p',
                    handlers=[
                        logging.FileHandler(os.path.join(state['save_dir'], 'train_full.log')),
                        logging.StreamHandler()
                             ],
                    level=logging.DEBUG)
log = logging.getLogger('train_full')

if "Proto" in state['sender_percept_arch']:
    approach = "proto"
elif "Cw" in state['sender_percept_arch']:
    approach = "cw"
elif "CnnB" in state['sender_percept_arch']:
    approach = "cnnb"
else:
    approach = "cnn"

try:
    np.random.seed(state['seed'])
    torch.manual_seed(state['seed'])
    torch.backends.cudnn.deterministic = True
    
    config_basename = args['config'].split('/')[-1]
    shutil.copyfile(args['config'], os.path.join(state['save_dir'], config_basename))

    history_path = os.path.join(state['save_dir'], 'history.pkl')
    epoch_histories = []
    start_epoch = 0
    
    log.info(f"Sender percept mean: {state['sender_mean']}, std: {state['sender_std']}")
    log.info(f"Receiver percept mean: {state['recv_mean']}, std: {state['recv_std']}")
    
    ## ========================= manage restarts, invariant of perceptual wrapper
    cond1 = state['restart']
    cond2 = util.any_matching_prefix(state['save_dir'], 'sender_e') and util.any_matching_prefix(state['save_dir'], 'receiver_e')
    if cond1 and cond2:
        state['sender_ckpt'] = util.get_latest_checkpoint(state['save_dir'], 'sender_e')
        state['recv_ckpt'] = util.get_latest_checkpoint(state['save_dir'], 'receiver_e')
        log.info(f"\Found sender at {state['sender_ckpt']}")
        log.info(f"\Found receiver at {state['recv_ckpt']}")

        start_epoch = int(re.search(r'\d+', state['sender_ckpt'].split('/')[-1]).group(0))
        log.info(f"\tAdvance to epoch {start_epoch}")

        # update signs model to the latest available epoch (only sender is trained)
        latest_ckpt_file = get_last_semiotic_model_file(state['save_dir'], by_epoch=start_epoch)
        if latest_ckpt_file:
            state['sender_percept_ckpt'] = os.path.join(state['save_dir'], latest_ckpt)

        epoch_histories = pickle_load(history_path)
    
    num_distractors = state['distractors']

    # Initialize models
    # ==================================
    sender_percept, _sender = build_complete_sender(state)
    recv_percept, _receiver = build_complete_receiver(state)
    
    # disable grad and only enable later
    for model in [sender_percept, _sender, recv_percept, _receiver]:
        disable_parameter_requires_grad(model)
        
    set_eval([sender_percept, _sender, recv_percept, _receiver])
    
    log.info('\t== system summary ==')
    log.info('sender: ' + ' -> '.join([
        str(state['sender_base_cnn']),
        str(sender_percept.__class__.__name__),
        str(_sender.__class__.__name__)]))
    log.info('receiver: ' + ' -> '.join([
        str(state['recv_base_cnn']),
        str(recv_percept.__class__.__name__),
        str(_receiver.__class__.__name__),
    ]))

    # go through prototype module to acquire grads (semiotic) or pass through otherwise (static proto and cached)
    # sometimes we have concept logits to deal with (semiotic case)
    # The receiver always receives cached features, e.g., pass_through
    def pass_through(reprs):
        return None, (reprs, None), None  # None, (feats, structures), None

    # semiotic input: images
    def sender_process(images):
        # for proto: logits, (feats, structures), min_distances
        # otherwise: logits, (feats, structures), None
        return sender_percept(images)  
    
    # During training, guaranteed to always get cached feats from loader, because receiver is never updated
    def receiver_process(reprs):
        return pass_through(reprs)
    
    # dataloaders
    # ==================================
    # receiver and sender models are separate
    # normalization happens inside percept wrappers
    log.debug("Initiate training loader")
    semiotic_train_loader = separated_cached_distractor_train_loader(state, 
                                                                     sender_percept.prelinguistic, 
                                                                     recv_percept.prelinguistic, 
                                                                     img_size, mean=None, std=None)
    log.debug("Initiate test loader")
    semiotic_test_loader = separated_cached_distractor_test_loader(state, 
                                                                   sender_percept.prelinguistic, 
                                                                   recv_percept.prelinguistic, 
                                                                   img_size, mean=None, std=None)

    # loader for true conv projection operation
    train_push_loader = push_train_loader(state, img_size, seed=state['seed'])

    # loader for last layer convex optimization (of sender)
    ll_train_loader = normalized_train_loader(state, img_size, state['seed'], 
                                              mean=state['sender_mean'], 
                                              std=state['sender_std'])
    ll_test_loader = normalized_test_loader(state, img_size, state['seed'], 
                                            mean=state['sender_mean'], 
                                            std=state['sender_std'])

    log.info(f'training set size: {semiotic_train_loader.dataset_size}')
    log.info(f'test set size: {semiotic_test_loader.dataset_size}')
    log.info(f"batch size: {state['train_batch_size']}/{state['test_batch_size']}")
    log.info(f"push batch size: {state['train_push_batch_size']}")


    assert len(state['aux_losses']) == len(state['aux_weights']), "Each aux loss should have a weight!"
    aux_losses, system_losses = losses.unpack_losses(state['aux_losses'], state['aux_weights'])


    # Game define
    # ==================================
    _game = SignalGameGS(_sender, _receiver, loss_xent, 
                         length_cost=state['length_cost'], 
                         aux_losses=aux_losses,
                         sys_losses=system_losses)


    # housekeeping
    # ==================================
    # Always remove
    canary_path = os.path.join(state['save_dir'], "canary.txt")
    if os.path.exists(canary_path):
        os.remove(canary_path)

    epoch_log = EpochHistory(start_epoch)
    recv_percept.model_multi.eval()
    
    # Check Epoch 0 metrics using cached representations
    metrics = check_model(sender_percept, _game, 
                          pass_through, receiver_process, 
                          semiotic_test_loader, device, return_metrics=True)
    epoch_log.log_accuracy(metrics['accuracy'])

    torch.save(_game.sender.state_dict(), os.path.join(state['save_dir'], f'sender_e{start_epoch}.pth'))
    torch.save(_game.receiver.state_dict(), os.path.join(state['save_dir'], f'receiver_e{start_epoch}.pth'))
    log.info(f"Checkpointed at {state['save_dir']}")

    # baseline sender classifier accuracy
    if approach == 'proto':
        pack = tnt.test(model=sender_percept.model_multi, dataloader=ll_test_loader,
                        class_specific=True, log=log.info)
        epoch_log.log_concept_accuracies({'push': pack[2]})
        epoch_log.log_concept_costs({'push': pack})
    else:
        accu = ct.test(state, sender_percept.model_multi, ll_test_loader, log.info)
        epoch_log.log_concept_accuracies({'push': accu})   
        
    epoch_histories.append(epoch_log)

    img_dir = os.path.join(state['save_dir'], 'sign-img')
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)


    # optimizers define
    # ==================================
    # simulate agents condition
    util.agents_only(sender_percept, _game.sender, _game.receiver, log=log.debug)

    sender_params_to_update = []
    for name,param in _game.sender.named_parameters():
        if param.requires_grad == True:
            sender_params_to_update.append(param)

    recv_params_to_update = []
    for name,param in _game.receiver.named_parameters():
        if param.requires_grad == True:
            recv_params_to_update.append(param)

    static_optimizer = torch.optim.Adam([
        {'params': sender_params_to_update, 
         'lr': state['sender_lr']},
        {'params': recv_params_to_update, 
         'lr': state['receiver_lr']}
    ])

    if len(state['semiotic_sgd_epochs']):
        for model in [sender_percept,  _game.sender, _game.receiver, _receiver]:
            disable_parameter_requires_grad(model)
        
        # simulate semiosis condition
        util.semiosis_joint(sender_percept, _game.sender, _game.receiver, log=log.debug)
        
        sender_params_to_update = []
        for name,param in _game.sender.named_parameters():
            if param.requires_grad == True:
                sender_params_to_update.append(param)

        recv_params_to_update = []
        for name,param in _game.receiver.named_parameters():
            if param.requires_grad == True:
                recv_params_to_update.append(param)
            
        semiotic_optimizer_specs = \
        [
            {'params': sender_params_to_update, 
             'lr': state['sender_lr']},
            {'params': recv_params_to_update, 
             'lr': state['receiver_lr']}
        ]
        if approach == 'proto':
            semiotic_optimizer_specs.extend([
                {'params': sender_percept.model.features.parameters(), 
                 'lr': state['features_lr'], 
                 'weight_decay': 1e-3}, 
                {'params': sender_percept.model.add_on_layers.parameters(), 
                 'lr': state['add_on_layers_lr'], 
                 'weight_decay': 1e-3},
                {'params': sender_percept.model.prototype_vectors, 
                 'lr': state['prototype_vectors_lr']}
            ])
            classifier_optimizer_specs = [
                {'params': sender_percept.model.last_layer.parameters(), 
                 'lr': state['last_layer_lr']}
            ]
        else:
            semiotic_optimizer_specs.append(
                {'params': sender_percept.model.base_model.parameters(), 
                 'lr': state['features_lr'], 
                 'weight_decay': 1e-3}, 
            )
            classifier_optimizer_specs = [
                {'params': sender_percept.model.classifier.parameters(), 
                 'lr': state['last_layer_lr']}
            ]
            
        classifier_optimizer = torch.optim.Adam(classifier_optimizer_specs)
        semiotic_optimizer = torch.optim.Adam(semiotic_optimizer_specs)
        # joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(semiotic_optimizer, step_size=10, gamma=0.1)

    # Learn social task T
    # ==================================
    for epoch in np.arange(start_epoch+1, state['epochs']+1, 1):
        # decide the loader/optim to use based on this epoch's task
        semiotic_sgd = epoch >= state['semiosis_start'] and epoch in state['semiotic_sgd_epochs']
        semiotic_push = epoch >= state['semiosis_start'] and epoch in state['semiotic_push_epochs']
        # If we do push, assume semitotic sgd scenario
        semiotic_sgd = (semiotic_sgd or semiotic_push)
        
        train_loader = semiotic_train_loader
        test_loader = semiotic_test_loader

        if semiotic_sgd or semiotic_push:
            _optim = semiotic_optimizer

            sender_interpret = sender_process
            receiver_interpret = receiver_process  # always pass_through
            epoch_mode = 'semiotic'
        else:
            _optim = static_optimizer
            sender_interpret = pass_through
            receiver_interpret = pass_through
            epoch_mode = 'static'

        # disable grad and only enable next
        for model in [sender_percept, _sender, recv_percept, _receiver]:
            disable_parameter_requires_grad(model)

        set_eval([sender_percept, recv_percept])
    
        # init loader based on task (for caching)
        train_loader.start_epoch(epoch_mode)
        test_loader.start_epoch(epoch_mode)

        # enable/disable grad
        if semiotic_sgd:
            if semiotic_push:
                set_train([sender_percept])
                util.semiosis_classifier(sender_percept, _game.sender, _game.receiver, log=log.debug)
            else:
                set_train([sender_percept, _game.sender, _game.receiver])
                util.semiosis_joint(sender_percept, _game.sender, _game.receiver, log=log.debug)
        else:
            set_train([_game.sender, _game.receiver])
            util.agents_only(sender_percept, _game.sender, _game.receiver, log=log.debug)

        log_whole_system_params(sender_percept, recv_percept, _game, log=log.debug)  
        
        epoch_log = train.one_epoch(state, epoch, device, train_loader, 
                                    sender_interpret, receiver_interpret, 
                                    sender_percept, _game, _optim, approach, log)
        push_str = 'nopush'


        # log agents task success
        metrics = check_model(sender_percept, _game, 
                              sender_interpret, receiver_interpret, 
                              test_loader, device, return_metrics=True)
        epoch_log.log_accuracy(metrics['accuracy'])
        log.info(f"Receiver test accuracy @ Epoch {epoch}:\t{metrics['accuracy']:.2f}")

        # log signs model accuracies (semiotic prototype model only)
        if semiotic_sgd or semiotic_push:
            # joint_lr_scheduler.step()

            if approach == 'proto':
                pack = tnt.test(model=sender_percept.model_multi, dataloader=ll_test_loader,
                                class_specific=True, log=log.info)
                epoch_log.log_concept_accuracies({push_str: pack[2]})
                epoch_log.log_concept_costs({push_str: pack})
                log.info(f"{push_str} sign concept test accuracy @ Epoch {epoch}:\t{pack[2]:.2f}") 
            else:
                accu = ct.test(state, sender_percept.model_multi, ll_test_loader, log.info)
                epoch_log.log_concept_accuracies({push_str: accu})  

            util.save_enc_model(model=sender_percept.model, 
                                model_dir=state['save_dir'], 
                                model_name=str(epoch) + push_str, log=log.debug)

        train_loader.end_epoch(epoch_mode)
        test_loader.end_epoch(epoch_mode)

        epoch_histories.append(epoch_log)

        if epoch % state['checkpoint_interval'] == 0:
            torch.save(_game.sender.state_dict(), os.path.join(state['save_dir'], f'sender_e{epoch}.pth'))
            torch.save(_game.receiver.state_dict(), os.path.join(state['save_dir'], f'receiver_e{epoch}.pth'))
            log.info(f"Checkpointed at {state['save_dir']}")

        pickle_write(history_path, epoch_histories)

        del epoch_log  # failsafe


    with open(canary_path, 'w') as cf:
        cf.write("We made it.\n")
        
except Exception as e:
    log.exception(e)
