import re
import logging

from transformers import \
    AutoConfig, \
    AutoModelForSeq2SeqLM

PDICT = {
    'bart': {
        'model_args': 'facebook/bart-large',
        'config': AutoConfig,
        'model': AutoModelForSeq2SeqLM
    },
    't5_base': {
        'model_args': 't5-base',
        'config': AutoConfig,
        'model': AutoModelForSeq2SeqLM
    },
    't5_large': {
        'model_args': 't5-large',
        'config': AutoConfig,
        'model': AutoModelForSeq2SeqLM
    },
    'gpt2': {
        'model_args': 'gpt2-large',
        'config': AutoConfig,
        'model': AutoModelForSeq2SeqLM
    }
}


def replace_unicode_punct(text):
    """
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
    """
    text = str(text)
    text = text.replace("，", ",")
    text = re.sub(r"。\s*", ". ", text)
    text = text.replace("、", ",")
    text = text.replace("”", '"')
    text = text.replace("“", '"')
    text = text.replace("∶", ":")
    text = text.replace("：", ":")
    text = text.replace("？", "?")
    text = text.replace("《", '"')
    text = text.replace("》", '"')
    text = text.replace("）", ")")
    text = text.replace("！", "!")
    text = text.replace("（", "(")
    text = text.replace("；", ";")
    text = text.replace("１", "1")
    text = text.replace("」", '"')
    text = text.replace("「", '"')
    text = text.replace("０", "0")
    text = text.replace("３", "3")
    text = text.replace("２", "2")
    text = text.replace("５", "5")
    text = text.replace("６", "6")
    text = text.replace("９", "9")
    text = text.replace("７", "7")
    text = text.replace("８", "8")
    text = text.replace("４", "4")
    text = re.sub(r"．\s*", ". ", text)
    text = text.replace("～", "~")
    text = text.replace("’", "'")
    text = text.replace("…", "...")
    text = text.replace("━", "-")
    text = text.replace("〈", "<")
    text = text.replace("〉", ">")
    text = text.replace("【", "[")
    text = text.replace("】", "]")
    text = text.replace("％", "%")
    return text


def get_logger(output_log_path=None):
    logger = logging.getLogger()
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    if output_log_path is not None:
        file_handler = logging.FileHandler(output_log_path)
    logger.setLevel(logging.INFO)
    logger.addHandler(stream_handler)
    if output_log_path is not None:
        logger.addHandler(file_handler)
    return logger


logger = get_logger()

if __name__ == '__main__':
    print('')

