import dspy
import os
import sys
import re
import argparse
import concurrent.futures
from tqdm import tqdm

script_dir = os.path.dirname(os.path.abspath(__file__))
src_root = os.path.join(script_dir, "..", "..")
sys.path.append(src_root)

from src.lm import OpenAIModel
from utils import load_api_key

class ResizeText(dspy.Signature):
    """ You are helping to make a paragraph shorter. Keep original text as much as possible and do not change the meaning. If you think there is no room to shorten the text, output None.
    """
    text = dspy.InputField(prefix="Paragraph:\n", format=str)
    output = dspy.OutputField(prefix="Now give the shortened paragraph (keep as much original text as possible and keep references represented by integer enclosed by square bracket intact):\n", format=str)

class ResizeArticle(dspy.Module):
    def __init__(self, engine):
        self.engine = engine
        self.resize_text = dspy.Predict(ResizeText)
    
    def resize_node(self, article_node):
        with dspy.settings.context(lm=self.engine, show_guideline=False):
            output = self.resize_text(text=article_node.content).output
            if output == "None":
                return
            else:
                article_node.content = output

    def forward(self, article_text, max_words, resize_factor):
        article_dict = parse_article_into_dict(article_text)
        article = Article()
        article.insert_from_dict(article_dict)
        current_length = article.get_total_length()
        target_length = min(max_words, current_length * resize_factor)
        last_round_length = current_length
        while current_length > target_length:
            nodes_to_resize = article.get_all_node_sorted_by_content_length()
            for node in nodes_to_resize:
                self.resize_node(node)
                current_length = article.get_total_length()
                if current_length <= target_length:
                    break
            if article.get_total_length() == last_round_length:
                break
        return article.pre_order_printing()

class ArticleNode():
    def __init__(self, name, content):
        self.name = name
        self.content = content
        self.original_content = content
        self.children = []

    def get_length(self):
        return len(self.content.split())
    
class Article():
    def __init__(self):
        self.root = ArticleNode(name="root", content="")
    
    def insert_from_dict(self, data, root=None):
        root = self.root if root is None else root
        for section_name, section_data in data.items():
            new_node = ArticleNode(name=section_name, content=section_data["content"])
            self.insert_from_dict(data=section_data["subsections"], root=new_node)
            root.children.append(new_node)
    
    def get_all_node_sorted_by_content_length(self):
        nodes = []
        
        def traverse(node):
            if node != self.root and node.get_length() > 0:
                nodes.append(node)
            for child in node.children:
                traverse(child)
        
        traverse(self.root)
        nodes.sort(key=lambda x: x.get_length(), reverse=True)
        return nodes

    def pre_order_printing(self):
        result = []
        
        def traverse(node, level):
            if node != self.root:  # Skip the root node
                result.append(f"{'#' * level} {node.name}\n{node.content}")
            for child in node.children:
                traverse(child, level + 1)
        
        traverse(self.root, 1)
        return '\n'.join(result)
    
    def get_total_length(self):
        total_length = 0
        
        def traverse(node):
            nonlocal total_length
            if node != self.root:  # Skip the root node
                total_length += node.get_length() + len(node.name.split())
            for child in node.children:
                traverse(child)
        
        traverse(self.root)
        return total_length

def parse_article_into_dict(input_string):
        """
        Parses a structured text into a nested dictionary. The structure of the text
        is defined by markdown-like headers (using '#' symbols) to denote sections
        and subsections. Each section can contain content and further nested subsections.

        The resulting dictionary captures the hierarchical structure of sections, where
        each section is represented as a key (the section's title) mapping to a value
        that is another dictionary. This dictionary contains two keys:
        - 'content': content of the section
        - 'subsections': a list of dictionaries, each representing a nested subsection
        following the same structure.

        Args:
            input_string (str): A string containing the structured text to parse.

        Returns:
            A dictionary representing contains the section title as the key, and another dictionary
        as the value, which includes the 'content' and 'subsections' keys as described above.
        """
        lines = input_string.split('\n')
        lines = [line for line in lines if line.strip()]
        root = {'content': '', 'subsections': {}}
        current_path = [(root, -1)]  # (current_dict, level)

        for line in lines:
            if line.startswith('#'):
                level = line.count('#')
                title = line.strip('# ').strip()
                new_section = {'content': '', 'subsections': {}}

                # Pop from stack until find the parent level
                while current_path and current_path[-1][1] >= level:
                    current_path.pop()

                # Append new section to the nearest upper level's subsections
                current_path[-1][0]['subsections'][title] = new_section
                current_path.append((new_section, level))
            else:
                current_path[-1][0]['content'] += line + '\n'

        return root['subsections']

def load_text(path):
    assert os.path.exists(path)
    with open(path) as f:
        return f.read()

def count_citations(text):
    pattern = r'\[\d+\]'
    matches = re.findall(pattern, text)
    return len(matches)

def count_unique_citations(text):
    pattern = r'\[(\d+)\]'
    matches = re.findall(pattern, text)
    unique_citations = set(matches)
    return len(unique_citations)

def process_article_directory(article_dir, original_file_name, resized_file_name, max_words, resize_factor):
    assert os.path.exists(article_dir)
    if os.path.exists(os.path.join(article_dir, resized_file_name)):
        article_to_check = load_text(os.path.join(article_dir, resized_file_name))
        if len(article_to_check.split()) <= max_words:
            return
    openai_kwargs = {
        'api_key': os.getenv("OPENAI_API_KEY"),
        'api_provider': os.getenv('OPENAI_API_TYPE'),
        'temperature': 1.0,
        'top_p': 0.9,
        'api_base': os.getenv('AZURE_API_BASE'),
    }
    lm = OpenAIModel(model="gpt-3.5-turbo-instruct", max_tokens=1000, **openai_kwargs)
    article_resize_module = ResizeArticle(engine=lm)
    article_text = load_text(os.path.join(article_dir, original_file_name))
    resized_text = article_resize_module(article_text, max_words=max_words, resize_factor=resize_factor)
    if len(resized_text.split()) > max_words:
        words = resized_text.split()
        words = words[:max_words]
        resized_text = " ".join(words)
    with open(os.path.join(article_dir, resized_file_name), "w") as f:
        f.write(resized_text)

def main(args):
    load_api_key(toml_file_path=os.path.join(src_root, "..", "secrets.toml"))

    method_dir = args.method_dir
    original_file_name = args.original_file_name
    resized_file_name = args.resized_file_name
    max_words = args.max_words
    resize_factor = args.resize_factor

    # Collect all article directories
    article_dirs = [os.path.join(method_dir, d) for d in os.listdir(method_dir) if os.path.isdir(os.path.join(method_dir, d))]

    # Multi-threaded processing of article directories with progress bar
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        futures = {executor.submit(process_article_directory, article_dir, original_file_name, resized_file_name, max_words, resize_factor): article_dir for article_dir in article_dirs}

        with tqdm(total=len(futures)) as pbar:
            for future in concurrent.futures.as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    print(f"Error processing article directory: {e}")
                pbar.update(1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process article directories with resizing.")
    parser.add_argument('--method-dir', required=True, help="Directory containing article method directories")
    parser.add_argument('--original-file-name', choices=['storm_gen_article.txt', 'report.txt'], required=True, help="Original file name to process")
    parser.add_argument('--resized-file-name', default='article_to_evaluate.txt', help="Name of the resized file to save")
    parser.add_argument('--max-words', type=int, default=2000, help="max words cap")
    parser.add_argument('--resize-factor', type=float, default=1.0, help="max words cap")

    args = parser.parse_args()
    main(args)