from enum import Enum


class TextType(Enum):
    prompt = 1
    plan_tbe = 2
    bte = 3
    prompt_with_name = 4
    cards = 5
    wo_target = 6
    bte_outline = 7
    tbe = 8


# def clean_text(f):
#     def wrapper(*args, **kwargs):
#         text = f(*args, **kwargs)
#         special_tokens = ['<|endofcard|>', '<|endofprompt|>', '<|beginofbedding|>',
#                           '<|beginoftarget|>', '<|beginofending|>', '<|endoftarget|>',
#                           '<|endofoutline|>', '<|sepofoutline|>', '<end_card>', '<|beginofoutline|>',
#                           '<|sepofoutlinesent|>', '<|sepofname|>']
#
#         def clean_str(s):
#             for t in special_tokens:
#                 s = s.replace(t, '')
#             return s.strip()
#
#         if isinstance(text, tuple):
#             res = []
#             for i in range(len(text)):
#                 res.append(clean_str(text[i]))
#             return tuple(res)
#         else:
#             text = clean_str(text)
#             return text
#
#     return wrapper


def process_text(text, mode: TextType):
    def clean_special_tokens(s):
        special_tokens = ['<|endofcard|>', '<|endofprompt|>', '<|beginofbedding|>',
                          '<|beginoftarget|>', '<|beginofending|>', '<|endoftarget|>',
                          '<|endofoutline|>', '<|sepofoutline|>', '<end_card>', '<|beginofoutline|>',
                          '<|sepofoutlinesent|>', '<|sepofname|>']
        for t in special_tokens:
            s = s.replace(t, ' ')
        return s.strip()

    def final_clean(text):
        text = clean_special_tokens(text)
        import re
        return re.sub(r' +', ' ', text.strip())

    if mode == TextType.prompt:
        persona, prompt = text.split('<|endofcard|>')
        prompt = prompt.replace('<|endofprompt|>', ' ')
        return final_clean(persona), final_clean(prompt)
    elif mode == TextType.wo_target:
        text = text.split("<|beginofbedding")[-1]
        return final_clean(text)
    elif mode == TextType.plan_tbe:
        bot = text.find('<|beginoftarget|>')
        eot = text.find('<|endoftarget|>')
        target = text[bot:eot].replace('<|beginoftarget|>', ' ')
        bob = text.find('<|beginofbedding|>')
        boe = text.find('<|beginofending|>')
        bedding = text[bob:boe].replace('<|beginofbedding|>', ' ')
        ending = text[boe:].replace('<|beginofending|>', ' ')
        return final_clean(bedding + ' ' + target + ' ' + ending)

    elif mode == TextType.bte:
        text = text.replace('<|beginofbedding|>', ' ').replace('<|beginoftarget|>', ' ').replace('<|beginofending|>',
                                                                                                 ' ')
        return final_clean(text)

    elif mode == TextType.prompt_with_name:
        persona, prompt = text.split('<|endofcard|>')
        persona = persona.split('<|sepofname|>')[1]
        prompt = prompt.replace('<|endofprompt|>', ' ')
        return final_clean(persona), final_clean(prompt)

    elif mode == TextType.cards:
        return final_clean(text)

    elif mode == TextType.bte_outline:
        idx = text.find('<|beginofbedding|>')
        return final_clean(text[idx:])

    elif mode == TextType.tbe:
        bot = text.find('<|beginoftarget|>')
        bob = text.find('<|beginofbedding|>')
        # eot = text.find('<|endoftarget|>')
        target = text[bot:bob].replace('<|beginoftarget|>', ' ')
        boe = text.find('<|beginofending|>')
        bedding = text[bob:boe].replace('<|beginofbedding|>', ' ')
        ending = text[boe:].replace('<|beginofending|>', ' ')
        return final_clean(bedding + ' ' + target + ' ' + ending)


if __name__ == '__main__':
    print(process_text('gg<|endofcard|>xx', TextType.prompt))
