# Separate the training sets into 2 parts

import argparse
import sys
import numpy as np
import pickle

sys.path.append("xxx")

from tqdm import tqdm
from shutil import copyfile

from common.utils import seed_everything



parser = argparse.ArgumentParser()

parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--data_folder', type = str, default = "../../DATASETS/RedditTIFU/data/en/") # CNNDM / WikiHow / XSum / RedditTIFU
parser.add_argument('--thresh', type = int, default = 17000) # in [143000, 84000, 102000, 17000]
parser.add_argument('--dataset_name', type = str, default = "reddit") # in ["cnndm", "wikihow", "xsum"]
parser.add_argument('--individual_files', type = bool, default = False) # True for CNNDM, False for WikiHow + XSum

args = parser.parse_args()

print("*"*50)
print(args)



def main(args):

    # seed
    seed_everything(args.seed)

    # load full training
    train_summaries = []
    train_texts = []
    train_top_sents = []
    with open(args.data_folder + "train_summary.txt", "rb") as f:
        for l in f.readlines():
            train_summaries.append(l)
    with open(args.data_folder + "train_text.txt", "rb") as f:
        for l in f.readlines():
            train_texts.append(l)
    with open(args.data_folder + "train_top_sentences.txt", "rb") as f:
        for l in f.readlines():
            train_top_sents.append(l)
    print(len(train_summaries), len(train_texts), len(train_top_sents)) 

    # shuffle
    p = np.random.permutation(len(train_texts))
    print(p[:10])
    with open("dataset_permutations/{}_train_permutation.pkl".format(args.dataset_name), "wb") as f:
        pickle.dump(p, f)
        print("saved permutation!")
    train_summaries = [train_summaries[i] for i in p]
    train_texts = [train_texts[i] for i in p]
    train_top_sents = [train_top_sents[i] for i in p]
    print("permuted the training set!")
    p_to_normal = {}
    for i in range(len(p)):
        p_to_normal[p[i]] = i

    # 1st half - full files
    first_half_summaries = train_summaries[:args.thresh]
    first_half_texts = train_texts[:args.thresh]
    first_half_top_sents = train_top_sents[:args.thresh]
    print(len(first_half_summaries), len(first_half_texts), len(first_half_top_sents))
    with open(args.data_folder + "first_half_train_shuffled_summary.txt", "wb") as f:
        for l in first_half_summaries:
            f.write(l)
    with open(args.data_folder + "first_half_train_shuffled_text.txt", "wb") as f:
        for l in first_half_texts:
            f.write(l)
    with open(args.data_folder + "first_half_train_shuffled_top_sentences.txt", "wb") as f:
        for l in first_half_top_sents:
            f.write(l)

    # 2nd half - full files
    second_half_summaries = train_summaries[args.thresh:]
    second_half_texts = train_texts[args.thresh:]
    second_half_top_sents = train_top_sents[args.thresh:]
    print(len(second_half_summaries), len(second_half_texts), len(second_half_top_sents))
    with open(args.data_folder + "second_half_train_shuffled_summary.txt", "wb") as f:
        for l in second_half_summaries:
            f.write(l)
    with open(args.data_folder + "second_half_train_shuffled_text.txt", "wb") as f:
        for l in second_half_texts:
            f.write(l)
    with open(args.data_folder + "second_half_train_shuffled_top_sentences.txt", "wb") as f:
        for l in second_half_top_sents:
            f.write(l)

    # individual files
    if args.individual_files:
        docs = ["summary", "text"]
        for doc in docs:
            path = args.data_folder + "train/{}/".format(doc)
            print(path)
            idx_first = 0
            idx_second = 0
            for i in tqdm(range(len(p))):
                src_path = path + "train_{}_{}.txt".format(doc, p[i])
                if i < args.thresh:
                    dst_path = args.data_folder + "first_half_train_shuffled/{}/first_half_train_shuffled_{}_{}.txt".format(doc, doc, idx_first)
                    idx_first += 1
                else:
                    dst_path = args.data_folder + "second_half_train_shuffled/{}/second_half_train_shuffled_{}_{}.txt".format(doc, doc, idx_second)
                    idx_second += 1
                #print(i, src_path, dst_path)
                #raise Exception 
                copyfile(src_path, dst_path)



if __name__ == '__main__':
    main(args)




