This commit is contained in:
lbykkkk
2024-11-23 19:40:56 +08:00
parent 241c9641bb
commit 81ab9f91a4

View File

@@ -46,28 +46,27 @@ class ArxivRagWorker:
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None): def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
self.user_name = user_name self.user_name = user_name
self.llm_kwargs = llm_kwargs self.llm_kwargs = llm_kwargs
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数 self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
# 初始化基础存储目录 # Initialize base storage directory
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
# 如果提供了 arxiv_id,创建针对该论文的子目录
if self.arxiv_id: if self.arxiv_id:
self.checkpoint_dir = self.base_dir / self.arxiv_id self.checkpoint_dir = self.base_dir / self.arxiv_id
self.vector_store_dir = self.checkpoint_dir / "vector_store" self.vector_store_dir = self.checkpoint_dir / "vector_store"
self.fragment_store_dir = self.checkpoint_dir / "fragments" self.fragment_store_dir = self.checkpoint_dir / "fragments"
else: else:
# 如果没有 arxiv_id,使用基础目录
self.checkpoint_dir = self.base_dir self.checkpoint_dir = self.base_dir
self.vector_store_dir = self.base_dir / "vector_store" self.vector_store_dir = self.base_dir / "vector_store"
self.fragment_store_dir = self.base_dir / "fragments" self.fragment_store_dir = self.base_dir / "fragments"
# 创建必要的目录
if os.path.exists(self.vector_store_dir): if os.path.exists(self.vector_store_dir):
self.loading = True self.loading = True
else: else:
self.loading = False self.loading = False
# Create necessary directories
self.checkpoint_dir.mkdir(parents=True, exist_ok=True) self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.vector_store_dir.mkdir(parents=True, exist_ok=True) self.vector_store_dir.mkdir(parents=True, exist_ok=True)
self.fragment_store_dir.mkdir(parents=True, exist_ok=True) self.fragment_store_dir.mkdir(parents=True, exist_ok=True)
@@ -76,11 +75,11 @@ class ArxivRagWorker:
logger.info(f"Vector store directory: {self.vector_store_dir}") logger.info(f"Vector store directory: {self.vector_store_dir}")
logger.info(f"Fragment store directory: {self.fragment_store_dir}") logger.info(f"Fragment store directory: {self.fragment_store_dir}")
# 初始化处理队列和线程池 # Initialize processing queue and thread pool
self.processing_queue = {} self.processing_queue = {}
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
# 初始化RAG worker # Initialize RAG worker
self.rag_worker = LlamaIndexRagWorker( self.rag_worker = LlamaIndexRagWorker(
user_name=user_name, user_name=user_name,
llm_kwargs=llm_kwargs, llm_kwargs=llm_kwargs,
@@ -88,38 +87,50 @@ class ArxivRagWorker:
auto_load_checkpoint=True auto_load_checkpoint=True
) )
# 初始化arxiv splitter # Initialize arxiv splitter
# 初始化 arxiv splitter
self.arxiv_splitter = ArxivSplitter( self.arxiv_splitter = ArxivSplitter(
root_dir=str(self.checkpoint_dir / "arxiv_cache") root_dir=str(self.checkpoint_dir / "arxiv_cache")
) )
# 初始化处理队列和线程池 async def _async_get_fragments(self, arxiv_id: str) -> List[Fragment]:
self._semaphore = None """Async helper to get fragments"""
self._loop = None return await self.arxiv_splitter.process(arxiv_id)
@property def _get_fragments_sync(self, arxiv_id: str) -> List[Fragment]:
def loop(self): """Synchronous wrapper for async fragment retrieval"""
"""获取当前事件循环""" loop = asyncio.new_event_loop()
if self._loop is None: asyncio.set_event_loop(loop)
self._loop = asyncio.get_event_loop() try:
return self._loop return loop.run_until_complete(self._async_get_fragments(arxiv_id))
finally:
loop.close()
def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
"""Process a single paper fragment"""
try:
text = (
f"Paper Title: {fragment.title}\n"
f"Abstract: {fragment.abstract}\n"
f"ArXiv ID: {fragment.arxiv_id}\n"
f"Section: {fragment.current_section}\n"
f"Section Tree: {fragment.section_tree}\n"
f"Content: {fragment.content}\n"
f"Bibliography: {fragment.bibliography}\n"
f"Type: FRAGMENT"
)
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
self.rag_worker.add_text_to_vector_store(text)
logger.info(f"Successfully added fragment {index} to vector store")
@property except Exception as e:
def semaphore(self): logger.error(f"Error processing fragment {index}: {str(e)}")
"""延迟创建 semaphore""" raise
if self._semaphore is None:
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
return self._semaphore
def _process_fragments(self, fragments: List[Fragment]) -> None:
"""Process paper fragments in parallel using thread pool"""
async def _process_fragments(self, fragments: List[Fragment]) -> None:
"""并行处理论文片段"""
if not fragments: if not fragments:
logger.warning("No fragments to process") logger.warning("No fragments to process")
return return
# 首先添加论文概述 # First add paper overview
overview = { overview = {
"title": fragments[0].title, "title": fragments[0].title,
"abstract": fragments[0].abstract, "abstract": fragments[0].abstract,
@@ -136,39 +147,28 @@ class ArxivRagWorker:
) )
try: try:
# 同步添加概述 # Add overview synchronously
self.rag_worker.add_text_to_vector_store(overview_text) self.rag_worker.add_text_to_vector_store(overview_text)
logger.info(f"Added paper overview for {overview['arxiv_id']}") logger.info(f"Added paper overview for {overview['arxiv_id']}")
# 创建线程池 # Process fragments in parallel using thread pool
with ThreadPoolExecutor(max_workers=10) as executor: with ThreadPoolExecutor(max_workers=10) as executor:
# 使用 asyncio.gather 收集所有任务 # Submit all fragments for processing
loop = asyncio.get_event_loop() futures = [
tasks = [ executor.submit(self._process_single_fragment, fragment, i)
loop.run_in_executor(
executor,
self._process_single_fragment,
fragment,
i
)
for i, fragment in enumerate(fragments) for i, fragment in enumerate(fragments)
] ]
# 等待所有任务完成 # Wait for all tasks to complete and handle any exceptions
results = await asyncio.gather(*tasks, return_exceptions=True) for future in futures:
try:
# 处理结果和异常 future.result()
for i, result in enumerate(results): except Exception as e:
if isinstance(result, Exception): logger.error(f"Error processing fragment: {str(e)}")
logger.error(f"Error processing fragment {i}: {result}")
else:
# 处理成功的结果
pass
logger.info(f"Processed {len(fragments)} fragments successfully") logger.info(f"Processed {len(fragments)} fragments successfully")
# Save to local file for debugging
# 保存到本地文件用于调试
save_fragments_to_file( save_fragments_to_file(
fragments, fragments,
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
@@ -178,49 +178,8 @@ class ArxivRagWorker:
logger.error(f"Error processing fragments: {str(e)}") logger.error(f"Error processing fragments: {str(e)}")
raise raise
async def _process_single_fragment(self, fragment: Fragment, index: int) -> None: def process_paper(self, arxiv_id: str) -> bool:
"""处理单个论文片段(改为异步方法)""" """Process paper main function - mixed sync/async version"""
try:
text = (
f"Paper Title: {fragment.title}\n"
f"Abstract: {fragment.abstract}\n"
f"ArXiv ID: {fragment.arxiv_id}\n"
f"Section: {fragment.current_section}\n"
f"Section Tree: {fragment.section_tree}\n"
f"Content: {fragment.content}\n"
f"Bibliography: {fragment.bibliography}\n"
f"Type: FRAGMENT"
)
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
# 如果 add_text_to_vector_store 是异步的,使用 await
self.rag_worker.add_text_to_vector_store(text)
logger.info(f"Successfully added fragment {index} to vector store")
except Exception as e:
logger.error(f"Error processing fragment {index}: {str(e)}")
raise """处理单个论文片段"""
try:
text = (
f"Paper Title: {fragment.title}\n"
f"Abstract: {fragment.abstract}\n"
f"ArXiv ID: {fragment.arxiv_id}\n"
f"Section: {fragment.current_section}\n"
f"Section Tree: {fragment.section_tree}\n"
f"Content: {fragment.content}\n"
f"Bibliography: {fragment.bibliography}\n"
f"Type: FRAGMENT"
)
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
self.rag_worker.add_text_to_vector_store(text)
logger.info(f"Successfully added fragment {index} to vector store")
except Exception as e:
logger.error(f"Error processing fragment {index}: {str(e)}")
raise
async def process_paper(self, arxiv_id: str) -> bool:
"""处理论文主函数"""
try: try:
arxiv_id = self._normalize_arxiv_id(arxiv_id) arxiv_id = self._normalize_arxiv_id(arxiv_id)
logger.info(f"Starting to process paper: {arxiv_id}") logger.info(f"Starting to process paper: {arxiv_id}")
@@ -231,24 +190,23 @@ class ArxivRagWorker:
logger.info(f"Paper {arxiv_id} already processed") logger.info(f"Paper {arxiv_id} already processed")
return True return True
# 创建处理任务 # Create processing task
task = ProcessingTask(arxiv_id=arxiv_id) task = ProcessingTask(arxiv_id=arxiv_id)
self.processing_queue[arxiv_id] = task self.processing_queue[arxiv_id] = task
task.status = "processing" task.status = "processing"
async with self.semaphore: # Download and split paper using the sync wrapper
# 下载和分割论文 fragments = self._get_fragments_sync(arxiv_id)
fragments = await self.arxiv_splitter.process(arxiv_id)
if not fragments: if not fragments:
raise ValueError(f"No fragments extracted from paper {arxiv_id}") raise ValueError(f"No fragments extracted from paper {arxiv_id}")
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
# 处理片段 # Process fragments
await self._process_fragments(fragments) self._process_fragments(fragments)
# 标记完成 # Mark as completed
paper_path.touch() paper_path.touch()
task.status = "completed" task.status = "completed"
task.fragments = fragments task.fragments = fragments
@@ -262,19 +220,8 @@ class ArxivRagWorker:
self.processing_queue[arxiv_id].status = "failed" self.processing_queue[arxiv_id].status = "failed"
self.processing_queue[arxiv_id].error = str(e) self.processing_queue[arxiv_id].error = str(e)
return False return False
def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
def _normalize_arxiv_id(self, input_str: str) -> str: """Wait for paper processing to complete - synchronous version"""
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
"""等待论文处理完成"""
try: try:
start_time = datetime.now() start_time = datetime.now()
while True: while True:
@@ -288,16 +235,27 @@ class ArxivRagWorker:
if task.status == "failed": if task.status == "failed":
return False return False
# 检查超时 # Check timeout
if (datetime.now() - start_time).total_seconds() > timeout: if (datetime.now() - start_time).total_seconds() > timeout:
logger.error(f"Processing paper {arxiv_id} timed out") logger.error(f"Processing paper {arxiv_id} timed out")
return False return False
await asyncio.sleep(0.1) time.sleep(0.1)
except Exception as e: except Exception as e:
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}") logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
return False return False
def _normalize_arxiv_id(self, input_str: str) -> str:
"""Normalize ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def retrieve_and_generate(self, query: str) -> str: def retrieve_and_generate(self, query: str) -> str:
"""检索相关内容并生成提示词""" """检索相关内容并生成提示词"""
try: try:
@@ -374,20 +332,10 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
chatbot.append((txt, "正在处理论文,请稍等...")) chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history) yield from update_ui(chatbot=chatbot, history=history)
# 创建事件循环来处理异步调用 success = worker.process_paper(txt)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 运行异步处理函数
success = loop.run_until_complete(worker.process_paper(txt))
if success: if success:
arxiv_id = worker._normalize_arxiv_id(txt) arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id)) success = worker.wait_for_paper(arxiv_id)
if success:
# 执行自动分析
yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
finally:
loop.close()
if not success: if not success:
chatbot[-1] = (txt, "论文处理失败请检查论文ID是否正确或稍后重试。") chatbot[-1] = (txt, "论文处理失败请检查论文ID是否正确或稍后重试。")