# coding=utf-8
# Perform self-training.
# Before run this code, predict unlabeled sentences with fine-tuned model.
# We recommand to split the unlabeled sentences into several bins(e.g. 1M per bin) and predict.
# The predicted results for each bin should be in directory of same name with their bin.
# The results file only contains the extracted sentences, so please add them with original training set 
# and run the mixup_data_convert.py code.

import csv
import os
import numpy as np
import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("unlabeled_dir",None,
                    "The directory of unlabeled sentences, should contain only test.tsv file per each bin folder")

flags.DEFINE_string("unlabeled_prd_dir",None,
                    "The directory of prediction of unlabeled sentences, should contain test_results.tsv file per each bin folder")

flags.DEFINE_string("output_dir",None,
                    "The output directory.")

flags.DEFINE_integer("top_k",200,"The k value for self-training")

num_mil_bins = len(os.listdir(FLAGS.unlabeled_dir))
num_mil_bins_prd = len(os.listdir(FLAGS.unlabeled_prd_dir))

assert num_mil_bins == num_mil_bins_prd, "Different bin number between unlabeled data and prediction"

prd_res_list = os.listdir(FLAGS.unlabeled_prd_dir)
unlabeled_data_list = os.listdir(FLAGS.unlabeled_dir)

assert unlabeled_data_list == prd_res_list, "Please use same directory name between unlabeled data and prediction"


aug_sents_dict = {}
labels = ['CPR:3','CPR:4','CPR:5','CPR:6','CPR:9']

for i, label in enumerate(labels):
    for prd_dir,dat_dir in zip(prd_res_list,unlabeled_data_list):
        with open(os.path.join(FLAGS.unlabeled_prd_dir,FLAGS.prd_dir,"test_results.tsv"),
                  'r',encoding='utf-8') as f_prd:
            rdr = csv.reader(f_prd, delimiter='\t',quotechar=None)
            prd_probs = list(rdr)

        prd_probs = np.array(prd_probs)

        prob_index = np.arange(len(prd_probs)).reshape(len(prd_probs),1)
        prd_probs = np.hstack([prob_index,prd_probs])
        prd_probs = prd_probs.astype(np.float)

        # sort the output probability per class
        prd_probs = prd_probs[prd_probs[:,i+1].argsort()[::-1]]
        top_k_idx = prd_probs[:FLAGS.top_k,0]
        top_k_idx = top_k_idx.astype(np.int)

        with open(os.path.join(FLAGS.unlabeled_dir,FLAGS.dat_dir,'test.tsv'),'r',encoding='utf-8') as f_sen:
            # load the unlabeled sentences
            rdr = csv.reader(f_sen, delimiter='\t',quotechar=None)
            sen_list = list(rdr)
            sen_list = sen_list[1:]

        if label not in aug_sents_dict.keys():
            aug_sents_dict[label] = []
        for top_idx in top_k_idx:
            aug_sents_dict[label].append(sen_list[top_idx][1])

with open(os.path.join(FLAGS.output_dir,"augmented_data.tsv"),'w',encoding='utf-8',newline='') as f_trn_new:
    wr = csv.writer(f_trn_new,delimiter='\t')
    for label in aug_sents_dict.keys():
        for sen in aug_sents_dict[label]:
            wr.writerow([sen, label])