# coding=utf-8
# Create mixup training examples, and convert the original data into mixup training format
# All input data(train.tsv, dev.tsv, test.tsv) should follow the BERT input format
# For training set, the augmented examples are added to original data(formatted)
# For dev and test set, only re-formatting is performed

import csv
import os
import numpy as np
import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("data_dir",None,
                    "The input data dir. should only contain .tsv files for train,dev,test")

flags.DEFINE_bool("train",True,
                  "when train.tsv exists")

flags.DEFINE_bool("dev",True,
                  "when dev.tsv exists")

flags.DEFINE_bool("test",True,
                  "when test.tsv exists")

# randomly select the mixing ratio from this list, under uniform distribution
ratio_list = np.linspace(0.0,1.0,11)

os.mkdir(os.path.join(data_dir,"mixup"))

np.random.seed(0)

if FLAGS.train == True:
    with open(os.path.join(FLAGS.data_dir,"train.tsv"),'r',encoding='utf-8') as f_trn:
        rdr = csv.reader(f_trn,delimiter='\t')
        rdr = list(rdr)
        train_sen = [item[0] for item in rdr]
        train_label = [item[1] for item in rdr]
    
    sen_num = len(train_sen)
    
    with open(os.path.join(FLAGS.data_dir,"mixup/train.tsv"),'w',encoding='utf-8',newline='') as f_mxp:
        wr = csv.writer(f_mxp,delimiter='\t')
        for sen,label in zip(train_sen,train_label):
            # add dummy sentence and ratio2=0 for formatting
            row = [sen,'dummy',label,1.0,'false',0.0]
            wr.writerow(row)

        for sen,label in zip(train_sen,train_label):
            for i in range(3):
                second_index = np.random.choice(range(sen_num),1)[0]
                second_sen = train_sen[second_index]
                second_label = train_label[second_index]
                first_ratio = np.random.choice(ratio_list,1)[0]
                second_ratio = 1.0 - first_ratio

                if first_ratio >= second_ratio:
                    row = [sen,second_sen,label,first_ratio,second_label,second_ratio]
                    wr.writerow(row)
                else:
                    row = [second_sen,sen,second_label,second_ratio,label,first_ratio]
                    wr.writerow(row)

if FLAGS.dev == True:
    with open(os.path.join(FLAGS.data_dir,"dev.tsv"),'r',encoding='utf-8') as f:
        rdr = csv.reader(f,delimiter='\t')
        rdr = list(rdr)
        dev_sen = [item[0] for item in rdr]
        dev_label = [item[1] for item in rdr]

    with open(os.path.join(FLAGS.data_dir,"mixup/dev.tsv"),'w',encoding='utf-8',newline='') as f_mxp:
        wr = csv.writer(f_mxp,delimiter='\t')
        for sen,label in zip(dev_sen,dev_label):
            # add dummy sentence and ratio2=0 for formatting
            row = [sen,'dummy',label,1.0,'false',0.0]
            wr.writerow(row)

if FLAGS.test == True:
    with open(os.path.join(FLAGS.data_dir,"test.tsv"),'r',encoding='utf-8') as f:
        rdr = csv.reader(f,delimiter='\t')
        rdr = list(rdr)[1:]
        tst_sen = [item[1] for item in rdr]
        tst_label = [item[2] for item in rdr]

    with open(os.path.join(FLAGS.data_dir,"mixup/test.tsv"),'w',encoding='utf-8',newline='') as f_mxp:
        wr = csv.writer(f_mxp,delimiter='\t')
        # Test set file should contain headers in BERT
        wr.writerow(['index','sentence1','sentence2','label1','ratio1','label2','ratio2'])
        for i,(sen,label) in enumerate(zip(tst_sen,tst_label)):
            # add dummy sentence and ratio2=0 for formatting
            row = [i, sen,'dummy',label,1.0,'false',0.0]
            wr.writerow(row)