360 lines
13 KiB
Python
360 lines
13 KiB
Python
from typing import List, Dict, Optional, Tuple
|
|
import asyncio
|
|
import os
|
|
from datetime import datetime
|
|
from pprint import pprint
|
|
import json
|
|
from loguru import logger
|
|
import numpy as np
|
|
|
|
from core.prompt_templates import PromptTemplates
|
|
from core.extractor import EntityRelationExtractor, PromptInfo
|
|
from core.storage import JsonKVStorage, VectorStorage, NetworkStorage
|
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
|
|
from openai import OpenAI
|
|
|
|
client = OpenAI(api_key=os.getenv("API_KEY"), base_url=os.getenv("API_URL"))
|
|
|
|
|
|
class ExtractionExample:
|
|
"""Example class demonstrating comprehensive RAG system functionality"""
|
|
|
|
def __init__(self):
|
|
"""Initialize RAG system components"""
|
|
# 设置工作目录
|
|
self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
|
os.makedirs(self.working_dir, exist_ok=True)
|
|
logger.info(f"Working directory: {self.working_dir}")
|
|
|
|
# 初始化embedding
|
|
self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1',
|
|
'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
|
'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1}
|
|
self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs)
|
|
|
|
# 初始化提示模板和抽取器
|
|
self.prompt_templates = PromptTemplates()
|
|
self.extractor = EntityRelationExtractor(
|
|
prompt_templates=self.prompt_templates,
|
|
required_prompts = {
|
|
'entity_extraction'
|
|
},
|
|
entity_extract_max_gleaning=1
|
|
|
|
)
|
|
|
|
# 初始化存储系统
|
|
self._init_storage_system()
|
|
|
|
# 对话历史
|
|
self.conversation_history = {}
|
|
|
|
def _init_storage_system(self):
|
|
"""Initialize storage components"""
|
|
# KV存储 - 用于原始文本和分块
|
|
self.text_chunks = JsonKVStorage[dict](
|
|
namespace="text_chunks",
|
|
working_dir=self.working_dir
|
|
)
|
|
|
|
self.full_docs = JsonKVStorage[dict](
|
|
namespace="full_docs",
|
|
working_dir=self.working_dir
|
|
)
|
|
|
|
# 向量存储 - 用于相似度检索
|
|
self.vector_store = VectorStorage(
|
|
namespace="vectors",
|
|
working_dir=self.working_dir,
|
|
llm_kwargs=self.llm_kwargs,
|
|
embedding_func=self.embedding_func,
|
|
meta_fields={"entity_name", "entity_type"}
|
|
)
|
|
|
|
# 图存储 - 用于实体关系
|
|
self.graph_store = NetworkStorage(
|
|
namespace="graph",
|
|
working_dir=self.working_dir
|
|
)
|
|
|
|
async def simulate_llm_call(self, prompt: str, prompt_info: PromptInfo) -> str:
|
|
"""Simulate LLM call with conversation history"""
|
|
# 获取当前chunk的对话历史
|
|
chunk_history = self.conversation_history.get(prompt_info.chunk_key, [])
|
|
|
|
messages = [
|
|
{"role": "system",
|
|
"content": "You are a helpful assistant specialized in entity and relationship extraction."}
|
|
]
|
|
|
|
# 添加历史对话
|
|
for msg in chunk_history:
|
|
messages.append(msg)
|
|
|
|
# 添加当前prompt
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
try:
|
|
# 调用LLM
|
|
response = client.chat.completions.create(
|
|
model="deepseek-chat",
|
|
messages=messages,
|
|
stream=False
|
|
)
|
|
|
|
response_content = response.choices[0].message.content
|
|
|
|
# 更新对话历史
|
|
chunk_history.extend([
|
|
{"role": "user", "content": prompt},
|
|
{"role": "assistant", "content": response_content}
|
|
])
|
|
self.conversation_history[prompt_info.chunk_key] = chunk_history
|
|
|
|
logger.info(f"\nProcessing chunk: {prompt_info.chunk_key}")
|
|
logger.info(f"Phase: {prompt_info.prompt_type}")
|
|
logger.info(f"Response: {response_content[:200]}...")
|
|
|
|
return response_content
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in LLM call: {e}")
|
|
raise
|
|
|
|
async def process_document(self, content: str) -> Tuple[Dict, Dict]:
|
|
"""Process a single document through the RAG pipeline"""
|
|
doc_id = f"doc_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
|
|
# 存储原始文档
|
|
await self.full_docs.upsert({
|
|
doc_id: {"content": content}
|
|
})
|
|
|
|
# 文档分块
|
|
from core.chunking import chunk_document
|
|
chunks = chunk_document(content)
|
|
chunk_dict = {
|
|
f"{doc_id}_chunk_{i}": {"content": chunk, "doc_id": doc_id}
|
|
for i, chunk in enumerate(chunks)
|
|
}
|
|
|
|
# 存储分块
|
|
await self.text_chunks.upsert(chunk_dict)
|
|
|
|
# 处理分块并提取实体关系
|
|
nodes, edges = await self.process_chunk_batch(chunk_dict)
|
|
|
|
return nodes, edges
|
|
|
|
async def process_chunk_batch(self, chunks: Dict[str, dict]):
|
|
"""Process text chunks and store results"""
|
|
try:
|
|
# 向量存储
|
|
logger.info("Adding chunks to vector store...")
|
|
await self.vector_store.upsert(chunks)
|
|
|
|
# 初始化对话历史
|
|
self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()}
|
|
|
|
# 提取实体和关系
|
|
logger.info("Extracting entities and relationships...")
|
|
prompts = self.extractor.initialize_extraction(chunks)
|
|
|
|
while prompts:
|
|
# 处理prompts
|
|
responses = await asyncio.gather(
|
|
*[self.simulate_llm_call(p.prompt, p) for p in prompts]
|
|
)
|
|
|
|
# 处理响应
|
|
next_prompts = []
|
|
for response, prompt_info in zip(responses, prompts):
|
|
next_batch = self.extractor.process_response(response, prompt_info)
|
|
next_prompts.extend(next_batch)
|
|
|
|
prompts = next_prompts
|
|
|
|
# 获取结果
|
|
nodes, edges = self.extractor.get_results()
|
|
|
|
# 存储到图数据库
|
|
logger.info("Storing extracted information in graph database...")
|
|
for node_name, node_instances in nodes.items():
|
|
for node in node_instances:
|
|
await self.graph_store.upsert_node(node_name, node)
|
|
|
|
for (src, tgt), edge_instances in edges.items():
|
|
for edge in edge_instances:
|
|
await self.graph_store.upsert_edge(src, tgt, edge)
|
|
|
|
return nodes, edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in processing chunks: {e}")
|
|
raise
|
|
|
|
async def query_knowledge_base(self, query: str, top_k: int = 5):
|
|
"""Query the knowledge base using various methods"""
|
|
try:
|
|
# 向量相似度搜索
|
|
vector_results = await self.vector_store.query(query, top_k=top_k)
|
|
|
|
# 获取相关文本块
|
|
chunk_ids = [r["id"] for r in vector_results]
|
|
chunks = await self.text_chunks.get_by_ids(chunk_ids)
|
|
|
|
# 获取相关实体
|
|
# 假设query中包含实体名称
|
|
relevant_nodes = []
|
|
for word in query.split():
|
|
if await self.graph_store.has_node(word.upper()):
|
|
node_data = await self.graph_store.get_node(word.upper())
|
|
if node_data:
|
|
relevant_nodes.append(node_data)
|
|
|
|
return {
|
|
"vector_results": vector_results,
|
|
"text_chunks": chunks,
|
|
"relevant_entities": relevant_nodes
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in querying knowledge base: {e}")
|
|
raise
|
|
|
|
def export_knowledge_base(self, export_dir: str):
|
|
"""Export the entire knowledge base"""
|
|
os.makedirs(export_dir, exist_ok=True)
|
|
|
|
try:
|
|
# 导出向量存储
|
|
self.vector_store.vector_store.export_nodes(
|
|
os.path.join(export_dir, "vector_nodes.json"),
|
|
include_embeddings=True
|
|
)
|
|
|
|
# 导出图数据统计
|
|
graph_stats = {
|
|
"total_nodes": len(list(self.graph_store._graph.nodes())),
|
|
"total_edges": len(list(self.graph_store._graph.edges())),
|
|
"node_degrees": dict(self.graph_store._graph.degree()),
|
|
"largest_component_size": len(self.graph_store.get_largest_connected_component())
|
|
}
|
|
|
|
with open(os.path.join(export_dir, "graph_stats.json"), "w") as f:
|
|
json.dump(graph_stats, f, indent=2)
|
|
|
|
# 导出存储统计
|
|
storage_stats = {
|
|
"chunks": len(self.text_chunks._data),
|
|
"docs": len(self.full_docs._data),
|
|
"vector_store": self.vector_store.vector_store.get_statistics()
|
|
}
|
|
|
|
with open(os.path.join(export_dir, "storage_stats.json"), "w") as f:
|
|
json.dump(storage_stats, f, indent=2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in exporting knowledge base: {e}")
|
|
raise
|
|
|
|
def print_extraction_results(self, nodes: Dict[str, List[dict]], edges: Dict[tuple, List[dict]]):
|
|
"""Print extraction results and statistics"""
|
|
print("\nExtracted Entities:")
|
|
print("-" * 50)
|
|
for entity_name, entity_instances in nodes.items():
|
|
print(f"\nEntity: {entity_name}")
|
|
for inst in entity_instances:
|
|
pprint(inst, indent=2)
|
|
|
|
print("\nExtracted Relationships:")
|
|
print("-" * 50)
|
|
for (src, tgt), rel_instances in edges.items():
|
|
print(f"\nRelationship: {src} -> {tgt}")
|
|
for inst in rel_instances:
|
|
pprint(inst, indent=2)
|
|
|
|
print("\nStorage Statistics:")
|
|
print("-" * 50)
|
|
print(f"Working Directory: {self.working_dir}")
|
|
print(f"Number of Documents: {len(self.full_docs._data)}")
|
|
print(f"Number of Chunks: {len(self.text_chunks._data)}")
|
|
print(f"Conversation Turns: {sum(len(h) // 2 for h in self.conversation_history.values())}")
|
|
|
|
# 打印图统计
|
|
print("\nGraph Statistics:")
|
|
print("-" * 50)
|
|
print(f"Total Nodes: {len(list(self.graph_store._graph.nodes()))}")
|
|
print(f"Total Edges: {len(list(self.graph_store._graph.edges()))}")
|
|
|
|
|
|
async def main():
|
|
"""Run comprehensive RAG example"""
|
|
# 测试文档
|
|
documents = {
|
|
"tech_news": """
|
|
Apple Inc. announced new iPhone models today in Cupertino.
|
|
Tim Cook, the CEO, presented the keynote. The presentation highlighted
|
|
the company's commitment to innovation and sustainability. The new iPhone
|
|
features groundbreaking AI capabilities.
|
|
""",
|
|
|
|
# "business_news": """
|
|
# Microsoft and OpenAI expanded their partnership today.
|
|
# Satya Nadella emphasized the importance of AI development while
|
|
# Sam Altman discussed the future of large language models. The collaboration
|
|
# aims to accelerate AI research and deployment.
|
|
# """,
|
|
#
|
|
# "science_paper": """
|
|
# Researchers at DeepMind published a breakthrough paper on quantum computing.
|
|
# The team demonstrated novel approaches to quantum error correction.
|
|
# Dr. Sarah Johnson led the research, collaborating with Google's quantum lab.
|
|
# """
|
|
}
|
|
|
|
try:
|
|
# 创建RAG系统实例
|
|
example = ExtractionExample()
|
|
|
|
# 处理文档
|
|
all_nodes = {}
|
|
all_edges = {}
|
|
|
|
for doc_name, content in documents.items():
|
|
logger.info(f"\nProcessing document: {doc_name}")
|
|
nodes, edges = await example.process_document(content)
|
|
all_nodes.update(nodes)
|
|
all_edges.update(edges)
|
|
|
|
# 打印结果
|
|
example.print_extraction_results(all_nodes, all_edges)
|
|
|
|
# 测试查询
|
|
query = "What are the latest developments in AI?"
|
|
logger.info(f"\nTesting query: {query}")
|
|
results = await example.query_knowledge_base(query)
|
|
|
|
print("\nQuery Results:")
|
|
print("-" * 50)
|
|
pprint(results)
|
|
|
|
# 导出知识库
|
|
export_dir = os.path.join(example.working_dir, "export")
|
|
print("\nExporting knowledge base...")
|
|
logger.info(f"\nExporting knowledge base to: {export_dir}")
|
|
example.export_knowledge_base(export_dir)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main: {e}")
|
|
raise
|
|
|
|
|
|
def run_example():
|
|
"""Run the example"""
|
|
asyncio.run(main())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_example() |