This commit is contained in:
lbykkkk
2024-11-16 00:35:31 +08:00
parent dd902e9519
commit 21626a44d5
12 changed files with 2385 additions and 1169 deletions

View File

@@ -1,8 +1,8 @@
import os.path import os.path
from toolbox import CatchException, update_ui from toolbox import CatchException, update_ui
from crazy_functions.rag_essay_fns.paper_processing import ArxivPaperProcessor from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor
import asyncio
@CatchException @CatchException
def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
@@ -12,14 +12,14 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
""" """
if_project, if_arxiv = False, False if_project, if_arxiv = False, False
if os.path.exists(txt): if os.path.exists(txt):
from crazy_functions.rag_essay_fns.document_splitter import SmartDocumentSplitter from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter
splitter = SmartDocumentSplitter( splitter = SmartDocumentSplitter(
char_range=(1000, 1200), char_range=(1000, 1200),
max_workers=32 # 可选默认会根据CPU核心数自动设置 max_workers=32 # 可选默认会根据CPU核心数自动设置
) )
if_project = True if_project = True
else: else:
from crazy_functions.rag_essay_fns.arxiv_splitter import SmartArxivSplitter from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter
splitter = SmartArxivSplitter( splitter = SmartArxivSplitter(
char_range=(1000, 1200), char_range=(1000, 1200),
root_dir="gpt_log/arxiv_cache" root_dir="gpt_log/arxiv_cache"
@@ -61,23 +61,3 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
yield from update_ui(chatbot=chatbot, history=history) yield from update_ui(chatbot=chatbot, history=history)
# 交互式问答 # 交互式问答
chatbot.append(["知识图谱构建完成", "您可以开始提问了。支持以下类型的问题:\n1. 论文的具体内容\n2. 研究方法的细节\n3. 实验结果分析\n4. 与其他工作的比较"])
yield from update_ui(chatbot=chatbot, history=history)
# 等待用户提问并回答
while True:
question = yield from wait_user_input()
if not question:
break
# 根据问题类型选择不同的查询模式
if "比较" in question or "关系" in question:
mode = "global" # 使用全局模式处理比较类问题
elif "具体" in question or "细节" in question:
mode = "local" # 使用局部模式处理细节问题
else:
mode = "hybrid" # 默认使用混合模式
response = rag_handler.query(question, mode=mode)
chatbot.append([question, response])
yield from update_ui(chatbot=chatbot, history=history)

View File

@@ -1,510 +0,0 @@
import os
import re
import requests
import tarfile
import logging
from dataclasses import dataclass
from typing import Generator, List, Tuple, Optional, Dict, Set
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
@dataclass
class ArxivFragment:
"""Arxiv论文片段数据类"""
file_path: str
content: str
segment_index: int
total_segments: int
rel_path: str
segment_type: str
title: str
abstract: str
section: str
is_appendix: bool
class SmartArxivSplitter:
def __init__(self,
char_range: Tuple[int, int],
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
max_workers: int = 4):
self.min_chars, self.max_chars = char_range
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.max_workers = max_workers
# 定义特殊环境模式
self._init_patterns()
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
def _init_patterns(self):
"""初始化LaTeX环境和命令模式"""
self.special_envs = {
'math': [r'\\begin{(equation|align|gather|eqnarray)\*?}.*?\\end{\1\*?}',
r'\$\$.*?\$\$', r'\$[^$]+\$'],
'table': [r'\\begin{(table|tabular)\*?}.*?\\end{\1\*?}'],
'figure': [r'\\begin{figure\*?}.*?\\end{figure\*?}'],
'algorithm': [r'\\begin{(algorithm|algorithmic)}.*?\\end{\1}']
}
self.section_patterns = [
r'\\(sub)*section\{([^}]+)\}',
r'\\chapter\{([^}]+)\}'
]
self.include_patterns = [
r'\\(input|include|subfile)\{([^}]+)\}'
]
def _find_main_tex_file(self, directory: str) -> Optional[str]:
"""查找主TEX文件"""
tex_files = list(Path(directory).rglob("*.tex"))
if not tex_files:
return None
# 按以下优先级查找:
# 1. 包含documentclass的文件
# 2. 文件名为main.tex
# 3. 最大的tex文件
for tex_file in tex_files:
try:
content = self._read_file(str(tex_file))
if content and r'\documentclass' in content:
return str(tex_file)
if tex_file.name.lower() == 'main.tex':
return str(tex_file)
except Exception:
continue
return str(max(tex_files, key=lambda x: x.stat().st_size))
def _resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
"""递归解析tex文件中的include/input命令"""
if processed is None:
processed = set()
if tex_file in processed:
return []
processed.add(tex_file)
result = [tex_file]
content = self._read_file(tex_file)
if not content:
return result
base_dir = Path(tex_file).parent
for pattern in self.include_patterns:
for match in re.finditer(pattern, content):
included_file = match.group(2)
if not included_file.endswith('.tex'):
included_file += '.tex'
# 构建完整路径
full_path = str(base_dir / included_file)
if os.path.exists(full_path) and full_path not in processed:
result.extend(self._resolve_includes(full_path, processed))
return result
def _smart_split(self, content: str) -> List[Tuple[str, str, bool]]:
"""智能分割TEX内容确保在字符范围内并保持语义完整性"""
content = self._preprocess_content(content)
segments = []
current_buffer = []
current_length = 0
current_section = "Unknown Section"
is_appendix = False
# 保护特殊环境
protected_blocks = {}
content = self._protect_special_environments(content, protected_blocks)
# 按段落分割
paragraphs = re.split(r'\n\s*\n', content)
for para in paragraphs:
para = para.strip()
if not para:
continue
# 恢复特殊环境
para = self._restore_special_environments(para, protected_blocks)
# 更新章节信息
section_info = self._get_section_info(para, content)
if section_info:
current_section, is_appendix = section_info
# 判断是否是特殊环境
if self._is_special_environment(para):
# 处理当前缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = []
current_length = 0
# 添加特殊环境作为独立片段
segments.append((para, current_section, is_appendix))
continue
# 处理普通段落
sentences = self._split_into_sentences(para)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sent_length = len(sentence)
new_length = current_length + sent_length + (1 if current_buffer else 0)
if new_length <= self.max_chars:
current_buffer.append(sentence)
current_length = new_length
else:
# 如果当前缓冲区达到最小长度要求
if current_length >= self.min_chars:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = [sentence]
current_length = sent_length
else:
# 尝试将过长的句子分割
split_sentences = self._split_long_sentence(sentence)
for split_sent in split_sentences:
if current_length + len(split_sent) <= self.max_chars:
current_buffer.append(split_sent)
current_length += len(split_sent) + 1
else:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
current_buffer = [split_sent]
current_length = len(split_sent)
# 处理剩余的缓冲区
if current_buffer:
segments.append((
'\n'.join(current_buffer),
current_section,
is_appendix
))
return segments
def _split_into_sentences(self, text: str) -> List[str]:
"""将文本分割成句子"""
return re.split(r'(?<=[.!?。!?])\s+', text)
def _split_long_sentence(self, sentence: str) -> List[str]:
"""智能分割过长的句子"""
if len(sentence) <= self.max_chars:
return [sentence]
result = []
while sentence:
# 在最大长度位置寻找合适的分割点
split_pos = self._find_split_position(sentence[:self.max_chars])
if split_pos <= 0:
split_pos = self.max_chars
result.append(sentence[:split_pos])
sentence = sentence[split_pos:].strip()
return result
def _find_split_position(self, text: str) -> int:
"""找到合适的句子分割位置"""
# 优先在标点符号处分割
punctuation_match = re.search(r'[,;]\s*', text[::-1])
if punctuation_match:
return len(text) - punctuation_match.end()
# 其次在空白字符处分割
space_match = re.search(r'\s+', text[::-1])
if space_match:
return len(text) - space_match.end()
return -1
def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""保护特殊环境内容"""
for env_type, patterns in self.special_envs.items():
for pattern in patterns:
content = re.sub(
pattern,
lambda m: self._store_protected_block(m.group(0), protected_blocks),
content,
flags=re.DOTALL
)
return content
def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""存储受保护的内容块"""
placeholder = f"PROTECTED_{len(protected_blocks)}"
protected_blocks[placeholder] = content
return placeholder
def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
"""恢复特殊环境内容"""
for placeholder, original in protected_blocks.items():
content = content.replace(placeholder, original)
return content
def _is_special_environment(self, text: str) -> bool:
"""判断是否是特殊环境"""
for patterns in self.special_envs.values():
for pattern in patterns:
if re.search(pattern, text, re.DOTALL):
return True
return False
def _preprocess_content(self, content: str) -> str:
"""预处理TEX内容"""
# 移除注释
content = re.sub(r'(?m)%.*$', '', content)
# 规范化空白字符
content = re.sub(r'\s+', ' ', content)
content = re.sub(r'\n\s*\n', '\n\n', content)
# 移除不必要的命令
content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content)
return content.strip()
def process(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]:
"""处理单篇arxiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self._download_and_extract(arxiv_id)
# 查找主tex文件
main_tex = self._find_main_tex_file(paper_dir)
if not main_tex:
raise RuntimeError(f"No main tex file found in {paper_dir}")
# 获取所有相关tex文件
tex_files = self._resolve_includes(main_tex)
# 处理所有tex文件
fragments = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_file = {
executor.submit(self._process_single_tex, file_path): file_path
for file_path in tex_files
}
for future in as_completed(future_to_file):
try:
fragments.extend(future.result())
except Exception as e:
logging.error(f"Error processing file: {e}")
# 重新计算片段索引
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
total_fragments = len(fragments)
for i, fragment in enumerate(fragments):
fragment.segment_index = i
fragment.total_segments = total_fragments
yield fragment
except Exception as e:
logging.error(f"Error processing paper {arxiv_id_or_url}: {e}")
raise RuntimeError(f"Failed to process paper: {str(e)}")
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化arxiv ID"""
if input_str.startswith('https://arxiv.org/'):
if '/pdf/' in input_str:
return input_str.split('/pdf/')[-1].split('v')[0]
return input_str.split('/abs/')[-1].split('v')[0]
return input_str.split('v')[0]
def _download_and_extract(self, arxiv_id: str) -> str:
"""下载并解压arxiv论文源码"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
response = requests.get(url, proxies=self.proxies)
if response.status_code == 200:
tar_path.write_bytes(response.content)
with tarfile.open(tar_path, 'r:gz') as tar:
tar.extractall(path=paper_dir)
return str(paper_dir)
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def _read_file(self, file_path: str) -> Optional[str]:
"""使用多种编码尝试读取文件"""
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
logging.warning(f"Failed to read file {file_path} with all encodings")
return None
def _extract_metadata(self, content: str) -> Tuple[str, str]:
"""提取论文标题和摘要"""
title = ""
abstract = ""
# 提取标题
title_patterns = [
r'\\title{([^}]*)}',
r'\\Title{([^}]*)}'
]
for pattern in title_patterns:
match = re.search(pattern, content)
if match:
title = match.group(1)
title = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', title)
break
# 提取摘要
abstract_patterns = [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract{([^}]*)}'
]
for pattern in abstract_patterns:
match = re.search(pattern, content, re.DOTALL)
if match:
abstract = match.group(1).strip()
abstract = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', abstract)
break
return title.strip(), abstract.strip()
def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]:
"""获取段落所属的章节信息"""
section = "Unknown Section"
is_appendix = False
# 查找所有章节标记
all_sections = []
for pattern in self.section_patterns:
for match in re.finditer(pattern, content):
all_sections.append((match.start(), match.group(2)))
# 查找appendix标记
appendix_pos = content.find(r'\appendix')
# 确定当前章节
para_pos = content.find(para)
if para_pos >= 0:
current_section = None
for sec_pos, sec_title in sorted(all_sections):
if sec_pos > para_pos:
break
current_section = sec_title
if current_section:
section = current_section
is_appendix = appendix_pos >= 0 and para_pos > appendix_pos
return section, is_appendix
return None
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TEX文件"""
try:
content = self._read_file(file_path)
if not content:
return []
# 提取元数据
is_main = r'\documentclass' in content
title = ""
abstract = ""
if is_main:
title, abstract = self._extract_metadata(content)
# 智能分割内容
segments = self._smart_split(content)
fragments = []
for i, (segment_content, section, is_appendix) in enumerate(segments):
if segment_content.strip():
segment_type = 'text'
for env_type, patterns in self.special_envs.items():
if any(re.search(pattern, segment_content, re.DOTALL)
for pattern in patterns):
segment_type = env_type
break
fragments.append(ArxivFragment(
file_path=file_path,
content=segment_content,
segment_index=i,
total_segments=len(segments),
rel_path=os.path.relpath(file_path, str(self.root_dir)),
segment_type=segment_type,
title=title,
abstract=abstract,
section=section,
is_appendix=is_appendix
))
return fragments
except Exception as e:
logging.error(f"Error processing file {file_path}: {e}")
return []
def main():
"""使用示例"""
# 创建分割器实例
splitter = SmartArxivSplitter(
char_range=(1000, 1200),
root_dir="gpt_log/arxiv_cache"
)
# 处理论文
for fragment in splitter.process("2411.03663"):
print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}")
print(f"Length: {len(fragment.content)}")
print(f"Section: {fragment.section}")
print(f"Title: {fragment.file_path}")
print(fragment.content)
print("-" * 80)
if __name__ == "__main__":
main()

