import sys
from collections import namedtuple
import Levenshtein
import os
import threading
import time
import random
import math
from gensim import corpora, models, similarities

from operator import itemgetter

pwd_path = os.path.abspath(os.path.dirname(__file__))
NUM_BEST=1
NUM_THREADS=20

def init_simi(file_list, out_dir):
    #print('begin init sim')

    file_list = [line.split() for line in file_list]
    dictionary = corpora.Dictionary(file_list)
    #dictionary.save_as_text(os.path.join(out_dir, 'gensim.dict'))
    corpus = [dictionary.doc2bow(line) for line in file_list]
    #print(corpus)
    tfidf = models.TfidfModel(corpus)
    corpus_tfidf = tfidf[corpus]
    lsi = models.LsiModel(corpus_tfidf, id2word=dictionary)
    #lsi.save(os.path.join(out_dir, 'gensim.lsi.model'))
    index = similarities.SparseMatrixSimilarity(lsi[corpus], num_best=NUM_BEST, num_features=50000)
    #index.save(os.path.join(out_dir, 'gensim.index'))
    #print('init done')
    return dictionary, lsi, index
    

def top_simi(querys, dictionary, lsi, index):
    #print('calculate top simi')
    querys_lsi = []
    for query in querys:
        querys_lsi.append(lsi[dictionary.doc2bow(query.split())])

    #query_bow = dictionary.doc2bow(query.split())
    #query_lsi = lsi[query_bow]
    sims = index[querys_lsi]

    #sims = index[query_lsi]
    #sims = sorted(enumerate(sims), key=lambda item: -item[1]) # sims: [(index1, sim1), ...(indexn, simn)]

    # 如果完全相同，则取第二个相似的
    #def select_sim(sim_list):
    #    if len(sim_list) == 0:
    #        return (-1, 100.0)
    #    for sim in sim_list:
    #        if 1.0 - sim[1] >= 0.00001:
    #            return sim

    #    return (-1, 100.0)

    #index_select = []
    #for sim_list_item in sims:
    #    index_select.append(select_sim(sim_list_item))
        
    return sims


def select_sentence(src_line, obj_line, threshold=0.4):

    if src_line == obj_line:
        return False

    src_line_list = src_line.split()
    obj_line_list = obj_line.split()

    src_len, obj_len = len(src_line_list), len(obj_line_list)

    if src_len > 100 or obj_len > 100:
        return False

    matrix = [[ i + j for j in range(obj_len + 1)] for i in range(src_len + 1)]

    for i in range(1, src_len +1):
        for j in range(1, obj_len + 1):
            if src_line_list[i-1] == obj_line_list[j-1]:
                d = 0
            else:
                d = 1
            matrix[i][j] = min(matrix[i][j-1]+1, matrix[i-1][j]+1, matrix[i-1][j-1]+d)

    distance = matrix[src_len][obj_len]

    if distance >= min(src_len, obj_len) * threshold:
        return False

    return True

class my_thread(threading.Thread):
    def __init__(self, func, args, name=''):
        threading.Thread.__init__(self)
        self.name = name
        self.func = func
        self.args = args

    def run(self):
        self.results = self.func(*self.args)

    def get_result(self):
        try:
            return self.results
        except Exception:
            return None

def thread_func(src_lines, tgt_lines, tgt_lang, pre_tuple, idx, out_dir, threshold):

    out_filename = tgt_lang + '-' + pre_tuple.tgt_lang + str(idx) +'.out'
    out_filepath = os.path.join(out_dir, out_filename)

    f_out = open(out_filepath, 'w', encoding='utf8', buffering=10)

    batch_size = 50
    batch_num = len(src_lines) // batch_size
    found = 0

    for i in range(batch_num):
        sys.stdout.flush()
        print('thread {0} process batch {1}, found {2}'.format(str(idx), str(i), str(found)))
        batch_src_lines, batch_tgt_lines = src_lines[i * batch_size: (i+1) * batch_size], \
                      tgt_lines[i * batch_size:(i+1)*batch_size]


        sims = top_simi(batch_src_lines, pre_tuple.dictionary, 
                                    pre_tuple.lsi, pre_tuple.index)

        for _, (sim_item, src_line, tgt_line) in enumerate(zip(sims, \
                      batch_src_lines, batch_tgt_lines)):

            for sim in sim_item:
                if not sim:
                    continue

                top_sim_pre_src = pre_tuple.src_lines[sim[0]]
                top_sim_pre_tgt = pre_tuple.tgt_lines[sim[0]]
                if select_sentence(src_line, top_sim_pre_src, threshold):
                    found += 1
                    out_str = '||||'.join([src_line, tgt_line, top_sim_pre_src, top_sim_pre_tgt])
                    f_out.write(out_str + '\n')

    if batch_num * batch_size < len(src_lines):
        batch_src_lines, batch_tgt_lines = src_lines[batch_num * batch_size:], tgt_lines[batch_num * batch_size:]

        sims = top_simi(batch_src_lines, pre_tuple.dictionary, 
                                    pre_tuple.lsi, pre_tuple.index)

        for _, (sim_item, src_line, tgt_line) in enumerate(zip(sims, \
                      batch_src_lines, batch_tgt_lines)):

            for sim in sim_item:
                if not sim:
                    continue

                top_sim_pre_src = pre_tuple.src_lines[sim[0]]
                top_sim_pre_tgt = pre_tuple.tgt_lines[sim[0]]
                if select_sentence(src_line, top_sim_pre_src, threshold):
                    found += 1
                    out_str = '||||'.join([src_line, tgt_line, top_sim_pre_src, top_sim_pre_tgt])
                    f_out.write(out_str + '\n')


