From 81ab9f91a40c3bc9df625219d207fca661e87839 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sat, 23 Nov 2024 19:40:56 +0800 Subject: [PATCH] up --- crazy_functions/Arxiv_论文对话.py | 226 ++++++++++++------------------ 1 file changed, 87 insertions(+), 139 deletions(-) diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 7e9924ce..125df8cc 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -46,28 +46,27 @@ class ArxivRagWorker: def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None): self.user_name = user_name self.llm_kwargs = llm_kwargs - self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数 + self.max_concurrent_papers = MAX_CONCURRENT_PAPERS self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None - # 初始化基础存储目录 + # Initialize base storage directory self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) - # 如果提供了 arxiv_id,创建针对该论文的子目录 if self.arxiv_id: self.checkpoint_dir = self.base_dir / self.arxiv_id self.vector_store_dir = self.checkpoint_dir / "vector_store" self.fragment_store_dir = self.checkpoint_dir / "fragments" else: - # 如果没有 arxiv_id,使用基础目录 self.checkpoint_dir = self.base_dir self.vector_store_dir = self.base_dir / "vector_store" self.fragment_store_dir = self.base_dir / "fragments" - # 创建必要的目录 if os.path.exists(self.vector_store_dir): self.loading = True else: self.loading = False + + # Create necessary directories self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.vector_store_dir.mkdir(parents=True, exist_ok=True) self.fragment_store_dir.mkdir(parents=True, exist_ok=True) @@ -76,11 +75,11 @@ class ArxivRagWorker: logger.info(f"Vector store directory: {self.vector_store_dir}") logger.info(f"Fragment store directory: {self.fragment_store_dir}") - # 初始化处理队列和线程池 + # Initialize processing queue and thread pool self.processing_queue = {} self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) - # 初始化RAG worker + # Initialize RAG worker self.rag_worker = LlamaIndexRagWorker( user_name=user_name, llm_kwargs=llm_kwargs, @@ -88,38 +87,50 @@ class ArxivRagWorker: auto_load_checkpoint=True ) - # 初始化arxiv splitter - # 初始化 arxiv splitter + # Initialize arxiv splitter self.arxiv_splitter = ArxivSplitter( root_dir=str(self.checkpoint_dir / "arxiv_cache") ) - # 初始化处理队列和线程池 - self._semaphore = None - self._loop = None + async def _async_get_fragments(self, arxiv_id: str) -> List[Fragment]: + """Async helper to get fragments""" + return await self.arxiv_splitter.process(arxiv_id) - @property - def loop(self): - """获取当前事件循环""" - if self._loop is None: - self._loop = asyncio.get_event_loop() - return self._loop + def _get_fragments_sync(self, arxiv_id: str) -> List[Fragment]: + """Synchronous wrapper for async fragment retrieval""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self._async_get_fragments(arxiv_id)) + finally: + loop.close() + def _process_single_fragment(self, fragment: Fragment, index: int) -> None: + """Process a single paper fragment""" + try: + text = ( + f"Paper Title: {fragment.title}\n" + f"Abstract: {fragment.abstract}\n" + f"ArXiv ID: {fragment.arxiv_id}\n" + f"Section: {fragment.current_section}\n" + f"Section Tree: {fragment.section_tree}\n" + f"Content: {fragment.content}\n" + f"Bibliography: {fragment.bibliography}\n" + f"Type: FRAGMENT" + ) + logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") + self.rag_worker.add_text_to_vector_store(text) + logger.info(f"Successfully added fragment {index} to vector store") - @property - def semaphore(self): - """延迟创建 semaphore""" - if self._semaphore is None: - self._semaphore = asyncio.Semaphore(self.max_concurrent_papers) - return self._semaphore + except Exception as e: + logger.error(f"Error processing fragment {index}: {str(e)}") + raise - - - async def _process_fragments(self, fragments: List[Fragment]) -> None: - """并行处理论文片段""" + def _process_fragments(self, fragments: List[Fragment]) -> None: + """Process paper fragments in parallel using thread pool""" if not fragments: logger.warning("No fragments to process") return - # 首先添加论文概述 + # First add paper overview overview = { "title": fragments[0].title, "abstract": fragments[0].abstract, @@ -136,39 +147,28 @@ class ArxivRagWorker: ) try: - # 同步添加概述 + # Add overview synchronously self.rag_worker.add_text_to_vector_store(overview_text) logger.info(f"Added paper overview for {overview['arxiv_id']}") - # 创建线程池 + # Process fragments in parallel using thread pool with ThreadPoolExecutor(max_workers=10) as executor: - # 使用 asyncio.gather 收集所有任务 - loop = asyncio.get_event_loop() - tasks = [ - loop.run_in_executor( - executor, - self._process_single_fragment, - fragment, - i - ) + # Submit all fragments for processing + futures = [ + executor.submit(self._process_single_fragment, fragment, i) for i, fragment in enumerate(fragments) ] - # 等待所有任务完成 - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理结果和异常 - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"Error processing fragment {i}: {result}") - else: - # 处理成功的结果 - pass + # Wait for all tasks to complete and handle any exceptions + for future in futures: + try: + future.result() + except Exception as e: + logger.error(f"Error processing fragment: {str(e)}") logger.info(f"Processed {len(fragments)} fragments successfully") - - # 保存到本地文件用于调试 + # Save to local file for debugging save_fragments_to_file( fragments, str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") @@ -178,49 +178,8 @@ class ArxivRagWorker: logger.error(f"Error processing fragments: {str(e)}") raise - async def _process_single_fragment(self, fragment: Fragment, index: int) -> None: - """处理单个论文片段(改为异步方法)""" - try: - text = ( - f"Paper Title: {fragment.title}\n" - f"Abstract: {fragment.abstract}\n" - f"ArXiv ID: {fragment.arxiv_id}\n" - f"Section: {fragment.current_section}\n" - f"Section Tree: {fragment.section_tree}\n" - f"Content: {fragment.content}\n" - f"Bibliography: {fragment.bibliography}\n" - f"Type: FRAGMENT" - ) - - logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") - # 如果 add_text_to_vector_store 是异步的,使用 await - self.rag_worker.add_text_to_vector_store(text) - logger.info(f"Successfully added fragment {index} to vector store") - - except Exception as e: - logger.error(f"Error processing fragment {index}: {str(e)}") - raise """处理单个论文片段""" - try: - text = ( - f"Paper Title: {fragment.title}\n" - f"Abstract: {fragment.abstract}\n" - f"ArXiv ID: {fragment.arxiv_id}\n" - f"Section: {fragment.current_section}\n" - f"Section Tree: {fragment.section_tree}\n" - f"Content: {fragment.content}\n" - f"Bibliography: {fragment.bibliography}\n" - f"Type: FRAGMENT" - ) - logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") - self.rag_worker.add_text_to_vector_store(text) - logger.info(f"Successfully added fragment {index} to vector store") - - except Exception as e: - logger.error(f"Error processing fragment {index}: {str(e)}") - raise - - async def process_paper(self, arxiv_id: str) -> bool: - """处理论文主函数""" + def process_paper(self, arxiv_id: str) -> bool: + """Process paper main function - mixed sync/async version""" try: arxiv_id = self._normalize_arxiv_id(arxiv_id) logger.info(f"Starting to process paper: {arxiv_id}") @@ -231,30 +190,29 @@ class ArxivRagWorker: logger.info(f"Paper {arxiv_id} already processed") return True - # 创建处理任务 + # Create processing task task = ProcessingTask(arxiv_id=arxiv_id) self.processing_queue[arxiv_id] = task task.status = "processing" - async with self.semaphore: - # 下载和分割论文 - fragments = await self.arxiv_splitter.process(arxiv_id) + # Download and split paper using the sync wrapper + fragments = self._get_fragments_sync(arxiv_id) - if not fragments: - raise ValueError(f"No fragments extracted from paper {arxiv_id}") + if not fragments: + raise ValueError(f"No fragments extracted from paper {arxiv_id}") - logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") + logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") - # 处理片段 - await self._process_fragments(fragments) + # Process fragments + self._process_fragments(fragments) - # 标记完成 - paper_path.touch() - task.status = "completed" - task.fragments = fragments + # Mark as completed + paper_path.touch() + task.status = "completed" + task.fragments = fragments - logger.info(f"Successfully processed paper {arxiv_id}") - return True + logger.info(f"Successfully processed paper {arxiv_id}") + return True except Exception as e: logger.error(f"Error processing paper {arxiv_id}: {str(e)}") @@ -262,19 +220,8 @@ class ArxivRagWorker: self.processing_queue[arxiv_id].status = "failed" self.processing_queue[arxiv_id].error = str(e) return False - - def _normalize_arxiv_id(self, input_str: str) -> str: - """规范化ArXiv ID""" - if 'arxiv.org/' in input_str.lower(): - if '/pdf/' in input_str: - arxiv_id = input_str.split('/pdf/')[-1] - else: - arxiv_id = input_str.split('/abs/')[-1] - return arxiv_id.split('v')[0].strip() - return input_str.split('v')[0].strip() - - async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: - """等待论文处理完成""" + def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: + """Wait for paper processing to complete - synchronous version""" try: start_time = datetime.now() while True: @@ -288,16 +235,27 @@ class ArxivRagWorker: if task.status == "failed": return False - # 检查超时 + # Check timeout if (datetime.now() - start_time).total_seconds() > timeout: logger.error(f"Processing paper {arxiv_id} timed out") return False - await asyncio.sleep(0.1) + time.sleep(0.1) except Exception as e: logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}") return False + def _normalize_arxiv_id(self, input_str: str) -> str: + """Normalize ArXiv ID""" + if 'arxiv.org/' in input_str.lower(): + if '/pdf/' in input_str: + arxiv_id = input_str.split('/pdf/')[-1] + else: + arxiv_id = input_str.split('/abs/')[-1] + return arxiv_id.split('v')[0].strip() + return input_str.split('v')[0].strip() + + def retrieve_and_generate(self, query: str) -> str: """检索相关内容并生成提示词""" try: @@ -374,20 +332,10 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: chatbot.append((txt, "正在处理论文,请稍等...")) yield from update_ui(chatbot=chatbot, history=history) - # 创建事件循环来处理异步调用 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - # 运行异步处理函数 - success = loop.run_until_complete(worker.process_paper(txt)) - if success: - arxiv_id = worker._normalize_arxiv_id(txt) - success = loop.run_until_complete(worker.wait_for_paper(arxiv_id)) - if success: - # 执行自动分析 - yield from worker.auto_analyze_paper(chatbot, history, system_prompt) - finally: - loop.close() + success = worker.process_paper(txt) + if success: + arxiv_id = worker._normalize_arxiv_id(txt) + success = worker.wait_for_paper(arxiv_id) if not success: chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")