View File

@@ -1,311 +0,0 @@
from typing import Tuple, Optional, Generator, List
from toolbox import update_ui, update_ui_lastest_msg, get_conf
import os, tarfile, requests, time, re
class ArxivPaperProcessor:
"""Arxiv论文处理器类"""
def __init__(self):
self.supported_encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
self.arxiv_cache_dir = get_conf("ARXIV_CACHE_DIR")
def download_and_extract(self, txt: str, chatbot, history) -> Generator[Optional[Tuple[str, str]], None, None]:
"""
Step 1: 下载和提取arxiv论文
返回: 生成器: (project_folder, arxiv_id)
"""
try:
if txt == "":
chatbot.append(("", "请输入arxiv论文链接或ID"))
yield from update_ui(chatbot=chatbot, history=history)
return
project_folder, arxiv_id = self.arxiv_download(txt, chatbot, history)
if project_folder is None or arxiv_id is None:
return
if not os.path.exists(project_folder):
chatbot.append((txt, f"找不到项目文件夹: {project_folder}"))
yield from update_ui(chatbot=chatbot, history=history)
return
# 期望的返回值
yield project_folder, arxiv_id
except Exception as e:
print(e)
# yield from update_ui_lastest_msg(
# "下载失败请手动下载latex源码请前往arxiv打开此论文下载页面点other Formats然后download source。",
# chatbot=chatbot, history=history)
return
def arxiv_download(self, txt: str, chatbot, history) -> Tuple[str, str]:
"""
下载arxiv论文并提取
返回: (project_folder, arxiv_id)
"""
def is_float(s: str) -> bool:
try:
float(s)
return True
except ValueError:
return False
if txt.startswith('https://arxiv.org/pdf/'):
arxiv_id = txt.split('/')[-1] # 2402.14207v2.pdf
txt = arxiv_id.split('v')[0] # 2402.14207
if ('.' in txt) and ('/' not in txt) and is_float(txt): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt.strip()
if ('.' in txt) and ('/' not in txt) and is_float(txt[:10]): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt[:10]
if not txt.startswith('https://arxiv.org'):
chatbot.append((txt, "不是有效的arxiv链接或ID"))
# yield from update_ui(chatbot=chatbot, history=history)
return None, None # 返回两个值即使其中一个为None
chatbot.append([f"检测到arxiv文档连接", '尝试下载 ...'])
# yield from update_ui(chatbot=chatbot, history=history)
url_ = txt # https://arxiv.org/abs/1707.06690
if not txt.startswith('https://arxiv.org/abs/'):
msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}"
# yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
return None, None # 返回两个值即使其中一个为None
arxiv_id = url_.split('/')[-1].split('v')[0]
dst = os.path.join(self.arxiv_cache_dir, arxiv_id, f'{arxiv_id}.tar.gz')
project_folder = os.path.join(self.arxiv_cache_dir, arxiv_id)
success = self.download_arxiv_paper(url_, dst, chatbot, history)
# if os.path.exists(dst) and get_conf('allow_cache'):
# # yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
# success = True
# else:
# # yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
# success = self.download_arxiv_paper(url_, dst, chatbot, history)
# # yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
if not success:
# chatbot.append([f"下载失败 {arxiv_id}", ""])
# yield from update_ui(chatbot=chatbot, history=history)
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
# yield from update_ui_lastest_msg(f"开始解压 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
extract_dst = self.extract_tar_file(dst, project_folder, chatbot, history)
# yield from update_ui_lastest_msg(f"解压完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
return extract_dst, arxiv_id
def download_arxiv_paper(self, url_: str, dst: str, chatbot, history) -> bool:
"""下载arxiv论文"""
try:
proxies = get_conf('proxies')
for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
r = requests.get(url_tar, proxies=proxies)
if r.status_code == 200:
with open(dst, 'wb+') as f:
f.write(r.content)
return True
return False
except requests.RequestException as e:
# chatbot.append((f"下载失败 {url_}", str(e)))
# yield from update_ui(chatbot=chatbot, history=history)
return False
def extract_tar_file(self, file_path: str, dest_dir: str, chatbot, history) -> str:
"""解压arxiv论文"""
try:
with tarfile.open(file_path, 'r:gz') as tar:
tar.extractall(path=dest_dir)
return dest_dir
except tarfile.ReadError as e:
chatbot.append((f"解压失败 {file_path}", str(e)))
yield from update_ui(chatbot=chatbot, history=history)
raise e
def find_main_tex_file(self, tex_files: list) -> str:
"""查找主TEX文件"""
for tex_file in tex_files:
with open(tex_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
if r'\documentclass' in content:
return tex_file
return max(tex_files, key=lambda x: os.path.getsize(x))
def read_file_with_encoding(self, file_path: str) -> Optional[str]:
"""使用多种编码尝试读取文件"""
for encoding in self.supported_encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
return None
def process_tex_content(self, content: str, base_path: str, processed_files=None) -> str:
"""处理TEX内容包括递归处理包含的文件"""
if processed_files is None:
processed_files = set()
include_patterns = [
r'\\input{([^}]+)}',
r'\\include{([^}]+)}',
r'\\subfile{([^}]+)}',
r'\\input\s+([^\s{]+)',
]
for pattern in include_patterns:
matches = re.finditer(pattern, content)
for match in matches:
include_file = match.group(1)
if not include_file.endswith('.tex'):
include_file += '.tex'
include_path = os.path.join(base_path, include_file)
include_path = os.path.normpath(include_path)
if include_path in processed_files:
continue
processed_files.add(include_path)
if os.path.exists(include_path):
included_content = self.read_file_with_encoding(include_path)
if included_content:
included_content = self.process_tex_content(
included_content,
os.path.dirname(include_path),
processed_files
)
content = content.replace(match.group(0), included_content)
return content
def merge_tex_files(self, folder_path: str, chatbot, history) -> Optional[str]:
"""
Step 2: 合并TEX文件
返回: 合并后的内容
"""
try:
tex_files = []
for root, _, files in os.walk(folder_path):
tex_files.extend([os.path.join(root, f) for f in files if f.endswith('.tex')])
if not tex_files:
chatbot.append(("", "未找到任何TEX文件"))
yield from update_ui(chatbot=chatbot, history=history)
return None
main_tex_file = self.find_main_tex_file(tex_files)
chatbot.append(("", f"找到主TEX文件{os.path.basename(main_tex_file)}"))
yield from update_ui(chatbot=chatbot, history=history)
tex_content = self.read_file_with_encoding(main_tex_file)
if tex_content is None:
chatbot.append(("", "无法读取TEX文件可能是编码问题"))
yield from update_ui(chatbot=chatbot, history=history)
return None
full_content = self.process_tex_content(
tex_content,
os.path.dirname(main_tex_file)
)
cleaned_content = self.clean_tex_content(full_content)
chatbot.append(("",
f"成功处理所有TEX文件\n"
f"- 原始内容大小:{len(full_content)}字符\n"
f"- 清理后内容大小:{len(cleaned_content)}字符"
))
yield from update_ui(chatbot=chatbot, history=history)
# 添加标题和摘要提取
title = ""
abstract = ""
if tex_content:
# 提取标题
title_match = re.search(r'\\title{([^}]*)}', tex_content)
if title_match:
title = title_match.group(1)
# 提取摘要
abstract_match = re.search(r'\\begin{abstract}(.*?)\\end{abstract}',
tex_content, re.DOTALL)
if abstract_match:
abstract = abstract_match.group(1)
# 按token限制分段
def split_by_token_limit(text: str, token_limit: int = 1024) -> List[str]:
segments = []
current_segment = []
current_tokens = 0
for line in text.split('\n'):
line_tokens = len(line.split())
if current_tokens + line_tokens > token_limit:
segments.append('\n'.join(current_segment))
current_segment = [line]
current_tokens = line_tokens
else:
current_segment.append(line)
current_tokens += line_tokens
if current_segment:
segments.append('\n'.join(current_segment))
return segments
text_segments = split_by_token_limit(cleaned_content)
return {
'title': title,
'abstract': abstract,
'segments': text_segments
}
except Exception as e:
chatbot.append(("", f"处理TEX文件时发生错误{str(e)}"))
yield from update_ui(chatbot=chatbot, history=history)
return None
@staticmethod
def clean_tex_content(content: str) -> str:
"""清理TEX内容"""
content = re.sub(r'(?m)%.*$', '', content) # 移除注释
content = re.sub(r'\\cite{[^}]*}', '', content) # 移除引用
content = re.sub(r'\\label{[^}]*}', '', content) # 移除标签
content = re.sub(r'\s+', ' ', content) # 规范化空白
return content.strip()
if __name__ == "__main__":
# 测试 arxiv_download 函数
processor = ArxivPaperProcessor()
chatbot = []
history = []
# 测试不同格式的输入
test_inputs = [
"https://arxiv.org/abs/2402.14207", # 标准格式
"https://arxiv.org/pdf/2402.14207.pdf", # PDF链接格式
"2402.14207", # 纯ID格式
"2402.14207v1", # 带版本号的ID格式
"https://invalid.url", # 无效URL测试
]
for input_url in test_inputs:
print(f"\n测试输入: {input_url}")
try:
project_folder, arxiv_id = processor.arxiv_download(input_url, chatbot, history)
if project_folder and arxiv_id:
print(f"下载成功:")
print(f"- 项目文件夹: {project_folder}")
print(f"- Arxiv ID: {arxiv_id}")
print(f"- 文件夹是否存在: {os.path.exists(project_folder)}")
else:
print("下载失败: 返回值为 None")
except Exception as e:
print(f"发生错误: {str(e)}")

View File

@@ -1,164 +0,0 @@
from typing import Dict, List, Optional
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
import os
from toolbox import get_conf
import openai
class RagHandler:
def __init__(self):
# 初始化工作目录
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
# 初始化 LightRAG
self.rag = LightRAG(
working_dir=self.working_dir,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536, # OpenAI embedding 维度
max_token_size=8192,
func=self._embedding_func,
),
)
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
history_messages: List = None, **kwargs) -> str:
"""LLM 模型函数"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 1000)
)
return response.choices[0].message.content
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
"""Embedding 函数"""
response = await openai.Embedding.acreate(
model="text-embedding-ada-002",
input=texts
)
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
def process_paper_content(self, paper_content: Dict) -> None:
"""处理论文内容,构建知识图谱"""
# 处理标题和摘要
content_list = []
if paper_content['title']:
content_list.append(f"Title: {paper_content['title']}")
if paper_content['abstract']:
content_list.append(f"Abstract: {paper_content['abstract']}")
# 添加分段内容
content_list.extend(paper_content['segments'])
# 插入到 RAG 系统
self.rag.insert(content_list)
def query(self, question: str, mode: str = "hybrid") -> str:
"""查询论文内容
mode: 查询模式,可选 naive/local/global/hybrid
"""
try:
response = self.rag.query(
question,
param=QueryParam(
mode=mode,
top_k=5, # 返回相关度最高的5个结果
max_token_for_text_unit=2048, # 每个文本单元的最大token数
response_type="detailed" # 返回详细回答
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"
class RagHandler:
def __init__(self):
# 初始化工作目录
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
if not os.path.exists(self.working_dir):
os.makedirs(self.working_dir)
# 初始化 LightRAG
self.rag = LightRAG(
working_dir=self.working_dir,
llm_model_func=self._llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1536, # OpenAI embedding 维度
max_token_size=8192,
func=self._embedding_func,
),
)
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
history_messages: List = None, **kwargs) -> str:
"""LLM 模型函数"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=messages,
temperature=kwargs.get("temperature", 0),
max_tokens=kwargs.get("max_tokens", 1000)
)
return response.choices[0].message.content
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
"""Embedding 函数"""
response = await openai.Embedding.acreate(
model="text-embedding-ada-002",
input=texts
)
embeddings = [item["embedding"] for item in response["data"]]
return np.array(embeddings)
def process_paper_content(self, paper_content: Dict) -> None:
"""处理论文内容,构建知识图谱"""
# 处理标题和摘要
content_list = []
if paper_content['title']:
content_list.append(f"Title: {paper_content['title']}")
if paper_content['abstract']:
content_list.append(f"Abstract: {paper_content['abstract']}")
# 添加分段内容
content_list.extend(paper_content['segments'])
# 插入到 RAG 系统
self.rag.insert(content_list)
def query(self, question: str, mode: str = "hybrid") -> str:
"""查询论文内容
mode: 查询模式,可选 naive/local/global/hybrid
"""
try:
response = self.rag.query(
question,
param=QueryParam(
mode=mode,
top_k=5, # 返回相关度最高的5个结果
max_token_for_text_unit=2048, # 每个文本单元的最大token数
response_type="detailed" # 返回详细回答
)
)
return response
except Exception as e:
return f"查询出错: {str(e)}"