def find_multi_way(srcs, tgts, src_langs, tgt_langs, accto_src, out_dir, threshold):


    """
    Objects: find multi-way data from several parallel corpus
    Parameters:
        srcs: source files
        tgts: target files
        src_langs: source languages
        tgt_langs: target languages
        accto_src: whether to find the multi-way data according to the source side
    """

    if accto_src:
        print('find multi-way data according to source')
        assert len(set(src_langs)) == 1, 'langs in source must be same'
        assert len(set(tgt_langs)) == len(tgt_langs), 'langs in target must be different'
    else:
        print('find multi-way data according to target')
        assert len(set(tgt_langs)) == 1, 'langs in target must be same'
        assert len(set(src_langs)) == len(src_langs), 'langs in source must be different'

    len_src = len(srcs)
    assert len(tgts) == len_src and len(src_langs) == len_src and \
          len(tgt_langs) == len_src, 'input error, number of the input files are not equal'

    corpus_tuple = namedtuple('corpus_tuple', 'src_lang, tgt_lang, src_lines, tgt_lines, dictionary, lsi, index')

    corpus_list = []

    for idx, (src_lang, tgt_lang, src_file, tgt_file) in enumerate(zip(src_langs, tgt_langs, srcs, tgts)):
        with open(src_file, 'r', encoding='utf8') as f_s, open(tgt_file, 'r', encoding='utf8') as f_g:
 
            src_line_list, tgt_line_list = f_s.readlines(), f_g.readlines()

            src_line_list = [item.strip() for item in src_line_list]
            tgt_line_list = [item.strip() for item in tgt_line_list]

            block_size = len(src_line_list) // NUM_THREADS

            # search if and only if some previous files have been initialized
            if len(corpus_list) != 0:
                sys.stdout.flush()
                print('search for {0}-{1} begins ...'.format(src_lang, tgt_lang))
                for i in range(len(corpus_list)):
                    pre_tuple = corpus_list[i]

                    thread_list = []

                    for th_idx in range(1, NUM_THREADS+1):
                        
                        thread_src_list = src_line_list[(th_idx-1) * block_size: th_idx *block_size]
                        thread_tgt_list = tgt_line_list[(th_idx-1) * block_size: th_idx * block_size]
                        t = my_thread(thread_func, (thread_src_list, thread_tgt_list, tgt_lang, pre_tuple, th_idx, out_dir, threshold,))
                        thread_list.append(t)
                    
                    if NUM_THREADS * block_size < len(src_line_list):
                        thread_src_list = src_line_list[block_size * NUM_THREADS:]
                        thread_tgt_list = tgt_line_list[block_size * NUM_THREADS:]
                        thread_list.append(my_thread(thread_func, (thread_src_list, thread_tgt_list, tgt_lang, pre_tuple, NUM_THREADS+1, out_dir, threshold)))


                    for thread in thread_list:
                        thread.start()

                    for thread in thread_list:
                        thread.join()

                print('search for {0}-{1} done'.format(src_lang, tgt_lang))

            sys.stdout.flush()
            # we do not need to init the last file
            if idx != len_src - 1:
                print('init file dict for {0}-{1} begin...'.format(src_lang, tgt_lang))
                dictionary, lsi, index = init_simi(src_line_list, out_dir)
                curr_tuple = corpus_tuple(src_lang, tgt_lang, src_line_list, tgt_line_list, dictionary, lsi, index)

                corpus_list.append(curr_tuple)
                print('file dict init ok!')


def deal_multi_way_list(found_multi_way_list, output_path=None):
    if output_path is None:
        output_path = pwd_path

    print(len(found_multi_way_list))
    out_f_dict = dict()
    for item in found_multi_way_list: 
        print(item)
        lang1, lang2, src_line1, tgt_line1, src_line2, tgt_line2 = item
        lang1, lang2  = sorted([lang1, lang2])
        file_prefix = lang1 + '-' + lang2
        if file_prefix in out_f_dict:
            out_f_dict[file_prefix].write(src_line1 + '\t' + tgt_line1 + 
                                          '\t' + src_line2 + '\t' + tgt_line2 + '\n')
        else:
           f_out = open(os.path.join(output_path, file_prefix+'.multi_way'), 'w', encoding='utf8')
           out_f_dict[file_prefix] = f_out
           f_out.write(src_line1 + '\t' + tgt_line1 + 
                       '\t' + src_line2 + '\t' + tgt_line2 + '\n')

    for (key, value) in out_f_dict.items():
        value.close()


if __name__ == '__main__':
    sys.stdout.flush()
    threshold = float(sys.argv[-1])
    out_dir=sys.argv[-2]
    params = sys.argv[1:-2]
    assert len(params) % 4 == 0, 'bad input params'
    srcs, tgts, src_langs, tgt_langs = [], [], [], []
    for i in range(len(params)):
        if i % 4 == 0:
            srcs.append(params[i])
        elif i % 4 == 1:
            tgts.append(params[i])
        elif i % 4 == 2:
            src_langs.append(params[i])
        else:
            tgt_langs.append(params[i])

    find_multi_way(srcs, tgts, src_langs, tgt_langs, True, out_dir, threshold)
