# TODO: FIX, CURRENTLY BROKEN

import sys
import os

sys.path.append('../')

from multiprocessing import Pool
import torch
from torch import nn
from emb2emb.utils import glove_dict, get_data, pretty_print_prediction, word_index_mapping, Namespace
import argparse
from emb2emb.encoders import tokenize
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from lstmae import LSTMAE
from lstmvae import LSTMVAE
import numpy as np

def eval(model, X, X_lens, noise, device):
    encoded = model.encode(X, X_lens)
    if noise != 0.0:
        encoded += torch.randn_like(encoded, device=device) * noise
    return (model.beam_decode(encoded), model.greedy_decode(encoded))

# turn words into indices. Parallelized
def word_to_index(tokenized, word2index, processes):
    indexed = []
    lengths = []

    p = Pool(processes)

    indexed = p.starmap(_tokens_to_index, [(token_list, word2index) for token_list in tokenized])

    return np.array(indexed), np.array([len(index_list) for index_list in indexed])

# Work for each worker
def _tokens_to_index(token_list, word2index):
    index_list = [word2index["<SOS>"]]
    for t in token_list:
        if t in word2index:
            index_list.append(word2index[t])
    index_list.append(word2index["<EOS>"])
    return index_list

def prepare_batch(indexed, lengths, device):
    X = pad_sequence([torch.tensor(index_list, device=device) for index_list in indexed], batch_first=True, padding_value=0)
    lengths, idx = torch.sort(torch.tensor(lengths, device=device).long(), descending=True)
    return X[idx].to(torch.int64), lengths.to(torch.int64), idx

def write_to_file(path, index2word, data):
    print("Writing to file")
    # Write to file
    with open(params.output_path, "w+") as f: 
        for original, decode_beam, decode_greedy in data:
            f.write(" ".join([index2word[w] for w in original]) + "\n")
            f.write(" ".join([index2word[w] for w in decode_beam]) + "\n")
            f.write(" ".join([index2word[w] for w in decode_greedy]) + "\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default='../data/nli/', help="Path to dataset")
    parser.add_argument("--model_path", type=str, default="savedir/model.pickle", help="Path to autoencoder model")
    parser.add_argument("--output_path", type=str, default="savedir/results.txt", help="Path to results")
    parser.add_argument("--input_size", type=int, default=300, help="Input embedding size")
    parser.add_argument("--max_sequence_len", type=int, default=100)
    parser.add_argument("--hidden_size", type=int, default=1024)
    parser.add_argument("--teacher_forcing_ratio", type=float, default=0.5)
    parser.add_argument("--data_fraction", type=float, default=1., help = "How much of the data to use.")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--sentence", type=str, help="The single sentence to evaluate on.")
    parser.add_argument("--noise", type=float, default=0.0, help="Magnitude of noise to add to the embeddings")
    parser.add_argument("--all", action="store_true")
    parser.add_argument("--variational", action="store_true")
    parser.add_argument("--n", type=int, default=1, help="Number of times to evaluate a single sentence. (--sentence)")
    params, _ = parser.parse_known_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    (train, valid, test), eval_sari = get_data(params)
    if params.sentence is None:
        sentences = np.concatenate((train["Sx"], train["Sy"], valid["Sx"], valid["Sy"], test["Sx"], test["Sy"]))
    else:
        sentences = np.array([params.sentence])

    mapping = np.load(params.model_path + ".mapping.npy", allow_pickle=True)
    word2index = mapping[()]["word2index"]
    index2word = mapping[()]["index2word"]

    print(f"{sentences.shape[0]} sentences")

    tokenized = list(filter(lambda x: len(x) <= params.max_sequence_len, [tokenize(s) for s in sentences]))
    print(f"{len(tokenized)} sentences below len {params.max_sequence_len}")

    if params.variational:
        model = LSTMVAE(Namespace(input_size = params.input_size,
                                  max_sequence_len = params.max_sequence_len,
                                  device = device,
                                  hidden_size=params.hidden_size,
                                  teacher_forcing_ratio=params.teacher_forcing_ratio,
                                  vocab_size=len(word2index),
                                  eos_idx=word2index["<EOS>"],
                                  sos_idx=word2index["<SOS>"]))

    else:
        model = LSTMAE(Namespace(input_size = params.input_size,
                                max_sequence_len = params.max_sequence_len,
                                device = device,
                                hidden_size=params.hidden_size,
                                vocab_size=len(word2index),
                                teacher_forcing_ratio=params.teacher_forcing_ratio,
                                eos_idx=word2index["<EOS>"],
                                sos_idx=word2index["<SOS>"]))
    model.load_state_dict(torch.load(params.model_path, map_location=device))
    model = model.to(device)

    print(model)

    print("Preparing data")
    data, lengths = word_to_index(tokenized, word2index, os.cpu_count())

    print("Starting evaluation")

    if params.sentence is None:
        i = 0
        b_it = 0
        epoch = 0
        losses = []
        valid_losses = []
    
        epoch_batches = data.shape[0] // params.batch_size
        print(f"Batch size: {params.batch_size}")
        print(f"Epoch batches: {epoch_batches}")

        eval_data = []
        for b_it in range(epoch_batches):
            if (b_it+1) * params.batch_size <= data.shape[0]:
                s_batch = data[b_it*params.batch_size:(b_it+1)*params.batch_size]
                l_batch = lengths[b_it*params.batch_size:(b_it+1)*params.batch_size]
            else:
                s_batch = data[b_it*params.batch_size:]
                l_batch = lengths[b_it*params.batch_size:]

            # Eval on batch
            X, X_lens, idx = prepare_batch(s_batch, l_batch, device)
            d_batch = eval(model, X, X_lens, params.noise, device)

            for model_output_idx, s_batch_idx in enumerate(idx.cpu().numpy().tolist()):
                eval_data.append((s_batch[s_batch_idx], d_batch[0][model_output_idx], d_batch[1][model_output_idx]))

            print("Original        : "+" ".join([index2word[w] for w in eval_data[-1][0]]))
            print("Decoded (beam)  : "+" ".join([index2word[w] for w in eval_data[-1][1]]))
            print("Decoded (greedy): "+" ".join([index2word[w] for w in eval_data[-1][2]]))

            print(f"{100 * (b_it+1) / epoch_batches}% finished")
            if b_it % (epoch_batches // 1000) == 0:
                write_to_file(params.output_path, index2word, eval_data)

        write_to_file(params.output_path, index2word, eval_data)
    else:
        print("Original        : "+" ".join([index2word[w] for w in data[0]]))
        for i in range(params.n):
            X, X_lens, idx = prepare_batch(data, lengths, device)

            decode = eval(model, X, X_lens, params.noise, device)
            print("Decoded (beam)  : "+" ".join([index2word[w] for w in decode[0][0]]))
            print("Decoded (greedy): "+" ".join([index2word[w] for w in decode[1][0]]))
