Files
gpt_academic/crazy_functions/rag_essay_fns/rag_handler.py
lbykkkk 68aa846a89 up
2024-11-10 15:06:50 +08:00

164 lines
5.9 KiB
Python

from typing import Dict, List, Optional
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
import os
from toolbox import get_conf
import openai
class RagHandler:
def __init__(self):
# 初始化工作目录
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
# 初始化 LightRAG
self.rag = LightRAG(
working_dir=self.working_dir,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536, # OpenAI embedding 维度
max_token_size=8192,
func=self._embedding_func,
),
)
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
history_messages: List = None, **kwargs) -> str:
"""LLM 模型函数"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 1000)
)
return response.choices[0].message.content
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
"""Embedding 函数"""
response = await openai.Embedding.acreate(
model="text-embedding-ada-002",
input=texts
)
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
def process_paper_content(self, paper_content: Dict) -> None:
"""处理论文内容,构建知识图谱"""
# 处理标题和摘要
content_list = []
if paper_content['title']:
content_list.append(f"Title: {paper_content['title']}")
if paper_content['abstract']:
content_list.append(f"Abstract: {paper_content['abstract']}")
# 添加分段内容
content_list.extend(paper_content['segments'])
# 插入到 RAG 系统
self.rag.insert(content_list)
def query(self, question: str, mode: str = "hybrid") -> str:
"""查询论文内容
mode: 查询模式,可选 naive/local/global/hybrid
"""
try:
response = self.rag.query(
question,
param=QueryParam(
mode=mode,
top_k=5, # 返回相关度最高的5个结果
max_token_for_text_unit=2048, # 每个文本单元的最大token数
response_type="detailed" # 返回详细回答
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"
class RagHandler:
def __init__(self):
# 初始化工作目录
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
# 初始化 LightRAG
self.rag = LightRAG(
working_dir=self.working_dir,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536, # OpenAI embedding 维度
max_token_size=8192,
func=self._embedding_func,
),
)
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
history_messages: List = None, **kwargs) -> str:
"""LLM 模型函数"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 1000)
)
return response.choices[0].message.content
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
"""Embedding 函数"""
response = await openai.Embedding.acreate(
model="text-embedding-ada-002",
input=texts
)
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
def process_paper_content(self, paper_content: Dict) -> None:
"""处理论文内容,构建知识图谱"""
# 处理标题和摘要
content_list = []
if paper_content['title']:
content_list.append(f"Title: {paper_content['title']}")
if paper_content['abstract']:
content_list.append(f"Abstract: {paper_content['abstract']}")
# 添加分段内容
content_list.extend(paper_content['segments'])
# 插入到 RAG 系统
self.rag.insert(content_list)
def query(self, question: str, mode: str = "hybrid") -> str:
"""查询论文内容
mode: 查询模式,可选 naive/local/global/hybrid
"""
try:
response = self.rag.query(
question,
param=QueryParam(
mode=mode,
top_k=5, # 返回相关度最高的5个结果
max_token_for_text_unit=2048, # 每个文本单元的最大token数
response_type="detailed" # 返回详细回答
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"