This commit is contained in:
lbykkkk
2024-12-01 22:00:41 +08:00
parent b3aef6b393
commit 3beb22a347
3 changed files with 459 additions and 39 deletions

View File

@@ -12,7 +12,7 @@ from typing import List, Dict, Optional
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
@@ -51,6 +51,7 @@ class ArxivRagWorker:
self.user_name = user_name
self.llm_kwargs = llm_kwargs
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
self.fragments = None
# 初始化基础目录
@@ -63,7 +64,6 @@ class ArxivRagWorker:
self._processing_lock = ThreadLock()
self._processed_fragments = set()
self._processed_count = 0
# 优化的线程池配置
cpu_count = os.cpu_count() or 1
self.thread_pool = ThreadPoolExecutor(
@@ -268,27 +268,18 @@ class ArxivRagWorker:
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
)
async def process_paper(self, arxiv_id: str) -> bool:
async def process_paper(self, fragments: List[Fragment]) -> bool:
"""处理论文主函数"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id)
logger.info(f"Starting to process paper: {arxiv_id}")
if self.paper_path.exists():
logger.info(f"Paper {arxiv_id} already processed")
logger.info(f"Paper {self.arxiv_id} already processed")
return True
task = self._create_processing_task(arxiv_id)
task = self._create_processing_task(self.arxiv_id)
try:
async with self.semaphore:
fragments = await self.arxiv_splitter.process(arxiv_id)
if not fragments:
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
await self._process_fragments(fragments)
self._complete_task(task, fragments, self.paper_path)
return True
@@ -297,7 +288,7 @@ class ArxivRagWorker:
raise
except Exception as e:
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
return False
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
@@ -429,29 +420,28 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
return
user_name = chatbot.get_user()
worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
# 处理新论文的情况
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading:
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history)
arxiv_id = arxiv_worker.arxiv_id
fragments, formatted_content, output_dir = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_worker.arxiv_id)
chatbot.append(["论文下载成功,接下来将编码论文,预计等待两分钟,请耐心等待,论文内容如下:", formatted_content])
try:
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 使用超时控制
success = False
try:
# 设置超时时间为5分钟
success = loop.run_until_complete(
asyncio.wait_for(worker.process_paper(txt), timeout=300)
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
)
if success:
arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(
asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
)
if success:
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
@@ -515,7 +505,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
yield from update_ui(chatbot=chatbot, history=history)
# 生成提示词
prompt = worker.retrieve_and_generate(query_clip)
prompt = arxiv_worker.retrieve_and_generate(query_clip)
if not prompt:
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
yield from update_ui(chatbot=chatbot, history=history)