import os
import sys
import time
import argparse
import random
from random import choices

import numpy as np

import torch
from torch.autograd import Variable
import torch.nn as nn

from torch.nn.modules.loss import MSELoss, CrossEntropyLoss
#from emb2emb.losses import FlipLoss
#from emb2emb.classifier import train_binary_classifier


sys.path.append('../')
sys.path.append('../autoencoders/')

DEFAULT_CONFIG = "../autoencoders/config/default.json"

from emb2emb.classifier import train_binary_classifier
from emb2emb.losses import CosineLoss, FlipLoss, BacktranslationLoss, CombinedBaseLoss, SumBaseLoss,\
    AlignmentLoss, MeanSimilarityLoss, HausdorffLoss, LocalBagLoss
from emb2emb.architectures import MLP, HighwayNetwork, OffsetVectorMLP, FixOffsetVectorMLP, MeanOffsetVectorMLP, ResNet, BovToBovMapping, BovIdentity, BovOracle, SimpleBovMapping
from emb2emb.encoders import RAEDecoder, RAEEncoder, AEEncoder, AEDecoder, tokenize
from emb2emb.utils import glove_dict, get_data, pretty_print_prediction,\
    Namespace
from emb2emb.trainer import Emb2EmbTrainer, MODE_EMB2EMB, MODE_FINETUNEDECODER, MODE_SEQ2SEQ, MODE_SEQ2SEQFREEZE
from emb2emb.analyze_l0drop import compute_l0drop_statistics, plot_num_words, compute_neighborhood_preservation
from emb2emb.hausdorff import get_local_hausdorff_similarities_function, get_local_classifier_loss, get_weighted_localbagloss_function,\
    get_local_regression_loss
from emb2emb.classifier import BoVBinaryClassifier, binary_clf_predict
from emb2emb.gmm import get_local_gmm_divergence, gmm_jsd, gmm_kl, gmm_symkl


