import sys, shutil
sys.path.append('../')
sys.path.append('../autoencoders/')

DEFAULT_CONFIG = "../autoencoders/config/default.json"

from train import get_decoder, get_encoder
import torch
import argparse
import os
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from utils import bleu_tokenize
import time

#for testing purposes
#LM_CONFIG = "--emsize 2 --nhid 2 --dropout 0.5 --epochs 1 --tied"
#for real experiments
LM_CONFIG = "--emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied"

def get_params():
    parser = argparse.ArgumentParser(description='Emb2Emb')
    # paths
    parser.add_argument("--dataset_path", type=str, default='../data/nli/', help="Path to dataset")
    parser.add_argument("--outputdir", type=str, default='savedir/', help="Output directory")
    parser.add_argument("--outputmodelname", type=str, default='model.pickle')
    parser.add_argument("--glove_path", type=str, default="../resource/smallglove.840B.300d.txt", help="Path to glove embeddings")
    parser.add_argument("--modeldir", type=str, default="../autoencoders/lstmvae_nli.pt", help="Path to autoencoder dir")
    parser.add_argument("--model_name", type=str, default="model.pt", help="Name of the model file.")
    parser.add_argument("--vocab_path", type=str, default="", help="Path to vocabulary.")
    parser.add_argument("--temp_data_path", type=str, default='./.tmp_data/', help="Path to folder where to store temporary data.")
    parser.add_argument("--path_to_lm_code", type=str, required=True, help="Path to LM code")
    parser.add_argument("--path_to_real_data", type=str, required=True, help="Path to file containing reconstruction BLEU test data")
    parser.add_argument("--num_samples", type=int, default=100)
    
    # training
    parser.add_argument("--validate", action="store_true")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--n_epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=64)
    
    # model
    parser.add_argument("--emb2emb", type=str, default='mlp', help="emb2emb architecture to use", choices = ["fixoffsetnet", "mlp", "identity", "highway", "offsetnet"])
    parser.add_argument("--emb2emb_noise", type=float, default=0., help="Amount of noise to add to input embeddings when training emb2emb model.")
    parser.add_argument("--dropout_p", type=float, default=0., help="Amount of dropout to have in the emb2emb model.")
    parser.add_argument("--residual_connections", action="store_true")
    parser.add_argument("--skip_connections", action="store_true")
    parser.add_argument("--outlayers", action="store_true", help="Add an outlayers to computation of offset vector in offsetnet.")
    parser.add_argument("--activate_result", action="store_true", help="Add a nonlinerity after adding the non-linearity in offsetnet")
    parser.add_argument("--loss", type=str, default='mse', help="loss", choices = ["mse", "cosine", "ce"])
    parser.add_argument("--highway_bias", type=float, default = -1.)
    parser.add_argument("--n_layers", type=int, default=1, help="Number of layers to use in the Emb2Emb model.")
    parser.add_argument("--hidden_layer_size", type=int, default=1024, help="Hidden layer size to use in the Emb2Emb model.")
    parser.add_argument("--autoencoder", type=str, default="FromFile", help="Specify the autoencoder to use.", choices = ["FromFile", "RAE"])
    parser.add_argument("--estimate_distribution", action="store_true", help="Return the estimated distribution (mu,v) instead of the noisy sample.")
    parser.add_argument("--compute_fwdppl", action="store_false", help="If set, fwdppl is NOT computed.")
    parser.add_argument("--compute_bwdppl", action="store_false", help="If set, bwdppl is NOT computed.")
    parser.add_argument("--compute_bleu", action="store_false", help="If set, bleu is NOT computed.")
    
    # adversarial reg for emb2emb
    parser.add_argument("--adversarial_regularization", action="store_true", help="Perform adversarial regularization while training emb2emb.")
    parser.add_argument("--critic_lr", type=float, default = 0.00001, help="LR for training the critic.")
    parser.add_argument("--critic_rounds", type=int, default = 100, help="How many steps to go through each phase.")
    parser.add_argument("--adversarial_rounds", type=int, default = 100, help="How many steps to go through each phase.")
    parser.add_argument("--task_rounds", type=int, default = 100, help="How many steps to go through each phase.")
    parser.add_argument("--joint", action="store_true", help="In each iteration, train reconstruction, critic, and adversarial.")
    parser.add_argument("--joint_rec_adv", action="store_true", help="When training adversarially, do it jointly with the reconstruction task.")
    parser.add_argument("--critic_hidden_layers", type=int, default = 1, help="Number of hidden layers the critic has.")
    parser.add_argument("--critic_hidden_units", type=int, default = 300, help="Number of hidden units the critic has.")
    parser.add_argument("--adversarial_reconstruction_weight", type=float, default=1.0, help="Weight of adversarial loss. Decrease to reduce the adversarial loss term's influence.")
    parser.add_argument("--lambda_schedule", type=str, default="fixed", choices = ["fixed", "annealing", "dynamic"])
    parser.add_argument("--adversarial_delay", type=int, default=0, help="number of epochs to wait until starting to increase weight")
    parser.add_argument("--dynamic_lambda_epsilon", type=float, default=0.03, help="The maximal tolarable deviation from total confusion (log(0.5)")
    parser.add_argument("--dynamic_lambda_stepsize", type=float, default=0.01, help="The amount by which to increase lambda.")
    parser.add_argument("--dynamic_lambda_frequency", type=int, default=100, help="The frequency (# of batches) whith which to check whether lambda need to be changed.")
    parser.add_argument("--dynamic_lambda_target", type=float, default=0.5, help="The target accuracy we want the generator to achieve in fooling the discriminator.")
    
    # reproducibility
    parser.add_argument("--seed", type=int, default=1234, help="seed")
    
    # data
    parser.add_argument("--embedding_dim", type=int, default=1024, help="sentence embedding dimension")
    parser.add_argument("--data_fraction", type=float, default=1., help = "How much of the data to use.")
    parser.add_argument("--print_outputs", action="store_true", help = "Print some of the outputs at validation time for inspection.")
    parser.add_argument("--max_prints", type=int, default=5, help = "How many examples to print during validation time.")
    parser.add_argument("--log_freq", type=int,default=100, help = "How often to print the logs.")
    params, unknown = parser.parse_known_args()
    if len(unknown) > 0:
        raise ValueError("Got unknown parameters " + str(unknown))
    return params

