import concurrent.futures

from typing import Union, List
from utils import ArticleTextProcessing
from .dataclass import NewInformation
from src.interface import Retriever, Information
from src.rm import YouRM
from storm_wiki.modules.retriever import is_valid_wikipedia_source

class NewRetriever(Retriever):
    def __init__(self, 
                 ydc_api_key=None, 
                 search_top_k_source=3,
                 max_thread=3):
        super().__init__(search_top_k=search_top_k_source)
        self.max_thread = max_thread
        self.you_rm = YouRM(ydc_api_key=ydc_api_key, k=self.search_top_k, is_valid_source=is_valid_wikipedia_source)

    def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]:
        self.you_rm.k = self.search_top_k
        queries = query if isinstance(query, list) else [query]
        to_return = []
        
        def process_query(q):
            retrieved_data_list = self.you_rm(query_or_queries=[q], exclude_urls=exclude_urls)
            local_to_return = []
            for data in retrieved_data_list:
                for i in range(len(data['snippets'])):
                    # STORM generate the article with citations. We do not consider multi-hop citations.
                    # Remove citations in the source to avoid confusion.
                    data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i])
                storm_info = NewInformation.from_dict(data)
                storm_info.meta["query"] = q
                local_to_return.append(storm_info)
            return local_to_return

        with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread) as executor:
            results = list(executor.map(process_query, queries))
        
        for result in results:
            to_return.extend(result)
        
        return to_return