def get_train_parser():
    parser = argparse.ArgumentParser(description='Emb2Emb')
    # paths
    parser.add_argument("--dataset_path", type=str,
                        default='../data/nli/', help="Path to dataset")
    parser.add_argument("--outputdir", type=str,
                        default='savedir/', help="Output directory")
    parser.add_argument("--outputmodelname", type=str, default='model.pickle')
    parser.add_argument("--glove_path", type=str,
                        default="../resource/smallglove.840B.300d.txt", help="Path to glove embeddings")
    parser.add_argument("--modeldir", type=str,
                        default="../autoencoders/lstmvae_nli.pt", help="Path to autoencoder dir")
    parser.add_argument("--save_dev_set_predictions", action="store_true")
    parser.add_argument("--model_name", type=str,
                        default="model.pt", help="Name of the model file.")
    parser.add_argument("--vocab_path", type=str,
                        default="", help="Path to vocabulary.")
    parser.add_argument("--real_data_path", type=str, default="input",
                        help="If 'input' is specified, we use the target sequence embeddings for adversarial regularization. Otherwise randomly sample from the data file given at the path.")
    parser.add_argument("--binary_classifier_path", type=str, default=None,
                        help="Path to the BERT SequenceClassification model and it's tokenizer.")
    parser.add_argument("--output_file", type=str, default='output.csv',
                        help="Output file for csv to store results.")
    parser.add_argument("--load_emb2emb_path", type=str, default=None,
                        help="Path to already trained emb2emb model.")
    parser.add_argument("--no_cleanup", action="store_true",
                        help="If set, the moel under 'load_emb2emb_path' is not deleted after training finishes")
    parser.add_argument("--pretrained_model", type=str, default=None,
                        help="Path to a pretrained model to initialize with. Not going to be used if load_emb2emb_path is set.")
    parser.add_argument("--dont_predict_done", action="store_true")
    parser.add_argument("--use_end_of_sequence_vector", action="store_true")
    parser.add_argument("--select_input_length", action="store_true")
    parser.add_argument("--end_of_sequence_epsilon",
                        type=float, default=0.00001)
    parser.add_argument("--learned_positional_embeddings", action="store_true")
    parser.add_argument("--train_classifier_only", action="store_true")
    parser.add_argument("--project_input_dimension", type=int, default=128)
    parser.add_argument("--discriminate_moments", action="store_true")

    # training
    parser.add_argument("--validate", action="store_true")
    parser.add_argument("--validation_frequency", type=int, default=-1)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--lr_bclf", type=float, default=0.0001)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--n_epochs", type=int, default=20)
    parser.add_argument("--n_epochs_binary", type=int, default=5)
    parser.add_argument("--load_binary_clf", action="store_true")
    parser.add_argument("--sentence_wise_evaluation", action="store_true")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--mode", type=str, default=MODE_EMB2EMB, help="The training mode to use.",
                        choices=[MODE_EMB2EMB, MODE_FINETUNEDECODER, MODE_SEQ2SEQ, MODE_SEQ2SEQFREEZE])

    # model
    parser.add_argument("--emb2emb", type=str, default='mlp', help="emb2emb architecture to use",
                        choices=["resnet", "meanoffsetvector", "fixoffsetnet", "mlp", "identity", "highway", "offsetnet", "bovtobov", "bovidentity", "bovoracle", "simplebov"])
    parser.add_argument("--heads", type=int, default=1,
                        help="Number of heads in bovtobov mapping")
    parser.add_argument("--backprop_through_outputs", action="store_true")
    parser.add_argument("--teacher_forcing", type=float, default=0.0)
    parser.add_argument("--bov_output_layer", action="store_true")
    parser.add_argument("--max_length", type=int, default=30,
                        help="Maximum number of vectors to generate in BovToBovMapping.")
    parser.add_argument("--max_input_length", type=int, default=999,
                        help="Maximum number of vectors the encoder generates.")
    parser.add_argument("--remove_sos_and_eos", action="store_true",
                        help="Remove SOS and EOS symbols from input sequences for debugging purposes.")
    parser.add_argument("--similarity_function", type=str, default='euclidean', help="Similairy function used in hausdorff loss.",
                        choices=["euclidean", "cosine"])
    parser.add_argument("--offset", action="store_true")
    parser.add_argument("--point_gen", action="store_true")
    parser.add_argument("--point_gen_offset", action="store_true")
    parser.add_argument("--point_gen_out_to_in", action="store_true")
    parser.add_argument("--point_gen_mask_first_vector", action="store_true")
    parser.add_argument("--point_gen_context_vector", action="store_true")
    parser.add_argument(
        "--point_gen_offset_no_copy_dependence", action="store_true")
    parser.add_argument("--point_gen_coverage", action="store_true")
    parser.add_argument("--emb2emb_noise", type=float, default=0.,
                        help="Amount of noise to add to input embeddings when training emb2emb model.")
    parser.add_argument("--dropout_p", type=float, default=0.,
                        help="Amount of dropout to have in the emb2emb model.")
    parser.add_argument("--dropout_binary", type=float, default=0.,
                        help="Amount of dropout in binary classifier.")
    parser.add_argument("--vector_distortion_rate", type=float, default=0.,
                        help="Amount of distortion added when training the binary classifier.")
    parser.add_argument("--vector_distortion_probability", type=float, default=0.,
                        help="Probability that any given batch is distorted when training the binary classifier.")
    parser.add_argument("--gaussian_noise_binary", type=float, default=0.,
                        help="Amount of gaussian noise in binary classifier.")
    parser.add_argument("--offset_dropout_p", type=float, default=0.,
                        help="Amount of dropout to have in the offset vectors of OffsetNetworks.")
    parser.add_argument("--meanoffsetvector_factor", type=float,
                        default=2., help="Initialization for MeanOffsetVector factor.")
    parser.add_argument("--residual_connections", action="store_true")
    parser.add_argument("--skip_connections", action="store_true")
    parser.add_argument("--outlayers", action="store_true",
                        help="Add an outlayers to computation of offset vector in offsetnet.")
    parser.add_argument("--activate_result", action="store_true",
                        help="Add a nonlinerity after adding the non-linearity in offsetnet")
    parser.add_argument("--loss", type=str, default='cosine',
                        help="loss", choices=["mse", "cosine", "ce", "fliploss", "hausdorff", "localbagloss"])
    parser.add_argument("--baseloss", type=str, default='cosine', help="loss",
                        choices=["mse", "cosine", "backtranslation", "combined", "combinedsum", "alignmentloss", "meansim"])
    parser.add_argument("--al_differentiable", action="store_true")
    parser.add_argument("--al_detach", action="store_true")
    parser.add_argument("--al_softmax_temp", type=float, default=1.0)
    parser.add_argument("--denoise", type=float, default=0.0)
    parser.add_argument("--al_alpha", type=float, default=0.5)
    parser.add_argument("--al_force_expected_gate_value",
                        default=-1.0, type=float)
    parser.add_argument("--al_magnitude_weighting", type=str, default=None, choices=["norm", "gates"],
                        help="If set, use the magnitude of a vector to weight its influence on the hausdorff similarity (either gate value or magnitude value).")
    parser.add_argument("--al_weighting", type=str, default='uniform', help="How to weight the steps in hausdorff loss.",
                        choices=["uniform", "window", "out_lens", "uniform_till_input", "sumtoone"])
    parser.add_argument("--al_input_center_factor", type=float, default=1.0)
    parser.add_argument("--al_bag_loss", nargs="+", type=str, default=['hausdorff'], help="Loss function(s) to use for local bag loss.",
                        choices=["hausdorff", "classifier", "adversarial", "length", "gmmkl", "gmmsymkl", "gmmjsd"])
    parser.add_argument("--al_classifier_loss_target", type=float, default=0.0)
    parser.add_argument("--al_classifier_loss_freebits", default=None, type=float, help="Set the free bits that are used. Set closer to 1.0 to cap the loss that the model receives from the sentiment classification stronger.")
    parser.add_argument("--al_bag_loss_weights", nargs="+", type=float,
                        default=[1.0], help="How to weight the loss functions.")
    parser.add_argument("--al_weighting_center", type=str, default='input', help="Where to put the center when using 'window' weighting.",
                        choices=["input", "optimum"])
    parser.add_argument("--al_windowsize", type=int, default=3,
                        help="Size of the window when using 'window' weighting.")
    parser.add_argument("--gmm_approximation", type=str, default="variational_lower",
                        help="Type of approximation to GMM KL divergence.")
    parser.add_argument("--gmm_sigma", type=float, default=1.0,
                        help="Fixed sigma used for KL divergence.")
    parser.add_argument("--gmm_weighting", type=str, default="uniform", choices=["uniform", "model", "magnitude"],
                        help="How to weight the components of the gmm. 'uniform': uniform weighting over all components. \
                        'model': First vector dimension determines softmax-logits of the respective component. \
                        'magnitude': Sum-normalization over the vectors magnitudes is applied.")
    parser.add_argument("--force_target_length", action="store_true")
    parser.add_argument("--force_eval_mode", action="store_true")
    parser.add_argument(
        "--al_directions", default=["input", "output"], nargs="+", choices=["input", "output"])
    parser.add_argument("--bt_baseloss", type=str,
                        default='cosine', help="loss", choices=["mse", "cosine"])
    parser.add_argument("--increase_clf_lambda_until", type=int, default=0,
                        help="Number of epochs over which to increase clf lambda to full value.")
    parser.add_argument("--lambda_clfloss", type=float, default=0.5,
                        help="Weight of the clf loss in comparison to the baseloss. Specify between 0 and 1.")
    parser.add_argument("--highway_bias", type=float, default=-1.)
    parser.add_argument("--n_layers", type=int, default=1,
                        help="Number of layers to use in the Emb2Emb model.")
    parser.add_argument("--hidden_layer_size", type=int, default=1024,
                        help="Hidden layer size to use in the Emb2Emb model.")
    parser.add_argument("--autoencoder", type=str, default="FromFile",
                        help="Specify the autoencoder to use.", choices=["FromFile", "RAE"])
    parser.add_argument("--estimate_distribution", action="store_true",
                        help="Return the estimated distribution (mu,v) instead of the noisy sample.")
    parser.add_argument("--fast_gradient_iterative_modification", action="store_true",
                        help="Follow the gradient of the binary classifier to change the label.")
    parser.add_argument("--fgim_decay", type=float, default=1.0)
    parser.add_argument("--fgim_threshold", type=float, default=0.001)
    parser.add_argument("--fgim_customloss", action="store_true")
    parser.add_argument("--fgim_start_at_y", action="store_true")
    parser.add_argument("--beam_width", type=int, default=1,
                        help="Beam width to use in AEDecoder. If 1, greedy decoding is used.")

    parser.add_argument("--binary_dense_layer_size", type=int, default=16,
                        help="Size of the final dense layer in binary classifier.")

    # adversarial reg for emb2emb
    parser.add_argument("--adversarial_regularization", action="store_true",
                        help="Perform adversarial regularization while training emb2emb.")

    parser.add_argument("--adversarial_remove_last", action="store_true",
                        help="Remove the end-of-sequence token before plugging things into the discriminator.")
    parser.add_argument("--tokenwise", action="store_true",
                        help="Perform adversarial regularization token-wise.")
    parser.add_argument("--critic_lr", type=float,
                        default=0.00001, help="LR for training the critic.")
    parser.add_argument("--critic_rounds", type=int, default=100,
                        help="How many steps to go through each phase.")
    parser.add_argument("--adversarial_rounds", type=int, default=100,
                        help="How many steps to go through each phase.")
    parser.add_argument("--task_rounds", type=int, default=100,
                        help="How many steps to go through each phase.")
    parser.add_argument("--joint", action="store_true",
                        help="In each iteration, train reconstruction, critic, and adversarial.")
    parser.add_argument("--joint_rec_adv", action="store_true",
                        help="When training adversarially, do it jointly with the reconstruction task.")
    parser.add_argument("--critic_hidden_layers", type=int,
                        default=1, help="Number of hidden layers the critic has.")
    parser.add_argument("--n_layers_binary", type=int,
                        default=1, help="Number of hidden layers the binary classifier has.")
    parser.add_argument("--hidden_size_binary", type=int,
                        default=300, help="Hidden size of the binary classifier.")
    parser.add_argument("--n_heads_binary", type=int,
                        default=4, help="Number of heads used in the BoV binary classifier.")
    parser.add_argument("--critic_hidden_units", type=int,
                        default=300, help="Number of hidden units the critic has.")
    parser.add_argument("--adversarial_reconstruction_weight", type=float, default=1.0,
                        help="Weight of adversarial loss. Decrease to reduce the adversarial loss term's influence.")
    parser.add_argument("--lambda_schedule", type=str,
                        default="fixed", choices=["fixed", "annealing", "dynamic"])
    parser.add_argument("--adversarial_delay", type=float, default=0.,
                        help="number of epochs to wait until starting to increase weight")
    parser.add_argument("--dynamic_lambda_epsilon", type=float, default=0.03,
                        help="The maximal tolarable deviation from total confusion (log(0.5)")
    parser.add_argument("--dynamic_lambda_stepsize", type=float,
                        default=0.01, help="The amount by which to increase lambda.")
    parser.add_argument("--log_adversarial_lambda", action="store_true",
                        help="When this flag is set, dynamic lambda scheduling interprets lambdas (except initial one) as exponents of 2.")
    parser.add_argument("--dynamic_lambda_frequency", type=int, default=100,
                        help="The frequency (# of batches) whith which to check whether lambda need to be changed.")
    parser.add_argument("--dynamic_lambda_target", type=float, default=0.5,
                        help="The target accuracy we want the generator to achieve in fooling the discriminator.")
    parser.add_argument("--unaligned", action="store_true",
                        help="If set, input and desired output to the basemodel are the same.")
    # reproducibility
    parser.add_argument("--seed", type=int, default=1337,
                        help="Seed. If specified as 'None', the seed is generated automatically using the current time.")

    # data
    parser.add_argument("--embedding_dim", type=int,
                        default=1024, help="sentence embedding dimension")
    parser.add_argument("--data_fraction", type=float,
                        default=1., help="How much of the data to use.")
    parser.add_argument("--test_data_fraction", type=float,
                        default=1., help="How much of the data to use.")
    parser.add_argument("--print_outputs", action="store_true",
                        help="Print some of the outputs at validation time for inspection.")
    parser.add_argument("--max_prints", type=int, default=5,
                        help="How many examples to print during validation time.")
    parser.add_argument("--log_freq", type=int, default=100,
                        help="How often to print the logs.")
    parser.add_argument("--eval_self_bleu", action="store_true",
                        help="Whether to compute self-bleu scores on wikipedia.")
    parser.add_argument("--invert_style", action="store_true",
                        help="Whether to invert the style transfer task (yelp).")
    parser.add_argument('--track_input_output_distance', action="store_true")
    parser.add_argument('--compute_l0drop_statistics', action="store_true")
    parser.add_argument(
        '--compute_neighborhood_preservation', action="store_true")
    parser.add_argument('--plot_num_words', action="store_true")
    return parser


