# Preprocess CQG dataset
import argparse
import pickle
import random
import os
import codecs
from urllib.request import urlretrieve
from zipfile import ZipFile
from tqdm import tqdm
from model.utils import Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN

CQG_dir = './CQG/org/org/tiny/'

def get_objects(trainfile, devfile, testfile):
    train, dev, test = [], [], []
    with codecs.open(trainfile, 'r', encoding='utf-8') as train_f:
        for line in train_f:
            session = []
            linelist = line.strip().split('\t')[:-1]
            for i in linelist:
                session.append(i.strip().split(' '))
            train.append(session)
    with codecs.open(devfile, 'r', encoding='utf-8') as dev_f:
        for line in dev_f:
            session = []
            linelist = line.strip().split('\t')[:-1]
            for i in linelist:
                session.append(i.strip().split(' '))
            dev.append(session)
    with codecs.open(testfile, 'r', encoding='utf-8') as test_f:
        for line in test_f:
            session = []
            linelist = line.strip().split('\t')[:-1]
            for i in linelist:
                session.append(i.strip().split(' '))
            test.append(session)
    return train, dev, test

def get_objects_type(trainfile, devfile, testfile):
    train_type, dev_type, test_type = [], [], []
    with codecs.open(trainfile, 'r', encoding='utf-8') as train_f:
        for line in train_f:
            linelist = line.strip().split('\t')[-1]
            train_type.append(linelist)
    with codecs.open(devfile, 'r', encoding='utf-8') as dev_f:
        for line in dev_f:
            linelist = line.strip().split('\t')[-1]
            dev_type.append(linelist)
    with codecs.open(testfile, 'r', encoding='utf-8') as test_f:
        for line in test_f:
            linelist = line.strip().split('\t')[-1]
            test_type.append(linelist)
    return train_type, dev_type, test_type

def pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10):
    def pad_tokens(tokens, max_sentence_length=max_sentence_length):
        n_valid_tokens = len(tokens)
        if n_valid_tokens > max_sentence_length - 1:
            tokens = tokens[:max_sentence_length - 1]
        n_pad = max_sentence_length - n_valid_tokens - 1
        tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad
        return tokens

    def pad_conversation(conversation):
        conversation = [pad_tokens(sentence) for sentence in conversation]
        return conversation

    all_padded_sentences = []
    all_sentence_length = []

    for conversation in conversations:
        if len(conversation) > max_conversation_length:
            conversation = conversation[:max_conversation_length]
        sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token
                           for sentence in conversation]
        all_sentence_length.append(sentence_length)

        sentences = pad_conversation(conversation)
        all_padded_sentences.append(sentences)

    sentences = all_padded_sentences
    sentence_length = all_sentence_length
    return sentences, sentence_length

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    # Maximum valid length of sentence
    # => SOS/EOS will surround sentence (EOS for source / SOS for target)
    # => maximum length of tensor = max_sentence_length + 1
    parser.add_argument('-s', '--max_sentence_length', type=int, default=30)
    parser.add_argument('-c', '--max_conversation_length', type=int, default=3)

    # Vocabulary
    parser.add_argument('--max_vocab_size', type=int, default=50000)
    parser.add_argument('--min_vocab_frequency', type=int, default=1)

    # Multiprocess
    parser.add_argument('--n_workers', type=int, default=os.cpu_count())

    args = parser.parse_args()

    max_sent_len = args.max_sentence_length
    max_conv_len = args.max_conversation_length
    max_vocab_size = args.max_vocab_size
    min_freq = args.min_vocab_frequency
    n_workers = args.n_workers

    def to_pickle(obj, path):
        with open(path, 'wb') as f:
            pickle.dump(obj, f)

    train, dev, test = get_objects(CQG_dir+'label_train.txt', CQG_dir+'label_dev.txt', CQG_dir+'label_test.txt')
    train_type, dev_type, test_type = get_objects_type(CQG_dir+'label_train.txt', CQG_dir+'label_dev.txt', CQG_dir+'label_test.txt')
   
    print(len(train), len(dev), len(test), len(train_type), len(dev_type), len(test_type))
   
    for split_type, conv_objects, type_objects in [('train', train, train_type), ('valid', dev, dev_type), ('test', test, test_type)]:
        print(f'Processing {split_type} dataset...')
        split_data_dir = CQG_dir + split_type
        if not os.path.exists(split_data_dir):
            os.mkdir(split_data_dir)
        conversations = conv_objects
        conversation_length = [len(conv) for conv in conv_objects]
        types = [i for i in type_objects]

        sentences, sentence_length = pad_sentences(
            conversations,
            max_sentence_length=max_sent_len,
            max_conversation_length=max_conv_len)

        print('Saving preprocessed data at', split_data_dir)
        to_pickle(conversation_length, split_data_dir+'/conversation_length.pkl')
        to_pickle(sentences, split_data_dir+'/sentences.pkl')
        to_pickle(sentence_length, split_data_dir+'/sentence_length.pkl')
        to_pickle(types, split_data_dir+'/types.pkl')

        if split_type == 'train':

            print('Save Vocabulary...')
            vocab = Vocab()
            vocab.add_dataframe(conversations)
            vocab.update(max_size=max_vocab_size, min_freq=min_freq, is_type=False)

            print('Vocabulary size: ', len(vocab))
            vocab.pickle(CQG_dir+'word2id.pkl', CQG_dir+'id2word.pkl')

            print('Save Type Vocabulary...')
            vocab_t = Vocab()
            vocab_t.add_sentence(types)
            vocab_t.update(max_size=15, min_freq=1, is_type=True)

            print('Type Vocabulary size: ', len(vocab_t))
            vocab_t.pickle(CQG_dir+'word2id_t.pkl', CQG_dir+'id2word_t.pkl')


    print('Done!')
