up
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import os.path
|
||||
|
||||
from toolbox import CatchException, update_ui
|
||||
from crazy_functions.rag_essay_fns.paper_processing import ArxivPaperProcessor
|
||||
import asyncio
|
||||
from crazy_functions.rag_fns.arxiv_fns.paper_processing import ArxivPaperProcessor
|
||||
|
||||
|
||||
@CatchException
|
||||
def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||
@@ -12,14 +12,14 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
"""
|
||||
if_project, if_arxiv = False, False
|
||||
if os.path.exists(txt):
|
||||
from crazy_functions.rag_essay_fns.document_splitter import SmartDocumentSplitter
|
||||
from crazy_functions.rag_fns.doc_fns.document_splitter import SmartDocumentSplitter
|
||||
splitter = SmartDocumentSplitter(
|
||||
char_range=(1000, 1200),
|
||||
max_workers=32 # 可选,默认会根据CPU核心数自动设置
|
||||
)
|
||||
if_project = True
|
||||
else:
|
||||
from crazy_functions.rag_essay_fns.arxiv_splitter import SmartArxivSplitter
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import SmartArxivSplitter
|
||||
splitter = SmartArxivSplitter(
|
||||
char_range=(1000, 1200),
|
||||
root_dir="gpt_log/arxiv_cache"
|
||||
@@ -61,23 +61,3 @@ def Rag论文对话(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 交互式问答
|
||||
chatbot.append(["知识图谱构建完成", "您可以开始提问了。支持以下类型的问题:\n1. 论文的具体内容\n2. 研究方法的细节\n3. 实验结果分析\n4. 与其他工作的比较"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 等待用户提问并回答
|
||||
while True:
|
||||
question = yield from wait_user_input()
|
||||
if not question:
|
||||
break
|
||||
|
||||
# 根据问题类型选择不同的查询模式
|
||||
if "比较" in question or "关系" in question:
|
||||
mode = "global" # 使用全局模式处理比较类问题
|
||||
elif "具体" in question or "细节" in question:
|
||||
mode = "local" # 使用局部模式处理细节问题
|
||||
else:
|
||||
mode = "hybrid" # 默认使用混合模式
|
||||
|
||||
response = rag_handler.query(question, mode=mode)
|
||||
chatbot.append([question, response])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -1,510 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import tarfile
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator, List, Tuple, Optional, Dict, Set
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArxivFragment:
|
||||
"""Arxiv论文片段数据类"""
|
||||
file_path: str
|
||||
content: str
|
||||
segment_index: int
|
||||
total_segments: int
|
||||
rel_path: str
|
||||
segment_type: str
|
||||
title: str
|
||||
abstract: str
|
||||
section: str
|
||||
is_appendix: bool
|
||||
|
||||
|
||||
class SmartArxivSplitter:
|
||||
def __init__(self,
|
||||
char_range: Tuple[int, int],
|
||||
root_dir: str = "gpt_log/arxiv_cache",
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
max_workers: int = 4):
|
||||
|
||||
self.min_chars, self.max_chars = char_range
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.proxies = proxies or {}
|
||||
self.max_workers = max_workers
|
||||
|
||||
# 定义特殊环境模式
|
||||
self._init_patterns()
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def _init_patterns(self):
|
||||
"""初始化LaTeX环境和命令模式"""
|
||||
self.special_envs = {
|
||||
'math': [r'\\begin{(equation|align|gather|eqnarray)\*?}.*?\\end{\1\*?}',
|
||||
r'\$\$.*?\$\$', r'\$[^$]+\$'],
|
||||
'table': [r'\\begin{(table|tabular)\*?}.*?\\end{\1\*?}'],
|
||||
'figure': [r'\\begin{figure\*?}.*?\\end{figure\*?}'],
|
||||
'algorithm': [r'\\begin{(algorithm|algorithmic)}.*?\\end{\1}']
|
||||
}
|
||||
|
||||
self.section_patterns = [
|
||||
r'\\(sub)*section\{([^}]+)\}',
|
||||
r'\\chapter\{([^}]+)\}'
|
||||
]
|
||||
|
||||
self.include_patterns = [
|
||||
r'\\(input|include|subfile)\{([^}]+)\}'
|
||||
]
|
||||
|
||||
def _find_main_tex_file(self, directory: str) -> Optional[str]:
|
||||
"""查找主TEX文件"""
|
||||
tex_files = list(Path(directory).rglob("*.tex"))
|
||||
if not tex_files:
|
||||
return None
|
||||
|
||||
# 按以下优先级查找:
|
||||
# 1. 包含documentclass的文件
|
||||
# 2. 文件名为main.tex
|
||||
# 3. 最大的tex文件
|
||||
for tex_file in tex_files:
|
||||
try:
|
||||
content = self._read_file(str(tex_file))
|
||||
if content and r'\documentclass' in content:
|
||||
return str(tex_file)
|
||||
if tex_file.name.lower() == 'main.tex':
|
||||
return str(tex_file)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return str(max(tex_files, key=lambda x: x.stat().st_size))
|
||||
|
||||
def _resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
|
||||
"""递归解析tex文件中的include/input命令"""
|
||||
if processed is None:
|
||||
processed = set()
|
||||
|
||||
if tex_file in processed:
|
||||
return []
|
||||
|
||||
processed.add(tex_file)
|
||||
result = [tex_file]
|
||||
content = self._read_file(tex_file)
|
||||
|
||||
if not content:
|
||||
return result
|
||||
|
||||
base_dir = Path(tex_file).parent
|
||||
for pattern in self.include_patterns:
|
||||
for match in re.finditer(pattern, content):
|
||||
included_file = match.group(2)
|
||||
if not included_file.endswith('.tex'):
|
||||
included_file += '.tex'
|
||||
|
||||
# 构建完整路径
|
||||
full_path = str(base_dir / included_file)
|
||||
if os.path.exists(full_path) and full_path not in processed:
|
||||
result.extend(self._resolve_includes(full_path, processed))
|
||||
|
||||
return result
|
||||
|
||||
def _smart_split(self, content: str) -> List[Tuple[str, str, bool]]:
|
||||
"""智能分割TEX内容,确保在字符范围内并保持语义完整性"""
|
||||
content = self._preprocess_content(content)
|
||||
segments = []
|
||||
current_buffer = []
|
||||
current_length = 0
|
||||
current_section = "Unknown Section"
|
||||
is_appendix = False
|
||||
|
||||
# 保护特殊环境
|
||||
protected_blocks = {}
|
||||
content = self._protect_special_environments(content, protected_blocks)
|
||||
|
||||
# 按段落分割
|
||||
paragraphs = re.split(r'\n\s*\n', content)
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
# 恢复特殊环境
|
||||
para = self._restore_special_environments(para, protected_blocks)
|
||||
|
||||
# 更新章节信息
|
||||
section_info = self._get_section_info(para, content)
|
||||
if section_info:
|
||||
current_section, is_appendix = section_info
|
||||
|
||||
# 判断是否是特殊环境
|
||||
if self._is_special_environment(para):
|
||||
# 处理当前缓冲区
|
||||
if current_buffer:
|
||||
segments.append((
|
||||
'\n'.join(current_buffer),
|
||||
current_section,
|
||||
is_appendix
|
||||
))
|
||||
current_buffer = []
|
||||
current_length = 0
|
||||
|
||||
# 添加特殊环境作为独立片段
|
||||
segments.append((para, current_section, is_appendix))
|
||||
continue
|
||||
|
||||
# 处理普通段落
|
||||
sentences = self._split_into_sentences(para)
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
sent_length = len(sentence)
|
||||
new_length = current_length + sent_length + (1 if current_buffer else 0)
|
||||
|
||||
if new_length <= self.max_chars:
|
||||
current_buffer.append(sentence)
|
||||
current_length = new_length
|
||||
else:
|
||||
# 如果当前缓冲区达到最小长度要求
|
||||
if current_length >= self.min_chars:
|
||||
segments.append((
|
||||
'\n'.join(current_buffer),
|
||||
current_section,
|
||||
is_appendix
|
||||
))
|
||||
current_buffer = [sentence]
|
||||
current_length = sent_length
|
||||
else:
|
||||
# 尝试将过长的句子分割
|
||||
split_sentences = self._split_long_sentence(sentence)
|
||||
for split_sent in split_sentences:
|
||||
if current_length + len(split_sent) <= self.max_chars:
|
||||
current_buffer.append(split_sent)
|
||||
current_length += len(split_sent) + 1
|
||||
else:
|
||||
segments.append((
|
||||
'\n'.join(current_buffer),
|
||||
current_section,
|
||||
is_appendix
|
||||
))
|
||||
current_buffer = [split_sent]
|
||||
current_length = len(split_sent)
|
||||
|
||||
# 处理剩余的缓冲区
|
||||
if current_buffer:
|
||||
segments.append((
|
||||
'\n'.join(current_buffer),
|
||||
current_section,
|
||||
is_appendix
|
||||
))
|
||||
|
||||
return segments
|
||||
|
||||
def _split_into_sentences(self, text: str) -> List[str]:
|
||||
"""将文本分割成句子"""
|
||||
return re.split(r'(?<=[.!?。!?])\s+', text)
|
||||
|
||||
def _split_long_sentence(self, sentence: str) -> List[str]:
|
||||
"""智能分割过长的句子"""
|
||||
if len(sentence) <= self.max_chars:
|
||||
return [sentence]
|
||||
|
||||
result = []
|
||||
while sentence:
|
||||
# 在最大长度位置寻找合适的分割点
|
||||
split_pos = self._find_split_position(sentence[:self.max_chars])
|
||||
if split_pos <= 0:
|
||||
split_pos = self.max_chars
|
||||
|
||||
result.append(sentence[:split_pos])
|
||||
sentence = sentence[split_pos:].strip()
|
||||
|
||||
return result
|
||||
|
||||
def _find_split_position(self, text: str) -> int:
|
||||
"""找到合适的句子分割位置"""
|
||||
# 优先在标点符号处分割
|
||||
punctuation_match = re.search(r'[,,;;]\s*', text[::-1])
|
||||
if punctuation_match:
|
||||
return len(text) - punctuation_match.end()
|
||||
|
||||
# 其次在空白字符处分割
|
||||
space_match = re.search(r'\s+', text[::-1])
|
||||
if space_match:
|
||||
return len(text) - space_match.end()
|
||||
|
||||
return -1
|
||||
|
||||
def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
|
||||
"""保护特殊环境内容"""
|
||||
for env_type, patterns in self.special_envs.items():
|
||||
for pattern in patterns:
|
||||
content = re.sub(
|
||||
pattern,
|
||||
lambda m: self._store_protected_block(m.group(0), protected_blocks),
|
||||
content,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
return content
|
||||
|
||||
def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str:
|
||||
"""存储受保护的内容块"""
|
||||
placeholder = f"PROTECTED_{len(protected_blocks)}"
|
||||
protected_blocks[placeholder] = content
|
||||
return placeholder
|
||||
|
||||
def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str:
|
||||
"""恢复特殊环境内容"""
|
||||
for placeholder, original in protected_blocks.items():
|
||||
content = content.replace(placeholder, original)
|
||||
return content
|
||||
|
||||
def _is_special_environment(self, text: str) -> bool:
|
||||
"""判断是否是特殊环境"""
|
||||
for patterns in self.special_envs.values():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text, re.DOTALL):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _preprocess_content(self, content: str) -> str:
|
||||
"""预处理TEX内容"""
|
||||
# 移除注释
|
||||
content = re.sub(r'(?m)%.*$', '', content)
|
||||
# 规范化空白字符
|
||||
content = re.sub(r'\s+', ' ', content)
|
||||
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||
# 移除不必要的命令
|
||||
content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content)
|
||||
return content.strip()
|
||||
|
||||
def process(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]:
|
||||
"""处理单篇arxiv论文"""
|
||||
try:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||
paper_dir = self._download_and_extract(arxiv_id)
|
||||
|
||||
# 查找主tex文件
|
||||
main_tex = self._find_main_tex_file(paper_dir)
|
||||
if not main_tex:
|
||||
raise RuntimeError(f"No main tex file found in {paper_dir}")
|
||||
|
||||
# 获取所有相关tex文件
|
||||
tex_files = self._resolve_includes(main_tex)
|
||||
|
||||
# 处理所有tex文件
|
||||
fragments = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(self._process_single_tex, file_path): file_path
|
||||
for file_path in tex_files
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_file):
|
||||
try:
|
||||
fragments.extend(future.result())
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
|
||||
# 重新计算片段索引
|
||||
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
|
||||
total_fragments = len(fragments)
|
||||
|
||||
for i, fragment in enumerate(fragments):
|
||||
fragment.segment_index = i
|
||||
fragment.total_segments = total_fragments
|
||||
yield fragment
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing paper {arxiv_id_or_url}: {e}")
|
||||
raise RuntimeError(f"Failed to process paper: {str(e)}")
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化arxiv ID"""
|
||||
if input_str.startswith('https://arxiv.org/'):
|
||||
if '/pdf/' in input_str:
|
||||
return input_str.split('/pdf/')[-1].split('v')[0]
|
||||
return input_str.split('/abs/')[-1].split('v')[0]
|
||||
return input_str.split('v')[0]
|
||||
|
||||
def _download_and_extract(self, arxiv_id: str) -> str:
|
||||
"""下载并解压arxiv论文源码"""
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 检查缓存
|
||||
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||||
logging.info(f"Using cached version for {arxiv_id}")
|
||||
return str(paper_dir)
|
||||
|
||||
paper_dir.mkdir(exist_ok=True)
|
||||
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
logging.info(f"Downloading from {url}")
|
||||
response = requests.get(url, proxies=self.proxies)
|
||||
if response.status_code == 200:
|
||||
tar_path.write_bytes(response.content)
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.extractall(path=paper_dir)
|
||||
return str(paper_dir)
|
||||
except Exception as e:
|
||||
logging.warning(f"Download failed for {url}: {e}")
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||
|
||||
def _read_file(self, file_path: str) -> Optional[str]:
|
||||
"""使用多种编码尝试读取文件"""
|
||||
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
logging.warning(f"Failed to read file {file_path} with all encodings")
|
||||
return None
|
||||
|
||||
def _extract_metadata(self, content: str) -> Tuple[str, str]:
|
||||
"""提取论文标题和摘要"""
|
||||
title = ""
|
||||
abstract = ""
|
||||
|
||||
# 提取标题
|
||||
title_patterns = [
|
||||
r'\\title{([^}]*)}',
|
||||
r'\\Title{([^}]*)}'
|
||||
]
|
||||
for pattern in title_patterns:
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
title = match.group(1)
|
||||
title = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', title)
|
||||
break
|
||||
|
||||
# 提取摘要
|
||||
abstract_patterns = [
|
||||
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||
r'\\abstract{([^}]*)}'
|
||||
]
|
||||
for pattern in abstract_patterns:
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
abstract = match.group(1).strip()
|
||||
abstract = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', abstract)
|
||||
break
|
||||
|
||||
return title.strip(), abstract.strip()
|
||||
|
||||
def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]:
|
||||
"""获取段落所属的章节信息"""
|
||||
section = "Unknown Section"
|
||||
is_appendix = False
|
||||
|
||||
# 查找所有章节标记
|
||||
all_sections = []
|
||||
for pattern in self.section_patterns:
|
||||
for match in re.finditer(pattern, content):
|
||||
all_sections.append((match.start(), match.group(2)))
|
||||
|
||||
# 查找appendix标记
|
||||
appendix_pos = content.find(r'\appendix')
|
||||
|
||||
# 确定当前章节
|
||||
para_pos = content.find(para)
|
||||
if para_pos >= 0:
|
||||
current_section = None
|
||||
for sec_pos, sec_title in sorted(all_sections):
|
||||
if sec_pos > para_pos:
|
||||
break
|
||||
current_section = sec_title
|
||||
|
||||
if current_section:
|
||||
section = current_section
|
||||
is_appendix = appendix_pos >= 0 and para_pos > appendix_pos
|
||||
|
||||
return section, is_appendix
|
||||
|
||||
return None
|
||||
|
||||
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
|
||||
"""处理单个TEX文件"""
|
||||
try:
|
||||
content = self._read_file(file_path)
|
||||
if not content:
|
||||
return []
|
||||
|
||||
# 提取元数据
|
||||
is_main = r'\documentclass' in content
|
||||
title = ""
|
||||
abstract = ""
|
||||
if is_main:
|
||||
title, abstract = self._extract_metadata(content)
|
||||
|
||||
# 智能分割内容
|
||||
segments = self._smart_split(content)
|
||||
fragments = []
|
||||
|
||||
for i, (segment_content, section, is_appendix) in enumerate(segments):
|
||||
if segment_content.strip():
|
||||
segment_type = 'text'
|
||||
for env_type, patterns in self.special_envs.items():
|
||||
if any(re.search(pattern, segment_content, re.DOTALL)
|
||||
for pattern in patterns):
|
||||
segment_type = env_type
|
||||
break
|
||||
|
||||
fragments.append(ArxivFragment(
|
||||
file_path=file_path,
|
||||
content=segment_content,
|
||||
segment_index=i,
|
||||
total_segments=len(segments),
|
||||
rel_path=os.path.relpath(file_path, str(self.root_dir)),
|
||||
segment_type=segment_type,
|
||||
title=title,
|
||||
abstract=abstract,
|
||||
section=section,
|
||||
is_appendix=is_appendix
|
||||
))
|
||||
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {file_path}: {e}")
|
||||
return []
|
||||
|
||||
def main():
|
||||
"""使用示例"""
|
||||
# 创建分割器实例
|
||||
splitter = SmartArxivSplitter(
|
||||
char_range=(1000, 1200),
|
||||
root_dir="gpt_log/arxiv_cache"
|
||||
)
|
||||
|
||||
# 处理论文
|
||||
for fragment in splitter.process("2411.03663"):
|
||||
print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}")
|
||||
print(f"Length: {len(fragment.content)}")
|
||||
print(f"Section: {fragment.section}")
|
||||
print(f"Title: {fragment.file_path}")
|
||||
|
||||
print(fragment.content)
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,311 +0,0 @@
|
||||
from typing import Tuple, Optional, Generator, List
|
||||
from toolbox import update_ui, update_ui_lastest_msg, get_conf
|
||||
import os, tarfile, requests, time, re
|
||||
class ArxivPaperProcessor:
|
||||
"""Arxiv论文处理器类"""
|
||||
|
||||
def __init__(self):
|
||||
self.supported_encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||
self.arxiv_cache_dir = get_conf("ARXIV_CACHE_DIR")
|
||||
|
||||
def download_and_extract(self, txt: str, chatbot, history) -> Generator[Optional[Tuple[str, str]], None, None]:
|
||||
"""
|
||||
Step 1: 下载和提取arxiv论文
|
||||
返回: 生成器: (project_folder, arxiv_id)
|
||||
"""
|
||||
try:
|
||||
if txt == "":
|
||||
chatbot.append(("", "请输入arxiv论文链接或ID"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
project_folder, arxiv_id = self.arxiv_download(txt, chatbot, history)
|
||||
if project_folder is None or arxiv_id is None:
|
||||
return
|
||||
|
||||
if not os.path.exists(project_folder):
|
||||
chatbot.append((txt, f"找不到项目文件夹: {project_folder}"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 期望的返回值
|
||||
yield project_folder, arxiv_id
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# yield from update_ui_lastest_msg(
|
||||
# "下载失败,请手动下载latex源码:请前往arxiv打开此论文下载页面,点other Formats,然后download source。",
|
||||
# chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
def arxiv_download(self, txt: str, chatbot, history) -> Tuple[str, str]:
|
||||
"""
|
||||
下载arxiv论文并提取
|
||||
返回: (project_folder, arxiv_id)
|
||||
"""
|
||||
def is_float(s: str) -> bool:
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if txt.startswith('https://arxiv.org/pdf/'):
|
||||
arxiv_id = txt.split('/')[-1] # 2402.14207v2.pdf
|
||||
txt = arxiv_id.split('v')[0] # 2402.14207
|
||||
|
||||
if ('.' in txt) and ('/' not in txt) and is_float(txt): # is arxiv ID
|
||||
txt = 'https://arxiv.org/abs/' + txt.strip()
|
||||
if ('.' in txt) and ('/' not in txt) and is_float(txt[:10]): # is arxiv ID
|
||||
txt = 'https://arxiv.org/abs/' + txt[:10]
|
||||
|
||||
if not txt.startswith('https://arxiv.org'):
|
||||
chatbot.append((txt, "不是有效的arxiv链接或ID"))
|
||||
# yield from update_ui(chatbot=chatbot, history=history)
|
||||
return None, None # 返回两个值,即使其中一个为None
|
||||
|
||||
chatbot.append([f"检测到arxiv文档连接", '尝试下载 ...'])
|
||||
# yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
url_ = txt # https://arxiv.org/abs/1707.06690
|
||||
|
||||
if not txt.startswith('https://arxiv.org/abs/'):
|
||||
msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}。"
|
||||
# yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
|
||||
return None, None # 返回两个值,即使其中一个为None
|
||||
|
||||
arxiv_id = url_.split('/')[-1].split('v')[0]
|
||||
|
||||
dst = os.path.join(self.arxiv_cache_dir, arxiv_id, f'{arxiv_id}.tar.gz')
|
||||
project_folder = os.path.join(self.arxiv_cache_dir, arxiv_id)
|
||||
|
||||
success = self.download_arxiv_paper(url_, dst, chatbot, history)
|
||||
|
||||
# if os.path.exists(dst) and get_conf('allow_cache'):
|
||||
# # yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
# success = True
|
||||
# else:
|
||||
# # yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
# success = self.download_arxiv_paper(url_, dst, chatbot, history)
|
||||
# # yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
if not success:
|
||||
# chatbot.append([f"下载失败 {arxiv_id}", ""])
|
||||
# yield from update_ui(chatbot=chatbot, history=history)
|
||||
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
|
||||
|
||||
# yield from update_ui_lastest_msg(f"开始解压 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
extract_dst = self.extract_tar_file(dst, project_folder, chatbot, history)
|
||||
# yield from update_ui_lastest_msg(f"解压完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
return extract_dst, arxiv_id
|
||||
|
||||
def download_arxiv_paper(self, url_: str, dst: str, chatbot, history) -> bool:
|
||||
"""下载arxiv论文"""
|
||||
try:
|
||||
proxies = get_conf('proxies')
|
||||
for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
|
||||
r = requests.get(url_tar, proxies=proxies)
|
||||
if r.status_code == 200:
|
||||
with open(dst, 'wb+') as f:
|
||||
f.write(r.content)
|
||||
return True
|
||||
return False
|
||||
except requests.RequestException as e:
|
||||
# chatbot.append((f"下载失败 {url_}", str(e)))
|
||||
# yield from update_ui(chatbot=chatbot, history=history)
|
||||
return False
|
||||
|
||||
def extract_tar_file(self, file_path: str, dest_dir: str, chatbot, history) -> str:
|
||||
"""解压arxiv论文"""
|
||||
try:
|
||||
with tarfile.open(file_path, 'r:gz') as tar:
|
||||
tar.extractall(path=dest_dir)
|
||||
return dest_dir
|
||||
except tarfile.ReadError as e:
|
||||
chatbot.append((f"解压失败 {file_path}", str(e)))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
raise e
|
||||
|
||||
def find_main_tex_file(self, tex_files: list) -> str:
|
||||
"""查找主TEX文件"""
|
||||
for tex_file in tex_files:
|
||||
with open(tex_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
if r'\documentclass' in content:
|
||||
return tex_file
|
||||
return max(tex_files, key=lambda x: os.path.getsize(x))
|
||||
|
||||
def read_file_with_encoding(self, file_path: str) -> Optional[str]:
|
||||
"""使用多种编码尝试读取文件"""
|
||||
for encoding in self.supported_encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
def process_tex_content(self, content: str, base_path: str, processed_files=None) -> str:
|
||||
"""处理TEX内容,包括递归处理包含的文件"""
|
||||
if processed_files is None:
|
||||
processed_files = set()
|
||||
|
||||
include_patterns = [
|
||||
r'\\input{([^}]+)}',
|
||||
r'\\include{([^}]+)}',
|
||||
r'\\subfile{([^}]+)}',
|
||||
r'\\input\s+([^\s{]+)',
|
||||
]
|
||||
|
||||
for pattern in include_patterns:
|
||||
matches = re.finditer(pattern, content)
|
||||
for match in matches:
|
||||
include_file = match.group(1)
|
||||
if not include_file.endswith('.tex'):
|
||||
include_file += '.tex'
|
||||
|
||||
include_path = os.path.join(base_path, include_file)
|
||||
include_path = os.path.normpath(include_path)
|
||||
|
||||
if include_path in processed_files:
|
||||
continue
|
||||
processed_files.add(include_path)
|
||||
|
||||
if os.path.exists(include_path):
|
||||
included_content = self.read_file_with_encoding(include_path)
|
||||
if included_content:
|
||||
included_content = self.process_tex_content(
|
||||
included_content,
|
||||
os.path.dirname(include_path),
|
||||
processed_files
|
||||
)
|
||||
content = content.replace(match.group(0), included_content)
|
||||
|
||||
return content
|
||||
|
||||
def merge_tex_files(self, folder_path: str, chatbot, history) -> Optional[str]:
|
||||
"""
|
||||
Step 2: 合并TEX文件
|
||||
返回: 合并后的内容
|
||||
"""
|
||||
try:
|
||||
tex_files = []
|
||||
for root, _, files in os.walk(folder_path):
|
||||
tex_files.extend([os.path.join(root, f) for f in files if f.endswith('.tex')])
|
||||
|
||||
if not tex_files:
|
||||
chatbot.append(("", "未找到任何TEX文件"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
main_tex_file = self.find_main_tex_file(tex_files)
|
||||
chatbot.append(("", f"找到主TEX文件:{os.path.basename(main_tex_file)}"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
tex_content = self.read_file_with_encoding(main_tex_file)
|
||||
if tex_content is None:
|
||||
chatbot.append(("", "无法读取TEX文件,可能是编码问题"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
full_content = self.process_tex_content(
|
||||
tex_content,
|
||||
os.path.dirname(main_tex_file)
|
||||
)
|
||||
|
||||
cleaned_content = self.clean_tex_content(full_content)
|
||||
|
||||
chatbot.append(("",
|
||||
f"成功处理所有TEX文件:\n"
|
||||
f"- 原始内容大小:{len(full_content)}字符\n"
|
||||
f"- 清理后内容大小:{len(cleaned_content)}字符"
|
||||
))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 添加标题和摘要提取
|
||||
title = ""
|
||||
abstract = ""
|
||||
if tex_content:
|
||||
# 提取标题
|
||||
title_match = re.search(r'\\title{([^}]*)}', tex_content)
|
||||
if title_match:
|
||||
title = title_match.group(1)
|
||||
|
||||
# 提取摘要
|
||||
abstract_match = re.search(r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||
tex_content, re.DOTALL)
|
||||
if abstract_match:
|
||||
abstract = abstract_match.group(1)
|
||||
|
||||
# 按token限制分段
|
||||
def split_by_token_limit(text: str, token_limit: int = 1024) -> List[str]:
|
||||
segments = []
|
||||
current_segment = []
|
||||
current_tokens = 0
|
||||
|
||||
for line in text.split('\n'):
|
||||
line_tokens = len(line.split())
|
||||
if current_tokens + line_tokens > token_limit:
|
||||
segments.append('\n'.join(current_segment))
|
||||
current_segment = [line]
|
||||
current_tokens = line_tokens
|
||||
else:
|
||||
current_segment.append(line)
|
||||
current_tokens += line_tokens
|
||||
|
||||
if current_segment:
|
||||
segments.append('\n'.join(current_segment))
|
||||
|
||||
return segments
|
||||
|
||||
text_segments = split_by_token_limit(cleaned_content)
|
||||
|
||||
return {
|
||||
'title': title,
|
||||
'abstract': abstract,
|
||||
'segments': text_segments
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
chatbot.append(("", f"处理TEX文件时发生错误:{str(e)}"))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def clean_tex_content(content: str) -> str:
|
||||
"""清理TEX内容"""
|
||||
content = re.sub(r'(?m)%.*$', '', content) # 移除注释
|
||||
content = re.sub(r'\\cite{[^}]*}', '', content) # 移除引用
|
||||
content = re.sub(r'\\label{[^}]*}', '', content) # 移除标签
|
||||
content = re.sub(r'\s+', ' ', content) # 规范化空白
|
||||
return content.strip()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试 arxiv_download 函数
|
||||
processor = ArxivPaperProcessor()
|
||||
chatbot = []
|
||||
history = []
|
||||
|
||||
# 测试不同格式的输入
|
||||
test_inputs = [
|
||||
"https://arxiv.org/abs/2402.14207", # 标准格式
|
||||
"https://arxiv.org/pdf/2402.14207.pdf", # PDF链接格式
|
||||
"2402.14207", # 纯ID格式
|
||||
"2402.14207v1", # 带版本号的ID格式
|
||||
"https://invalid.url", # 无效URL测试
|
||||
]
|
||||
|
||||
for input_url in test_inputs:
|
||||
print(f"\n测试输入: {input_url}")
|
||||
try:
|
||||
project_folder, arxiv_id = processor.arxiv_download(input_url, chatbot, history)
|
||||
if project_folder and arxiv_id:
|
||||
print(f"下载成功:")
|
||||
print(f"- 项目文件夹: {project_folder}")
|
||||
print(f"- Arxiv ID: {arxiv_id}")
|
||||
print(f"- 文件夹是否存在: {os.path.exists(project_folder)}")
|
||||
else:
|
||||
print("下载失败: 返回值为 None")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
@@ -1,164 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
import os
|
||||
from toolbox import get_conf
|
||||
import openai
|
||||
|
||||
class RagHandler:
|
||||
def __init__(self):
|
||||
# 初始化工作目录
|
||||
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
|
||||
if not os.path.exists(self.working_dir):
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
# 初始化 LightRAG
|
||||
self.rag = LightRAG(
|
||||
working_dir=self.working_dir,
|
||||
llm_model_func=self._llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1536, # OpenAI embedding 维度
|
||||
max_token_size=8192,
|
||||
func=self._embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
|
||||
history_messages: List = None, **kwargs) -> str:
|
||||
"""LLM 模型函数"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0),
|
||||
max_tokens=kwargs.get("max_tokens", 1000)
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
|
||||
"""Embedding 函数"""
|
||||
response = await openai.Embedding.acreate(
|
||||
model="text-embedding-ada-002",
|
||||
input=texts
|
||||
)
|
||||
embeddings = [item["embedding"] for item in response["data"]]
|
||||
return np.array(embeddings)
|
||||
|
||||
def process_paper_content(self, paper_content: Dict) -> None:
|
||||
"""处理论文内容,构建知识图谱"""
|
||||
# 处理标题和摘要
|
||||
content_list = []
|
||||
if paper_content['title']:
|
||||
content_list.append(f"Title: {paper_content['title']}")
|
||||
if paper_content['abstract']:
|
||||
content_list.append(f"Abstract: {paper_content['abstract']}")
|
||||
|
||||
# 添加分段内容
|
||||
content_list.extend(paper_content['segments'])
|
||||
|
||||
# 插入到 RAG 系统
|
||||
self.rag.insert(content_list)
|
||||
|
||||
def query(self, question: str, mode: str = "hybrid") -> str:
|
||||
"""查询论文内容
|
||||
mode: 查询模式,可选 naive/local/global/hybrid
|
||||
"""
|
||||
try:
|
||||
response = self.rag.query(
|
||||
question,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
top_k=5, # 返回相关度最高的5个结果
|
||||
max_token_for_text_unit=2048, # 每个文本单元的最大token数
|
||||
response_type="detailed" # 返回详细回答
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
return f"查询出错: {str(e)}"
|
||||
|
||||
|
||||
class RagHandler:
|
||||
def __init__(self):
|
||||
# 初始化工作目录
|
||||
self.working_dir = os.path.join(get_conf('ARXIV_CACHE_DIR'), 'rag_cache')
|
||||
if not os.path.exists(self.working_dir):
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
# 初始化 LightRAG
|
||||
self.rag = LightRAG(
|
||||
working_dir=self.working_dir,
|
||||
llm_model_func=self._llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1536, # OpenAI embedding 维度
|
||||
max_token_size=8192,
|
||||
func=self._embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
async def _llm_model_func(self, prompt: str, system_prompt: str = None,
|
||||
history_messages: List = None, **kwargs) -> str:
|
||||
"""LLM 模型函数"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0),
|
||||
max_tokens=kwargs.get("max_tokens", 1000)
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def _embedding_func(self, texts: List[str]) -> np.ndarray:
|
||||
"""Embedding 函数"""
|
||||
response = await openai.Embedding.acreate(
|
||||
model="text-embedding-ada-002",
|
||||
input=texts
|
||||
)
|
||||
embeddings = [item["embedding"] for item in response["data"]]
|
||||
return np.array(embeddings)
|
||||
|
||||
def process_paper_content(self, paper_content: Dict) -> None:
|
||||
"""处理论文内容,构建知识图谱"""
|
||||
# 处理标题和摘要
|
||||
content_list = []
|
||||
if paper_content['title']:
|
||||
content_list.append(f"Title: {paper_content['title']}")
|
||||
if paper_content['abstract']:
|
||||
content_list.append(f"Abstract: {paper_content['abstract']}")
|
||||
|
||||
# 添加分段内容
|
||||
content_list.extend(paper_content['segments'])
|
||||
|
||||
# 插入到 RAG 系统
|
||||
self.rag.insert(content_list)
|
||||
|
||||
def query(self, question: str, mode: str = "hybrid") -> str:
|
||||
"""查询论文内容
|
||||
mode: 查询模式,可选 naive/local/global/hybrid
|
||||
"""
|
||||
try:
|
||||
response = self.rag.query(
|
||||
question,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
top_k=5, # 返回相关度最高的5个结果
|
||||
max_token_for_text_unit=2048, # 每个文本单元的最大token数
|
||||
response_type="detailed" # 返回详细回答
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
return f"查询出错: {str(e)}"
|
||||
0
crazy_functions/rag_fns/arxiv_fns/__init__.py
Normal file
0
crazy_functions/rag_fns/arxiv_fns/__init__.py
Normal file
111
crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py
Normal file
111
crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging
|
||||
import requests
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
class ArxivDownloader:
|
||||
"""用于下载arXiv论文源码的下载器"""
|
||||
|
||||
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
初始化下载器
|
||||
|
||||
Args:
|
||||
root_dir: 保存下载文件的根目录
|
||||
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(exist_ok=True)
|
||||
self.proxies = proxies
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def _download_and_extract(self, arxiv_id: str) -> str:
|
||||
"""
|
||||
下载并解压arxiv论文源码
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv论文ID,例如"2103.00020"
|
||||
|
||||
Returns:
|
||||
str: 解压后的文件目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当下载失败时抛出
|
||||
"""
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 检查缓存
|
||||
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||||
logging.info(f"Using cached version for {arxiv_id}")
|
||||
return str(paper_dir)
|
||||
|
||||
paper_dir.mkdir(exist_ok=True)
|
||||
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
logging.info(f"Downloading from {url}")
|
||||
response = requests.get(url, proxies=self.proxies)
|
||||
if response.status_code == 200:
|
||||
tar_path.write_bytes(response.content)
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.extractall(path=paper_dir)
|
||||
return str(paper_dir)
|
||||
except Exception as e:
|
||||
logging.warning(f"Download failed for {url}: {e}")
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||
|
||||
def download_paper(self, arxiv_id: str) -> str:
|
||||
"""
|
||||
下载指定的arXiv论文
|
||||
|
||||
Args:
|
||||
arxiv_id: arXiv论文ID
|
||||
|
||||
Returns:
|
||||
str: 论文文件所在的目录路径
|
||||
"""
|
||||
return self._download_and_extract(arxiv_id)
|
||||
|
||||
def main():
|
||||
"""测试下载功能"""
|
||||
# 配置代理(如果需要)
|
||||
proxies = {
|
||||
"http": "http://your-proxy:port",
|
||||
"https": "https://your-proxy:port"
|
||||
}
|
||||
|
||||
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
||||
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
||||
|
||||
# 测试下载一篇论文(这里使用一个示例ID)
|
||||
try:
|
||||
paper_id = "2103.00020" # 这是一个示例ID
|
||||
paper_dir = downloader.download_paper(paper_id)
|
||||
print(f"Successfully downloaded paper to: {paper_dir}")
|
||||
|
||||
# 检查下载的文件
|
||||
paper_path = Path(paper_dir)
|
||||
if paper_path.exists():
|
||||
print("Downloaded files:")
|
||||
for file in paper_path.rglob("*"):
|
||||
if file.is_file():
|
||||
print(f"- {file.relative_to(paper_path)}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading paper: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py
Normal file
55
crazy_functions/rag_fns/arxiv_fns/arxiv_fragment.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ArxivFragment:
|
||||
"""Arxiv论文片段数据类"""
|
||||
file_path: str # 文件路径
|
||||
content: str # 内容
|
||||
segment_index: int # 片段索引
|
||||
total_segments: int # 总片段数
|
||||
rel_path: str # 相对路径
|
||||
segment_type: str # 片段类型(text/math/table/figure等)
|
||||
title: str # 论文标题
|
||||
abstract: str # 论文摘要
|
||||
section: str # 所属章节
|
||||
is_appendix: bool # 是否是附录
|
||||
importance: float = 1.0 # 重要性得分
|
||||
|
||||
@staticmethod
|
||||
def merge_segments(seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment':
|
||||
"""
|
||||
合并两个片段的静态方法
|
||||
|
||||
Args:
|
||||
seg1: 第一个片段
|
||||
seg2: 第二个片段
|
||||
|
||||
Returns:
|
||||
ArxivFragment: 合并后的片段
|
||||
"""
|
||||
# 合并内容
|
||||
merged_content = f"{seg1.content}\n{seg2.content}"
|
||||
|
||||
# 确定合并后的类型
|
||||
def _merge_segment_type(type1: str, type2: str) -> str:
|
||||
if type1 == type2:
|
||||
return type1
|
||||
if type1 == 'text':
|
||||
return type2
|
||||
if type2 == 'text':
|
||||
return type1
|
||||
return 'mixed'
|
||||
|
||||
return ArxivFragment(
|
||||
file_path=seg1.file_path,
|
||||
content=merged_content,
|
||||
segment_index=seg1.segment_index,
|
||||
total_segments=seg1.total_segments - 1,
|
||||
rel_path=seg1.rel_path,
|
||||
segment_type=_merge_segment_type(seg1.segment_type, seg2.segment_type),
|
||||
title=seg1.title,
|
||||
abstract=seg1.abstract,
|
||||
section=seg1.section,
|
||||
is_appendix=seg1.is_appendix,
|
||||
importance=max(seg1.importance, seg2.importance)
|
||||
)
|
||||
449
crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
Normal file
449
crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
Normal file
@@ -0,0 +1,449 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import requests
|
||||
import tarfile
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Tuple, Optional, Dict, Set
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment
|
||||
|
||||
|
||||
|
||||
def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"):
|
||||
"""
|
||||
将所有fragments保存为单个结构化markdown文件
|
||||
|
||||
Args:
|
||||
fragments: fragment列表
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
# 创建输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
filename = f"fragments_{timestamp}.md"
|
||||
file_path = output_path / filename
|
||||
|
||||
current_section = ""
|
||||
section_count = {} # 用于跟踪每个章节的片段数量
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
# 写入文档头部
|
||||
f.write("# Document Fragments Analysis\n\n")
|
||||
f.write("## Overview\n")
|
||||
f.write(f"- Total Fragments: {len(fragments)}\n")
|
||||
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
|
||||
# 如果有标题和摘要,添加到开头
|
||||
if fragments and (fragments[0].title or fragments[0].abstract):
|
||||
f.write("\n## Paper Information\n")
|
||||
if fragments[0].title:
|
||||
f.write(f"### Title\n{fragments[0].title}\n")
|
||||
if fragments[0].abstract:
|
||||
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
|
||||
|
||||
# 生成目录
|
||||
f.write("\n## Table of Contents\n")
|
||||
|
||||
# 首先收集所有章节信息
|
||||
sections = {}
|
||||
for fragment in fragments:
|
||||
section = fragment.section or "Uncategorized"
|
||||
if section not in sections:
|
||||
sections[section] = []
|
||||
sections[section].append(fragment)
|
||||
|
||||
# 写入目录
|
||||
for section, section_fragments in sections.items():
|
||||
clean_section = section.strip()
|
||||
if not clean_section:
|
||||
clean_section = "Uncategorized"
|
||||
f.write(
|
||||
f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n")
|
||||
|
||||
# 写入正文内容
|
||||
f.write("\n## Content\n")
|
||||
|
||||
# 按章节组织内容
|
||||
for section, section_fragments in sections.items():
|
||||
clean_section = section.strip() or "Uncategorized"
|
||||
f.write(f"\n### {clean_section}\n")
|
||||
|
||||
# 写入每个fragment
|
||||
for i, fragment in enumerate(section_fragments, 1):
|
||||
f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n")
|
||||
|
||||
# 元数据
|
||||
f.write("**Metadata:**\n")
|
||||
f.write(f"- Type: {fragment.segment_type}\n")
|
||||
f.write(f"- Length: {len(fragment.content)} chars\n")
|
||||
f.write(f"- Importance: {fragment.importance:.2f}\n")
|
||||
f.write(f"- Is Appendix: {fragment.is_appendix}\n")
|
||||
f.write(f"- File: {fragment.rel_path}\n")
|
||||
|
||||
# 内容
|
||||
f.write("\n**Content:**\n")
|
||||
f.write("```tex\n")
|
||||
f.write(fragment.content)
|
||||
f.write("\n```\n")
|
||||
|
||||
# 添加分隔线
|
||||
if i < len(section_fragments):
|
||||
f.write("\n---\n")
|
||||
|
||||
# 添加统计信息
|
||||
f.write("\n## Statistics\n")
|
||||
f.write("\n### Fragment Type Distribution\n")
|
||||
type_stats = {}
|
||||
for fragment in fragments:
|
||||
type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1
|
||||
|
||||
for ftype, count in type_stats.items():
|
||||
percentage = (count / len(fragments)) * 100
|
||||
f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n")
|
||||
|
||||
# 长度分布
|
||||
f.write("\n### Length Distribution\n")
|
||||
lengths = [len(f.content) for f in fragments]
|
||||
f.write(f"- Minimum: {min(lengths)} chars\n")
|
||||
f.write(f"- Maximum: {max(lengths)} chars\n")
|
||||
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
|
||||
|
||||
print(f"Fragments saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
|
||||
class ArxivSplitter:
|
||||
"""Arxiv论文智能分割器"""
|
||||
|
||||
def __init__(self,
|
||||
char_range: Tuple[int, int],
|
||||
root_dir: str = "gpt_log/arxiv_cache",
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
cache_ttl: int = 7 * 24 * 60 * 60):
|
||||
"""
|
||||
初始化分割器
|
||||
|
||||
Args:
|
||||
char_range: 字符数范围(最小值, 最大值)
|
||||
root_dir: 缓存根目录
|
||||
proxies: 代理设置
|
||||
cache_ttl: 缓存过期时间(秒)
|
||||
"""
|
||||
self.min_chars, self.max_chars = char_range
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.proxies = proxies or {}
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
# 动态计算最优线程数
|
||||
import multiprocessing
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
# 根据CPU核心数动态设置,但设置上限防止过度并发
|
||||
self.max_workers = min(32, cpu_count * 2)
|
||||
|
||||
# 初始化TeX处理器
|
||||
self.tex_processor = TexProcessor(char_range)
|
||||
|
||||
# 配置日志
|
||||
self._setup_logging()
|
||||
|
||||
|
||||
|
||||
def _setup_logging(self):
|
||||
"""配置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化ArXiv ID"""
|
||||
if 'arxiv.org/' in input_str.lower():
|
||||
# 处理URL格式
|
||||
if '/pdf/' in input_str:
|
||||
arxiv_id = input_str.split('/pdf/')[-1]
|
||||
else:
|
||||
arxiv_id = input_str.split('/abs/')[-1]
|
||||
# 移除版本号和其他后缀
|
||||
return arxiv_id.split('v')[0].strip()
|
||||
return input_str.split('v')[0].strip()
|
||||
|
||||
|
||||
def _check_cache(self, paper_dir: Path) -> bool:
|
||||
"""
|
||||
检查缓存是否有效,包括文件完整性检查
|
||||
|
||||
Args:
|
||||
paper_dir: 论文目录路径
|
||||
|
||||
Returns:
|
||||
bool: 如果缓存有效返回True,否则返回False
|
||||
"""
|
||||
if not paper_dir.exists():
|
||||
return False
|
||||
|
||||
# 检查目录中是否存在必要文件
|
||||
has_tex_files = False
|
||||
has_main_tex = False
|
||||
|
||||
for file_path in paper_dir.rglob("*"):
|
||||
if file_path.suffix == '.tex':
|
||||
has_tex_files = True
|
||||
content = self.tex_processor.read_file(str(file_path))
|
||||
if content and r'\documentclass' in content:
|
||||
has_main_tex = True
|
||||
break
|
||||
|
||||
if not (has_tex_files and has_main_tex):
|
||||
return False
|
||||
|
||||
# 检查缓存时间
|
||||
cache_time = paper_dir.stat().st_mtime
|
||||
if (time.time() - cache_time) < self.cache_ttl:
|
||||
self.logger.info(f"Using valid cache for {paper_dir.name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
|
||||
"""
|
||||
异步下载论文,包含重试机制和临时文件处理
|
||||
|
||||
Args:
|
||||
arxiv_id: ArXiv论文ID
|
||||
paper_dir: 目标目录路径
|
||||
|
||||
Returns:
|
||||
bool: 下载成功返回True,否则返回False
|
||||
"""
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
|
||||
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
|
||||
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 确保目录存在
|
||||
paper_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 尝试使用 ArxivDownloader 下载
|
||||
try:
|
||||
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
|
||||
downloaded_dir = downloader.download_paper(arxiv_id)
|
||||
if downloaded_dir:
|
||||
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
|
||||
|
||||
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 初始重试延迟(秒)
|
||||
|
||||
for url in urls:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, proxy=self.proxies.get('http')) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# 写入临时文件
|
||||
temp_tar_path.write_bytes(content)
|
||||
|
||||
try:
|
||||
# 验证tar文件完整性并解压
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
|
||||
|
||||
# 下载成功后移动临时文件到最终位置
|
||||
temp_tar_path.rename(final_tar_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Invalid tar file: {str(e)}")
|
||||
if temp_tar_path.exists():
|
||||
temp_tar_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _process_tar_file(self, tar_path: Path, extract_path: Path):
|
||||
"""处理tar文件的同步操作"""
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.testall() # 验证文件完整性
|
||||
tar.extractall(path=extract_path) # 解压文件
|
||||
|
||||
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
|
||||
"""处理单个TeX文件"""
|
||||
try:
|
||||
content = self.tex_processor.read_file(file_path)
|
||||
if not content:
|
||||
return []
|
||||
|
||||
# 提取元数据
|
||||
is_main = r'\documentclass' in content
|
||||
title, abstract = "", ""
|
||||
if is_main:
|
||||
title, abstract = self.tex_processor.extract_metadata(content)
|
||||
|
||||
# 分割内容
|
||||
segments = self.tex_processor.split_content(content)
|
||||
fragments = []
|
||||
|
||||
for i, (segment_content, section, is_appendix) in enumerate(segments):
|
||||
if not segment_content.strip():
|
||||
continue
|
||||
|
||||
segment_type = self.tex_processor.detect_segment_type(segment_content)
|
||||
importance = self.tex_processor.calculate_importance(
|
||||
segment_content, segment_type, is_main
|
||||
)
|
||||
fragments.append(ArxivFragment(
|
||||
file_path=file_path,
|
||||
content=segment_content,
|
||||
segment_index=i,
|
||||
total_segments=len(segments),
|
||||
rel_path=str(Path(file_path).relative_to(self.root_dir)),
|
||||
segment_type=segment_type,
|
||||
title=title,
|
||||
abstract=abstract,
|
||||
section=section,
|
||||
is_appendix=is_appendix,
|
||||
importance=importance
|
||||
))
|
||||
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file_path}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]:
|
||||
"""处理ArXiv论文"""
|
||||
try:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
|
||||
# 检查缓存
|
||||
if not self._check_cache(paper_dir):
|
||||
paper_dir.mkdir(exist_ok=True)
|
||||
if not await self.download_paper(arxiv_id, paper_dir):
|
||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||
|
||||
# 查找主TeX文件
|
||||
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
|
||||
if not main_tex:
|
||||
raise RuntimeError(f"No main TeX file found in {paper_dir}")
|
||||
|
||||
# 获取所有相关TeX文件
|
||||
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||
if not tex_files:
|
||||
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
|
||||
|
||||
# 并行处理所有TeX文件
|
||||
fragments = []
|
||||
chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def process_chunk(chunk_files):
|
||||
chunk_fragments = []
|
||||
for file_path in chunk_files:
|
||||
try:
|
||||
result = await loop.run_in_executor(None, self._process_single_tex, file_path)
|
||||
chunk_fragments.extend(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file_path}: {str(e)}")
|
||||
return chunk_fragments
|
||||
|
||||
# 将文件分成多个块
|
||||
file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)]
|
||||
# 异步处理每个块
|
||||
chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks])
|
||||
for result in chunk_results:
|
||||
fragments.extend(result)
|
||||
# 重新计算片段索引并排序
|
||||
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
|
||||
total_fragments = len(fragments)
|
||||
|
||||
for i, fragment in enumerate(fragments):
|
||||
fragment.segment_index = i
|
||||
fragment.total_segments = total_fragments
|
||||
# 在返回之前添加过滤
|
||||
fragments = self.tex_processor.filter_fragments(fragments)
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def test_arxiv_splitter():
|
||||
"""测试ArXiv分割器的功能"""
|
||||
|
||||
# 测试配置
|
||||
test_cases = [
|
||||
{
|
||||
"arxiv_id": "2411.03663",
|
||||
"expected_title": "Large Language Models and Simple Scripts",
|
||||
"min_fragments": 10,
|
||||
},
|
||||
# {
|
||||
# "arxiv_id": "1805.10988",
|
||||
# "expected_title": "RAG vs Fine-tuning",
|
||||
# "min_fragments": 15,
|
||||
# }
|
||||
]
|
||||
|
||||
# 创建分割器实例
|
||||
splitter = ArxivSplitter(
|
||||
char_range=(800, 1800),
|
||||
root_dir="test_cache"
|
||||
)
|
||||
|
||||
|
||||
for case in test_cases:
|
||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||
try:
|
||||
fragments = await splitter.process(case['arxiv_id'])
|
||||
|
||||
# 保存fragments
|
||||
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
||||
print(f"Output saved to: {output_dir}")
|
||||
# 内容检查
|
||||
for fragment in fragments:
|
||||
# 长度检查
|
||||
|
||||
print((fragment.content))
|
||||
print(len(fragment.content))
|
||||
# 类型检查
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_arxiv_splitter())
|
||||
395
crazy_functions/rag_fns/arxiv_fns/latex_patterns.py
Normal file
395
crazy_functions/rag_fns/arxiv_fns/latex_patterns.py
Normal file
@@ -0,0 +1,395 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class LaTeXPatterns:
|
||||
"""LaTeX模式存储类,用于集中管理所有LaTeX相关的正则表达式模式"""
|
||||
special_envs = {
|
||||
'math': [
|
||||
# 基础数学环境
|
||||
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
|
||||
r'\$\$.*?\$\$',
|
||||
r'\$[^$]+\$',
|
||||
# 矩阵环境
|
||||
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
|
||||
# 数组环境
|
||||
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
|
||||
# 其他数学环境
|
||||
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'table': [
|
||||
# 基础表格环境
|
||||
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
|
||||
# 复杂表格环境
|
||||
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
|
||||
# 自定义表格环境
|
||||
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
|
||||
# 表格注释环境
|
||||
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'figure': [
|
||||
# 图片环境
|
||||
r'\\begin{figure\*?}.*?\\end{figure\*?}',
|
||||
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
|
||||
# 图片插入命令
|
||||
r'\\includegraphics(\[.*?\])?\{.*?\}',
|
||||
# tikz 图形环境
|
||||
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
|
||||
# 其他图形环境
|
||||
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'algorithm': [
|
||||
# 算法环境
|
||||
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
|
||||
# 代码块环境
|
||||
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
|
||||
# 伪代码环境
|
||||
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'list': [
|
||||
# 列表环境
|
||||
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
|
||||
# 自定义列表环境
|
||||
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'theorem': [
|
||||
# 定理类环境
|
||||
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
|
||||
# 其他证明环境
|
||||
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'box': [
|
||||
# 文本框环境
|
||||
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
|
||||
# 强调环境
|
||||
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'quote': [
|
||||
# 引用环境
|
||||
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'bibliography': [
|
||||
# 参考文献环境
|
||||
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
|
||||
],
|
||||
|
||||
'index': [
|
||||
# 索引环境
|
||||
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
|
||||
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
|
||||
]
|
||||
}
|
||||
# 章节模式
|
||||
section_patterns = [
|
||||
# 基础章节命令
|
||||
r'\\chapter\{([^}]+)\}',
|
||||
r'\\section\{([^}]+)\}',
|
||||
r'\\subsection\{([^}]+)\}',
|
||||
r'\\subsubsection\{([^}]+)\}',
|
||||
r'\\paragraph\{([^}]+)\}',
|
||||
r'\\subparagraph\{([^}]+)\}',
|
||||
|
||||
# 带星号的变体(不编号)
|
||||
r'\\chapter\*\{([^}]+)\}',
|
||||
r'\\section\*\{([^}]+)\}',
|
||||
r'\\subsection\*\{([^}]+)\}',
|
||||
r'\\subsubsection\*\{([^}]+)\}',
|
||||
r'\\paragraph\*\{([^}]+)\}',
|
||||
r'\\subparagraph\*\{([^}]+)\}',
|
||||
|
||||
# 特殊章节
|
||||
r'\\part\{([^}]+)\}',
|
||||
r'\\part\*\{([^}]+)\}',
|
||||
r'\\appendix\{([^}]+)\}',
|
||||
|
||||
# 前言部分
|
||||
r'\\frontmatter\{([^}]+)\}',
|
||||
r'\\mainmatter\{([^}]+)\}',
|
||||
r'\\backmatter\{([^}]+)\}',
|
||||
|
||||
# 目录相关
|
||||
r'\\tableofcontents',
|
||||
r'\\listoffigures',
|
||||
r'\\listoftables',
|
||||
|
||||
# 自定义章节命令
|
||||
r'\\addchap\{([^}]+)\}', # KOMA-Script类
|
||||
r'\\addsec\{([^}]+)\}', # KOMA-Script类
|
||||
r'\\minisec\{([^}]+)\}', # KOMA-Script类
|
||||
|
||||
# 带可选参数的章节命令
|
||||
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
|
||||
r'\\section\[([^]]+)\]\{([^}]+)\}',
|
||||
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
|
||||
]
|
||||
|
||||
# 包含模式
|
||||
include_patterns = [
|
||||
r'\\(input|include|subfile)\{([^}]+)\}'
|
||||
]
|
||||
|
||||
metadata_patterns = {
|
||||
# 标题相关
|
||||
'title': [
|
||||
r'\\title\{([^}]+)\}',
|
||||
r'\\Title\{([^}]+)\}',
|
||||
r'\\doctitle\{([^}]+)\}',
|
||||
r'\\subtitle\{([^}]+)\}',
|
||||
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
|
||||
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
|
||||
],
|
||||
|
||||
# 摘要相关
|
||||
'abstract': [
|
||||
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||
r'\\abstract\{([^}]+)\}',
|
||||
r'\\begin{摘要}(.*?)\\end{摘要}',
|
||||
r'\\begin{Summary}(.*?)\\end{Summary}',
|
||||
r'\\begin{synopsis}(.*?)\\end{synopsis}',
|
||||
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
|
||||
],
|
||||
|
||||
# 作者信息
|
||||
'author': [
|
||||
r'\\author\{([^}]+)\}',
|
||||
r'\\Author\{([^}]+)\}',
|
||||
r'\\authorinfo\{([^}]+)\}',
|
||||
r'\\authors\{([^}]+)\}',
|
||||
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
|
||||
r'\\begin{authors}(.*?)\\end{authors}'
|
||||
],
|
||||
|
||||
# 日期相关
|
||||
'date': [
|
||||
r'\\date\{([^}]+)\}',
|
||||
r'\\Date\{([^}]+)\}',
|
||||
r'\\submitdate\{([^}]+)\}',
|
||||
r'\\publishdate\{([^}]+)\}',
|
||||
r'\\revisiondate\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 关键词
|
||||
'keywords': [
|
||||
r'\\keywords\{([^}]+)\}',
|
||||
r'\\Keywords\{([^}]+)\}',
|
||||
r'\\begin{keywords}(.*?)\\end{keywords}',
|
||||
r'\\key\{([^}]+)\}',
|
||||
r'\\begin{关键词}(.*?)\\end{关键词}'
|
||||
],
|
||||
|
||||
# 机构/单位
|
||||
'institution': [
|
||||
r'\\institute\{([^}]+)\}',
|
||||
r'\\institution\{([^}]+)\}',
|
||||
r'\\affiliation\{([^}]+)\}',
|
||||
r'\\organization\{([^}]+)\}',
|
||||
r'\\department\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 学科/主题
|
||||
'subject': [
|
||||
r'\\subject\{([^}]+)\}',
|
||||
r'\\Subject\{([^}]+)\}',
|
||||
r'\\field\{([^}]+)\}',
|
||||
r'\\discipline\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 版本信息
|
||||
'version': [
|
||||
r'\\version\{([^}]+)\}',
|
||||
r'\\revision\{([^}]+)\}',
|
||||
r'\\release\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 许可证/版权
|
||||
'license': [
|
||||
r'\\license\{([^}]+)\}',
|
||||
r'\\copyright\{([^}]+)\}',
|
||||
r'\\begin{license}(.*?)\\end{license}'
|
||||
],
|
||||
|
||||
# 联系方式
|
||||
'contact': [
|
||||
r'\\email\{([^}]+)\}',
|
||||
r'\\phone\{([^}]+)\}',
|
||||
r'\\address\{([^}]+)\}',
|
||||
r'\\contact\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 致谢
|
||||
'acknowledgments': [
|
||||
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
|
||||
r'\\acknowledgments\{([^}]+)\}',
|
||||
r'\\thanks\{([^}]+)\}',
|
||||
r'\\begin{致谢}(.*?)\\end{致谢}'
|
||||
],
|
||||
|
||||
# 项目/基金
|
||||
'funding': [
|
||||
r'\\funding\{([^}]+)\}',
|
||||
r'\\grant\{([^}]+)\}',
|
||||
r'\\project\{([^}]+)\}',
|
||||
r'\\support\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 分类号/编号
|
||||
'classification': [
|
||||
r'\\classification\{([^}]+)\}',
|
||||
r'\\serialnumber\{([^}]+)\}',
|
||||
r'\\id\{([^}]+)\}',
|
||||
r'\\doi\{([^}]+)\}'
|
||||
],
|
||||
|
||||
# 语言
|
||||
'language': [
|
||||
r'\\documentlanguage\{([^}]+)\}',
|
||||
r'\\lang\{([^}]+)\}',
|
||||
r'\\language\{([^}]+)\}'
|
||||
]
|
||||
}
|
||||
latex_only_patterns = {
|
||||
# 文档类和包引入
|
||||
r'\\documentclass(\[.*?\])?\{.*?\}',
|
||||
r'\\usepackage(\[.*?\])?\{.*?\}',
|
||||
# 常见的文档设置命令
|
||||
r'\\setlength\{.*?\}\{.*?\}',
|
||||
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
|
||||
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
|
||||
# 页面设置相关
|
||||
r'\\pagestyle\{.*?\}',
|
||||
r'\\thispagestyle\{.*?\}',
|
||||
# 其他常见的设置命令
|
||||
r'\\bibliographystyle\{.*?\}',
|
||||
r'\\bibliography\{.*?\}',
|
||||
r'\\setcounter\{.*?\}\{.*?\}',
|
||||
# 字体和文本设置命令
|
||||
r'\\makeFNbottom',
|
||||
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
|
||||
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
|
||||
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
|
||||
r'\\renewcommand\\footnoterule\{.*?\}',
|
||||
r'\\color\{.*?\}',
|
||||
|
||||
# 页面和节标题设置
|
||||
r'\\setcounter\{secnumdepth\}\{.*?\}',
|
||||
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
|
||||
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
|
||||
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
|
||||
|
||||
# 字体样式设置
|
||||
r'\\sectionfont\{.*?\}',
|
||||
r'\\subsectionfont\{.*?\}',
|
||||
r'\\subsubsectionfont\{.*?\}',
|
||||
|
||||
# 间距和布局设置
|
||||
r'\\setstretch\{.*?\}',
|
||||
r'\\setlength\{\\skip\\footins\}\{.*?\}',
|
||||
r'\\setlength\{\\footnotesep\}\{.*?\}',
|
||||
r'\\setlength\{\\jot\}\{.*?\}',
|
||||
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
|
||||
|
||||
# makeatletter 和 makeatother
|
||||
r'\\makeatletter\s*',
|
||||
r'\\makeatother\s*',
|
||||
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
|
||||
# r'\\footnotetext\{[^}]*\}', # 普通脚注
|
||||
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
|
||||
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
|
||||
# 文档结构命令
|
||||
r'\\begin\{document\}',
|
||||
r'\\end\{document\}',
|
||||
r'\\maketitle',
|
||||
r'\\printbibliography',
|
||||
r'\\newpage',
|
||||
|
||||
# 输入文件命令
|
||||
r'\\input\{[^}]*\}',
|
||||
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
|
||||
|
||||
# 脚注相关
|
||||
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
|
||||
|
||||
# 致谢环境
|
||||
r'\\begin\{ack\}',
|
||||
r'\\end\{ack\}',
|
||||
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
|
||||
|
||||
# 其他文档控制命令
|
||||
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
|
||||
}
|
||||
math_envs = [
|
||||
# 基础数学环境
|
||||
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
|
||||
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
|
||||
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
|
||||
(r'\$\$.*?\$\$', 'display'), # 行间公式
|
||||
(r'\$.*?\$', 'inline'), # 行内公式
|
||||
|
||||
# 矩阵环境
|
||||
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
|
||||
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
|
||||
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
|
||||
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
|
||||
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
|
||||
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
|
||||
|
||||
# 数组环境
|
||||
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
|
||||
|
||||
# 多行公式环境
|
||||
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
|
||||
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
|
||||
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
|
||||
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
|
||||
|
||||
# 特殊数学环境
|
||||
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
|
||||
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
|
||||
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
|
||||
|
||||
# 定理类环境
|
||||
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
|
||||
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
|
||||
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
|
||||
|
||||
# 数学模式中的表格环境
|
||||
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
|
||||
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
|
||||
|
||||
# 其他专业数学环境
|
||||
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
|
||||
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
|
||||
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
|
||||
|
||||
# 化学方程式环境 (需要加载 mhchem 包)
|
||||
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
|
||||
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
|
||||
|
||||
# 物理单位环境 (需要加载 siunitx 包)
|
||||
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
|
||||
(r'\\si\{.*?\}', 'si'), # 单位
|
||||
|
||||
# 补充环境
|
||||
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
|
||||
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
|
||||
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
|
||||
]
|
||||
|
||||
# 示例使用函数
|
||||
|
||||
# 使用示例
|
||||
1099
crazy_functions/rag_fns/arxiv_fns/tex_processor.py
Normal file
1099
crazy_functions/rag_fns/arxiv_fns/tex_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageBase:
|
||||
"""Base class for all storage implementations"""
|
||||
@@ -42,7 +43,7 @@ class JsonKVStorage(StorageBase, Generic[T]):
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize storage file and load data"""
|
||||
self._file_name = os.path.join(self.working_dir, f"kv_{self.namespace}.json")
|
||||
self._file_name = os.path.join(self.working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: Dict[str, T] = {}
|
||||
self.load()
|
||||
|
||||
@@ -95,69 +96,138 @@ class JsonKVStorage(StorageBase, Generic[T]):
|
||||
await self.save()
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStorage(StorageBase):
|
||||
"""
|
||||
Vector storage using LlamaIndex
|
||||
Vector storage using LlamaIndexRagWorker
|
||||
|
||||
Attributes:
|
||||
namespace (str): Storage namespace
|
||||
namespace (str): Storage namespace (e.g., 'entities', 'relationships', 'chunks')
|
||||
working_dir (str): Working directory for storage files
|
||||
llm_kwargs (dict): LLM configuration
|
||||
embedding_func (OpenAiEmbeddingModel): Embedding function
|
||||
meta_fields (Set[str]): Additional fields to store
|
||||
cosine_better_than_threshold (float): Similarity threshold
|
||||
meta_fields (Set[str]): Additional metadata fields to store
|
||||
"""
|
||||
llm_kwargs: dict
|
||||
embedding_func: OpenAiEmbeddingModel
|
||||
meta_fields: Set[str] = field(default_factory=set)
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize LlamaIndex worker"""
|
||||
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}")
|
||||
# 使用正确的文件命名格式
|
||||
self._vector_file = os.path.join(self.working_dir, f"vdb_{self.namespace}.json")
|
||||
|
||||
# 设置检查点目录
|
||||
checkpoint_dir = os.path.join(self.working_dir, f"vector_{self.namespace}_checkpoint")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 初始化向量存储
|
||||
self.vector_store = LlamaIndexRagWorker(
|
||||
user_name=self.namespace,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
auto_load_checkpoint=True # 自动加载检查点
|
||||
auto_load_checkpoint=True
|
||||
)
|
||||
logger.info(f"Initialized vector storage for {self.namespace}")
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> List[dict]:
|
||||
async def query(self, query: str, top_k: int = 5, metadata_filters: Optional[Dict[str, Any]] = None) -> List[dict]:
|
||||
"""
|
||||
Query vectors by similarity
|
||||
Query vectors by similarity with optional metadata filtering
|
||||
|
||||
Args:
|
||||
query: Query text
|
||||
top_k: Maximum number of results
|
||||
top_k: Maximum number of results to return
|
||||
metadata_filters: Optional metadata filters
|
||||
|
||||
Returns:
|
||||
List of similar documents with scores
|
||||
"""
|
||||
nodes = self.vector_store.retrieve_from_store_with_query(query)
|
||||
results = [{
|
||||
try:
|
||||
if metadata_filters:
|
||||
nodes = self.vector_store.retrieve_with_metadata_filter(query, metadata_filters, top_k)
|
||||
else:
|
||||
nodes = self.vector_store.retrieve_from_store_with_query(query)[:top_k]
|
||||
|
||||
results = []
|
||||
for node in nodes:
|
||||
result = {
|
||||
"id": node.node_id,
|
||||
"text": node.text,
|
||||
"score": node.score,
|
||||
**{k: getattr(node, k) for k in self.meta_fields if hasattr(node, k)}
|
||||
} for node in nodes[:top_k]]
|
||||
return [r for r in results if r.get('score', 0) > self.cosine_better_than_threshold]
|
||||
"score": node.score if hasattr(node, 'score') else 0.0,
|
||||
}
|
||||
# Add metadata fields if they exist and are in meta_fields
|
||||
if hasattr(node, 'metadata'):
|
||||
result.update({
|
||||
k: node.metadata[k]
|
||||
for k in self.meta_fields
|
||||
if k in node.metadata
|
||||
})
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in vector query: {e}")
|
||||
raise
|
||||
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
"""
|
||||
Insert or update vectors
|
||||
|
||||
Args:
|
||||
data: Dictionary of documents to insert/update
|
||||
data: Dictionary of documents to insert/update with format:
|
||||
{id: {"content": text, "metadata": dict}}
|
||||
"""
|
||||
for id, item in data.items():
|
||||
try:
|
||||
for doc_id, item in data.items():
|
||||
content = item["content"]
|
||||
metadata = {k: item[k] for k in self.meta_fields if k in item}
|
||||
self.vector_store.add_text_with_metadata(content, metadata=metadata)
|
||||
# 提取元数据
|
||||
metadata = {
|
||||
k: item[k]
|
||||
for k in self.meta_fields
|
||||
if k in item
|
||||
}
|
||||
# 添加文档ID到元数据
|
||||
metadata["doc_id"] = doc_id
|
||||
|
||||
# 添加到向量存储
|
||||
self.vector_store.add_text_with_metadata(content, metadata)
|
||||
|
||||
# 导出向量数据到json文件
|
||||
self.vector_store.export_nodes(
|
||||
self._vector_file,
|
||||
format="json",
|
||||
include_embeddings=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in vector upsert: {e}")
|
||||
raise
|
||||
|
||||
async def save(self):
|
||||
"""Save vector store to checkpoint and export data"""
|
||||
try:
|
||||
# 保存检查点
|
||||
self.vector_store.save_to_checkpoint()
|
||||
|
||||
# 导出向量数据
|
||||
self.vector_store.export_nodes(
|
||||
self._vector_file,
|
||||
format="json",
|
||||
include_embeddings=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving vector storage: {e}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save after indexing"""
|
||||
self.vector_store.save_to_checkpoint()
|
||||
await self.save()
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get vector store statistics"""
|
||||
return self.vector_store.get_statistics()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -174,12 +244,15 @@ class NetworkStorage(StorageBase):
|
||||
"""Initialize graph and storage file"""
|
||||
self._file_name = os.path.join(self.working_dir, f"graph_{self.namespace}.graphml")
|
||||
self._graph = self._load_graph() or nx.Graph()
|
||||
logger.info(f"Initialized graph storage for {self.namespace}")
|
||||
|
||||
def _load_graph(self) -> Optional[nx.Graph]:
|
||||
"""Load graph from GraphML file"""
|
||||
if os.path.exists(self._file_name):
|
||||
try:
|
||||
return nx.read_graphml(self._file_name)
|
||||
graph = nx.read_graphml(self._file_name)
|
||||
logger.info(f"Loaded graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
|
||||
return graph
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading graph from {self._file_name}: {e}")
|
||||
return None
|
||||
@@ -187,9 +260,14 @@ class NetworkStorage(StorageBase):
|
||||
|
||||
async def save_graph(self):
|
||||
"""Save graph to GraphML file"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self._file_name), exist_ok=True)
|
||||
logger.info(f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
|
||||
logger.info(
|
||||
f"Saving graph with {self._graph.number_of_nodes()} nodes, {self._graph.number_of_edges()} edges")
|
||||
nx.write_graphml(self._graph, self._file_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving graph: {e}")
|
||||
raise
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""Check if node exists"""
|
||||
@@ -227,15 +305,15 @@ class NetworkStorage(StorageBase):
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, str]):
|
||||
"""Insert or update node"""
|
||||
# Clean and normalize node data
|
||||
cleaned_data = {k: html.escape(str(v).upper().strip()) for k, v in node_data.items()}
|
||||
self._graph.add_node(node_id, **cleaned_data)
|
||||
await self.save_graph()
|
||||
|
||||
async def upsert_edge(self, source_id: str, target_id: str, edge_data: Dict[str, str]):
|
||||
"""Insert or update edge"""
|
||||
# Clean and normalize edge data
|
||||
cleaned_data = {k: html.escape(str(v).strip()) for k, v in edge_data.items()}
|
||||
self._graph.add_edge(source_id, target_id, **cleaned_data)
|
||||
await self.save_graph()
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save after indexing"""
|
||||
@@ -253,39 +331,39 @@ class NetworkStorage(StorageBase):
|
||||
largest_component = max(components, key=len)
|
||||
return self._graph.subgraph(largest_component).copy()
|
||||
|
||||
async def embed_nodes(self, algorithm: str, **kwargs) -> Tuple[np.ndarray, List[str]]:
|
||||
"""
|
||||
Embed nodes using specified algorithm
|
||||
|
||||
Args:
|
||||
algorithm: Node embedding algorithm name
|
||||
**kwargs: Additional algorithm parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (node embeddings, node IDs)
|
||||
"""
|
||||
async def embed_nodes(
|
||||
self,
|
||||
algorithm: str = "node2vec",
|
||||
dimensions: int = 128,
|
||||
walk_length: int = 30,
|
||||
num_walks: int = 200,
|
||||
workers: int = 4,
|
||||
window: int = 10,
|
||||
min_count: int = 1,
|
||||
**kwargs
|
||||
) -> Tuple[np.ndarray, List[str]]:
|
||||
"""Generate node embeddings using specified algorithm"""
|
||||
if algorithm == "node2vec":
|
||||
from node2vec import Node2Vec
|
||||
|
||||
# Create node2vec model
|
||||
node2vec = Node2Vec(
|
||||
# Create and train node2vec model
|
||||
n2v = Node2Vec(
|
||||
self._graph,
|
||||
dimensions=kwargs.get('dimensions', 128),
|
||||
walk_length=kwargs.get('walk_length', 30),
|
||||
num_walks=kwargs.get('num_walks', 200),
|
||||
workers=kwargs.get('workers', 4)
|
||||
dimensions=dimensions,
|
||||
walk_length=walk_length,
|
||||
num_walks=num_walks,
|
||||
workers=workers
|
||||
)
|
||||
|
||||
# Train model
|
||||
model = node2vec.fit(
|
||||
window=kwargs.get('window', 10),
|
||||
min_count=kwargs.get('min_count', 1)
|
||||
model = n2v.fit(
|
||||
window=window,
|
||||
min_count=min_count
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
# Get embeddings for all nodes
|
||||
node_ids = list(self._graph.nodes())
|
||||
embeddings = np.array([model.wv[node] for node in node_ids])
|
||||
|
||||
return embeddings, node_ids
|
||||
else:
|
||||
|
||||
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
|
||||
@@ -23,25 +23,29 @@ class ExtractionExample:
|
||||
def __init__(self):
|
||||
"""Initialize RAG system components"""
|
||||
# 设置工作目录
|
||||
self.working_dir = f"private_upload/default_user/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
self.working_dir = f"crazy_functions/rag_fns/LightRAG/rag_cache_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
|
||||
os.makedirs(self.working_dir, exist_ok=True)
|
||||
logger.info(f"Working directory: {self.working_dir}")
|
||||
|
||||
# 初始化embedding
|
||||
self.llm_kwargs = {'api_key': os.getenv("one_api_key"), 'client_ip': '127.0.0.1',
|
||||
'embed_model': 'text-embedding-3-small', 'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
||||
'max_length': 4096, 'most_recent_uploaded': None, 'temperature': 1, 'top_p': 1}
|
||||
self.llm_kwargs = {
|
||||
'api_key': os.getenv("one_api_key"),
|
||||
'client_ip': '127.0.0.1',
|
||||
'embed_model': 'text-embedding-3-small',
|
||||
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
|
||||
'max_length': 4096,
|
||||
'most_recent_uploaded': None,
|
||||
'temperature': 1,
|
||||
'top_p': 1
|
||||
}
|
||||
self.embedding_func = OpenAiEmbeddingModel(self.llm_kwargs)
|
||||
|
||||
# 初始化提示模板和抽取器
|
||||
self.prompt_templates = PromptTemplates()
|
||||
self.extractor = EntityRelationExtractor(
|
||||
prompt_templates=self.prompt_templates,
|
||||
required_prompts = {
|
||||
'entity_extraction'
|
||||
},
|
||||
required_prompts={'entity_extraction'},
|
||||
entity_extract_max_gleaning=1
|
||||
|
||||
)
|
||||
|
||||
# 初始化存储系统
|
||||
@@ -63,18 +67,33 @@ class ExtractionExample:
|
||||
working_dir=self.working_dir
|
||||
)
|
||||
|
||||
# 向量存储 - 用于相似度检索
|
||||
self.vector_store = VectorStorage(
|
||||
namespace="vectors",
|
||||
# 向量存储 - 用于实体、关系和文本块的向量表示
|
||||
self.entities_vdb = VectorStorage(
|
||||
namespace="entities",
|
||||
working_dir=self.working_dir,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name", "entity_type"}
|
||||
)
|
||||
|
||||
self.relationships_vdb = VectorStorage(
|
||||
namespace="relationships",
|
||||
working_dir=self.working_dir,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"}
|
||||
)
|
||||
|
||||
self.chunks_vdb = VectorStorage(
|
||||
namespace="chunks",
|
||||
working_dir=self.working_dir,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
embedding_func=self.embedding_func
|
||||
)
|
||||
|
||||
# 图存储 - 用于实体关系
|
||||
self.graph_store = NetworkStorage(
|
||||
namespace="graph",
|
||||
namespace="chunk_entity_relation",
|
||||
working_dir=self.working_dir
|
||||
)
|
||||
|
||||
@@ -152,7 +171,7 @@ class ExtractionExample:
|
||||
try:
|
||||
# 向量存储
|
||||
logger.info("Adding chunks to vector store...")
|
||||
await self.vector_store.upsert(chunks)
|
||||
await self.chunks_vdb.upsert(chunks)
|
||||
|
||||
# 初始化对话历史
|
||||
self.conversation_history = {chunk_key: [] for chunk_key in chunks.keys()}
|
||||
@@ -178,14 +197,32 @@ class ExtractionExample:
|
||||
# 获取结果
|
||||
nodes, edges = self.extractor.get_results()
|
||||
|
||||
# 存储到图数据库
|
||||
logger.info("Storing extracted information in graph database...")
|
||||
# 存储实体到向量数据库和图数据库
|
||||
for node_name, node_instances in nodes.items():
|
||||
for node in node_instances:
|
||||
# 存储到向量数据库
|
||||
await self.entities_vdb.upsert({
|
||||
f"entity_{node_name}": {
|
||||
"content": f"{node_name}: {node['description']}",
|
||||
"entity_name": node_name,
|
||||
"entity_type": node['entity_type']
|
||||
}
|
||||
})
|
||||
# 存储到图数据库
|
||||
await self.graph_store.upsert_node(node_name, node)
|
||||
|
||||
# 存储关系到向量数据库和图数据库
|
||||
for (src, tgt), edge_instances in edges.items():
|
||||
for edge in edge_instances:
|
||||
# 存储到向量数据库
|
||||
await self.relationships_vdb.upsert({
|
||||
f"rel_{src}_{tgt}": {
|
||||
"content": f"{edge['description']} | {edge['keywords']}",
|
||||
"src_id": src,
|
||||
"tgt_id": tgt
|
||||
}
|
||||
})
|
||||
# 存储到图数据库
|
||||
await self.graph_store.upsert_edge(src, tgt, edge)
|
||||
|
||||
return nodes, edges
|
||||
@@ -197,26 +234,39 @@ class ExtractionExample:
|
||||
async def query_knowledge_base(self, query: str, top_k: int = 5):
|
||||
"""Query the knowledge base using various methods"""
|
||||
try:
|
||||
# 向量相似度搜索
|
||||
vector_results = await self.vector_store.query(query, top_k=top_k)
|
||||
# 向量相似度搜索 - 文本块
|
||||
chunk_results = await self.chunks_vdb.query(query, top_k=top_k)
|
||||
|
||||
# 向量相似度搜索 - 实体
|
||||
entity_results = await self.entities_vdb.query(query, top_k=top_k)
|
||||
|
||||
# 获取相关文本块
|
||||
chunk_ids = [r["id"] for r in vector_results]
|
||||
chunk_ids = [r["id"] for r in chunk_results]
|
||||
chunks = await self.text_chunks.get_by_ids(chunk_ids)
|
||||
|
||||
# 获取相关实体
|
||||
# 假设query中包含实体名称
|
||||
relevant_nodes = []
|
||||
for word in query.split():
|
||||
if await self.graph_store.has_node(word.upper()):
|
||||
node_data = await self.graph_store.get_node(word.upper())
|
||||
if node_data:
|
||||
relevant_nodes.append(node_data)
|
||||
# 获取实体相关的图结构信息
|
||||
relevant_edges = []
|
||||
for entity in entity_results:
|
||||
if "entity_name" in entity:
|
||||
entity_name = entity["entity_name"]
|
||||
if await self.graph_store.has_node(entity_name):
|
||||
edges = await self.graph_store.get_node_edges(entity_name)
|
||||
if edges:
|
||||
edge_data = []
|
||||
for edge in edges:
|
||||
edge_info = await self.graph_store.get_edge(edge[0], edge[1])
|
||||
if edge_info:
|
||||
edge_data.append({
|
||||
"source": edge[0],
|
||||
"target": edge[1],
|
||||
"data": edge_info
|
||||
})
|
||||
relevant_edges.extend(edge_data)
|
||||
|
||||
return {
|
||||
"vector_results": vector_results,
|
||||
"text_chunks": chunks,
|
||||
"relevant_entities": relevant_nodes
|
||||
"chunks": chunks,
|
||||
"entities": entity_results,
|
||||
"relationships": relevant_edges
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -228,30 +278,27 @@ class ExtractionExample:
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 导出向量存储
|
||||
self.vector_store.vector_store.export_nodes(
|
||||
os.path.join(export_dir, "vector_nodes.json"),
|
||||
include_embeddings=True
|
||||
)
|
||||
|
||||
# 导出图数据统计
|
||||
graph_stats = {
|
||||
# 导出统计信息
|
||||
storage_stats = {
|
||||
"chunks": {
|
||||
"total": len(self.text_chunks._data),
|
||||
"vector_stats": self.chunks_vdb.get_statistics()
|
||||
},
|
||||
"entities": {
|
||||
"vector_stats": self.entities_vdb.get_statistics()
|
||||
},
|
||||
"relationships": {
|
||||
"vector_stats": self.relationships_vdb.get_statistics()
|
||||
},
|
||||
"graph": {
|
||||
"total_nodes": len(list(self.graph_store._graph.nodes())),
|
||||
"total_edges": len(list(self.graph_store._graph.edges())),
|
||||
"node_degrees": dict(self.graph_store._graph.degree()),
|
||||
"largest_component_size": len(self.graph_store.get_largest_connected_component())
|
||||
}
|
||||
|
||||
with open(os.path.join(export_dir, "graph_stats.json"), "w") as f:
|
||||
json.dump(graph_stats, f, indent=2)
|
||||
|
||||
# 导出存储统计
|
||||
storage_stats = {
|
||||
"chunks": len(self.text_chunks._data),
|
||||
"docs": len(self.full_docs._data),
|
||||
"vector_store": self.vector_store.vector_store.get_statistics()
|
||||
}
|
||||
|
||||
# 导出统计
|
||||
with open(os.path.join(export_dir, "storage_stats.json"), "w") as f:
|
||||
json.dump(storage_stats, f, indent=2)
|
||||
|
||||
@@ -299,19 +346,6 @@ async def main():
|
||||
the company's commitment to innovation and sustainability. The new iPhone
|
||||
features groundbreaking AI capabilities.
|
||||
""",
|
||||
|
||||
# "business_news": """
|
||||
# Microsoft and OpenAI expanded their partnership today.
|
||||
# Satya Nadella emphasized the importance of AI development while
|
||||
# Sam Altman discussed the future of large language models. The collaboration
|
||||
# aims to accelerate AI research and deployment.
|
||||
# """,
|
||||
#
|
||||
# "science_paper": """
|
||||
# Researchers at DeepMind published a breakthrough paper on quantum computing.
|
||||
# The team demonstrated novel approaches to quantum error correction.
|
||||
# Dr. Sarah Johnson led the research, collaborating with Google's quantum lab.
|
||||
# """
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user