From cbef9a908cf66fc177067b9f27f3dc99fca099d4 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sun, 17 Nov 2024 23:15:34 +0800 Subject: [PATCH] up --- crazy_functional.py | 40 +- crazy_functions/Arxiv_论文对话.py | 455 ++++++++++++++++-- crazy_functions/rag_fns/llama_index_worker.py | 331 +++---------- 3 files changed, 512 insertions(+), 314 deletions(-) diff --git a/crazy_functional.py b/crazy_functional.py index ddac5815..c4df777c 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -15,7 +15,7 @@ def get_crazy_functions(): from crazy_functions.SourceCode_Analyse import 解析一个Rust项目 from crazy_functions.SourceCode_Analyse import 解析一个Java项目 from crazy_functions.SourceCode_Analyse import 解析一个前端项目 - from crazy_functions.Arxiv_论文对话 import Rag论文对话 + from crazy_functions.Arxiv_论文对话 import Arxiv论文对话 from crazy_functions.高级功能函数模板 import 高阶功能模板函数 from crazy_functions.高级功能函数模板 import Demo_Wrap from crazy_functions.Latex全文润色 import Latex英文润色 @@ -31,6 +31,8 @@ def get_crazy_functions(): from crazy_functions.Markdown_Translate import Markdown英译中 from crazy_functions.批量总结PDF文档 import 批量总结PDF文档 from crazy_functions.PDF_Translate import 批量翻译PDF文档 + from crazy_functions.批量文件询问 import 批量文件询问 + from crazy_functions.谷歌检索小助手 import 谷歌检索小助手 from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入 from crazy_functions.Latex全文润色 import Latex中文润色 @@ -74,12 +76,25 @@ def get_crazy_functions(): "Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用 "Class": Arxiv_Localize, # 新一代插件需要注册Class }, - "Rag论文对话": { + "批量文件询问": { "Group": "学术", "Color": "stop", "AsButton": False, - "Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695", - "Function": HotReload(Rag论文对话), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用 + "AdvancedArgs": True, + "Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径", + "ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 " + r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“请总结下面的内容:”,此时,文档内容将添加在“:”后 ", + "Function": HotReload(批量文件询问), + }, + "Arxiv论文对话": { + "Group": "学术", + "Color": "stop", + "AsButton": False, + "AdvancedArgs": True, + "Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径", + "ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 " + r"2、请在下方高级参数区中输入你的prompt,文档中的内容将被添加你的prompt后。3、示例:“这篇文章的方法是什么:” ", + "Function": HotReload(Arxiv论文对话), }, "翻译README或MD": { "Group": "编程", @@ -604,6 +619,23 @@ def get_crazy_functions(): logger.error(trimmed_format_exc()) logger.error("Load function plugin failed") + try: + from crazy_functions.Arxiv_论文对话 import Arxiv论文对话 + + function_plugins.update( + { + "Arxiv论文对话": { + "Group": "对话", + "Color": "stop", + "AsButton": False, + "Info": "将问答数据记录到向量库中,作为长期参考。", + "Function": HotReload(Arxiv论文对话), + }, + } + ) + except: + logger.error(trimmed_format_exc()) + logger.error("Load function plugin failed") diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 86afe134..debc2c77 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -1,63 +1,416 @@ -import os.path +import os +import logging +import asyncio +from pathlib import Path +from typing import List, Optional, Generator, Dict, Union +from datetime import datetime +from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor +import aiohttp -from toolbox import CatchException, update_ui -from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor +from shared_utils.fastapi_server import validate_path_safety +from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg +from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file +from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment as Fragment +from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker +from crazy_functions.crazy_utils import input_clipping +from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive + +# 全局常量配置 +MAX_HISTORY_ROUND = 5 # 最大历史对话轮数 +MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数 +REMEMBER_PREVIEW = 1000 # 记忆预览长度 +VECTOR_STORE_TYPE = "Simple" # 向量存储类型:Simple或Milvus +MAX_CONCURRENT_PAPERS = 5 # 最大并行处理论文数 +MAX_WORKERS = 3 # 最大工作线程数 + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +@dataclass +class ProcessingTask: + """论文处理任务数据类""" + arxiv_id: str + status: str = "pending" # pending, processing, completed, failed + error: Optional[str] = None + fragments: List[Fragment] = None + + +class ArxivRagWorker: + def __init__(self, user_name: str, llm_kwargs: Dict): + self.user_name = user_name + self.llm_kwargs = llm_kwargs + + # 初始化存储目录 + self.checkpoint_dir = Path(get_log_folder(user_name, plugin_name='rag_cache')) + self.vector_store_dir = self.checkpoint_dir / "vector_store" + self.fragment_store_dir = self.checkpoint_dir / "fragments" + + # 创建必要的目录 + self.vector_store_dir.mkdir(parents=True, exist_ok=True) + self.fragment_store_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Vector store directory: {self.vector_store_dir}") + logger.info(f"Fragment store directory: {self.fragment_store_dir}") + + # 初始化RAG worker + self.rag_worker = LlamaIndexRagWorker( + user_name=user_name, + llm_kwargs=llm_kwargs, + checkpoint_dir=str(self.vector_store_dir), + auto_load_checkpoint=True + ) + + # 初始化arxiv splitter + self.arxiv_splitter = ArxivSplitter( + char_range=(1000, 1200), + root_dir=str(self.checkpoint_dir / "arxiv_cache") + ) + + # 初始化并行处理组件 + self.processing_queue = {} + self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_PAPERS) + self.thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) + + async def _process_fragments(self, fragments: List[Fragment]) -> None: + """并行处理论文片段""" + if not fragments: + logger.warning("No fragments to process") + return + + # 首先添加论文概述 + overview = { + "title": fragments[0].title, + "abstract": fragments[0].abstract, + "arxiv_id": fragments[0].arxiv_id, + } + + overview_text = ( + f"Paper Title: {overview['title']}\n" + f"ArXiv ID: {overview['arxiv_id']}\n" + f"Abstract: {overview['abstract']}\n" + f"Type: OVERVIEW" + ) + + try: + # 同步添加概述 + self.rag_worker.add_text_to_vector_store(overview_text) + logger.info(f"Added paper overview for {overview['arxiv_id']}") + + # 并行处理其余片段 + tasks = [] + for i, fragment in enumerate(fragments): + task = asyncio.get_event_loop().run_in_executor( + self.thread_pool, + self._process_single_fragment, + fragment, + i + ) + tasks.append(task) + + await asyncio.gather(*tasks) + logger.info(f"Processed {len(fragments)} fragments successfully") + + # 保存到本地文件用于调试 + save_fragments_to_file( + fragments, + str(self.fragment_store_dir / f"{overview['arxiv_id']}_fragments.json") + ) + + except Exception as e: + logger.error(f"Error processing fragments: {str(e)}") + raise + + def _process_single_fragment(self, fragment: Fragment, index: int) -> None: + """处理单个论文片段""" + try: + text = ( + f"Paper Title: {fragment.title}\n" + f"ArXiv ID: {fragment.arxiv_id}\n" + f"Section: {fragment.section}\n" + f"Fragment Index: {index}\n" + f"Content: {fragment.content}\n" + f"Type: FRAGMENT" + ) + + logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") + self.rag_worker.add_text_to_vector_store(text) + logger.info(f"Successfully added fragment {index} to vector store") + + except Exception as e: + logger.error(f"Error processing fragment {index}: {str(e)}") + raise + + async def process_paper(self, arxiv_id: str) -> bool: + """处理论文主函数""" + try: + arxiv_id = self._normalize_arxiv_id(arxiv_id) + logger.info(f"Starting to process paper: {arxiv_id}") + + paper_path = self.checkpoint_dir / f"{arxiv_id}.processed" + + if paper_path.exists(): + logger.info(f"Paper {arxiv_id} already processed") + return True + + # 创建处理任务 + task = ProcessingTask(arxiv_id=arxiv_id) + self.processing_queue[arxiv_id] = task + task.status = "processing" + + async with self.semaphore: + # 下载和分割论文 + fragments = await self.arxiv_splitter.process(arxiv_id) + + if not fragments: + raise ValueError(f"No fragments extracted from paper {arxiv_id}") + + logger.info(f"Got {len(fragments)} fragments from paper {arxiv_id}") + + # 处理片段 + await self._process_fragments(fragments) + + # 标记完成 + paper_path.touch() + task.status = "completed" + task.fragments = fragments + + logger.info(f"Successfully processed paper {arxiv_id}") + return True + + except Exception as e: + logger.error(f"Error processing paper {arxiv_id}: {str(e)}") + if arxiv_id in self.processing_queue: + self.processing_queue[arxiv_id].status = "failed" + self.processing_queue[arxiv_id].error = str(e) + return False + + def _normalize_arxiv_id(self, input_str: str) -> str: + """规范化ArXiv ID""" + if 'arxiv.org/' in input_str.lower(): + if '/pdf/' in input_str: + arxiv_id = input_str.split('/pdf/')[-1] + else: + arxiv_id = input_str.split('/abs/')[-1] + return arxiv_id.split('v')[0].strip() + return input_str.split('v')[0].strip() + + async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool: + """等待论文处理完成""" + try: + start_time = datetime.now() + while True: + task = self.processing_queue.get(arxiv_id) + if not task: + return False + + if task.status == "completed": + return True + + if task.status == "failed": + return False + + # 检查超时 + if (datetime.now() - start_time).total_seconds() > timeout: + logger.error(f"Processing paper {arxiv_id} timed out") + return False + + await asyncio.sleep(0.1) + except Exception as e: + logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}") + return False + + def retrieve_and_generate(self, query: str) -> str: + """检索相关内容并生成提示词""" + try: + nodes = self.rag_worker.retrieve_from_store_with_query(query) + return self.rag_worker.build_prompt(query=query, nodes=nodes) + except Exception as e: + logger.error(f"Error in retrieve and generate: {str(e)}") + return "" + + def remember_qa(self, question: str, answer: str) -> None: + """记忆问答对""" + try: + self.rag_worker.remember_qa(question, answer) + except Exception as e: + logger.error(f"Error remembering QA: {str(e)}") + + async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None: + """自动分析论文的关键问题""" + key_questions = [ + "What is the main research question or problem addressed in this paper?", + "What methods or approaches did the authors use to investigate the problem?", + "What are the key findings or results presented in the paper?", + "How do the findings of this paper contribute to the broader field or topic of study?", + "What are the limitations of this study, and what future research directions do the authors suggest?" + ] + + results = [] + for question in key_questions: + try: + prompt = self.retrieve_and_generate(question) + if prompt: + response = await request_gpt_model_in_new_thread_with_ui_alive( + inputs=prompt, + inputs_show_user=question, + llm_kwargs=self.llm_kwargs, + chatbot=chatbot, + history=history, + sys_prompt=system_prompt + ) + results.append(f"Q: {question}\nA: {response}\n") + self.remember_qa(question, response) + except Exception as e: + logger.error(f"Error in auto analysis: {str(e)}") + + # 合并所有结果 + summary = "\n\n".join(results) + chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。") @CatchException -def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): +def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, web_port: str) -> Generator: """ - txt: 用户输入,通常是arxiv论文链接 - 功能:RAG论文总结和对话 + Arxiv论文对话主函数 + Args: + txt: arxiv ID/URL + llm_kwargs: LLM配置参数 + plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令 + chatbot: 对话历史 + history: 聊天历史 + system_prompt: 系统提示词 + web_port: Web端口 """ - if_project, if_arxiv = False, False - if os.path.exists(txt): - from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter - splitter = SmartDocumentSplitter( - char_range=(1000, 1200), - max_workers=32 # 可选,默认会根据CPU核心数自动设置 - ) - if_project = True + # 初始化时,提示用户需要 arxiv ID/URL + if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')): + chatbot.append((txt, "请先提供Arxiv论文链接或ID。")) + yield from update_ui(chatbot=chatbot, history=history) + return + + user_name = chatbot.get_user() + worker = ArxivRagWorker(user_name, llm_kwargs) + + # 处理新论文的情况 + if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '1', '2')): + chatbot.append((txt, "正在处理论文,请稍等...")) + yield from update_ui(chatbot=chatbot, history=history) + + # 创建事件循环来处理异步调用 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + # 运行异步处理函数 + success = loop.run_until_complete(worker.process_paper(txt)) + if success: + arxiv_id = worker._normalize_arxiv_id(txt) + success = loop.run_until_complete(worker.wait_for_paper(arxiv_id)) + if success: + # 执行自动分析 + yield from worker.auto_analyze_paper(chatbot, history, system_prompt) + finally: + loop.close() + + if not success: + chatbot[-1] = (txt, "论文处理失败,请检查论文ID是否正确或稍后重试。") + yield from update_ui(chatbot=chatbot, history=history) + return + + yield from update_ui(chatbot=chatbot, history=history) + return + + # 处理用户询问的情况 + # 获取用户询问指令 + user_query = plugin_kwargs.get("advanced_arg", "") + if not user_query: + chatbot.append((txt, "请提供您的问题。")) + yield from update_ui(chatbot=chatbot, history=history) + return + + # 处理历史对话长度 + if len(history) > MAX_HISTORY_ROUND * 2: + history = history[-(MAX_HISTORY_ROUND * 2):] + + # 处理询问指令 + query_clip, history, flags = input_clipping( + user_query, + history, + max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, + return_clip_flags=True + ) + + if flags["original_input_len"] != flags["clipped_input_len"]: + yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0) + if len(user_query) > REMEMBER_PREVIEW: + HALF = REMEMBER_PREVIEW // 2 + query_to_remember = user_query[:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[-HALF:] + else: + query_to_remember = query_clip else: - from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter - splitter = SmartArxivSplitter( - char_range=(1000, 1200), - root_dir="gpt_log/arxiv_cache" - ) - if_arxiv = True - for fragment in splitter.process(txt): - pass - # 初始化处理器 - processor = ArxivPaperProcessor() - rag_handler = RagHandler() + query_to_remember = query_clip - # Step 1: 下载和提取论文 - download_result = processor.download_and_extract(txt, chatbot, history) - project_folder, arxiv_id = None, None - - for result in download_result: - if isinstance(result, tuple) and len(result) == 2: - project_folder, arxiv_id = result - break + chatbot.append((user_query, "正在思考中...")) + yield from update_ui(chatbot=chatbot, history=history) - if not project_folder or not arxiv_id: + # 生成提示词 + prompt = worker.retrieve_and_generate(query_clip) + if not prompt: + chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。") + yield from update_ui(chatbot=chatbot, history=history) return - # Step 2: 合并TEX文件 - paper_content = processor.merge_tex_files(project_folder, chatbot, history) - if not paper_content: - return - - # Step 3: RAG处理 - chatbot.append(["正在构建知识图谱...", "处理中..."]) + # 获取回答 + response = yield from request_gpt_model_in_new_thread_with_ui_alive( + inputs=prompt, + inputs_show_user=query_clip, + llm_kwargs=llm_kwargs, + chatbot=chatbot, + history=history, + sys_prompt=system_prompt + ) + + # 记忆问答对 + worker.remember_qa(query_to_remember, response) + history.extend([user_query, response]) + yield from update_ui(chatbot=chatbot, history=history) - - # 处理论文内容 - rag_handler.process_paper_content(paper_content) - - # 生成初始摘要 - summary = rag_handler.query("请总结这篇论文的主要内容,包括研究目的、方法、结果和结论。") - chatbot.append(["论文摘要", summary]) - yield from update_ui(chatbot=chatbot, history=history) - - # 交互式问答 + +if __name__ == "__main__": + # 测试代码 + llm_kwargs = { + 'api_key': os.getenv("one_api_key"), + 'client_ip': '127.0.0.1', + 'embed_model': 'text-embedding-3-small', + 'llm_model': 'one-api-Qwen2.5-72B-Instruct', + 'max_length': 4096, + 'most_recent_uploaded': None, + 'temperature': 1, + 'top_p': 1 + } + plugin_kwargs = {} + chatbot = [] + history = [] + system_prompt = "You are a helpful assistant." + web_port = "8080" + + # 测试论文导入 + arxiv_url = "https://arxiv.org/abs/2312.12345" + for response in Arxiv论文对话( + arxiv_url, llm_kwargs, plugin_kwargs, + chatbot, history, system_prompt, web_port + ): + print(response) + + # 测试问答 + question = "这篇论文的主要贡献是什么?" + for response in Arxiv论文对话( + question, llm_kwargs, plugin_kwargs, + chatbot, history, system_prompt, web_port + ): + print(response) \ No newline at end of file diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 15909dc0..50a23b03 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,15 +1,18 @@ +import llama_index +import os import atexit -from typing import List, Dict, Optional, Any, Tuple - -from llama_index.core import Document -from llama_index.core.ingestion import run_transformations -from llama_index.core.schema import TextNode, NodeWithScore from loguru import logger - -from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from typing import List +from llama_index.core import Document +from llama_index.core.schema import TextNode from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel -import json -import numpy as np +from shared_utils.connect_void_terminal import get_chat_default_kwargs +from llama_index.core import VectorStoreIndex, SimpleDirectoryReader +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from llama_index.core.ingestion import run_transformations +from llama_index.core import PromptTemplate +from llama_index.core.response_synthesizers import TreeSummarize + DEFAULT_QUERY_GENERATION_PROMPT = """\ Now, you have context information as below: --------------------- @@ -60,7 +63,7 @@ class SaveLoad(): def purge(self): import shutil shutil.rmtree(self.checkpoint_dir, ignore_errors=True) - self.vs_index = self.create_new_vs(self.checkpoint_dir) + self.vs_index = self.create_new_vs() class LlamaIndexRagWorker(SaveLoad): @@ -69,11 +72,61 @@ class LlamaIndexRagWorker(SaveLoad): self.embed_model = OpenAiEmbeddingModel(llm_kwargs) self.user_name = user_name self.checkpoint_dir = checkpoint_dir - if auto_load_checkpoint: - self.vs_index = self.load_from_checkpoint(checkpoint_dir) + + # 确保checkpoint_dir存在 + if checkpoint_dir: + os.makedirs(checkpoint_dir, exist_ok=True) + + logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}") + + # 初始化向量存储 + if auto_load_checkpoint and self.does_checkpoint_exist(): + logger.info("Loading existing vector store from checkpoint") + self.vs_index = self.load_from_checkpoint() else: + logger.info("Creating new vector store") self.vs_index = self.create_new_vs() - atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) + + # 注册退出时保存 + atexit.register(self.save_to_checkpoint) + + def add_text_to_vector_store(self, text: str) -> None: + """添加文本到向量存储""" + try: + logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...") + node = TextNode(text=text) + nodes = run_transformations( + [node], + self.vs_index._transformations, + show_progress=True + ) + self.vs_index.insert_nodes(nodes) + + # 立即保存 + self.save_to_checkpoint() + + if self.debug_mode: + self.inspect_vector_store() + + except Exception as e: + logger.error(f"Error adding text to vector store: {str(e)}") + raise + + def save_to_checkpoint(self, checkpoint_dir=None): + """保存向量存储到检查点""" + try: + if checkpoint_dir is None: + checkpoint_dir = self.checkpoint_dir + logger.info(f'Saving vector store to: {checkpoint_dir}') + if checkpoint_dir: + self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) + logger.info('Vector store saved successfully') + else: + logger.warning('No checkpoint directory specified, skipping save') + except Exception as e: + logger.error(f"Error saving checkpoint: {str(e)}") + raise + def assign_embedding_model(self): pass @@ -82,44 +135,28 @@ class LlamaIndexRagWorker(SaveLoad): # This function is for debugging self.vs_index.storage_context.index_store.to_dict() docstore = self.vs_index.storage_context.docstore.docs - vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) + vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()]) logger.info('\n++ --------inspect_vector_store begin--------') logger.info(vector_store_preview) logger.info('oo --------inspect_vector_store end--------') return vector_store_preview - def add_documents_to_vector_store(self, document_list: List[Document]): - """ - Adds a list of Document objects to the vector store after processing. - """ - documents = document_list + def add_documents_to_vector_store(self, document_list): + documents = [Document(text=t) for t in document_list] documents_nodes = run_transformations( documents, # type: ignore self.vs_index._transformations, show_progress=True ) self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: - self.inspect_vector_store() - - def add_text_to_vector_store(self, text: str): - node = TextNode(text=text) - documents_nodes = run_transformations( - [node], - self.vs_index._transformations, - show_progress=True - ) - self.vs_index.insert_nodes(documents_nodes) - if self.debug_mode: - self.inspect_vector_store() + if self.debug_mode: self.inspect_vector_store() def remember_qa(self, question, answer): formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer) self.add_text_to_vector_store(formatted_str) def retrieve_from_store_with_query(self, query): - if self.debug_mode: - self.inspect_vector_store() + if self.debug_mode: self.inspect_vector_store() retriever = self.vs_index.as_retriever() return retriever.retrieve(query) @@ -128,230 +165,6 @@ class LlamaIndexRagWorker(SaveLoad): return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query) def generate_node_array_preview(self, nodes): - buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) + buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) if self.debug_mode: logger.info(buf) return buf - - def purge_vector_store(self): - """ - Purges the current vector store and creates a new one. - """ - self.purge() - - - - """ - 以下是添加的新方法,原有方法保持不变 - """ - - def add_text_with_metadata(self, text: str, metadata: dict) -> str: - """ - 添加带元数据的文本到向量存储 - - Args: - text: 文本内容 - metadata: 元数据字典 - - Returns: - 添加的节点ID - """ - node = TextNode(text=text, metadata=metadata) - nodes = run_transformations( - [node], - self.vs_index._transformations, - show_progress=True - ) - self.vs_index.insert_nodes(nodes) - return nodes[0].node_id if nodes else None - - def batch_add_texts_with_metadata(self, texts: List[Tuple[str, dict]]) -> List[str]: - """ - 批量添加带元数据的文本 - - Args: - texts: (text, metadata)元组列表 - - Returns: - 添加的节点ID列表 - """ - nodes = [TextNode(text=t, metadata=m) for t, m in texts] - transformed_nodes = run_transformations( - nodes, - self.vs_index._transformations, - show_progress=True - ) - if transformed_nodes: - self.vs_index.insert_nodes(transformed_nodes) - return [node.node_id for node in transformed_nodes] - return [] - - def get_node_metadata(self, node_id: str) -> Optional[dict]: - """ - 获取节点的元数据 - - Args: - node_id: 节点ID - - Returns: - 节点的元数据字典 - """ - node = self.vs_index.storage_context.docstore.docs.get(node_id) - return node.metadata if node else None - - def update_node_metadata(self, node_id: str, metadata: dict, merge: bool = True) -> bool: - """ - 更新节点的元数据 - - Args: - node_id: 节点ID - metadata: 新的元数据 - merge: 是否与现有元数据合并 - - Returns: - 是否更新成功 - """ - docstore = self.vs_index.storage_context.docstore - if node_id in docstore.docs: - node = docstore.docs[node_id] - if merge: - node.metadata.update(metadata) - else: - node.metadata = metadata - return True - return False - - def filter_nodes_by_metadata(self, filters: Dict[str, Any]) -> List[TextNode]: - """ - 按元数据过滤节点 - - Args: - filters: 元数据过滤条件 - - Returns: - 符合条件的节点列表 - """ - docstore = self.vs_index.storage_context.docstore - results = [] - for node in docstore.docs.values(): - if all(node.metadata.get(k) == v for k, v in filters.items()): - results.append(node) - return results - - def retrieve_with_metadata_filter( - self, - query: str, - metadata_filters: Dict[str, Any], - top_k: int = 5 - ) -> List[NodeWithScore]: - """ - 结合元数据过滤的检索 - - Args: - query: 查询文本 - metadata_filters: 元数据过滤条件 - top_k: 返回结果数量 - - Returns: - 检索结果节点列表 - """ - retriever = self.vs_index.as_retriever(similarity_top_k=top_k) - nodes = retriever.retrieve(query) - - # 应用元数据过滤 - filtered_nodes = [] - for node in nodes: - if all(node.metadata.get(k) == v for k, v in metadata_filters.items()): - filtered_nodes.append(node) - - return filtered_nodes - - def get_node_stats(self, node_id: str) -> dict: - """ - 获取单个节点的统计信息 - - Args: - node_id: 节点ID - - Returns: - 节点统计信息字典 - """ - node = self.vs_index.storage_context.docstore.docs.get(node_id) - if not node: - return {} - - return { - "text_length": len(node.text), - "token_count": len(node.text.split()), - "has_embedding": node.embedding is not None, - "metadata_keys": list(node.metadata.keys()), - } - - def get_nodes_by_content_pattern(self, pattern: str) -> List[TextNode]: - """ - 按内容模式查找节点 - - Args: - pattern: 正则表达式模式 - - Returns: - 匹配的节点列表 - """ - import re - docstore = self.vs_index.storage_context.docstore - matched_nodes = [] - for node in docstore.docs.values(): - if re.search(pattern, node.text): - matched_nodes.append(node) - return matched_nodes - def export_nodes( - self, - output_file: str, - format: str = "json", - include_embeddings: bool = False - ) -> None: - """ - Export nodes to file - - Args: - output_file: Output file path - format: "json" or "csv" - include_embeddings: Whether to include embeddings - """ - docstore = self.vs_index.storage_context.docstore - - data = [] - for node_id, node in docstore.docs.items(): - node_data = { - "node_id": node_id, - "text": node.text, - "metadata": node.metadata, - } - if include_embeddings and node.embedding is not None: - node_data["embedding"] = node.embedding.tolist() - data.append(node_data) - - if format == "json": - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - elif format == "csv": - import csv - import pandas as pd - - df = pd.DataFrame(data) - df.to_csv(output_file, index=False, quoting=csv.QUOTE_NONNUMERIC) - - else: - raise ValueError(f"Unsupported format: {format}") - - def get_statistics(self) -> Dict[str, Any]: - """Get vector store statistics""" - docstore = self.vs_index.storage_context.docstore - docs = list(docstore.docs.values()) - - return { - "total_nodes": len(docs), - "total_tokens": sum(len(node.text.split()) for node in docs), - "avg_text_length": np.mean([len(node.text) for node in docs]) if docs else 0, - "embedding_dimension": len(docs[0].embedding) if docs and docs[0].embedding is not None else 0 - } \ No newline at end of file