# -*- coding: utf-8 -*-
import re
import os
import torch
import random
import numpy as np
import logging
from datetime import datetime

def get_sentences_and_paragraphs(article):
    paragraphs = article.split('\n')
    sentences = []
    paragraph_indices = []
    cur_idx = 0

    for p in paragraphs:
        sent = split_paragraph_into_sentences(p)
        ### NOTE : one-sentence paragraphs with a length less than 50 are not indexed.
        if len(sent) == 0:
            continue
        if len(sent) == 1 and len(sent[0]) < 50:
            continue
        sentences.extend(sent)

        paragraph_indices.append([cur_idx, cur_idx + len(sent)])
        cur_idx += len(sent)

    return sentences, paragraph_indices


alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov|edu|me)"
digits = "([0-9])"
multiple_dots = r'\.{2,}'


def split_paragraph_into_sentences(text):
    """
    Split the text into sentences.

    If the text contains substrings "<prd>" or "<stop>", they would lead 
    to incorrect splitting because they are used as markers for splitting.

    :param text: text to be split into sentences
    :type text: str

    :return: list of sentences
    :rtype: list[str]
    """
    text = " " + text + "  "
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(". ",".<stop> ")
    text = text.replace("? ","?<stop> ")
    text = text.replace("! ","!<stop> ")
    text = text.replace(": ",":<stop> ")
    text = text.replace("; ",";<stop> ")
    text = text.replace("et al.<stop>", "et al.")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = [s.strip() for s in sentences]
    if sentences and not sentences[-1]: sentences = sentences[:-1]
    return sentences


def split_article_into_sentences(text):
    paragraphs = text.split('\n')
    sentences = []
    for p in paragraphs:
        p = split_paragraph_into_sentences(p)
        if len(p) == 0:
            continue
        if len(p) == 1 and len(p[0]) < 50:
            continue
        sentences.extend(p)
    return sentences

def get_informative_paragraphs(num_paragraphs, article):
    
    min_threshold, max_threshold = 300, 2000

    paragraphs = article.split("\n")
    sentences = []
    valid_paragraphs = []  # sentence list, start index, end index(+1)
    cur_idx = 0

    for p in paragraphs:
        sent = split_paragraph_into_sentences(p)
        ### NOTE : one-sentence paragraphs with a length less than 50 are not indexed.
        if len(sent) == 0:
            continue
        if len(sent) == 1 and len(sent[0]) < 50:
            continue
        sentences.extend(sent)
        original_len = len(sent)

        if sent[-1].strip().endswith(":"):
            sent = sent[:-1]

        if min_threshold < len(" ".join(sent).strip()) < max_threshold:
            valid_paragraphs.append([sent, cur_idx, cur_idx + len(sent)])

        cur_idx += original_len

    if len(valid_paragraphs) <= num_paragraphs:
        return valid_paragraphs

    else:
        return random.sample(valid_paragraphs, num_paragraphs) 



def seed_everything(seed:int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # current gpu seed
    torch.cuda.manual_seed_all(seed) # All gpu seed
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  #

def set_logger():
    if not os.path.exists("logs"):
        os.makedirs("logs")
        
    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    logging.basicConfig(
        filename=f"logs/{now}.log",
        filemode="w",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    logger = logging.getLogger('logger')
    logger.setLevel(logging.DEBUG)

    fh = logging.FileHandler(filename=f"logs/{now}.log")
    fh.setLevel(logging.DEBUG)

    formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")
    fh.setFormatter(formatter)

    # add ch to logger
    logger.addHandler(fh)
    
    return logger