def generate_examples(params):
    # set gpu device
    device = torch.device(params.device)
    print("Using device {}".format(str(device)))
    
    decoder = get_decoder(params, device).to(device)
    decoder.eval()
    
    num_samples = params.num_samples
    temp_data_path = params.temp_data_path 
    
    outfile = open(os.path.join(temp_data_path, "generated_data.txt"), 'w')
    # draw samples
    with torch.no_grad():
        for i in range(0, num_samples, params.batch_size):
            gaussian_samples = torch.randn(params.batch_size, params.embedding_dim, device = device)
            predictions = decoder.predict(gaussian_samples)
            texts = decoder.prediction_to_text(predictions)
            for t in texts:
                t = " ".join(word_tokenize(t))
                #print(t)
                print(t, file = outfile)
                
    outfile.close()
    
def _train_lm(params):
    base_path = params.path_to_lm_code
        
    command =  "python " + base_path +"/main.py " + LM_CONFIG + " --data {} --save {}"
    command = command.format(params.temp_data_path, os.path.join(params.temp_data_path, "model.pt"))
    if "cuda" in params.device:
        command = command + " --cuda"
        
    command += " > " + os.path.join(params.temp_data_path, "results.txt")
    print(command)
    os.system(command)
    # retrieve result
    with open(os.path.join(params.temp_data_path, "results.txt"), 'r') as res:
        for l in res:
            if "test ppl" in l:
                ppl = float(l.strip()[-5:])
    return ppl

def count_lines(filepath):
    counter = 0
    with open(filepath, 'r') as f:
        for l in f:
            counter = counter + 1
    return counter


def _copy_to_temp(params, train_data_path, test_data_path):
    
    # setup dataset
    num_lines = count_lines(train_data_path)
    num_valid = int(0.05 * num_lines)
    num_train = num_lines - num_valid
    
    copy_valid = "tail -n {} {} > {}".format(num_valid, train_data_path, os.path.join(params.temp_data_path, "valid.txt"))
    copy_train = "head -n {} {} > {}".format(num_train, train_data_path, os.path.join(params.temp_data_path, "train.txt"))
    print(copy_valid)
    print(copy_train)
    os.system(copy_valid)
    os.system(copy_train)
    shutil.copy(test_data_path, os.path.join(params.temp_data_path,"test.txt"))

def compute_backward_perplexity(params):
    """
    Uses the generated examples to train a LM, which is then tested on real data.
    """
    
    _copy_to_temp(params, os.path.join(params.temp_data_path,"generated_data.txt"), params.path_to_real_data)
    
    return _train_lm(params)
    
def compute_forward_perplexity(params):
    """
    Trains an LM on real data and tests them on the generated data.
    """
    
    _copy_to_temp(params, params.path_to_real_data, os.path.join(params.temp_data_path,"generated_data.txt"))
    
    return _train_lm(params)

def compute_reconstruction_bleu(params):
    """
    Computes the reconstruction bleu on the test set of the real data.
    """
    # set gpu device
    device = torch.device(params.device)
    print("Using device {}".format(str(device)))
    
    decoder = get_decoder(params, device).to(device)
    decoder.eval()
    encoder = get_encoder(params, device).to(device)
    encoder.eval()

    with open(params.path_to_real_data, 'r') as data:
        real_data = data.readlines()
    
    reconstructed = []
    for idx in range(0, len(real_data), params.batch_size):
        with torch.no_grad():
            batch = real_data[idx:(idx + params.batch_size)]
            reconstructed.extend(decoder(encoder(batch)))
    
    list_of_references = []
    for r in real_data:
        list_of_references.append([bleu_tokenize(r)])
    reconstructed = [bleu_tokenize(r) for r in reconstructed]
    return corpus_bleu(list_of_references, reconstructed, smoothing_function = SmoothingFunction().method1)

def clean(params):
    shutil.rmtree(params.temp_data_path)
    
def main():
    
    params = get_params()
    params.temp_data_path = os.path.join(params.temp_data_path, str(time.time())) # add timestamp to avoid collapse when running multiple experiments
    try:
        os.makedirs(params.temp_data_path)
    except OSError as e:
        pass
    
    # first, we generate examples to use for training an LM and testing an LM
    if params.compute_fwdppl or params.compute_bwdppl:
        generate_examples(params)
        if params.compute_bwdppl:
            bwd_ppl = compute_backward_perplexity(params)
        if params.compute_fwdppl:
            fwd_ppl = compute_forward_perplexity(params)
    
    # we test the reconstruction performance on a hold out dataset
    if params.compute_bleu:
        reconstruction_bleu = compute_reconstruction_bleu(params)    
    clean(params)
    
    print("=======Results============")
    print("Forward ppl:", fwd_ppl if params.compute_fwdppl else " - ")
    print("Backward ppl:", bwd_ppl if params.compute_bwdppl else " - ")
    print("Reconstruction BLEU:", reconstruction_bleu if params.compute_bleu else " - ")
    print("==========================")
    
    
    
if __name__ == "__main__":
    main()
