up
This commit is contained in:
@@ -46,10 +46,10 @@ class ArxivRagWorker:
|
|||||||
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.llm_kwargs = llm_kwargs
|
self.llm_kwargs = llm_kwargs
|
||||||
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
|
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS # 存储最大并发数
|
||||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||||
|
|
||||||
# Initialize base storage directory
|
# 初始化基础存储目录
|
||||||
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
self.base_dir = Path(get_log_folder(user_name, plugin_name='rag_cache'))
|
||||||
|
|
||||||
if self.arxiv_id:
|
if self.arxiv_id:
|
||||||
@@ -57,6 +57,7 @@ class ArxivRagWorker:
|
|||||||
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
self.vector_store_dir = self.checkpoint_dir / "vector_store"
|
||||||
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
self.fragment_store_dir = self.checkpoint_dir / "fragments"
|
||||||
else:
|
else:
|
||||||
|
# 如果没有 arxiv_id,使用基础目录
|
||||||
self.checkpoint_dir = self.base_dir
|
self.checkpoint_dir = self.base_dir
|
||||||
self.vector_store_dir = self.base_dir / "vector_store"
|
self.vector_store_dir = self.base_dir / "vector_store"
|
||||||
self.fragment_store_dir = self.base_dir / "fragments"
|
self.fragment_store_dir = self.base_dir / "fragments"
|
||||||
@@ -75,11 +76,11 @@ class ArxivRagWorker:
|
|||||||
logger.info(f"Vector store directory: {self.vector_store_dir}")
|
logger.info(f"Vector store directory: {self.vector_store_dir}")
|
||||||
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
|
logger.info(f"Fragment store directory: {self.fragment_store_dir}")
|
||||||
|
|
||||||
# Initialize processing queue and thread pool
|
# 初始化处理队列和线程池
|
||||||
self.processing_queue = {}
|
self.processing_queue = {}
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||||
|
|
||||||
# Initialize RAG worker
|
# 初始化RAG worker
|
||||||
self.rag_worker = LlamaIndexRagWorker(
|
self.rag_worker = LlamaIndexRagWorker(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
@@ -87,24 +88,97 @@ class ArxivRagWorker:
|
|||||||
auto_load_checkpoint=True
|
auto_load_checkpoint=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize arxiv splitter
|
# 初始化arxiv splitter
|
||||||
|
# 初始化 arxiv splitter
|
||||||
self.arxiv_splitter = ArxivSplitter(
|
self.arxiv_splitter = ArxivSplitter(
|
||||||
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
root_dir=str(self.checkpoint_dir / "arxiv_cache")
|
||||||
)
|
)
|
||||||
async def _async_get_fragments(self, arxiv_id: str) -> List[Fragment]:
|
# 初始化处理队列和线程池
|
||||||
"""Async helper to get fragments"""
|
self._semaphore = None
|
||||||
return await self.arxiv_splitter.process(arxiv_id)
|
self._loop = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
"""获取当前事件循环"""
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
return self._loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def semaphore(self):
|
||||||
|
"""延迟创建 semaphore"""
|
||||||
|
if self._semaphore is None:
|
||||||
|
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
|
||||||
|
return self._semaphore
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_fragments(self, fragments: List[Fragment]) -> None:
|
||||||
|
"""并行处理论文片段"""
|
||||||
|
if not fragments:
|
||||||
|
logger.warning("No fragments to process")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 首先添加论文概述
|
||||||
|
overview = {
|
||||||
|
"title": fragments[0].title,
|
||||||
|
"abstract": fragments[0].abstract,
|
||||||
|
"arxiv_id": fragments[0].arxiv_id,
|
||||||
|
"section_tree": fragments[0].section_tree,
|
||||||
|
}
|
||||||
|
|
||||||
|
overview_text = (
|
||||||
|
f"Paper Title: {overview['title']}\n"
|
||||||
|
f"ArXiv ID: {overview['arxiv_id']}\n"
|
||||||
|
f"Abstract: {overview['abstract']}\n"
|
||||||
|
f"Section Tree:{overview['section_tree']}\n"
|
||||||
|
f"Type: OVERVIEW"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_fragments_sync(self, arxiv_id: str) -> List[Fragment]:
|
|
||||||
"""Synchronous wrapper for async fragment retrieval"""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self._async_get_fragments(arxiv_id))
|
# 同步添加概述
|
||||||
finally:
|
self.rag_worker.add_text_to_vector_store(overview_text)
|
||||||
loop.close()
|
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
||||||
def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
|
||||||
"""Process a single paper fragment"""
|
# 并行处理其余片段
|
||||||
|
tasks = []
|
||||||
|
for i, fragment in enumerate(fragments):
|
||||||
|
tasks.append(self._process_single_fragment(fragment, i))
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
logger.info(f"Processed {len(fragments)} fragments successfully")
|
||||||
|
|
||||||
|
|
||||||
|
# 保存到本地文件用于调试
|
||||||
|
save_fragments_to_file(
|
||||||
|
fragments,
|
||||||
|
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing fragments: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _process_single_fragment(self, fragment: Fragment, index: int) -> None:
|
||||||
|
"""处理单个论文片段(改为异步方法)"""
|
||||||
|
try:
|
||||||
|
text = (
|
||||||
|
f"Paper Title: {fragment.title}\n"
|
||||||
|
f"Abstract: {fragment.abstract}\n"
|
||||||
|
f"ArXiv ID: {fragment.arxiv_id}\n"
|
||||||
|
f"Section: {fragment.current_section}\n"
|
||||||
|
f"Section Tree: {fragment.section_tree}\n"
|
||||||
|
f"Content: {fragment.content}\n"
|
||||||
|
f"Bibliography: {fragment.bibliography}\n"
|
||||||
|
f"Type: FRAGMENT"
|
||||||
|
)
|
||||||
|
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
|
||||||
|
# 如果 add_text_to_vector_store 是异步的,使用 await
|
||||||
|
self.rag_worker.add_text_to_vector_store(text)
|
||||||
|
logger.info(f"Successfully added fragment {index} to vector store")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||||
|
raise """处理单个论文片段"""
|
||||||
try:
|
try:
|
||||||
text = (
|
text = (
|
||||||
f"Paper Title: {fragment.title}\n"
|
f"Paper Title: {fragment.title}\n"
|
||||||
@@ -124,62 +198,8 @@ class ArxivRagWorker:
|
|||||||
logger.error(f"Error processing fragment {index}: {str(e)}")
|
logger.error(f"Error processing fragment {index}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _process_fragments(self, fragments: List[Fragment]) -> None:
|
async def process_paper(self, arxiv_id: str) -> bool:
|
||||||
"""Process paper fragments in parallel using thread pool"""
|
"""处理论文主函数"""
|
||||||
if not fragments:
|
|
||||||
logger.warning("No fragments to process")
|
|
||||||
return
|
|
||||||
|
|
||||||
# First add paper overview
|
|
||||||
overview = {
|
|
||||||
"title": fragments[0].title,
|
|
||||||
"abstract": fragments[0].abstract,
|
|
||||||
"arxiv_id": fragments[0].arxiv_id,
|
|
||||||
"section_tree": fragments[0].section_tree,
|
|
||||||
}
|
|
||||||
|
|
||||||
overview_text = (
|
|
||||||
f"Paper Title: {overview['title']}\n"
|
|
||||||
f"ArXiv ID: {overview['arxiv_id']}\n"
|
|
||||||
f"Abstract: {overview['abstract']}\n"
|
|
||||||
f"Section Tree:{overview['section_tree']}\n"
|
|
||||||
f"Type: OVERVIEW"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Add overview synchronously
|
|
||||||
self.rag_worker.add_text_to_vector_store(overview_text)
|
|
||||||
logger.info(f"Added paper overview for {overview['arxiv_id']}")
|
|
||||||
|
|
||||||
# Process fragments in parallel using thread pool
|
|
||||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
|
||||||
# Submit all fragments for processing
|
|
||||||
futures = [
|
|
||||||
executor.submit(self._process_single_fragment, fragment, i)
|
|
||||||
for i, fragment in enumerate(fragments)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Wait for all tasks to complete and handle any exceptions
|
|
||||||
for future in futures:
|
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing fragment: {str(e)}")
|
|
||||||
|
|
||||||
logger.info(f"Processed {len(fragments)} fragments successfully")
|
|
||||||
|
|
||||||
# Save to local file for debugging
|
|
||||||
save_fragments_to_file(
|
|
||||||
fragments,
|
|
||||||
str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json")
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing fragments: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def process_paper(self, arxiv_id: str) -> bool:
|
|
||||||
"""Process paper main function - mixed sync/async version"""
|
|
||||||
try:
|
try:
|
||||||
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
||||||
logger.info(f"Starting to process paper: {arxiv_id}")
|
logger.info(f"Starting to process paper: {arxiv_id}")
|
||||||
@@ -190,29 +210,30 @@ class ArxivRagWorker:
|
|||||||
logger.info(f"Paper {arxiv_id} already processed")
|
logger.info(f"Paper {arxiv_id} already processed")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Create processing task
|
# 创建处理任务
|
||||||
task = ProcessingTask(arxiv_id=arxiv_id)
|
task = ProcessingTask(arxiv_id=arxiv_id)
|
||||||
self.processing_queue[arxiv_id] = task
|
self.processing_queue[arxiv_id] = task
|
||||||
task.status = "processing"
|
task.status = "processing"
|
||||||
|
|
||||||
# Download and split paper using the sync wrapper
|
async with self.semaphore:
|
||||||
fragments = self._get_fragments_sync(arxiv_id)
|
# 下载和分割论文
|
||||||
|
fragments = await self.arxiv_splitter.process(arxiv_id)
|
||||||
|
|
||||||
if not fragments:
|
if not fragments:
|
||||||
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
||||||
|
|
||||||
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
|
logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}")
|
||||||
|
|
||||||
# Process fragments
|
# 处理片段
|
||||||
self._process_fragments(fragments)
|
await self._process_fragments(fragments)
|
||||||
|
|
||||||
# Mark as completed
|
# 标记完成
|
||||||
paper_path.touch()
|
paper_path.touch()
|
||||||
task.status = "completed"
|
task.status = "completed"
|
||||||
task.fragments = fragments
|
task.fragments = fragments
|
||||||
|
|
||||||
logger.info(f"Successfully processed paper {arxiv_id}")
|
logger.info(f"Successfully processed paper {arxiv_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
||||||
@@ -220,8 +241,19 @@ class ArxivRagWorker:
|
|||||||
self.processing_queue[arxiv_id].status = "failed"
|
self.processing_queue[arxiv_id].status = "failed"
|
||||||
self.processing_queue[arxiv_id].error = str(e)
|
self.processing_queue[arxiv_id].error = str(e)
|
||||||
return False
|
return False
|
||||||
def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
|
||||||
"""Wait for paper processing to complete - synchronous version"""
|
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||||
|
"""规范化ArXiv ID"""
|
||||||
|
if 'arxiv.org/' in input_str.lower():
|
||||||
|
if '/pdf/' in input_str:
|
||||||
|
arxiv_id = input_str.split('/pdf/')[-1]
|
||||||
|
else:
|
||||||
|
arxiv_id = input_str.split('/abs/')[-1]
|
||||||
|
return arxiv_id.split('v')[0].strip()
|
||||||
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
|
||||||
|
"""等待论文处理完成"""
|
||||||
try:
|
try:
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
while True:
|
while True:
|
||||||
@@ -235,27 +267,16 @@ class ArxivRagWorker:
|
|||||||
if task.status == "failed":
|
if task.status == "failed":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check timeout
|
# 检查超时
|
||||||
if (datetime.now() - start_time).total_seconds() > timeout:
|
if (datetime.now() - start_time).total_seconds() > timeout:
|
||||||
logger.error(f"Processing paper {arxiv_id} timed out")
|
logger.error(f"Processing paper {arxiv_id} timed out")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
time.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
|
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
|
||||||
"""Normalize ArXiv ID"""
|
|
||||||
if 'arxiv.org/' in input_str.lower():
|
|
||||||
if '/pdf/' in input_str:
|
|
||||||
arxiv_id = input_str.split('/pdf/')[-1]
|
|
||||||
else:
|
|
||||||
arxiv_id = input_str.split('/abs/')[-1]
|
|
||||||
return arxiv_id.split('v')[0].strip()
|
|
||||||
return input_str.split('v')[0].strip()
|
|
||||||
|
|
||||||
|
|
||||||
def retrieve_and_generate(self, query: str) -> str:
|
def retrieve_and_generate(self, query: str) -> str:
|
||||||
"""检索相关内容并生成提示词"""
|
"""检索相关内容并生成提示词"""
|
||||||
try:
|
try:
|
||||||
@@ -332,10 +353,20 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
|||||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
success = worker.process_paper(txt)
|
# 创建事件循环来处理异步调用
|
||||||
if success:
|
loop = asyncio.new_event_loop()
|
||||||
arxiv_id = worker._normalize_arxiv_id(txt)
|
asyncio.set_event_loop(loop)
|
||||||
success = worker.wait_for_paper(arxiv_id)
|
try:
|
||||||
|
# 运行异步处理函数
|
||||||
|
success = loop.run_until_complete(worker.process_paper(txt))
|
||||||
|
if success:
|
||||||
|
arxiv_id = worker._normalize_arxiv_id(txt)
|
||||||
|
success = loop.run_until_complete(worker.wait_for_paper(arxiv_id))
|
||||||
|
if success:
|
||||||
|
# 执行自动分析
|
||||||
|
yield from worker.auto_analyze_paper(chatbot, history, system_prompt)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。")
|
||||||
|
|||||||
Reference in New Issue
Block a user