import sys
import os
import re

def extract_para_from_file(file_in):
    file_list = []
    with open(file_in, 'r', encoding='utf8') as f_in:
        for line in f_in:
            line = line.strip()
            line_list = line.split('||||')
            assert len(line_list) == 4, line
            line_list = [line_item.strip() for line_item in line_list]
            file_list.append(line_list)

    return file_list

def filter_sentence(para_list, threshold=0.4):
    out_list = []
    for idx, list_item in enumerate(para_list):
        if idx % 1000 == 0:
            print('processed {} lines'.format(idx))
        src_en_line, tgt_en_line = list_item[0], list_item[2]
        src_en_list = src_en_line.split()
        tgt_en_list = tgt_en_line.split()


        src_len, obj_len = len(src_en_list), len(tgt_en_list)


        matrix = [[ i + j for j in range(obj_len + 1)] for i in range(src_len + 1)]

        for i in range(1, src_len +1):
            for j in range(1, obj_len + 1):
                if src_en_list[i-1] == tgt_en_list[j-1]:
                    d = 0
                else:
                    d = 1
                matrix[i][j] = min(matrix[i][j-1]+1, matrix[i-1][j]+1, matrix[i-1][j-1]+d)

        distance = matrix[src_len][obj_len]

        if distance >= src_len * threshold:
            continue
        else:
            out_list.append(list_item)

    return out_list
    
def extract_from_dir(dir_path, threshold=0.4):
    dir_list = []
    list_file = os.listdir(dir_path)
    src_lang, tgt_lang= None, None
    for filename in list_file:
        if re.match(r'[a-z]{2}-[a-z]{2}\d+.out', filename):
            print('processing file {}'.format(filename))
            if src_lang is None or tgt_lang is None:
                src_lang = filename[:2]
                tgt_lang = filename[3:5]
            file_path = os.path.join(dir_path, filename)
            file_list = extract_para_from_file(file_path)
            dir_list.extend(file_list)
    
    dir_list = filter_sentence(dir_list, threshold=threshold)

    src_out_file = 'syn_multi_way_src_len_' + src_lang + '-' + tgt_lang + '.' + \
                    src_lang + '_' + str(threshold)
    src_en_out_file = 'syn_multi_way_src_len_' + src_lang + '-' + tgt_lang + '.' + \
                    src_lang + 'en' + '_' + str(threshold)
    tgt_en_out_file = 'syn_multi_way_src_len_' + src_lang + '-' + tgt_lang + '.' + \
                    tgt_lang + 'en' + '_' + str(threshold)
    tgt_out_file = 'syn_multi_way_src_len_' + src_lang + '-' + tgt_lang + '.' + \
                    tgt_lang + '_' + str(threshold)

    f_out_src = open(os.path.join(dir_path, src_out_file), 'w', encoding='utf8')
    f_out_src_en = open(os.path.join(dir_path, src_en_out_file), 'w', encoding='utf8')
    f_out_tgt_en = open(os.path.join(dir_path, tgt_en_out_file), 'w', encoding='utf8')
    f_out_tgt = open(os.path.join(dir_path, tgt_out_file), 'w', encoding='utf8')

    for line in dir_list:
        f_out_src.write(line[1] + '\n')
        f_out_src_en.write(line[0] + '\n')
        f_out_tgt_en.write(line[2] + '\n')
        f_out_tgt.write(line[3] + '\n')

    f_out_src.close()
    f_out_src_en.close()
    f_out_tgt_en.close()
    f_out_tgt.close()

if __name__ == "__main__":
    dir_path = sys.argv[1]
    dir_list = extract_from_dir(dir_path)
