import argparse
import os

import numpy as np


def main(input_path, output_dir, n_valid, n_test):
    slotset = set()

    with open(input_path, 'rt') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            tokens_slots = line.split('\t')[-1].split()
            slots = [token_slot.split('|')[1] for token_slot in tokens_slots]
            for slot in slots:
                if slot != 'other':
                    slotset.add(slot)

    slots = list(slotset)
    n_train = len(slots) - n_valid - n_test
    print(("{} slots in the training data, among which\n" +
           " - {} will be used for training, \n" +
           " - {} will be used for validation, and\n" +
           " - {} will be used for testing.").format(len(slots), n_train, n_valid, n_test))

    rnd = np.random.permutation(slots).tolist()
    train_slots = rnd[:n_train]
    valid_slots = rnd[n_train:n_train + n_valid]
    test_slots = rnd[n_train + n_valid:]

    train_path = os.path.join(output_dir, "train.tsv")
    valid_path = os.path.join(output_dir, "valid.tsv")
    test_path = os.path.join(output_dir, "test.tsv")

    with open(input_path, 'rt') as f, \
            open(train_path, 'wt') as trainf, \
            open(valid_path, 'wt') as validf, \
            open(test_path, 'wt') as testf:
        for line in f:
            line = line.strip()
            cols = line.split('\t')
            domain = cols[0]
            intent = cols[1]
            tokens_slots = cols[2].split()
            tokens = [token_slot.split('|')[0] for token_slot in tokens_slots]
            slots = [token_slot.split('|')[1] for token_slot in tokens_slots]
            train_num = np.isin(slots, train_slots).sum()
            valid_num = np.isin(slots, valid_slots).sum()
            test_num = np.isin(slots, test_slots).sum()
            if train_num > valid_num and train_num > test_num:
                slots = [slot if slot in set(train_slots) else 'other' for slot in slots]
                tokens_slots = ['{}|{}'.format(token, slot) for token, slot in zip(tokens, slots)]
                trainf.write(domain + '\t' + intent + '\t' + ' '.join(tokens_slots) + '\n')
            elif valid_num > test_num:
                slots = [slot if slot in set(valid_slots) else 'other' for slot in slots]
                tokens_slots = ['{}|{}'.format(token, slot) for token, slot in zip(tokens, slots)]
                validf.write(domain + '\t' + intent + '\t' + ' '.join(tokens_slots) + '\n')
            else:
                slots = [slot if slot in set(test_slots) else 'other' for slot in slots]
                tokens_slots = ['{}|{}'.format(token, slot) for token, slot in zip(tokens, slots)]
                testf.write(domain + '\t' + intent + '\t' + ' '.join(tokens_slots) + '\n')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('input', type=str, help='Path to a tsv file containing all data')
    parser.add_argument('output_dir', type=str,
                        help='Path to a directory to write "train.tsv", "valid.tsv" and "test.tsv" tsv files')
    parser.add_argument('n_valid', type=int,
                        help="Number of validation slots.")
    parser.add_argument('n_test', type=int,
                        help="Number of testing slots.")
    args = parser.parse_args()
    main(args.input, args.output_dir, args.n_valid, args.n_test)