def get_params():
    parser = get_train_parser()
    params, unknown = parser.parse_known_args()
    if len(unknown) > 0:
        raise ValueError("Got unknown parameters " + str(unknown))
    return params


def get_encoder(params, device, model_state_dict=None):
    if params.autoencoder == "RAE":
        glove = glove_dict(params.glove_path)
        config = {"model_path": os.path.join(
            params.modeldir, params.model_path), "glove": glove, "use_lookup": True, "device": device}
        encoder = RAEEncoder(config)
    else:
        config = {"modeldir": params.modeldir, "use_lookup": True,
                  "device": device, "default_config": DEFAULT_CONFIG,
                  "max_sequence_len": params.max_input_length,
                  "remove_sos_and_eos": params.remove_sos_and_eos}
        encoder = AEEncoder(config)
    return encoder


def get_decoder(params, device, model_state_dict=None):
    if params.autoencoder == "RAE":
        glove = glove_dict(params.glove_path)
        config = {"model_path": params.model_path,
                  "glove": glove, "device": device}
        decoder = RAEDecoder(config)
    else:
        config = {"modeldir": params.modeldir, "device": device,
                  "default_config": DEFAULT_CONFIG, "beam_width": params.beam_width}
        decoder = AEDecoder(config)
    return decoder


def get_emb2emb(params, encoder, train):
    if params.emb2emb == "mlp":
        return MLP(params.embedding_dim, params.n_layers, params.hidden_layer_size, residual_connections=params.residual_connections, skip_connections=params.skip_connections, dropout_p=params.dropout_p)
    if params.emb2emb == "identity":
        return nn.Sequential()
    if params.emb2emb == "bovidentity":
        return BovIdentity()
    if params.emb2emb == "bovoracle":
        return BovOracle()
    if params.emb2emb == "highway":
        return HighwayNetwork(params.embedding_dim, params.embedding_dim, params.n_layers, bias=params.highway_bias, final_layer=False)
    if params.emb2emb == "offsetnet":
        return OffsetVectorMLP(params.embedding_dim, params.n_layers,
                               dropout_p=params.dropout_p,
                               offset_dropout_p=params.offset_dropout_p,
                               outlayers=params.outlayers, activate_result=params.activate_result)
    if params.emb2emb == "resnet":
        return ResNet(params.embedding_dim, params.n_layers,
                      dropout_p=params.dropout_p,
                      offset_dropout_p=params.offset_dropout_p)
    if params.emb2emb == "fixoffsetnet":
        return FixOffsetVectorMLP(params.embedding_dim, params.n_layers)
    if params.emb2emb == "meanoffsetvector":
        return MeanOffsetVectorMLP(params.meanoffsetvector_factor, encoder, train["Sx"], train["Sy"])
    if params.emb2emb == "bovtobov":
        config = Namespace()
        config.layers = params.n_layers
        config.self_att_type = "full"
        config.cross_att_type = "full"
        config.heads = params.heads
        config.input_size = params.project_input_dimension
        config.ff_dimension = params.project_input_dimension
        config.positional_embeddings = True
        config.dropout = params.dropout_p
        config.att_dropout = params.dropout_p
        config.backprop_through_outputs = params.backprop_through_outputs
        config.offset = params.offset
        config.point_gen = params.point_gen
        config.point_gen_offset = params.point_gen_offset
        config.point_gen_out_to_in = params.point_gen_out_to_in
        config.point_gen_coverage = params.point_gen_coverage
        config.mask_first_vector = params.point_gen_mask_first_vector
        config.point_gen_context_vector = params.point_gen_context_vector
        config.point_gen_offset_copy_dependence = not params.point_gen_offset_no_copy_dependence
        config.output_layer = params.bov_output_layer
        config.learned_positional_embeddings = params.learned_positional_embeddings
        config.project_input_dimension = params.embedding_dim
        config.max_length = params.max_length
        config.adaptive_max_len = params.al_weighting == "window" and params.al_windowsize == 0
        mapping = BovToBovMapping(config)
        return mapping

    if params.emb2emb == "simplebov":
        config = Namespace()
        config.layers = params.n_layers
        config.self_att_type = "full"
        config.heads = params.heads
        config.input_size = params.project_input_dimension
        config.ff_dimension = params.project_input_dimension
        config.positional_embeddings = True
        config.dropout = params.dropout_p
        config.att_dropout = params.dropout_p
        config.offset = params.offset
        config.output_layer = params.bov_output_layer
        config.project_input_dimension = params.embedding_dim
        config.learned_positional_embeddings = params.learned_positional_embeddings
        mapping = SimpleBovMapping(config)
        return mapping


