import re, os
from rouge_score import rouge_scorer, tokenize

class DataUtils:
    @staticmethod
    def split_segments(statement):
        """
            	Split a statement into segments, each segment is separated by a period or question mark.
            	If the segment contains a citation mark, it will be added to the last segment.
            	Return two lists, one for segments and another for citations.
            	Args:
        	statement (str): The input statement.
            	Returns:
        	tuple (list, list): Two lists, the first for segments and the second for citations.
            	"""
        all_statements = []
        statement = re.sub(' +', ' ', statement.replace('\n', ' '))
        split_pattern = r'(?<!\w\.\w.)(?<![A-Z]\.)(?<![A-Z][a-z]\.)(?<! [a-z]\.)(?<![A-Z][a-z][a-z]\.)(?<=\.|\?|\!)\"*\s*\s*(?:\W*)([A-Z])'
        tmp_statements = []
        
        for s in re.split(r"(\[\d+\])", statement):
            if not s:
                continue
            cites = re.findall(r"\[(\d+)\]", s)
            if not cites: # Segment
                tmp_statements.append([s, []])
            elif not tmp_statements: # Citation Mark, but no Segments
                continue
            else: # Citation Mark
                for item in cites:
                    tmp_statements[-1][1].append(int(item) - 1)
        
        for s, cite in tmp_statements:
            prefix = ""
            for ix, seg in enumerate(re.split(split_pattern, s)):
                if len(prefix) > 20:
                    all_statements.append([prefix, []])
                    prefix = ""
                prefix += seg
                if prefix and prefix[-1] in ['.!?:']:
                    prefix += " "
            if prefix:
                if all_statements and len(prefix) < 20:
                    all_statements[-1][0] += prefix
                else:
                    all_statements.append([prefix, []])
            if all_statements:
                all_statements[-1][1] += cite
        
        return [seg[0] for seg in all_statements], [seg[1] for seg in all_statements]
    
    @staticmethod
    def matching_score(all_statements, references):
        def remove_stopwords(stmt):
            stmt = tokenize.tokenize(stmt, None)
            ret = []
            for item in stmt:
                if item in stopwords:
                    continue
                ret.append(item)
            return " ".join(ret)
        
        all_statements = [remove_stopwords(item) for item in all_statements]
        references = [remove_stopwords(item) for item in references]
        
        # return None
        scorer = rouge_scorer.RougeScorer(['rouge1'])
        all_scores = []
        for statement in all_statements:
            if len(tokenize.tokenize(statement, None)) < 5:
                all_scores.append([0] * len(references))
                continue
            ref_score = []
            for idx, ref in enumerate(references):
                rouge = scorer.score(ref, statement)['rouge1'].precision
                # print(rouge)
                ref_score.append(rouge)
            all_scores.append(ref_score)
        return all_scores
    
    @staticmethod
    def get_ideal_citations(all_scores, raw_citations, citation_threshold, extra_bonus=0.3):
        """
            	Args:
            		all_scores (list of list of float): all segmentation scores, shape (num_segs, num_words)
            		raw_citations (list of list of int): all raw citations, shape (num_segs, max_num_cits)
            		citation_threshold (float): threshold to determine whether a word is a citation
            		extra_bonus (float, optional): additional bonus for words that are already cited, default 0.3
            
            	Returns:
            		list of list of int: ideal citations for each segmentation, shape (num_segs, num_words)
            			each element is a list of word indices, representing the ideal citation order
            
            	Raises:
            		AssertionError: if length of `all_scores` and `raw_citations` are not equal
            	"""
        
        assert len(all_scores) == len(raw_citations)
        
        ideal_citations = []
        for seg_idx, scores in enumerate(all_scores):
            idc = []
            best_idx = 0
            best_scr = 0
            for idx, score in enumerate(scores):
                if idx in raw_citations[seg_idx]:
                    score += extra_bonus / len(raw_citations[seg_idx])
                if score >= citation_threshold:
                    idc.append(idx)
                if score > best_scr:
                    best_idx = idx
            if len(idc) == 0 and len(raw_citations[seg_idx]) > 0:
                idc.append(best_idx)
            ideal_citations.append(idc)
        return ideal_citations
    
    @staticmethod
    def recompose(all_statements, raw_citations, references, sep=" ", citation_threshold=0.75):
        """
        将原始引用和所有段落重新组合，并返回一个字符串。如果两个连续的段落之间没有空格，则在这两个段落之间插入一个空格。
        
        Args:
            all_statements (list[str]): 包含所有段落的列表，每个元素都是一个字符串。
            raw_citations (list[tuple[int]]): 包含所有原始引用的列表，其中每个元素都是一个元组（起始位置，结束位置）。
            references (list[str]): 包含所有参考文献的列表，每个元素都是一个字符串。
            sep (str, optional): 默认为" "，表示两个连续的段落之间插入的分隔符。 Default: " ".
            citation_threshold (float, optional): 默认为0.75，表示两个引用之间的最小相似度，以确定哪些引用应该被视为同一引用。 Default: 0.75.
        
        Returns:
            str: 一个字符串，包含所有段落、原始引用和参考文献，每个元素之间使用sep分隔。
        
        """
        scores = DataUtils.matching_score(all_statements, references)
        ret = ""
        ideal_citations = DataUtils.get_ideal_citations(scores, raw_citations, citation_threshold)
        for seg, cit in zip(all_statements, ideal_citations):
            # judge if seg[0] is alphanumeric
            if ret and ret[-1] == "]" and seg and seg[0].isalnum():
                ret += sep
            ret += seg
            for c in cit:
                ret += "[%d]"%(c+1)
            if ret and ret[-1] in ".!?:":
                ret += sep
        return ret.strip()

class Stopwords:
    @staticmethod
    def load():
        """
        加载停用词列表，返回一个列表，元素为字符串类型，不包含换行符。
        默认从两个文件中读取："./model/stopwords/english"和"./model/stopwords/explaination"。
        如果需要自定义路径或者文件名，可以在调用该函数时传入参数。
        
        Args:
            none (default None): 无参数，默认从两个文件中读取。
        
        Returns:
            list (str): 返回一个列表，元素为字符串类型，不包含换行符。
        
        Raises:
            none (default None): 没有异常抛出。
        
        Example:
            >>> StopWords.load()
            ['the', 'of', 'and', 'a', 'an', 'in', ...]
        """
        src = [
            "./model/stopwords/english",
            "./model/stopwords/explaination",
        ]
        ret = []
        for item in src:
            with open(item, "r") as f:
                ret += [word.strip() for word in f.readlines()]
        return ret


stopwords = set(Stopwords.load())

def citation_correction(original_answer, references):
    """
    根据原始答案和参考文献，进行引用错误的自动更正。
    
    Args:
        original_answer (str): 包含可能存在引用错误的原始答案字符串。
            format: "段落1\n段落2\n...\n段落N"（每个段落之间使用换行符分隔）。
        references (List[str]): 一个包含多个参考文献的列表，每个文献都是一个字符串，
            format: "作者姓名，标题，出版社，出版日期"。
    
    Returns:
        str: 经过自动更正后的原始答案字符串，格式与输入相同。
    """
    segments, raw_cite = DataUtils.split_segments(original_answer)
    
    return DataUtils.recompose(segments, raw_cite, references)