# we assume that inputs are chinese

import argparse
import string

def readTxt(inputfile):
    results = []
    with open(inputfile, 'r') as fin:
        for line in fin:
            results.append(line.strip())
    return results

def saveTxt(datas, outputfile):
    with open(outputfile, 'w') as fout:
        for line in datas:
            fout.write(line + '\n')

# def filter(inputs: list or str):
#     symbols = ".,!?。，！？\'\"“”"
#     results = [char for char in inputs if char not in symbols]
#     if isinstance(inputs, str):
#         results = "".join(results)
#     return results

def B2Q(uchar):
    """标点 单个字符 半角转全角"""
    if uchar in string.punctuation:
        inside_code = ord(uchar)
        if inside_code < 0x0020 or inside_code > 0x7e: # 不是半角字符就返回原来的字符
            return uchar 
        if inside_code == 0x0020: # 除了空格其他的全角半角的公式为: 半角 = 全角 - 0xfee0
            inside_code = 0x3000
        else:
            inside_code += 0xfee0
        return chr(inside_code)
    else:
        return uchar


def matchRawAndTokenized(raw_str, tokenized_str):
    """
    Return:
        char2spm: dict, {char position index (int): spm token position index (int)}
    """
    token_idx = 0
    char_idx = 0
    tokenized_tokens = tokenized_str.split()
    char2spm = dict()
    while token_idx < len(tokenized_tokens) and char_idx < len(raw_str):
        token = tokenized_tokens[token_idx]
        token = token.replace("_", "")
        while char_idx < len(raw_str) and raw_str[char_idx] in token:
            char2spm[char_idx] = token_idx
            char_idx += 1
        token_idx += 1
    char2spm[len(raw_str)] = len(tokenized_tokens) # add ``sequence_length'' char_idx
    return char2spm

def match(args):
    tokenized_passages = readTxt(args.tokenized_passage)
    tokenized_answers = readTxt(args.tokenized_answer)

    assert len(tokenized_passages) == len(tokenized_answers)
    num_failed = 0
    results = []
    for tp, ta in zip(tokenized_passages, tokenized_answers):
        token_tp = tp.split()
        token_ta = ta.split()
        if token_ta[0] == "▁":
            token_ta.pop(0)

        answer_first = token_ta[0]
        if answer_first not in token_tp:
            if answer_first.startswith("▁"):
                answer_first = answer_first[1:]
                if answer_first in token_tp:
                    token_ta[0] = answer_first
        
        results.append(" ".join(token_ta))

    saveTxt(results, args.output)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--tokenized-passage", help="tokenized passage", type=str)
    parser.add_argument("--tokenized-answer", help="tokenized answer", type=str)
    parser.add_argument("--output", help="tokenized answer", type=str)

    args = parser.parse_args()

    match(args)

    # python3 qgMatchTokenizedAnsPas.py --tokenized-passage /opt/tiger/sumtest/xqg_nospace/MSPM/zh/dev.e.zh.lc --tokenized-answer /opt/tiger/sumtest/xqg_nospace/MSPM/zh/dev.a.zh.lc --output /opt/tiger/sumtest/xqg_nospace/MSPM/zh/dev.a.zh.lc.match