View File

@@ -0,0 +1,111 @@
import logging
import requests
import tarfile
from pathlib import Path
from typing import Optional, Dict
class ArxivDownloader:
"""用于下载arXiv论文源码的下载器"""
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
"""
初始化下载器
Args:
root_dir: 保存下载文件的根目录
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(exist_ok=True)
self.proxies = proxies
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def _download_and_extract(self, arxiv_id: str) -> str:
"""
下载并解压arxiv论文源码
Args:
arxiv_id: arXiv论文ID例如"2103.00020"
Returns:
str: 解压后的文件目录路径
Raises:
RuntimeError: 当下载失败时抛出
"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
response = requests.get(url, proxies=self.proxies)
if response.status_code == 200:
tar_path.write_bytes(response.content)
with tarfile.open(tar_path, 'r:gz') as tar:
tar.extractall(path=paper_dir)
return str(paper_dir)
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def download_paper(self, arxiv_id: str) -> str:
"""
下载指定的arXiv论文
Args:
arxiv_id: arXiv论文ID
Returns:
str: 论文文件所在的目录路径
"""
return self._download_and_extract(arxiv_id)
def main():
"""测试下载功能"""
# 配置代理(如果需要)
proxies = {
"http": "http://your-proxy:port",
"https": "https://your-proxy:port"
}
# 创建下载器实例如果不需要代理可以不传入proxies参数
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
# 测试下载一篇论文这里使用一个示例ID
try:
paper_id = "2103.00020" # 这是一个示例ID
paper_dir = downloader.download_paper(paper_id)
print(f"Successfully downloaded paper to: {paper_dir}")
# 检查下载的文件
paper_path = Path(paper_dir)
if paper_path.exists():
print("Downloaded files:")
for file in paper_path.rglob("*"):
if file.is_file():
print(f"- {file.relative_to(paper_path)}")
except Exception as e:
print(f"Error downloading paper: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,55 @@
from dataclasses import dataclass
@dataclass
class ArxivFragment:
"""Arxiv论文片段数据类"""
file_path: str # 文件路径
content: str # 内容
segment_index: int # 片段索引
total_segments: int # 总片段数
rel_path: str # 相对路径
segment_type: str # 片段类型(text/math/table/figure等)
title: str # 论文标题
abstract: str # 论文摘要
section: str # 所属章节
is_appendix: bool # 是否是附录
importance: float = 1.0 # 重要性得分
@staticmethod
def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment':
"""
合并两个片段的静态方法
Args:
seg1: 第一个片段
seg2: 第二个片段
Returns:
ArxivFragment: 合并后的片段
"""
# 合并内容
merged_content = f"{seg1.content}\n{seg2.content}"
# 确定合并后的类型
def _merge_segment_type(type1: str, type2: str) -> str:
if type1 == type2:
return type1
if type1 == 'text':
return type2
if type2 == 'text':
return type1
return 'mixed'
return ArxivFragment(
file_path=seg1.file_path,
content=merged_content,
segment_index=seg1.segment_index,
total_segments=seg1.total_segments - 1,
rel_path=seg1.rel_path,
segment_type=_merge_segment_type(seg1.segment_type, seg2.segment_type),
title=seg1.title,
abstract=seg1.abstract,
section=seg1.section,
is_appendix=seg1.is_appendix,
importance=max(seg1.importance, seg2.importance)
)

View File

@@ -0,0 +1,449 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment
def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"):
"""
将所有fragments保存为单个结构化markdown文件
Args:
fragments: fragment列表
output_dir: 输出目录
"""
from datetime import datetime
from pathlib import Path
import re
# 创建输出目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 生成文件名
filename = f"fragments_{timestamp}.md"
file_path = output_path / filename
current_section = ""
section_count = {} # 用于跟踪每个章节的片段数量
with open(file_path, "w", encoding="utf-8") as f:
# 写入文档头部
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# 如果有标题和摘要,添加到开头
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# 生成目录
f.write("\n## Table of Contents\n")
# 首先收集所有章节信息
sections = {}
for fragment in fragments:
section = fragment.section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
# 写入目录
for section, section_fragments in sections.items():
clean_section = section.strip()
if not clean_section:
clean_section = "Uncategorized"
f.write(
f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n")
# 写入正文内容
f.write("\n## Content\n")
# 按章节组织内容
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# 写入每个fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n")
# 元数据
f.write("**Metadata:**\n")
f.write(f"- Type: {fragment.segment_type}\n")
f.write(f"- Length: {len(fragment.content)} chars\n")
f.write(f"- Importance: {fragment.importance:.2f}\n")
f.write(f"- Is Appendix: {fragment.is_appendix}\n")
f.write(f"- File: {fragment.rel_path}\n")
# 内容
f.write("\n**Content:**\n")
f.write("```tex\n")
f.write(fragment.content)
f.write("\n```\n")
# 添加分隔线
if i < len(section_fragments):
f.write("\n---\n")
# 添加统计信息
f.write("\n## Statistics\n")
f.write("\n### Fragment Type Distribution\n")
type_stats = {}
for fragment in fragments:
type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1
for ftype, count in type_stats.items():
percentage = (count / len(fragments)) * 100
f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n")
# 长度分布
f.write("\n### Length Distribution\n")
lengths = [len(f.content) for f in fragments]
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
print(f"Fragments saved to: {file_path}")
return file_path
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
char_range: Tuple[int, int],
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.min_chars, self.max_chars = char_range
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置但设置上限防止过度并发
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexProcessor(char_range)
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
"""处理单个TeX文件"""
try:
content = self.tex_processor.read_file(file_path)
if not content:
return []
# 提取元数据
is_main = r'\documentclass' in content
title, abstract = "", ""
if is_main:
title, abstract = self.tex_processor.extract_metadata(content)
# 分割内容
segments = self.tex_processor.split_content(content)
fragments = []
for i, (segment_content, section, is_appendix) in enumerate(segments):
if not segment_content.strip():
continue
segment_type = self.tex_processor.detect_segment_type(segment_content)
importance = self.tex_processor.calculate_importance(
segment_content, segment_type, is_main
)
fragments.append(ArxivFragment(
file_path=file_path,
content=segment_content,
segment_index=i,
total_segments=len(segments),
rel_path=str(Path(file_path).relative_to(self.root_dir)),
segment_type=segment_type,
title=title,
abstract=abstract,
section=section,
is_appendix=is_appendix,
importance=importance
))
return fragments
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return []
async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]:
"""处理ArXiv论文"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# 检查缓存
if not self._check_cache(paper_dir):
paper_dir.mkdir(exist_ok=True)
if not await self.download_paper(arxiv_id, paper_dir):
raise RuntimeError(f"Failed to download paper {arxiv_id}")
# 查找主TeX文件
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# 获取所有相关TeX文件
tex_files = self.tex_processor.resolve_includes(main_tex)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# 并行处理所有TeX文件
fragments = []
chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数
loop = asyncio.get_event_loop()
async def process_chunk(chunk_files):
chunk_fragments = []
for file_path in chunk_files:
try:
result = await loop.run_in_executor(None, self._process_single_tex, file_path)
chunk_fragments.extend(result)
except Exception as e:
self.logger.error(f"Error processing {file_path}: {str(e)}")
return chunk_fragments
# 将文件分成多个块
file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)]
# 异步处理每个块
chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks])
for result in chunk_results:
fragments.extend(result)
# 重新计算片段索引并排序
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
total_fragments = len(fragments)
for i, fragment in enumerate(fragments):
fragment.segment_index = i
fragment.total_segments = total_fragments
# 在返回之前添加过滤
fragments = self.tex_processor.filter_fragments(fragments)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
char_range=(800, 1800),
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# 内容检查
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())

View File

@@ -0,0 +1,395 @@
from dataclasses import dataclass, field
@dataclass
class LaTeXPatterns:
"""LaTeX模式存储类用于集中管理所有LaTeX相关的正则表达式模式"""
special_envs = {
'math': [
# 基础数学环境
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
r'\$\$.*?\$\$',
r'\$[^$]+\$',
# 矩阵环境
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
# 数组环境
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
# 其他数学环境
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
],
'table': [
# 基础表格环境
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
# 复杂表格环境
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
# 自定义表格环境
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
# 表格注释环境
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
],
'figure': [
# 图片环境
r'\\begin{figure\*?}.*?\\end{figure\*?}',
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
# 图片插入命令
r'\\includegraphics(\[.*?\])?\{.*?\}',
# tikz 图形环境
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
# 其他图形环境
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
],
'algorithm': [
# 算法环境
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
# 代码块环境
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
# 伪代码环境
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
],
'list': [
# 列表环境
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
# 自定义列表环境
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
],
'theorem': [
# 定理类环境
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
# 其他证明环境
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
],
'box': [
# 文本框环境
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
# 强调环境
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
],
'quote': [
# 引用环境
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
],
'bibliography': [
# 参考文献环境
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
],
'index': [
# 索引环境
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
]
}
# 章节模式
section_patterns = [
# 基础章节命令
r'\\chapter\{([^}]+)\}',
r'\\section\{([^}]+)\}',
r'\\subsection\{([^}]+)\}',
r'\\subsubsection\{([^}]+)\}',
r'\\paragraph\{([^}]+)\}',
r'\\subparagraph\{([^}]+)\}',
# 带星号的变体(不编号)
r'\\chapter\*\{([^}]+)\}',
r'\\section\*\{([^}]+)\}',
r'\\subsection\*\{([^}]+)\}',
r'\\subsubsection\*\{([^}]+)\}',
r'\\paragraph\*\{([^}]+)\}',
r'\\subparagraph\*\{([^}]+)\}',
# 特殊章节
r'\\part\{([^}]+)\}',
r'\\part\*\{([^}]+)\}',
r'\\appendix\{([^}]+)\}',
# 前言部分
r'\\frontmatter\{([^}]+)\}',
r'\\mainmatter\{([^}]+)\}',
r'\\backmatter\{([^}]+)\}',
# 目录相关
r'\\tableofcontents',
r'\\listoffigures',
r'\\listoftables',
# 自定义章节命令
r'\\addchap\{([^}]+)\}', # KOMA-Script类
r'\\addsec\{([^}]+)\}', # KOMA-Script类
r'\\minisec\{([^}]+)\}', # KOMA-Script类
# 带可选参数的章节命令
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
r'\\section\[([^]]+)\]\{([^}]+)\}',
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
]
# 包含模式
include_patterns = [
r'\\(input|include|subfile)\{([^}]+)\}'
]
metadata_patterns = {
# 标题相关
'title': [
r'\\title\{([^}]+)\}',
r'\\Title\{([^}]+)\}',
r'\\doctitle\{([^}]+)\}',
r'\\subtitle\{([^}]+)\}',
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
],
# 摘要相关
'abstract': [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract\{([^}]+)\}',
r'\\begin{摘要}(.*?)\\end{摘要}',
r'\\begin{Summary}(.*?)\\end{Summary}',
r'\\begin{synopsis}(.*?)\\end{synopsis}',
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
],
# 作者信息
'author': [
r'\\author\{([^}]+)\}',
r'\\Author\{([^}]+)\}',
r'\\authorinfo\{([^}]+)\}',
r'\\authors\{([^}]+)\}',
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
r'\\begin{authors}(.*?)\\end{authors}'
],
# 日期相关
'date': [
r'\\date\{([^}]+)\}',
r'\\Date\{([^}]+)\}',
r'\\submitdate\{([^}]+)\}',
r'\\publishdate\{([^}]+)\}',
r'\\revisiondate\{([^}]+)\}'
],
# 关键词
'keywords': [
r'\\keywords\{([^}]+)\}',
r'\\Keywords\{([^}]+)\}',
r'\\begin{keywords}(.*?)\\end{keywords}',
r'\\key\{([^}]+)\}',
r'\\begin{关键词}(.*?)\\end{关键词}'
],
# 机构/单位
'institution': [
r'\\institute\{([^}]+)\}',
r'\\institution\{([^}]+)\}',
r'\\affiliation\{([^}]+)\}',
r'\\organization\{([^}]+)\}',
r'\\department\{([^}]+)\}'
],
# 学科/主题
'subject': [
r'\\subject\{([^}]+)\}',
r'\\Subject\{([^}]+)\}',
r'\\field\{([^}]+)\}',
r'\\discipline\{([^}]+)\}'
],
# 版本信息
'version': [
r'\\version\{([^}]+)\}',
r'\\revision\{([^}]+)\}',
r'\\release\{([^}]+)\}'
],
# 许可证/版权
'license': [
r'\\license\{([^}]+)\}',
r'\\copyright\{([^}]+)\}',
r'\\begin{license}(.*?)\\end{license}'
],
# 联系方式
'contact': [
r'\\email\{([^}]+)\}',
r'\\phone\{([^}]+)\}',
r'\\address\{([^}]+)\}',
r'\\contact\{([^}]+)\}'
],
# 致谢
'acknowledgments': [
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
r'\\acknowledgments\{([^}]+)\}',
r'\\thanks\{([^}]+)\}',
r'\\begin{致谢}(.*?)\\end{致谢}'
],
# 项目/基金
'funding': [
r'\\funding\{([^}]+)\}',
r'\\grant\{([^}]+)\}',
r'\\project\{([^}]+)\}',
r'\\support\{([^}]+)\}'
],
# 分类号/编号
'classification': [
r'\\classification\{([^}]+)\}',
r'\\serialnumber\{([^}]+)\}',
r'\\id\{([^}]+)\}',
r'\\doi\{([^}]+)\}'
],
# 语言
'language': [
r'\\documentlanguage\{([^}]+)\}',
r'\\lang\{([^}]+)\}',
r'\\language\{([^}]+)\}'
]
}
latex_only_patterns = {
# 文档类和包引入
r'\\documentclass(\[.*?\])?\{.*?\}',
r'\\usepackage(\[.*?\])?\{.*?\}',
# 常见的文档设置命令
r'\\setlength\{.*?\}\{.*?\}',
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
# 页面设置相关
r'\\pagestyle\{.*?\}',
r'\\thispagestyle\{.*?\}',
# 其他常见的设置命令
r'\\bibliographystyle\{.*?\}',
r'\\bibliography\{.*?\}',
r'\\setcounter\{.*?\}\{.*?\}',
# 字体和文本设置命令
r'\\makeFNbottom',
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
r'\\renewcommand\\footnoterule\{.*?\}',
r'\\color\{.*?\}',
# 页面和节标题设置
r'\\setcounter\{secnumdepth\}\{.*?\}',
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
# 字体样式设置
r'\\sectionfont\{.*?\}',
r'\\subsectionfont\{.*?\}',
r'\\subsubsectionfont\{.*?\}',
# 间距和布局设置
r'\\setstretch\{.*?\}',
r'\\setlength\{\\skip\\footins\}\{.*?\}',
r'\\setlength\{\\footnotesep\}\{.*?\}',
r'\\setlength\{\\jot\}\{.*?\}',
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
# makeatletter 和 makeatother
r'\\makeatletter\s*',
r'\\makeatother\s*',
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
# r'\\footnotetext\{[^}]*\}', # 普通脚注
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
# 文档结构命令
r'\\begin\{document\}',
r'\\end\{document\}',
r'\\maketitle',
r'\\printbibliography',
r'\\newpage',
# 输入文件命令
r'\\input\{[^}]*\}',
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
# 脚注相关
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
# 致谢环境
r'\\begin\{ack\}',
r'\\end\{ack\}',
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
# 其他文档控制命令
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
}
math_envs = [
# 基础数学环境
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
(r'\$\$.*?\$\$', 'display'), # 行间公式
(r'\$.*?\$', 'inline'), # 行内公式
# 矩阵环境
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
# 数组环境
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
# 多行公式环境
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
# 特殊数学环境
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
# 定理类环境
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
# 数学模式中的表格环境
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
# 其他专业数学环境
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
# 化学方程式环境 (需要加载 mhchem 包)
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
# 物理单位环境 (需要加载 siunitx 包)
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
(r'\\si\{.*?\}', 'si'), # 单位
# 补充环境
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
]
# 示例使用函数
# 使用示例

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
T = TypeVar('T') T = TypeVar('T')
@dataclass @dataclass
class StorageBase: class StorageBase:
"""Base class for all storage implementations""" """Base class for all storage implementations"""
@@ -42,7 +43,7 @@ class JsonKVStorage(StorageBase, Generic[T]):
def __post_init__(self): def __post_init__(self):
"""Initialize storage file and load data""" """Initialize storage file and load data"""
self._file_name = os.path.join(self.working_dir, f"kv_{self.namespace}.json") self._file_name = os.path.join(self.working_dir, f"kv_store_{self.namespace}.json")
self._data: Dict[str, T] = {} self._data: Dict[str, T] = {}
self.load() self.load()
@@ -95,69 +96,138 @@ class JsonKVStorage(StorageBase, Generic[T]):
await self.save() await self.save()
@dataclass @dataclass
class VectorStorage(StorageBase): class VectorStorage(StorageBase):
""" """
Vector storage using LlamaIndex Vector storage using LlamaIndexRagWorker
Attributes: Attributes:
namespace (str): Storage namespace namespace (str): Storage namespace (e.g., 'entities', 'relationships', 'chunks')
working_dir (str): Working directory for storage files working_dir (str): Working directory for storage files
llm_kwargs (dict): LLM configuration llm_kwargs (dict): LLM configuration
embedding_func (OpenAiEmbeddingModel): Embedding function embedding_func (OpenAiEmbeddingModel): Embedding function
meta_fields (Set[str]): Additional fields to store meta_fields (Set[str]): Additional metadata fields to store
cosine_better_than_threshold (float): Similarity threshold
""" """
llm_kwargs: dict llm_kwargs: dict
embedding_func: OpenAiEmbeddingModel embedding_func: OpenAiEmbeddingModel
meta_fields: Set[str] = field(default_factory=set) meta_fields: Set[str] = field(default_factory=set)
cosine_better_than_threshold: float = 0.2
def __post_init__(self): def __post_init__(self):
"""Initialize LlamaIndex worker""" """Initialize LlamaIndex worker"""
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}") # 使用正确的文件命名格式
self._vector_file = os.path.join(self.working_dir, f"vdb_{self.namespace}.json")
# 设置检查点目录
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}_checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# 初始化向量存储
self.vector_store = LlamaIndexRagWorker( self.vector_store = LlamaIndexRagWorker(
user_name=self.namespace, user_name=self.namespace,
llm_kwargs=self.llm_kwargs, llm_kwargs=self.llm_kwargs,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True # 自动加载检查点 auto_load_checkpoint=True
) )
logger.info(f"Initialized vector storage for {self.namespace}")
async def query(self, query: str, top_k: int = 5) -> List[dict]: async def query(self, query: str, top_k: int = 5, metadata_filters: Optional[Dict[str, Any]] = None) -> List[dict]:
""" """
Query vectors by similarity Query vectors by similarity with optional metadata filtering
Args: Args:
query: Query text query: Query text
top_k: Maximum number of results top_k: Maximum number of results to return
metadata_filters: Optional metadata filters
Returns: Returns:
List of similar documents with scores List of similar documents with scores
""" """
nodes = self.vector_store.retrieve_from_store_with_query(query) try:
results = [{ if metadata_filters:
nodes = self.vector_store.retrieve_with_metadata_filter(query, metadata_filters, top_k)
else:
nodes = self.vector_store.retrieve_from_store_with_query(query)[:top_k]
results = []
for node in nodes:
result = {
"id": node.node_id, "id": node.node_id,
"text": node.text, "text": node.text,
"score": node.score, "score": node.score if hasattr(node, 'score') else 0.0,
**{k: getattr(node, k) for k in self.meta_fields if hasattr(node, k)} }
} for node in nodes[:top_k]] # Add metadata fields if they exist and are in meta_fields
return [r for r in results if r.get('score', 0) > self.cosine_better_than_threshold] if hasattr(node, 'metadata'):
result.update({
k: node.metadata[k]
for k in self.meta_fields
if k in node.metadata
})
results.append(result)
return results
except Exception as e:
logger.error(f"Error in vector query: {e}")
raise
async def upsert(self, data: Dict[str, dict]): async def upsert(self, data: Dict[str, dict]):
""" """
Insert or update vectors Insert or update vectors
Args: Args:
data: Dictionary of documents to insert/update data: Dictionary of documents to insert/update with format:
{id: {"content": text, "metadata": dict}}
""" """
for id, item in data.items(): try:
for doc_id, item in data.items():
content = item["content"] content = item["content"]
metadata = {k: item[k] for k in self.meta_fields if k in item} # 提取元数据
self.vector_store.add_text_with_metadata(content, metadata=metadata) metadata = {
k: item[k]
for k in self.meta_fields
if k in item
}
# 添加文档ID到元数据
metadata["doc_id"] = doc_id
# 添加到向量存储
self.vector_store.add_text_with_metadata(content, metadata)
# 导出向量数据到json文件
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error in vector upsert: {e}")
raise
async def save(self):
"""Save vector store to checkpoint and export data"""
try:
# 保存检查点
self.vector_store.save_to_checkpoint()
# 导出向量数据
self.vector_store.export_nodes(
self._vector_file,
format="json",
include_embeddings=True
)
except Exception as e:
logger.error(f"Error saving vector storage: {e}")
raise
async def index_done_callback(self): async def index_done_callback(self):
"""Save after indexing""" """Save after indexing"""
self.vector_store.save_to_checkpoint() await self.save()
def get_statistics(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return self.vector_store.get_statistics()
@dataclass @dataclass
@@ -174,12 +244,15 @@ class NetworkStorage(StorageBase):
"""Initialize graph and storage file""" """Initialize graph and storage file"""
self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml") self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml")
self._graph = self._load_graph() or nx.Graph() self._graph = self._load_graph() or nx.Graph()
logger.info(f"Initialized graph storage for {self.namespace}")
def _load_graph(self) -> Optional[nx.Graph]: def _load_graph(self) -> Optional[nx.Graph]:
"""Load graph from GraphML file""" """Load graph from GraphML file"""
if os.path.exists(self._file_name): if os.path.exists(self._file_name):
try: try:
return nx.read_graphml(self._file_name) graph = nx.read_graphml(self._file_name)
logger.info(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
return graph
except Exception as e: except Exception as e:
logger.error(f"Error loading graph from {self._file_name}: {e}") logger.error(f"Error loading graph from {self._file_name}: {e}")
return None return None
@@ -187,9 +260,14 @@ class NetworkStorage(StorageBase):
async def save_graph(self): async def save_graph(self):
"""Save graph to GraphML file""" """Save graph to GraphML file"""
try:
os.makedirs(os.path.dirname(self._file_name), exist_ok=True) os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
logger.info(f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges") logger.info(
f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
nx.write_graphml(self._graph, self._file_name) nx.write_graphml(self._graph, self._file_name)
except Exception as e:
logger.error(f"Error saving graph: {e}")
raise
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
"""Check if node exists""" """Check if node exists"""
@@ -227,15 +305,15 @@ class NetworkStorage(StorageBase):
async def upsert_node(self, node_id: str, node_data: Dict[str, str]): async def upsert_node(self, node_id: str, node_data: Dict[str, str]):
"""Insert or update node""" """Insert or update node"""
# Clean and normalize node data
cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()} cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()}
self._graph.add_node(node_id, **cleaned_data) self._graph.add_node(node_id, **cleaned_data)
await self.save_graph()
async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]): async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]):
"""Insert or update edge""" """Insert or update edge"""
# Clean and normalize edge data
cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()} cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()}
self._graph.add_edge(source_id, target_id, **cleaned_data) self._graph.add_edge(source_id, target_id, **cleaned_data)
await self.save_graph()
async def index_done_callback(self): async def index_done_callback(self):
"""Save after indexing""" """Save after indexing"""
@@ -253,39 +331,39 @@ class NetworkStorage(StorageBase):
largest_component = max(components, key=len) largest_component = max(components, key=len)
return self._graph.subgraph(largest_component).copy() return self._graph.subgraph(largest_component).copy()
async def embed_nodes(self, algorithm: str, **kwargs) -> Tuple[np.ndarray, List[str]]: async def embed_nodes(
""" self,
Embed nodes using specified algorithm algorithm: str = "node2vec",
dimensions: int = 128,
Args: walk_length: int = 30,
algorithm: Node embedding algorithm name num_walks: int = 200,
**kwargs: Additional algorithm parameters workers: int = 4,
window: int = 10,
Returns: min_count: int = 1,
Tuple of (node embeddings, node IDs) **kwargs
""" ) -> Tuple[np.ndarray, List[str]]:
"""Generate node embeddings using specified algorithm"""
if algorithm == "node2vec": if algorithm == "node2vec":
from node2vec import Node2Vec from node2vec import Node2Vec
# Create node2vec model # Create and train node2vec model
node2vec = Node2Vec( n2v = Node2Vec(
self._graph, self._graph,
dimensions=kwargs.get('dimensions', 128), dimensions=dimensions,
walk_length=kwargs.get('walk_length', 30), walk_length=walk_length,
num_walks=kwargs.get('num_walks', 200), num_walks=num_walks,
workers=kwargs.get('workers', 4) workers=workers
) )
# Train model model = n2v.fit(
model = node2vec.fit( window=window,
window=kwargs.get('window', 10), min_count=min_count
min_count=kwargs.get('min_count', 1)
) )
# Get embeddings # Get embeddings for all nodes
node_ids = list(self._graph.nodes()) node_ids = list(self._graph.nodes())
embeddings = np.array([model.wv[node] for node in node_ids]) embeddings = np.array([model.wv[node] for node in node_ids])
return embeddings, node_ids return embeddings, node_ids
else:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}") raise ValueError(f"Unsupported embedding algorithm: {algorithm}")

View File

@@ -23,25 +23,29 @@ class ExtractionExample:
def __init__(self): def __init__(self):
"""Initialize RAG system components""" """Initialize RAG system components"""
# 设置工作目录 # 设置工作目录
self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" self.working_dir = f"crazy_functions/rag_fns/LightRAG/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
os.makedirs(self.working_dir, exist_ok=True) os.makedirs(self.working_dir, exist_ok=True)
logger.info(f"Working directory: {self.working_dir}") logger.info(f"Working directory: {self.working_dir}")
# 初始化embedding # 初始化embedding
self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1', self.llm_kwargs = {
'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct', 'api_key': os.getenv("one_api_key"),
'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1} 'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small',
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096,
'most_recent_uploaded': None,
'temperature': 1,
'top_p': 1
}
self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs) self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs)
# 初始化提示模板和抽取器 # 初始化提示模板和抽取器
self.prompt_templates = PromptTemplates() self.prompt_templates = PromptTemplates()
self.extractor = EntityRelationExtractor( self.extractor = EntityRelationExtractor(
prompt_templates=self.prompt_templates, prompt_templates=self.prompt_templates,
required_prompts = { required_prompts={'entity_extraction'},
'entity_extraction'
},
entity_extract_max_gleaning=1 entity_extract_max_gleaning=1
) )
# 初始化存储系统 # 初始化存储系统
@@ -63,18 +67,33 @@ class ExtractionExample:
working_dir=self.working_dir working_dir=self.working_dir
) )
# 向量存储 - 用于相似度检索 # 向量存储 - 用于实体、关系和文本块的向量表示
self.vector_store = VectorStorage( self.entities_vdb = VectorStorage(
namespace="vectors", namespace="entities",
working_dir=self.working_dir, working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs, llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name", "entity_type"} meta_fields={"entity_name", "entity_type"}
) )
self.relationships_vdb = VectorStorage(
namespace="relationships",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
)
self.chunks_vdb = VectorStorage(
namespace="chunks",
working_dir=self.working_dir,
llm_kwargs=self.llm_kwargs,
embedding_func=self.embedding_func
)
# 图存储 - 用于实体关系 # 图存储 - 用于实体关系
self.graph_store = NetworkStorage( self.graph_store = NetworkStorage(
namespace="graph", namespace="chunk_entity_relation",
working_dir=self.working_dir working_dir=self.working_dir
) )
@@ -152,7 +171,7 @@ class ExtractionExample:
try: try:
# 向量存储 # 向量存储
logger.info("Adding chunks to vector store...") logger.info("Adding chunks to vector store...")
await self.vector_store.upsert(chunks) await self.chunks_vdb.upsert(chunks)
# 初始化对话历史 # 初始化对话历史
self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()} self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()}
@@ -178,14 +197,32 @@ class ExtractionExample:
# 获取结果 # 获取结果
nodes, edges = self.extractor.get_results() nodes, edges = self.extractor.get_results()
# 存储图数据库 # 存储实体到向量数据库和图数据库
logger.info("Storing extracted information in graph database...")
for node_name, node_instances in nodes.items(): for node_name, node_instances in nodes.items():
for node in node_instances: for node in node_instances:
# 存储到向量数据库
await self.entities_vdb.upsert({
f"entity_{node_name}": {
"content": f"{node_name}: {node['description']}",
"entity_name": node_name,
"entity_type": node['entity_type']
}
})
# 存储到图数据库
await self.graph_store.upsert_node(node_name, node) await self.graph_store.upsert_node(node_name, node)
# 存储关系到向量数据库和图数据库
for (src, tgt), edge_instances in edges.items(): for (src, tgt), edge_instances in edges.items():
for edge in edge_instances: for edge in edge_instances:
# 存储到向量数据库
await self.relationships_vdb.upsert({
f"rel_{src}_{tgt}": {
"content": f"{edge['description']} | {edge['keywords']}",
"src_id": src,
"tgt_id": tgt
}
})
# 存储到图数据库
await self.graph_store.upsert_edge(src, tgt, edge) await self.graph_store.upsert_edge(src, tgt, edge)
return nodes, edges return nodes, edges
@@ -197,26 +234,39 @@ class ExtractionExample:
async def query_knowledge_base(self, query: str, top_k: int = 5): async def query_knowledge_base(self, query: str, top_k: int = 5):
"""Query the knowledge base using various methods""" """Query the knowledge base using various methods"""
try: try:
# 向量相似度搜索 # 向量相似度搜索 - 文本块
vector_results = await self.vector_store.query(query, top_k=top_k) chunk_results = await self.chunks_vdb.query(query, top_k=top_k)
# 向量相似度搜索 - 实体
entity_results = await self.entities_vdb.query(query, top_k=top_k)
# 获取相关文本块 # 获取相关文本块
chunk_ids = [r["id"] for r in vector_results] chunk_ids = [r["id"] for r in chunk_results]
chunks = await self.text_chunks.get_by_ids(chunk_ids) chunks = await self.text_chunks.get_by_ids(chunk_ids)
# 获取相关实体 # 获取实体相关的图结构信息
# 假设query中包含实体名称 relevant_edges = []
relevant_nodes = [] for entity in entity_results:
for word in query.split(): if "entity_name" in entity:
if await self.graph_store.has_node(word.upper()): entity_name = entity["entity_name"]
node_data = await self.graph_store.get_node(word.upper()) if await self.graph_store.has_node(entity_name):
if node_data: edges = await self.graph_store.get_node_edges(entity_name)
relevant_nodes.append(node_data) if edges:
edge_data = []
for edge in edges:
edge_info = await self.graph_store.get_edge(edge[0], edge[1])
if edge_info:
edge_data.append({
"source": edge[0],
"target": edge[1],
"data": edge_info
})
relevant_edges.extend(edge_data)
return { return {
"vector_results": vector_results, "chunks": chunks,
"text_chunks": chunks, "entities": entity_results,
"relevant_entities": relevant_nodes "relationships": relevant_edges
} }
except Exception as e: except Exception as e:
@@ -228,30 +278,27 @@ class ExtractionExample:
os.makedirs(export_dir, exist_ok=True) os.makedirs(export_dir, exist_ok=True)
try: try:
# 导出向量存储 # 导出统计信息
self.vector_store.vector_store.export_nodes( storage_stats = {
os.path.join(export_dir, "vector_nodes.json"), "chunks": {
include_embeddings=True "total": len(self.text_chunks._data),
) "vector_stats": self.chunks_vdb.get_statistics()
},
# 导出图数据统计 "entities": {
graph_stats = { "vector_stats": self.entities_vdb.get_statistics()
},
"relationships": {
"vector_stats": self.relationships_vdb.get_statistics()
},
"graph": {
"total_nodes": len(list(self.graph_store._graph.nodes())), "total_nodes": len(list(self.graph_store._graph.nodes())),
"total_edges": len(list(self.graph_store._graph.edges())), "total_edges": len(list(self.graph_store._graph.edges())),
"node_degrees": dict(self.graph_store._graph.degree()), "node_degrees": dict(self.graph_store._graph.degree()),
"largest_component_size": len(self.graph_store.get_largest_connected_component()) "largest_component_size": len(self.graph_store.get_largest_connected_component())
} }
with open(os.path.join(export_dir, "graph_stats.json"), "w") as f:
json.dump(graph_stats, f, indent=2)
# 导出存储统计
storage_stats = {
"chunks": len(self.text_chunks._data),
"docs": len(self.full_docs._data),
"vector_store": self.vector_store.vector_store.get_statistics()
} }
# 导出统计
with open(os.path.join(export_dir, "storage_stats.json"), "w") as f: with open(os.path.join(export_dir, "storage_stats.json"), "w") as f:
json.dump(storage_stats, f, indent=2) json.dump(storage_stats, f, indent=2)
@@ -299,19 +346,6 @@ async def main():
the company's commitment to innovation and sustainability. The new iPhone the company's commitment to innovation and sustainability. The new iPhone
features groundbreaking AI capabilities. features groundbreaking AI capabilities.
""", """,
# "business_news": """
# Microsoft and OpenAI expanded their partnership today.
# Satya Nadella emphasized the importance of AI development while
# Sam Altman discussed the future of large language models. The collaboration
# aims to accelerate AI research and deployment.
# """,
#
# "science_paper": """
# Researchers at DeepMind published a breakthrough paper on quantum computing.
# The team demonstrated novel approaches to quantum error correction.
# Dr. Sarah Johnson led the research, collaborating with Google's quantum lab.
# """
} }
try: try: