import argparse
from os import system
import re
import string
from pysbd import Segmenter
from zhon.hanzi import punctuation

from ioFn import readTxt, readJsonl, saveGeneralCSV

def calOverlap(results_a: list, results_b: list):
    """
    results_a, results_b: both list of sentence id
    """
    average_ratio = 0.0
    for (ra, rb) in zip(results_a, results_b):
        ratio = len(set(ra).intersection(set(rb))) / (len(ra) + 1e-3)
        average_ratio += ratio
    return average_ratio / (len(results_a) + 1e-3)

def remove_punc(s):
    # chinese
    for i in punctuation:
        s = s.replace(i, ' ')

    # english
    punctuation_string = string.punctuation
    for i in punctuation_string:
        s = s.replace(i, ' ')
    return s.strip()
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--l-zh", help="label file", 
        default="/home/tiger/wikiLingua_enzh_ext_finetune_english_mspm4_EnZh_chinese/resource/dataset/test.label.jsonl"
    )
    parser.add_argument(
        "--l-en", help="label file", 
        default="/home/tiger/wikiLingua_enzh_ext_finetune_english_mspm4_EnZh_english/resource/dataset/test.label.jsonl"
    )
    parser.add_argument("--system-file", default="/opt/tiger/sumtest/multilingual/system_abs.txt")
    parser.add_argument('-k', type=int, default=30)
    parser.add_argument('-o', default="abs_case.csv")

    args = parser.parse_args()
    
    inputs_zh = readJsonl(args.l_zh)[:args.k]
    oracles_zh = [item['label'] for item in inputs_zh]

    inputs_en = readJsonl(args.l_en)[:args.k]
    oracles_en = [item['label'] for item in inputs_en]

    inputs = {
        'en': {'inputs': inputs_en, "oracles": oracles_en},
        'zh': {'inputs': inputs_zh, "oracles": oracles_zh},
    }

    systems = []
    with open(args.system_file, 'r') as fin:
        for line in fin.readlines():
            missed_num = 0
            sent_num = 0
            hypo_file, name, lang = line.strip().split('\t')
            infos = {}
            
            infos['hypos'] = readTxt(hypo_file)[:args.k]
            infos['name'] = name
            infos['lang'] = lang
            systems.append(infos)
            
    outputs = []
    for i in range(args.k):
        items = [
            " ".join(inputs['zh']['inputs'][i]['document']),
            inputs['zh']['inputs'][i]['summary'],
        ]
        for infos in systems:
            items.append(
                infos['hypos'][i]
            )
        outputs.append(items)

    headers = ['document', 'summary']
    for system_infos in systems:
        headers.append(system_infos['name'])
    saveGeneralCSV(outputs, args.o, start=headers)
