up
This commit is contained in:
@@ -12,7 +12,7 @@ from typing import List, Dict, Optional
|
||||
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||
@@ -51,6 +51,7 @@ class ArxivRagWorker:
|
||||
self.user_name = user_name
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||
self.fragments = None
|
||||
|
||||
|
||||
# 初始化基础目录
|
||||
@@ -63,7 +64,6 @@ class ArxivRagWorker:
|
||||
self._processing_lock = ThreadLock()
|
||||
self._processed_fragments = set()
|
||||
self._processed_count = 0
|
||||
|
||||
# 优化的线程池配置
|
||||
cpu_count = os.cpu_count() or 1
|
||||
self.thread_pool = ThreadPoolExecutor(
|
||||
@@ -268,27 +268,18 @@ class ArxivRagWorker:
|
||||
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
||||
)
|
||||
|
||||
async def process_paper(self, arxiv_id: str) -> bool:
|
||||
async def process_paper(self, fragments: List[Fragment]) -> bool:
|
||||
"""处理论文主函数"""
|
||||
try:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
||||
logger.info(f"Starting to process paper: {arxiv_id}")
|
||||
|
||||
if self.paper_path.exists():
|
||||
logger.info(f"Paper {arxiv_id} already processed")
|
||||
logger.info(f"Paper {self.arxiv_id} already processed")
|
||||
return True
|
||||
|
||||
task = self._create_processing_task(arxiv_id)
|
||||
|
||||
task = self._create_processing_task(self.arxiv_id)
|
||||
try:
|
||||
async with self.semaphore:
|
||||
fragments = await self.arxiv_splitter.process(arxiv_id)
|
||||
if not fragments:
|
||||
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
||||
|
||||
logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
|
||||
await self._process_fragments(fragments)
|
||||
|
||||
self._complete_task(task, fragments, self.paper_path)
|
||||
return True
|
||||
|
||||
@@ -297,7 +288,7 @@ class ArxivRagWorker:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
||||
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
||||
@@ -429,29 +420,28 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
return
|
||||
|
||||
user_name = chatbot.get_user()
|
||||
worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||
|
||||
# 处理新论文的情况
|
||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading:
|
||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
|
||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
arxiv_id = arxiv_worker.arxiv_id
|
||||
fragments, formatted_content, output_dir = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_worker.arxiv_id)
|
||||
chatbot.append(["论文下载成功,接下来将编码论文,预计等待两分钟,请耐心等待,论文内容如下:", formatted_content])
|
||||
try:
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 使用超时控制
|
||||
success = False
|
||||
try:
|
||||
# 设置超时时间为5分钟
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(worker.process_paper(txt), timeout=300)
|
||||
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
|
||||
)
|
||||
if success:
|
||||
arxiv_id = worker._normalize_arxiv_id(txt)
|
||||
success = loop.run_until_complete(
|
||||
asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
|
||||
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
|
||||
)
|
||||
if success:
|
||||
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
||||
@@ -515,7 +505,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 生成提示词
|
||||
prompt = worker.retrieve_and_generate(query_clip)
|
||||
prompt = arxiv_worker.retrieve_and_generate(query_clip)
|
||||
if not prompt:
|
||||
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
Reference in New Issue
Block a user