def get_lossfn(params, encoder, data):
    if params.loss == "mse":
        return MSELoss(reduction='none')
    elif params.loss == "cosine":
        return CosineLoss(reduction='none')
    elif params.loss == "ce":
        return CrossEntropyLoss(ignore_index=0)  # ignore padding symbol
    elif params.loss == "hausdorff":
        return HausdorffLoss(differentiable=params.al_differentiable,
                             detach=params.al_detach,
                             softmax_temp=params.al_softmax_temp,
                             weighting=params.al_weighting,
                             windowsize=params.al_windowsize,
                             weighting_center=params.al_weighting_center,
                             similarity_function=params.similarity_function)
    elif params.loss == "localbagloss":

        func_list = []
        for l in params.al_bag_loss:
            print(l)
            if l == "hausdorff":
                bag_loss_f = get_local_hausdorff_similarities_function(similarity_function=params.similarity_function,
                                                                       naive=False,
                                                                       differentiable=params.al_differentiable,
                                                                       softmax_temp=params.al_softmax_temp,
                                                                       naive_local=params.al_detach,  # complex one only works without detach
                                                                       alpha=params.al_alpha,
                                                                       magnitude_weighting=params.al_magnitude_weighting,
                                                                       force_expected_gate_value=params.al_force_expected_gate_value)
                func_list.append(bag_loss_f)
            elif l == "gmmkl":
                bag_loss_f = get_local_gmm_divergence(params.gmm_approximation,
                                                      params.gmm_weighting, params.gmm_sigma, divergence_f=gmm_kl)
                func_list.append(bag_loss_f)
            elif l == "gmmsymkl":
                bag_loss_f = get_local_gmm_divergence(params.gmm_approximation,
                                                      params.gmm_weighting, params.gmm_sigma, divergence_f=gmm_symkl)
                func_list.append(bag_loss_f)
            elif l == "gmmjsd":
                bag_loss_f = get_local_gmm_divergence(params.gmm_approximation,
                                                      params.gmm_weighting, params.gmm_sigma, divergence_f=gmm_jsd)
                func_list.append(bag_loss_f)
            elif l == "classifier":

                bclf = train_binary_classifier(
                    data['Sx'], data['Sy'], encoder, params)
                params.latent_binary_classifier = bclf
                no_gates = Namespace()
                no_gates.gates = False
                bag_loss_f = get_local_classifier_loss(
                    bclf, target=params.al_classifier_loss_target, params=no_gates, free_bits=params.al_classifier_loss_freebits)  # the target style is class zero
                func_list.append(bag_loss_f)
            elif l == "adversarial":
                config = Namespace()
                config.n_layers = params.n_layers_binary
                config.heads = params.n_heads_binary
                config.hidden_size = params.binary_dense_layer_size
                config.input_dim = params.hidden_size_binary
                config.embedding_dim = params.embedding_dim
                config.learned_positional_embeddings = params.learned_positional_embeddings
                config.dropout = 0.0
                config.gaussian_noise = 0.0
                config.vector_distortion_rate = 0.0
                config.vector_distortion_probability = 0.0
                params.critic = BoVBinaryClassifier(config).to(params.device)

                params.critic_loss = nn.BCEWithLogitsLoss()
                params.critic_optimizer = torch.optim.Adam(
                    params.critic.parameters(), lr=params.critic_lr)

                adv_bag_loss = get_local_classifier_loss(
                    params.critic, target=0.0, params=params)
                func_list.append(adv_bag_loss)

                # load real data
                if params.real_data_path == "input":
                    pass
                else:
                    params.real_data = _load_real_data(params.real_data_path)
            elif l == "length":
                bclf = train_binary_classifier(
                    data['Sx'], data['Sy'], encoder, params, regress=True)
                params.latent_binary_classifier = bclf
                bag_loss_f = get_local_regression_loss(bclf)
                func_list.append(bag_loss_f)
            else:
                raise ValueError(f"Unknown bag loss {params.al_bag_loss}")

        bag_loss_f = get_weighted_localbagloss_function(
            func_list, params.al_bag_loss_weights)
        return LocalBagLoss(
            bag_loss_f,
            detach=params.al_detach,
            weighting=params.al_weighting,
            windowsize=params.al_windowsize,
            weighting_center=params.al_weighting_center,
            input_center_factor=params.al_input_center_factor)
    elif params.loss == "fliploss":
        if params.baseloss == "cosine":
            baseloss = CosineLoss(reduction='none')
        elif params.baseloss == "mse":
            baseloss = MSELoss(reduction='none')
        elif params.baseloss in ["backtranslation", "combinedsum"]:
            if params.bt_baseloss == "cosine":
                bt_base = CosineLoss()
            elif params.bt_baseloss == "mse":
                bt_base = MSELoss()
            e2e = get_emb2emb(params, encoder, data)
            if params.baseloss == "backtranslation":
                baseloss = BacktranslationLoss(bt_base, e2e)
            elif params.baseloss == "combinedsum":
                baseloss = SumBaseLoss(bt_base, e2e)

        elif params.baseloss == "combined":
            e2e = get_emb2emb(params, encoder, data)
            if params.bt_baseloss == "cosine":
                bt_base = CosineLoss(reduction='none')
            elif params.bt_baseloss == "mse":
                bt_base = MSELoss(reduction='none')
            baseloss = CombinedBaseLoss(bt_base, e2e)
        elif params.baseloss == "alignmentloss":
            baseloss = AlignmentLoss(differentiable=params.al_differentiable,
                                     directions=params.al_directions,
                                     force_target_length=params.force_target_length)
        elif params.baseloss == "meansim":
            baseloss = MeanSimilarityLoss()

        bclf = train_binary_classifier(data['Sx'], data['Sy'], encoder, params)
        params.latent_binary_classifier = bclf
        if params.increase_clf_lambda_until > 0:
            # compute the number of steps until specified number of epochs are
            # over
            inc_until = int(
                float(len(data['Sx'])) / params.batch_size) * params.increase_clf_lambda_until
        else:
            inc_until = 0
        return FlipLoss(baseloss, bclf,
                        lambda_clfloss=params.lambda_clfloss,
                        increase_until=inc_until)


