164 lines
5.9 KiB
Python
164 lines
5.9 KiB
Python
from typing import Dict, List, Optional
|
|
from lightrag import LightRAG, QueryParam
|
|
from lightrag.utils import EmbeddingFunc
|
|
import numpy as np
|
|
import os
|
|
from toolbox import get_conf
|
|
import openai
|
|
|
|
class RagHandler:
|
|
def __init__(self):
|
|
# 初始化工作目录
|
|
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
|
|
if not os.path.exists(self.working_dir):
|
|
os.makedirs(self.working_dir)
|
|
|
|
# 初始化 LightRAG
|
|
self.rag = LightRAG(
|
|
working_dir=self.working_dir,
|
|
llm_model_func=self._llm_model_func,
|
|
embedding_func=EmbeddingFunc(
|
|
embedding_dim=1536, # OpenAI embedding 维度
|
|
max_token_size=8192,
|
|
func=self._embedding_func,
|
|
),
|
|
)
|
|
|
|
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
|
|
history_messages: List = None, **kwargs) -> str:
|
|
"""LLM 模型函数"""
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
if history_messages:
|
|
messages.extend(history_messages)
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
response = await openai.ChatCompletion.acreate(
|
|
model="gpt-3.5-turbo",
|
|
messages=messages,
|
|
temperature=kwargs.get("temperature", 0),
|
|
max_tokens=kwargs.get("max_tokens", 1000)
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
|
|
"""Embedding 函数"""
|
|
response = await openai.Embedding.acreate(
|
|
model="text-embedding-ada-002",
|
|
input=texts
|
|
)
|
|
embeddings = [item["embedding"] for item in response["data"]]
|
|
return np.array(embeddings)
|
|
|
|
def process_paper_content(self, paper_content: Dict) -> None:
|
|
"""处理论文内容,构建知识图谱"""
|
|
# 处理标题和摘要
|
|
content_list = []
|
|
if paper_content['title']:
|
|
content_list.append(f"Title: {paper_content['title']}")
|
|
if paper_content['abstract']:
|
|
content_list.append(f"Abstract: {paper_content['abstract']}")
|
|
|
|
# 添加分段内容
|
|
content_list.extend(paper_content['segments'])
|
|
|
|
# 插入到 RAG 系统
|
|
self.rag.insert(content_list)
|
|
|
|
def query(self, question: str, mode: str = "hybrid") -> str:
|
|
"""查询论文内容
|
|
mode: 查询模式,可选 naive/local/global/hybrid
|
|
"""
|
|
try:
|
|
response = self.rag.query(
|
|
question,
|
|
param=QueryParam(
|
|
mode=mode,
|
|
top_k=5, # 返回相关度最高的5个结果
|
|
max_token_for_text_unit=2048, # 每个文本单元的最大token数
|
|
response_type="detailed" # 返回详细回答
|
|
)
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
return f"查询出错: {str(e)}"
|
|
|
|
|
|
class RagHandler:
|
|
def __init__(self):
|
|
# 初始化工作目录
|
|
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
|
|
if not os.path.exists(self.working_dir):
|
|
os.makedirs(self.working_dir)
|
|
|
|
# 初始化 LightRAG
|
|
self.rag = LightRAG(
|
|
working_dir=self.working_dir,
|
|
llm_model_func=self._llm_model_func,
|
|
embedding_func=EmbeddingFunc(
|
|
embedding_dim=1536, # OpenAI embedding 维度
|
|
max_token_size=8192,
|
|
func=self._embedding_func,
|
|
),
|
|
)
|
|
|
|
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
|
|
history_messages: List = None, **kwargs) -> str:
|
|
"""LLM 模型函数"""
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
if history_messages:
|
|
messages.extend(history_messages)
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
response = await openai.ChatCompletion.acreate(
|
|
model="gpt-3.5-turbo",
|
|
messages=messages,
|
|
temperature=kwargs.get("temperature", 0),
|
|
max_tokens=kwargs.get("max_tokens", 1000)
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
|
|
"""Embedding 函数"""
|
|
response = await openai.Embedding.acreate(
|
|
model="text-embedding-ada-002",
|
|
input=texts
|
|
)
|
|
embeddings = [item["embedding"] for item in response["data"]]
|
|
return np.array(embeddings)
|
|
|
|
def process_paper_content(self, paper_content: Dict) -> None:
|
|
"""处理论文内容,构建知识图谱"""
|
|
# 处理标题和摘要
|
|
content_list = []
|
|
if paper_content['title']:
|
|
content_list.append(f"Title: {paper_content['title']}")
|
|
if paper_content['abstract']:
|
|
content_list.append(f"Abstract: {paper_content['abstract']}")
|
|
|
|
# 添加分段内容
|
|
content_list.extend(paper_content['segments'])
|
|
|
|
# 插入到 RAG 系统
|
|
self.rag.insert(content_list)
|
|
|
|
def query(self, question: str, mode: str = "hybrid") -> str:
|
|
"""查询论文内容
|
|
mode: 查询模式,可选 naive/local/global/hybrid
|
|
"""
|
|
try:
|
|
response = self.rag.query(
|
|
question,
|
|
param=QueryParam(
|
|
mode=mode,
|
|
top_k=5, # 返回相关度最高的5个结果
|
|
max_token_for_text_unit=2048, # 每个文本单元的最大token数
|
|
response_type="detailed" # 返回详细回答
|
|
)
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
return f"查询出错: {str(e)}" |