import sys
import os
import numpy as np
import math

word_pred = 0.6
insert_delete_replace_probs = [0.4, 0.3, 0.3]

word_list = []
freq_list = []
word2freq = dict()
dic_len = 0

def read_dict(dict_path):
    with open(dict_path, 'r', encoding='utf8') as f_in:
        for line in f_in:
            line = line.strip()
            line_list = line.split()
            if len(line_list) == 2:
                word, freq = line_list[0].strip(), int(line_list[1].strip())
                word_list.append(word)
                freq_list.append(freq)
                assert word not in word2freq, word
                word2freq[word] = freq
        global dic_len
        dic_len = len(word_list)

def rand_insert(tgt_line):
    tgt_list = tgt_line.split()
    tgt_len = len(tgt_list)
    insert_len = math.ceil(tgt_len * np.random.uniform(0, word_pred))
    insert_index = np.random.randint(tgt_len)
    
    insert_word_list = []
    for i in range(insert_len):
        insert_word_list.append(word_list[np.random.randint(dic_len)])

    pred_list = tgt_list[:insert_index] + insert_word_list + tgt_list[insert_index:]
    return ' '.join(pred_list)

def rand_delete(tgt_line):
    tgt_list = tgt_line.split()
    tgt_len = len(tgt_list)
    delete_len = math.ceil(tgt_len * np.random.uniform(0, word_pred))
    delete_index = np.random.randint(tgt_len)

    pred_list = tgt_list[0:delete_index]
    if not delete_index + delete_len >=tgt_len:
        pred_list += tgt_list[delete_index + delete_len:] 

    return ' '.join(pred_list)

def rand_replace(tgt_line):
    tgt_list = tgt_line.split()
    tgt_len = len(tgt_list)
    replace_len = math.ceil(tgt_len * np.random.uniform(0, word_pred))
    replace_index = np.random.randint(tgt_len)

    replace_word_list = []
    for i in range(replace_len):
        replace_word_list.append(word_list[np.random.randint(dic_len)])

    pred_list = tgt_list[0:replace_index] + replace_word_list

    if not replace_index + replace_len >=tgt_len:
        pred_list += tgt_list[replace_index + replace_len:] 

    return ' '.join(pred_list)
   
def rand_insert_delete_replace(tgt_line):
    tgt_list = tgt_line.split()
    pred_tgt_list = []
    tgt_len = len(tgt_list)
    pred_mask = np.random.rand(tgt_len) < word_pred

    for i in range(tgt_len):
        if pred_mask[i]:
            ins_del_rep = np.random.multinomial(1, insert_delete_replace_probs)
            if ins_del_rep[0] == 1:  # insert
                rand_insert_word_index = np.random.randint(dic_len)
                insert_word = word_list[rand_insert_word_index]
                pred_tgt_list.append(insert_word)
                pred_tgt_list.append(tgt_list[i])
            elif ins_del_rep[1] == 1:  # delete
                continue
            else:  # replace
                rand_insert_word_index = np.random.randint(dic_len)
                insert_word = word_list[rand_insert_word_index]
                pred_tgt_list.append(insert_word)
        else:
            pred_tgt_list.append(tgt_list[i])

    return ' '.join(pred_tgt_list)

def make_data(src_file, tgt_file, src_file_out, tgt_file_out, tgt_truth_out):
    f_src_out = open(src_file_out, 'w', encoding='utf8')
    f_tgt_out = open(tgt_file_out, 'w', encoding='utf8')
    f_tru_out = open(tgt_truth_out, 'w', encoding='utf8')
    
    with open(src_file, 'r', encoding='utf8') as f_s, open(tgt_file, 'r', encoding='utf8') as f_t:
        for line_id, (line_s, line_t) in enumerate(zip(f_s, f_t)):
            line_s, line_t = line_s.strip(), line_t.strip()
            pre_tgt = rand_insert_delete_replace(line_t)
            f_src_out.write(line_s + '\n')
            f_tgt_out.write(pre_tgt + '\n')
            f_tru_out.write(line_t + '\n')
            
            pre_insert_tgt = rand_insert(line_t)
            f_src_out.write(line_s + '\n')
            f_tgt_out.write(pre_insert_tgt + '\n')
            f_tru_out.write(line_t + '\n')

            pre_delete_tgt = rand_delete(line_t)
            f_src_out.write(line_s + '\n')
            f_tgt_out.write(pre_delete_tgt + '\n')
            f_tru_out.write(line_t + '\n')

            pre_replace_tgt = rand_replace(line_t)
            f_src_out.write(line_s + '\n')
            f_tgt_out.write(pre_replace_tgt + '\n')
            f_tru_out.write(line_t + '\n')

    f_src_out.close()
    f_tgt_out.close()
    f_tru_out.close()

if __name__ == '__main__':
    dict_path = sys.argv[1]
    read_dict(dict_path)

    src_file, tgt_file, src_file_out, tgt_file_out, tgt_truth_out =sys.argv[2:]
    make_data(src_file, tgt_file, src_file_out, tgt_file_out, tgt_truth_out)