def get_mode(params):
    return params.mode


def _load_real_data(real_data_file):
    data = []
    with open(real_data_file, 'r') as f:
        for l in f:
            data.append(l.strip())

    return data


def train(params):

    # set gpu device
    device = torch.device(params.device)
    print("Using device {}".format(str(device)))
    if "cuda" in params.device:
        print(torch.cuda.get_device_properties(device))

    # print parameters passed, and all parameters
    print('\ntogrep : {0}\n'.format(sys.argv[1:]))
    print(params)

    if params.al_magnitude_weighting == "gates":
        params.gates = True
        params.embedding_dim += 1
        if params.project_input_dimension == params.embedding_dim - 1:
            params.project_input_dimension += 1
    else:
        params.gates = False

    outputmodelname = params.outputmodelname + str(time.time())
    # save emb2emb model path for later use
    params.emb2emb_outputmodelname = outputmodelname
    """
    SEED
    """

    if params.seed is None:
        params.seed = time.time_ns() % (2**32 - 1)
        print("Using seed", params.seed)
    random.seed(params.seed)
    np.random.seed(params.seed)
    torch.manual_seed(params.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(params.seed)

    """
    DATA
    """
    (train, valid, test), eval_function = get_data(params)
    #vocab = set(np.concatenate((np.unique(train["Sx"]), np.unique(train["Sy"]), np.array(["<SOS>", "<EOS>"]))))
    #params.vocab = vocab
    if params.plot_num_words:
        plot_num_words(train)
        sys.exit()

    """
    Create the model.
    """
    # model
    encoder = get_encoder(params, device).to(device)
    decoder = get_decoder(params, device)
    emb2emb = get_emb2emb(params, encoder, train)
    loss_fn = get_lossfn(params, encoder, train)

    if params.train_classifier_only:
        return None

    mode = get_mode(params)

    if params.compute_l0drop_statistics:
        compute_l0drop_statistics(encoder, train)

    if params.compute_neighborhood_preservation:
        recall_vals = compute_neighborhood_preservation(encoder, valid, params)
        print(recall_vals)
        sys.exit()

    if params.invert_style:
        def swapstyles(d):
            tmp = d["Sx"]
            d["Sx"] = d["Sy"]
            d["Sy"] = tmp

        for d in [train, valid, test]:
            swapstyles(d)

    if params.unaligned:
        # set input and output of training the same
        train["Sy"] = train["Sx"]

    if params.adversarial_regularization:

        if params.loss == "localbagloss":
            raise Exception(
                "Can't do localbagloss and have adversarial_regularization turned on. Use the al_bag_losses option instead.")

        else:
            adv_reg = {"device": device, "critic_lr": params.critic_lr,
                       "critic_input_dim": params.embedding_dim,
                       "joint_reconstruction_adversarial": params.joint_rec_adv,
                       "joint": params.joint, "critic_rounds": params.critic_rounds,
                       "task_rounds": params.critic_rounds, "adversarial_rounds": params.critic_rounds,
                       "critic_hidden_units": params.critic_hidden_units,
                       "critic_hidden_layers": params.critic_hidden_layers,
                       "adversarial_reconstruction_weight": params.adversarial_reconstruction_weight,
                       "real_data": params.real_data_path if params.real_data_path == "input" else _load_real_data(params.real_data_path),
                       "adversarial_remove_last": params.adversarial_remove_last,
                       "discriminate_moments": params.discriminate_moments}
    else:
        adv_reg = {}
    model = Emb2EmbTrainer(encoder, decoder, emb2emb, loss_fn, mode,
                           gaussian_noise_std=params.emb2emb_noise,
                           embedding_dim=params.embedding_dim,
                           adversarial_regularization=adv_reg,
                           fast_gradient_iterative_modification=params.fast_gradient_iterative_modification,
                           binary_classifier=params.latent_binary_classifier if hasattr(
                               params, 'latent_binary_classifier') else None,
                           fgim_decay=params.fgim_decay, fgim_threshold=params.fgim_threshold,
                           fgim_customloss=params.fgim_customloss,
                           fgim_start_at_y=params.fgim_start_at_y,
                           predict_done=not params.dont_predict_done,
                           tokenwise=params.tokenwise,
                           use_end_of_sequence_vector=params.use_end_of_sequence_vector,
                           end_of_sequence_epsilon=params.end_of_sequence_epsilon,
                           teacher_forcing=params.teacher_forcing,
                           learned_positional_embeddings=params.learned_positional_embeddings,
                           force_eval_mode=params.force_eval_mode,
                           unaligned=params.unaligned,
                           track_input_output_distance=params.track_input_output_distance,
                           select_input_length=params.select_input_length,
                           gated_vectors=params.gates)
    print(model)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)

    # are we continuing training from a checkpoint?
    if params.load_emb2emb_path and os.path.isfile(params.load_emb2emb_path):
        checkpoint = torch.load(
            params.load_emb2emb_path)
        val_acc_best = checkpoint['val_acc_best']
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint['epoch']
        params.current_epoch = checkpoint['current_epoch']
        params.run_id = checkpoint['run_id']

        # move optimizer on right device
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        def fix_optimizer_device(opt):
            for state in opt.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor) and params.device != "cpu":
                        state[k] = v.cuda()

        fix_optimizer_device(optimizer)

        if params.adversarial_regularization:
            model._get_critic().load_state_dict(checkpoint['critic'])
            model.critic_optimizer.load_state_dict(
                checkpoint['critic_optimizer'])
            fix_optimizer_device(model.critic_optimizer)
        elif params.loss == "localbagloss" and "adversarial" in params.al_bag_loss:
            params.critic.load_state_dict(checkpoint['critic'])
            params.critic_optimizer.load_state_dict(
                checkpoint['critic_optimizer'])
            fix_optimizer_device(params.critic_optimizer)

    # if we are not continuing training from a checkpoint, are we initializing
    # our model with a pretrained one?
    elif params.pretrained_model and os.path.isfile(params.pretrained_model):
        checkpoint = torch.load(
            params.pretrained_model)
        missing_keys, unexpected_keys = model.load_state_dict(
            checkpoint['model_state_dict'], strict=False)

        load_ok = True
        for uk in unexpected_keys:
            if not uk.startswith('is_done_clf'):
                load_ok = False
        for mk in missing_keys:
            if not mk.startswith('binary_classifier'):
                load_ok = False
        if not load_ok:
            raise Exception(
                f"Can't load the pretrained model because of one of these keys: {str(missing_keys)}, {str(unexpected_keys)}.")
        epoch = 1
        val_acc_best = -1e10
        print("Load pretrained_model")
        # initialize run_id as the current timestamp
        params.run_id = time.time_ns()
    else:
        epoch = 1
        val_acc_best = -1e10
        # initialize run_id as the current timestamp
        params.run_id = time.time_ns()

    # cuda by default
    model.to(device)
    loss_fn.to(device)

    """
    TRAIN
    """
    stop_training = False
    batch_counter = 0
    critic_losses = []
    previous_difference = 0
    params.time_for_epoch = 0
    model.adversarial_reconstruction_weight = params.adversarial_reconstruction_weight

    def trainepoch(epoch):
        model.iterations = 0

        if (params.emb2emb == "identity" and params.mode != MODE_SEQ2SEQ) or params.lr == 0.:
            return 0.

        print('\nTRAINING : Epoch ' + str(epoch))
        model.train()
        all_costs = []
        logs = []
        nonlocal batch_counter, critic_losses, previous_difference

        last_time = time.time()
        # shuffle the data
        indices = list(range(len(train["Sx"])))
        random.shuffle(indices)

        Sx = [train['Sx'][i] for i in indices]
        Sy = [train['Sy'][i] for i in indices]

        def train_discriminator(num_batches):
            for b in range(num_batches):
                indices = range(len(Sx))
                indices = random.sample(indices, params.batch_size)
                Sx_batch = [Sx[i] for i in indices]
                #Sy_batch = [Sy[i] for i in indices]
                output_embeddings, _, out_lens = model.compute_emb2emb(
                    Sx_batch, Y=None)
                false_predictions = binary_clf_predict(
                    params.critic, (output_embeddings, out_lens))

                if params.real_data_path == "input":
                    #real_data = Sy_batch
                    pass
                else:
                    real_data = choices(params.real_data, k=params.batch_size)

                true_embeddings, true_lens = model._encode(real_data)
                true_predictions = binary_clf_predict(
                    params.critic, (true_embeddings, true_lens))
                true_targets = torch.ones((params.batch_size, 1),
                                          device=output_embeddings.device)
                false_targets = torch.ones((params.batch_size, 1),
                                           device=output_embeddings.device)
                targets = torch.cat(
                    [true_targets, false_targets], dim=0)
                predictions = torch.cat(
                    [true_predictions, false_predictions], dim=0)
                l = params.critic_loss(predictions, targets)

                # optimize
                params.critic_optimizer.zero_grad()
                l.backward()
                params.critic_optimizer.step()

        start_epoch_time = time.time()

        for stidx in range(0, len(Sx), params.batch_size):
            batch_counter = batch_counter + 1

            # prepare batch
            Sx_batch = Sx[stidx:stidx + params.batch_size]
            if params.denoise > 0.:
                new_Sx_batch = []
                for t in Sx_batch:
                    new_t = " ".join(
                        [w for w in t.split(" ") if random.random() > params.denoise])
                    new_Sx_batch.append(new_t)
                Sx_batch = new_Sx_batch

            Sy_batch = Sy[stidx:stidx + params.batch_size]

            k = len(Sx_batch)  # actual batch size

            # model forward
            if params.adversarial_regularization:

                # check if we need to increase the weight
                if params.adversarial_delay.is_integer():
                    adv_delay = int(params.adversarial_delay) + 1
                else:
                    adv_delay = int(params.n_epochs * params.adversarial_delay)
                if params.lambda_schedule == "annealing":
                    if epoch < adv_delay:
                        model.adversarial_reconstruction_weight = 0
                    elif epoch == adv_delay:
                        model.adversarial_reconstruction_weight = params.adversarial_reconstruction_weight * \
                            (float(stidx) / len(Sx)
                             )  # linearly increase to full weight
                    else:
                        model.adversarial_reconstruction_weight = params.adversarial_reconstruction_weight
                elif params.lambda_schedule == "dynamic":
                    if epoch >= adv_delay:
                        if model.adversarial_reconstruction_weight == 0.:
                            model.adversarial_reconstruction_weight = params.adversarial_reconstruction_weight

                        if (batch_counter % params.dynamic_lambda_frequency) == 0:
                            avg_c_loss = np.array(critic_losses).mean()
                            critic_losses = []
                            difference_to_confusion = avg_c_loss - \
                                (-np.log(1 - params.dynamic_lambda_target))

                            if np.abs(difference_to_confusion) > params.dynamic_lambda_epsilon:
                                if difference_to_confusion > 0:
                                    # generator is too strong.

                                    # check if the difference has decreased
                                    if difference_to_confusion - previous_difference <= 0:
                                        # difference has decreased, so we keep
                                        # the lambda to avoid overshooting
                                        step_size = 0
                                    else:
                                        # difference has increased, so we need
                                        # to decrease lambda more
                                        step_size = -params.dynamic_lambda_stepsize
                                if difference_to_confusion < 0:
                                    # generator is too weak

                                    # check if the difference has decreased
                                    if difference_to_confusion - previous_difference >= 0:
                                        # difference has decreased, so we keep
                                        # the lambda to avoid overshooting
                                        step_size = 0
                                    else:
                                        # difference has increased, so we need
                                        # to decrease lambda more
                                        step_size = +params.dynamic_lambda_stepsize

                                # change the lambda, but make sure we're not
                                # making in negative
                                if params.log_adversarial_lambda:
                                    model.adversarial_reconstruction_weight = max(
                                        0., 2 ** (np.log2(model.adversarial_reconstruction_weight) + step_size))
                                else:
                                    model.adversarial_reconstruction_weight = max(
                                        0., model.adversarial_reconstruction_weight + step_size)

                            previous_difference = difference_to_confusion
                    else:
                        model.adversarial_reconstruction_weight = 0.
                # forward pass
                loss, task_loss, critic_loss, train_critic_loss = model(
                    Sx_batch, Sy_batch)
                all_costs.append(
                    [loss.item(), task_loss.item(), critic_loss.item(), train_critic_loss.item()])
                critic_losses.append(critic_loss.item())

            else:
                loss = model(Sx_batch, Sy_batch)

                # loss
                all_costs.append(loss.item())

            # backward
            optimizer.zero_grad()
            loss.backward()

            #gradientsum = 0.
            # for n,p in model.emb2emb.named_parameters():
            #print(n, p.grad)
            # break

            # for n,p in model._get_critic().named_parameters():
            #print(n, p.grad)

            # optimizer step
            optimizer.step()

            if params.loss == "localbagloss" and "adversarial" in params.al_bag_loss:
                train_discriminator(1)

            if len(all_costs) == params.log_freq:

                if not params.adversarial_regularization:
                    log_string = '{0} ; loss {1} ; sentence/s {2}'
                    log_string = log_string.format(
                        stidx, round(np.mean(all_costs), 5),
                        int(len(all_costs) * params.batch_size / (time.time() - last_time)))
                else:
                    mean_losses = np.reshape(
                        np.array(all_costs).mean(axis=0), (-1))
                    mean_losses = np.round(mean_losses, decimals=5)
                    log_string = '{0} ; loss {1} ; sentence/s {2} ; t-loss {3} ; c-loss {4} ; tc-loss {5}'
                    log_string = log_string.format(
                        stidx, mean_losses[0],
                        int(len(all_costs) * params.batch_size /
                            (time.time() - last_time)),
                        mean_losses[1], mean_losses[2], mean_losses[3])

                logs.append(log_string)
                print(logs[-1])
                # for p in model.emb2emb.parameters():
                #    print(p.grad)
                #    break
                last_time = time.time()
                words_count = 0
                all_costs = []

            if params.validation_frequency > 0 and (batch_counter % params.validation_frequency) == 0:
                evaluate(epoch, eval_type='valid', final_eval=False)
                model.train()

        params.time_for_epoch = time.time() - start_epoch_time
        return round(np.mean(all_costs), 5)

    def evaluate(epoch, eval_type='valid', final_eval=False):

        model.eval()

        if type(model.emb2emb) == FixOffsetVectorMLP:
            model.emb2emb.print_vecs()

        if eval_function is not None:
            score = eval_function(
                model, mode="valid" if not final_eval else "test", params=params)
            print("Total Inference time", model.total_inference_time)
            print("Total Emb2Emb time", model.total_emb2emb_time)
            print("Total FGIM time", model.total_time_fgim)
            if type(score) == tuple:
                tmp_score = score
                score = tmp_score[0]
                self_bleu = tmp_score[1]
                b_acc = tmp_score[2]
            else:
                self_bleu = None
                b_acc = None

            if eval_type == 'valid':
                nonlocal val_acc_best

                if score > val_acc_best:
                    val_acc_best = max(val_acc_best, score)
                    checkpoint = {"model_state_dict": model.state_dict()}
                    torch.save(checkpoint, os.path.join(
                        params.outputdir, outputmodelname))
        else:

            if eval_type == 'valid':
                print('\nVALIDATION : Epoch {0}'.format(epoch))

            if eval_type == "valid":
                Sx = valid['Sx']
                Sy = valid['Sy']
            else:
                Sx = test['Sx']
                Sy = test['Sy']

            for stidx in range(0, len(Sx), params.batch_size):
                # prepare batch
                Sx_batch = Sx[stidx:stidx + params.batch_size]
                Sy_batch = Sy[stidx:stidx + params.batch_size]

                k = len(Sx_batch)  # actual batch size

                # model forward
                with torch.no_grad():
                    outputs = model(Sx_batch, Sy_batch)

                if params.print_outputs:
                    for i in range(len(Sx_batch[:5])):
                        input = Sx_batch[i]
                        gold_output = Sy_batch[i]
                        predicted_output = outputs[i]
                        pretty_print_prediction(
                            input, gold_output, predicted_output)

                    break
                else:
                    break
            score = 0

        eval_string = "Validation-Score in epoch {}/{} : {}; best : {}".format(
            epoch, batch_counter, score, val_acc_best)
        if b_acc is not None:
            eval_string = eval_string + " ; b-acc : {}".format(b_acc)
        if self_bleu is not None:
            eval_string = eval_string + " ; self-bleu : {}".format(self_bleu)
        print(eval_string)
        return score

    """
    Train model
    """

    while not stop_training and epoch <= params.n_epochs:
        train_loss = trainepoch(epoch)
        if params.adversarial_regularization:
            print('Epoch {0} ; loss {1} ; lambda {2}'.format(
                epoch, train_loss, model.adversarial_reconstruction_weight))
        else:
            print('Epoch {0} ; loss {1}'.format(
                epoch, train_loss))

        if params.validate and params.validation_frequency < 0:
            evaluate(epoch, 'valid')

        epoch += 1

        if params.load_emb2emb_path is not None:
            # checkpoint training

            checkpoint_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc_best': val_acc_best,
                'current_epoch': params.current_epoch,
                'run_id': params.run_id}
            if params.adversarial_regularization:
                checkpoint_dict['critic'] = model._get_critic().state_dict()
                checkpoint_dict['critic_optimizer'] = model.critic_optimizer.state_dict(
                )
            elif params.loss == "localbagloss" and "adversarial" in params.al_bag_loss:
                checkpoint_dict['critic'] = params.critic.state_dict()
                checkpoint_dict['critic_optimizer'] = params.critic_optimizer.state_dict(
                )
            torch.save(
                checkpoint_dict, params.load_emb2emb_path)

    # Run best model on test set.
    if params.validate:
        try:
            checkpoint = torch.load(os.path.join(
                params.outputdir, outputmodelname))
            model.load_state_dict(checkpoint["model_state_dict"])
        except:
            # no model saved so far
            pass

    #print('\nTEST : Epoch {0}'.format(epoch))
    results = {}
    if params.validate:
        print("Doing the final evaluation on the dev set...")
        final_val_score = evaluate(1e6, 'valid', False)
        results["dev"] = final_val_score
    print("Doing the final evaluation on the test set...")
    final_test_score = evaluate(0, 'test', True)
    results["test"] = final_test_score
    return results


if __name__ == "__main__":
    params = get_params()
    train(params)
    print("<<<JOB_FINISHED>>>")
    if not params.no_cleanup and params.load_emb2emb_path is not None:
        if os.path.exists(params.load_emb2emb_path):
            print("Cleaning up the model.")
            os.remove(params.load_emb2emb_path)
