This commit is contained in:
lbykkkk
2024-11-23 19:00:02 +08:00
parent 724940a9d8
commit 12be7c16e9
6 changed files with 1682 additions and 101 deletions

View File

@@ -140,11 +140,31 @@ class ArxivRagWorker:
self.rag_worker.add_text_to_vector_store(overview_text)
logger.info(f"Added paper overview for {overview['arxiv_id']}")
# 并行处理其余片段
tasks = []
for i, fragment in enumerate(fragments):
tasks.append(self._process_single_fragment(fragment, i))
await asyncio.gather(*tasks)
# 创建线程池
with ThreadPoolExecutor(max_workers=10) as executor:
# 使用 asyncio.gather 收集所有任务
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
executor,
self._process_single_fragment,
fragment,
i
)
for i, fragment in enumerate(fragments)
]
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果和异常
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"Error processing fragment {i}: {result}")
else:
# 处理成功的结果
pass
logger.info(f"Processed {len(fragments)} fragments successfully")

View File

@@ -0,0 +1,772 @@
import os
import re
import time
import aiohttp
import asyncio
import requests
import tarfile
import logging
from pathlib import Path
from copy import deepcopy
from typing import Generator, List, Tuple, Optional, Dict, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
"""
Save all fragments to a single structured markdown file.
Args:
fragments: List of SectionFragment objects
output_dir: Output directory path
Returns:
Path: Path to the generated markdown file
"""
from datetime import datetime
from pathlib import Path
import re
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Generate filename
filename = f"fragments_{timestamp}.md"
file_path = output_path / filename
# Group fragments by section
sections = {}
for fragment in fragments:
section = fragment.current_section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
with open(file_path, "w", encoding="utf-8") as f:
# Write document header
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# Add paper information if available
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# Write section tree if available
if fragments and fragments[0].section_tree:
f.write("\n## Section Tree\n")
f.write(fragments[0].section_tree)
# Generate table of contents
f.write("\n## Table of Contents\n")
for section in sections:
clean_section = section.strip() or "Uncategorized"
fragment_count = len(sections[section])
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
f"({fragment_count} fragments)\n")
# Write content sections
f.write("\n## Content\n")
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# Write each fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i}\n")
# Metadata
f.write("**Metadata:**\n")
metadata = [
f"- Section: {fragment.current_section}",
f"- Length: {len(fragment.content)} chars",
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
]
f.write("\n".join(filter(None, metadata)) + "\n")
# Content
f.write("\n**Content:**\n")
f.write("```tex\n")
f.write(fragment.content)
f.write("\n```\n")
# Bibliography if exists
if fragment.bibliography:
f.write("\n**Bibliography:**\n")
f.write("```bibtex\n")
f.write(fragment.bibliography)
f.write("\n```\n")
# Add separator
if i < len(section_fragments):
f.write("\n---\n")
# Add statistics
f.write("\n## Statistics\n")
# Length distribution
lengths = [len(f.content) for f in fragments]
f.write("\n### Length Distribution\n")
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
# Section distribution
f.write("\n### Section Distribution\n")
for section, section_fragments in sections.items():
percentage = (len(section_fragments) / len(fragments)) * 100
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
print(f"Fragments saved to: {file_path}")
return file_path
# 定义各种引用命令的模式
CITATION_PATTERNS = [
# 基本的 \cite{} 格式
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# natbib 格式
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# biblatex 格式
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# 自定义 [cite:...] 格式
r'\[cite:([^\]]+)\]',
]
# 编译所有模式
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置但设置上限防止过度并发
self.document_structure = DocumentStructure()
self.document_parser = EssayStructureParser()
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexUtils()
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
"""
Process citations in document structure and add referenced literature for each section
Args:
doc_structure: DocumentStructure object
ref_bib: String containing references separated by newlines
Returns:
Updated DocumentStructure object
"""
try:
# Create a copy to avoid modifying the original
doc = deepcopy(doc_structure)
# Parse references into a mapping
ref_map = self._parse_references(ref_bib)
if not ref_map:
self.logger.warning("No valid references found in ref_bib")
return doc
# Process all sections recursively
self._process_section_references(doc.toc, ref_map)
return doc
except Exception as e:
self.logger.error(f"Error processing references: {str(e)}")
return doc_structure # Return original if processing fails
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
"""
Recursively process sections to add references
Args:
sections: List of Section objects
ref_map: Mapping of citation keys to full references
"""
for section in sections:
if section.content:
# Find citations in current section
cited_refs = self.find_citations(section.content)
if cited_refs:
# Get full references for citations
full_refs = []
for ref_key in cited_refs:
ref_text = ref_map.get(ref_key)
if ref_text:
full_refs.append(ref_text)
else:
self.logger.warning(f"Reference not found for citation key: {ref_key}")
# Add references to section content
if full_refs:
section.bibliography = "\n\n".join(full_refs)
# Process subsections recursively
if section.subsections:
self._process_section_references(section.subsections, ref_map)
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
"""
Parse reference string into a mapping of citation keys to full references
Args:
ref_bib: Reference string with references separated by newlines
Returns:
Dict mapping citation keys to full reference text
"""
ref_map = {}
current_ref = []
current_key = None
try:
for line in ref_bib.split('\n'):
line = line.strip()
if not line:
continue
# New reference entry
if line.startswith('@'):
# Save previous reference if exists
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
current_ref = []
# Extract key from new reference
key_match = re.search(r'{(.*?),', line)
if key_match:
current_key = key_match.group(1)
current_ref.append(line)
else:
if current_ref is not None:
current_ref.append(line)
# Save last reference
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
except Exception as e:
self.logger.error(f"Error parsing references: {str(e)}")
return ref_map
# 编译一次正则表达式以提高效率
@staticmethod
def _clean_citation_key(key: str) -> str:
"""Clean individual citation key."""
return key.strip().strip(',').strip()
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
"""Extract and clean individual citation keys from a group."""
try:
# 分割多个引用键(支持逗号和分号分隔)
separators = '[,;]'
keys = re.split(separators, keys_str)
# 清理并过滤空键
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
except Exception as e:
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
return set()
def find_citations(self, content: str) -> Set[str]:
"""
Find citation keys in text content in various formats.
Args:
content: Text content to search for citations
Returns:
Set of unique citation keys
Examples:
Supported formats include:
- \cite{key1,key2}
- \cite[p. 1]{key}
- \citep{key}
- \citet{key}
- [cite:key1, key2]
- And many other variants
"""
citations = set()
if not content:
return citations
try:
# 对每个编译好的模式进行搜索
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取捕获组中的引用键
keys_str = match.group(1)
if keys_str:
# 提取并添加所有引用键
new_keys = self._extract_keys_from_group(keys_str)
citations.update(new_keys)
except Exception as e:
self.logger.error(f"Error finding citations: {str(e)}")
# 移除明显无效的键
citations = {key for key in citations
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
return citations
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
"""
Find citations and their surrounding context.
Args:
content: Text content to search for citations
context_chars: Number of characters of context to include before/after
Returns:
Dict mapping citation keys to lists of context strings
"""
contexts = {}
if not content:
return contexts
try:
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取匹配的位置
start = max(0, match.start() - context_chars)
end = min(len(content), match.end() + context_chars)
# 获取上下文
context = content[start:end]
# 获取并处理引用键
keys_str = match.group(1)
keys = self._extract_keys_from_group(keys_str)
# 为每个键添加上下文
for key in keys:
if key not in contexts:
contexts[key] = []
contexts[key].append(context)
except Exception as e:
self.logger.error(f"Error finding citation contexts: {str(e)}")
return contexts
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
"""
Process ArXiv paper and convert to list of SectionFragments.
Each fragment represents the smallest section unit.
Args:
arxiv_id_or_url: ArXiv paper ID or URL
Returns:
List[SectionFragment]: List of processed paper fragments
"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# Check if paper directory exists, if not, try to download
if not paper_dir.exists():
self.logger.info(f"Downloading paper {arxiv_id}")
await self.download_paper(arxiv_id, paper_dir)
# Find main TeX file
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# Get all related TeX files and references
tex_files = self.tex_processor.resolve_includes(main_tex)
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# Reset document structure for new processing
self.document_structure = DocumentStructure()
# Process each TeX file
for file_path in tex_files:
self.logger.info(f"Processing TeX file: {file_path}")
tex_content = read_tex_file(file_path)
if tex_content:
additional_doc = self.document_parser.parse(tex_content)
self.document_structure = self.document_structure.merge(additional_doc)
# Process references if available
if ref_bib:
self.document_structure = self.process_references(self.document_structure, ref_bib)
self.logger.info("Successfully processed references")
else:
self.logger.info("No references found to process")
# Generate table of contents once
section_tree = self.document_structure.generate_toc_tree()
# Convert DocumentStructure to SectionFragments
fragments = self._convert_to_fragments(
doc_structure=self.document_structure,
arxiv_id=arxiv_id,
section_tree=section_tree
)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
def _convert_to_fragments(self,
doc_structure: DocumentStructure,
arxiv_id: str,
section_tree: str) -> List[SectionFragment]:
"""
Convert DocumentStructure to list of SectionFragments.
Creates a fragment for each leaf section in the document hierarchy.
Args:
doc_structure: Source DocumentStructure
arxiv_id: ArXiv paper ID
section_tree: Pre-generated table of contents tree
Returns:
List[SectionFragment]: List of paper fragments
"""
fragments = []
# Create a base template for all fragments to avoid repetitive assignments
base_fragment_template = {
'title': doc_structure.title,
'abstract': doc_structure.abstract,
'section_tree': section_tree,
'arxiv_id': arxiv_id
}
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
"""
Recursively find all leaf sections and create fragments.
A leaf section is one that has content but no subsections, or has neither.
Args:
section: Current section being processed
path: List of section titles forming the path to current section
"""
if path is None:
path = []
current_path = path + [section.title]
if not section.subsections:
# This is a leaf section, create a fragment if it has content
if section.content or section.bibliography:
fragment = SectionFragment(
**base_fragment_template,
current_section="/".join(current_path),
content=self._clean_content(section.content),
bibliography=section.bibliography
)
if self._validate_fragment(fragment):
fragments.append(fragment)
else:
# Process each subsection
for subsection in section.subsections:
get_leaf_sections(subsection, current_path)
# Process all top-level sections
for section in doc_structure.toc:
get_leaf_sections(section)
# Add a fragment for the abstract if it exists
if doc_structure.abstract:
abstract_fragment = SectionFragment(
**base_fragment_template,
current_section="Abstract",
content=self._clean_content(doc_structure.abstract)
)
if self._validate_fragment(abstract_fragment):
fragments.insert(0, abstract_fragment)
self.logger.info(f"Created {len(fragments)} fragments")
return fragments
def _validate_fragment(self, fragment: SectionFragment) -> bool:
"""
Validate if the fragment has all required fields with meaningful content.
Args:
fragment: SectionFragment to validate
Returns:
bool: True if fragment is valid, False otherwise
"""
try:
return all([
fragment.title.strip(),
fragment.section_tree.strip(),
fragment.current_section.strip(),
fragment.content.strip() or fragment.bibliography.strip()
])
except AttributeError:
return False
def _clean_content(self, content: str) -> str:
"""
Clean and normalize content text.
Args:
content: Raw content text
Returns:
str: Cleaned content text
"""
if not content:
return ""
# Remove excessive whitespace
content = re.sub(r'\s+', ' ', content)
# Remove remaining LaTeX artifacts
content = re.sub(r'\\item\s*', '', content) # Convert \item to bullet points
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
# Clean special characters
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
return content.strip()
async def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
root_dir="test_cache"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
fragments = await splitter.process(case['arxiv_id'])
# 保存fragments
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
print(f"Output saved to: {output_dir}")
# 内容检查
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(test_arxiv_splitter())

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Set, Dict, Pattern, Optional
from typing import Set, Dict, Pattern, Optional, List, Tuple
import re
from enum import Enum
import logging
@@ -8,179 +8,259 @@ from functools import lru_cache
class EnvType(Enum):
"""Environment classification types."""
PRESERVE = "preserve"
REMOVE = "remove"
EXTRACT = "extract"
PRESERVE = "preserve" # Preserve complete environment including commands
REMOVE = "remove" # Remove environment completely
EXTRACT = "extract" # Extract and clean content
@dataclass
class LatexConfig:
"""Configuration for LaTeX processing."""
preserve_envs: Set[str] = field(default_factory=lambda: {
# Math environments
# Math environments - preserve complete content
'equation', 'equation*', 'align', 'align*', 'displaymath',
'math', 'eqnarray', 'gather', 'gather*', 'multline', 'multline*',
# Tables and figures
'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*',
'multline', 'multline*', 'flalign', 'flalign*',
'alignat', 'alignat*', 'cases', 'split', 'aligned',
# Tables and figures - preserve structure and content
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
'figure', 'figure*', 'subfigure',
'figure', 'figure*', 'subfigure', 'wrapfigure',
'minipage', 'tabbing', 'verbatim', 'longtable',
'sidewaystable', 'sidewaysfigure', 'floatrow',
# Arrays and matrices
'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix',
'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*',
# Algorithms and code
'algorithm', 'algorithmic', 'lstlisting',
'algorithm', 'algorithmic', 'lstlisting', 'verbatim',
'minted', 'listing', 'algorithmic*', 'algorithm2e',
# Theorems and proofs
'theorem', 'proof', 'definition', 'lemma', 'corollary',
'proposition', 'example', 'remark'
'proposition', 'example', 'remark', 'note', 'claim',
'axiom', 'property', 'assumption', 'conjecture', 'observation',
# Bibliography
'thebibliography', 'bibliography', 'references'
})
# 引用类命令的特殊处理配置
citation_commands: Set[str] = field(default_factory=lambda: {
# Basic citations
'cite', 'citep', 'citet', 'citeyear', 'citeauthor',
'citeyearpar', 'citetext', 'citenum',
# Natbib citations
'citefullauthor', 'citealp', 'citealt', 'citename',
'citepalias', 'citetalias', 'citetext',
# Cross-references
'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref',
'Cref', 'vref', 'Vref', 'fref', 'pref',
# Hyperref
'hyperref', 'href', 'url',
# Labels
'label', 'tag'
})
preserve_commands: Set[str] = field(default_factory=lambda: {
# Citations and references
'caption', 'label', 'ref', 'cite', 'citep', 'citet', 'eqref',
# Text formatting
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
'section', 'subsection', 'subsubsection', 'paragraph',
# Math operators
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf'
'section', 'subsection', 'subsubsection', 'paragraph', 'part',
'chapter', 'title', 'author', 'date', 'thanks',
# Math operators and symbols
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf',
'partial', 'nabla', 'implies', 'iff', 'therefore',
'exists', 'forall', 'in', 'subset', 'subseteq',
# Greek letters and math symbols
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu',
'nu', 'xi', 'pi', 'rho', 'sigma', 'tau',
'upsilon', 'phi', 'chi', 'psi', 'omega',
'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi',
'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega',
# Math commands
'left', 'right', 'big', 'Big', 'bigg', 'Bigg',
'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb',
'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop',
'operatorname', 'overline', 'underline', 'overbrace',
'underbrace', 'overset', 'underset', 'stackrel',
# Spacing and alignment
'quad', 'qquad', 'hspace', 'vspace', 'medskip',
'bigskip', 'smallskip', 'hfill', 'vfill', 'centering',
'raggedright', 'raggedleft'
})
remove_commands: Set[str] = field(default_factory=lambda: {
# Document setup
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
'bibliography', 'bibliographystyle', 'frontmatter', 'mainmatter',
'newtheorem', 'theoremstyle', 'proof', 'proofname', 'qed',
'bibliographystyle', 'frontmatter', 'mainmatter',
'newtheorem', 'theoremstyle', 'proofname',
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
'newenvironment',
# Layout and spacing
'pagestyle', 'thispagestyle', 'vspace', 'hspace', 'vfill', 'hfill',
'newpage', 'clearpage', 'pagebreak', 'linebreak', 'newline',
'setlength', 'setcounter', 'addtocounter', 'renewcommand',
'newcommand', 'makeatletter', 'makeatother', 'pagenumbering',
# Margins and columns
'marginpar', 'marginparsep', 'columnsep', 'columnseprule',
'twocolumn', 'onecolumn', 'minipage', 'parbox'
'pagestyle', 'thispagestyle', 'newpage', 'clearpage',
'pagebreak', 'linebreak', 'newline', 'setlength',
'setcounter', 'addtocounter', 'makeatletter',
'makeatother', 'pagenumbering'
})
latex_chars: Dict[str, str] = field(default_factory=lambda: {
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
'\\textasciitilde': '~', '\\textasciicircum': '^',
'\\quad': ' ', '\\qquad': ' ', '\\,': '', '\\;': '', '\\:': '',
'\\!': '', '\\space': ' ', '\\noindent': ''
'\\textasciitilde': '~', '\\textasciicircum': '^'
})
inline_math_delimiters: Set[str] = field(default_factory=lambda: {
'$', '\\(', '\\)', '\\[', '\\]'
})
# 保留原始格式的特殊命令模式
special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [
(r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'),
(r'\\ref\*?{([^}]+)}', r'\\ref{\1}'),
(r'\\label{([^}]+)}', r'\\label{\1}'),
(r'\\eqref{([^}]+)}', r'\\eqref{\1}'),
(r'\\autoref{([^}]+)}', r'\\autoref{\1}'),
(r'\\url{([^}]+)}', r'\\url{\1}'),
(r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}')
])
class LatexCleaner:
"""Efficient and modular LaTeX text cleaner."""
"""Enhanced LaTeX text cleaner that preserves mathematical content and citations."""
def __init__(self, config: Optional[LatexConfig] = None):
self.config = config or LatexConfig()
self.logger = logging.getLogger(__name__)
# 初始化正则表达式缓存
self._regex_cache = {}
@lru_cache(maxsize=128)
def _get_env_pattern(self, env_name: str) -> Pattern:
"""Get cached regex pattern for environment matching."""
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
def _get_env_type(self, env_name: str) -> EnvType:
"""Determine environment processing type."""
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
return EnvType.PRESERVE
elif env_name in {'verbatim', 'comment'}:
elif env_name in {'comment'}:
return EnvType.REMOVE
return EnvType.EXTRACT
def _preserve_special_commands(self, text: str) -> str:
"""Preserve special commands like citations and references with their complete structure."""
for pattern, replacement in self.config.special_command_patterns:
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern)
def replace_func(match):
# 保持原始命令格式
return match.group(0)
text = self._regex_cache[pattern].sub(replace_func, text)
return text
def _process_environment(self, match: re.Match) -> str:
"""Process LaTeX environments while preserving complete content for special environments."""
try:
env_name = match.group(1)
content = match.group(2)
env_type = self._get_env_type(env_name)
if env_type == EnvType.PRESERVE:
# Preserve math content without markers for inline math
if env_name in {'math', 'displaymath'}:
return f" {content} "
return f" [BEGIN_{env_name}] {content} [END_{env_name}] "
# 完整保留环境内容
complete_env = match.group(0)
return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n"
elif env_type == EnvType.REMOVE:
return ' '
# Process nested environments recursively
else:
# 处理嵌套环境
return self._clean_nested_environments(content)
except Exception as e:
self.logger.error(f"Error processing environment {env_name}: {e}")
return content
self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}")
return match.group(0)
def _preserve_inline_math(self, text: str) -> str:
"""Preserve complete inline math content."""
def preserve_math(match):
return f" {match.group(0)} "
patterns = [
(r'\$[^$]+\$', preserve_math),
(r'\\[\(\[].*?\\[\)\]]', preserve_math),
(r'\\begin{math}.*?\\end{math}', preserve_math)
]
for pattern, handler in patterns:
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
text = self._regex_cache[pattern].sub(handler, text)
return text
def _clean_nested_environments(self, text: str) -> str:
"""Process nested environments recursively."""
return re.sub(
r'\\begin{(\w+)}(.*?)\\end{\1}',
self._process_environment,
text,
flags=re.DOTALL
)
pattern = r'\\begin{(\w+)}(.*?)\\end{\1}'
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
return self._regex_cache[pattern].sub(self._process_environment, text)
def _clean_commands(self, text: str) -> str:
"""Clean LaTeX commands while preserving specified content."""
# Remove complete commands
for cmd in self.config.remove_commands:
text = re.sub(fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*', '', text)
"""Clean LaTeX commands while preserving important content."""
# 首先处理特殊命令
text = self._preserve_special_commands(text)
# Process commands with content
def handle_command(match: re.Match) -> str:
cmd = match.group(1).rstrip('*') # Handle starred versions
content = match.group(2)
# For these delimiters, return the original math content
if cmd in {'[', ']', '(', ')', '$'} or cmd in self.config.inline_math_delimiters:
return match.group(0)
# For preserved commands return content, otherwise return space
return match.group(0) if cmd in self.config.preserve_commands else ' '
# Handle commands with arguments
text = re.sub(r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}', handle_command, text)
# Handle inline math
# 保留内联数学
text = self._preserve_inline_math(text)
# Remove remaining standalone commands
return text
# 移除指定的命令
for cmd in self.config.remove_commands:
if cmd not in self._regex_cache:
self._regex_cache[cmd] = re.compile(
fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*'
)
text = self._regex_cache[cmd].sub('', text)
def _preserve_inline_math(self, text: str) -> str:
"""Preserve inline math content."""
# Handle $...$ math
text = re.sub(r'\$(.+?)\$', r' \1 ', text)
# Handle \(...\) math
text = re.sub(r'\\[\(\[](.+?)\\[\)\]]', r' \1 ', text)
# 处理带内容的命令
def handle_command(match: re.Match) -> str:
cmd = match.group(1).rstrip('*')
if cmd in self.config.preserve_commands or cmd in self.config.citation_commands:
return match.group(0) # 完整保留命令和内容
return ' '
if 'command_pattern' not in self._regex_cache:
self._regex_cache['command_pattern'] = re.compile(
r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}'
)
text = self._regex_cache['command_pattern'].sub(handle_command, text)
return text
def _normalize_text(self, text: str) -> str:
"""Normalize special characters and whitespace."""
# Replace special characters
"""Normalize text while preserving special content markers."""
# 替换特殊字符
for char, replacement in self.config.latex_chars.items():
text = text.replace(char, replacement)
# Clean up whitespace
# 清理空白字符,同时保留环境标记
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r' [BEGIN_\1] ', text)
text = re.sub(r'\s*\[END_(\w+)\]\s*', r' [END_\1] ', text)
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text)
text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text)
# Remove empty brackets and braces
text = re.sub(r'{\s*}|\[\s*\]|\(\s*\)', '', text)
# 保持块级环境之间的分隔
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def clean_text(self, text: str) -> str:
"""Clean LaTeX text while preserving meaningful content."""
"""Clean LaTeX text while preserving mathematical content, citations, and special environments."""
if not text:
return ""
try:
# Remove comments not inside environments
# 移除注释
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
# Process environments and their nested contents
# 处理环境
text = self._clean_nested_environments(text)
# Clean commands and normalize
# 清理命令并规范化
text = self._clean_commands(text)
text = self._normalize_text(text)
@@ -188,30 +268,39 @@ class LatexCleaner:
except Exception as e:
self.logger.error(f"Error cleaning text: {e}")
raise
return text # 发生错误时返回原始文本
def clean_latex_commands(text: str) -> str:
"""Convenience function for quick text cleaning with default config."""
config = LatexConfig(
preserve_envs={'equation', 'theorem'},
preserve_commands={'textbf', 'emph', "label"},
latex_chars={'~': ' ', '\\&': '&'}
)
return LatexCleaner(config).clean_text(text)
cleaner = LatexCleaner()
return cleaner.clean_text(text)
# Example usage:
if __name__ == "__main__":
# Basic usage with inline math
text = clean_latex_commands(r"""
text = r"""
\documentclass{article}
\begin{document}
\section{Introduction}
This is a reference to \cite{smith2020} and equation \eqref{eq:main}.
\begin{equation}\label{eq:main}
E = mc^2 \times \sum_{i=1}^{n} x_i
\end{equation}
See Figure \ref{fig:example} for details.
\begin{figure}
\includegraphics{image.png}
\caption{Example figure\label
\textbf{Important} result: $E=mc^2$ and
\begin{equation}
F = ma
\end{equation}
\label{sec:intro}
""")
print(text)
"""
# Custom configuration
config = LatexConfig(
@@ -236,5 +325,5 @@ if __name__ == "__main__":
file_path = 'test_cache/2411.03663/neurips_2024.tex'
content = read_tex_file(file_path)
cleaner = LatexCleaner(config)
text = cleaner.clean_text(content)
text = cleaner.clean_text(text)
print(text)

View File

@@ -0,0 +1,412 @@
import re
from typing import List, Dict, Tuple, Optional, Set
from enum import Enum
from dataclasses import dataclass, field
from copy import deepcopy
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SectionLevel(Enum):
CHAPTER = 0
SECTION = 1
SUBSECTION = 2
SUBSUBSECTION = 3
PARAGRAPH = 4
SUBPARAGRAPH = 5
def __lt__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value < other.value
def __le__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value <= other.value
def __gt__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value > other.value
def __ge__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value >= other.value
@dataclass
class Section:
level: SectionLevel
title: str
content: str = ''
bibliography: str = ''
subsections: List['Section'] = field(default_factory=list)
def merge(self, other: 'Section') -> 'Section':
"""Merge this section with another section."""
if self.title != other.title or self.level != other.level:
raise ValueError("Can only merge sections with same title and level")
merged = deepcopy(self)
merged.content = self._merge_content(self.content, other.content)
# Create subsections lookup for efficient merging
subsections_map = {s.title: s for s in merged.subsections}
for other_subsection in other.subsections:
if other_subsection.title in subsections_map:
# Merge existing subsection
idx = next(i for i, s in enumerate(merged.subsections)
if s.title == other_subsection.title)
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
else:
# Add new subsection
merged.subsections.append(deepcopy(other_subsection))
return merged
@staticmethod
def _merge_content(content1: str, content2: str) -> str:
"""Merge content strings intelligently."""
if not content1:
return content2
if not content2:
return content1
# Combine non-empty contents with a separator
return f"{content1}\n\n{content2}"
@dataclass
class LatexEnvironment:
"""表示LaTeX环境的数据类"""
name: str
start: int
end: int
content: str
raw: str
class EnhancedSectionExtractor:
"""Enhanced section extractor with comprehensive content handling and hierarchy management."""
def __init__(self, preserve_environments: bool = True):
"""
初始化Section提取器
Args:
preserve_environments: 是否保留特定环境如equation, figure等的原始LaTeX代码
"""
self.preserve_environments = preserve_environments
# Section级别定义
self.section_levels = {
'chapter': SectionLevel.CHAPTER,
'section': SectionLevel.SECTION,
'subsection': SectionLevel.SUBSECTION,
'subsubsection': SectionLevel.SUBSUBSECTION,
'paragraph': SectionLevel.PARAGRAPH,
'subparagraph': SectionLevel.SUBPARAGRAPH
}
# 需要保留的环境类型
self.important_environments = {
'equation', 'equation*', 'align', 'align*',
'figure', 'table', 'algorithm', 'algorithmic',
'definition', 'theorem', 'lemma', 'proof',
'itemize', 'enumerate', 'description'
}
# 改进的section pattern
self.section_pattern = (
r'\\(?P<type>chapter|section|subsection|subsubsection|paragraph|subparagraph)'
r'\*?' # Optional star
r'(?:\[(?P<short>.*?)\])?' # Optional short title
r'{(?P<title>(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support
)
# 环境匹配模式
self.environment_pattern = (
r'\\begin{(?P<env_name>[^}]+)}'
r'(?P<env_content>.*?)'
r'\\end{(?P=env_name)}'
)
def _find_environments(self, content: str) -> List[LatexEnvironment]:
"""
查找文档中的所有LaTeX环境。
支持嵌套环境的处理。
"""
environments = []
stack = []
# 使用正则表达式查找所有begin和end标记
begin_pattern = r'\\begin{([^}]+)}'
end_pattern = r'\\end{([^}]+)}'
# 组合模式来同时匹配begin和end
tokens = []
for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content):
if match.group(1): # begin标记
tokens.append(('begin', match.group(1), match.start()))
else: # end标记
tokens.append(('end', match.group(2), match.start()))
# 处理环境嵌套
for token_type, env_name, pos in tokens:
if token_type == 'begin':
stack.append((env_name, pos))
elif token_type == 'end' and stack:
if stack[-1][0] == env_name:
start_env_name, start_pos = stack.pop()
env_content = content[start_pos:pos]
raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')]
if start_env_name in self.important_environments:
environments.append(LatexEnvironment(
name=start_env_name,
start=start_pos,
end=pos + len('\\end{' + env_name + '}'),
content=env_content,
raw=raw_content
))
return sorted(environments, key=lambda x: x.start)
def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]:
"""
保护重要的LaTeX环境用占位符替换它们。
返回处理后的内容和恢复映射。
"""
environments = self._find_environments(content)
replacements = {}
# 从后向前替换,避免位置改变的问题
for env in reversed(environments):
if env.name in self.important_environments:
placeholder = f'__ENV_{len(replacements)}__'
replacements[placeholder] = env.raw
content = content[:env.start] + placeholder + content[env.end:]
return content, replacements
def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str:
"""
恢复之前保护的环境。
"""
for placeholder, original in replacements.items():
content = content.replace(placeholder, original)
return content
def extract(self, content: str) -> List[Section]:
"""
从LaTeX文档中提取sections及其内容。
Args:
content: LaTeX文档内容
Returns:
List[Section]: 提取的section列表包含层次结构
"""
try:
# 预处理:保护重要环境
if self.preserve_environments:
content, env_replacements = self._protect_environments(content)
# 查找所有sections
sections = self._find_all_sections(content)
if not sections:
return []
# 处理sections
root_sections = self._process_sections(content, sections)
# 如果需要,恢复环境
if self.preserve_environments:
for section in self._traverse_sections(root_sections):
section.content = self._restore_environments(section.content, env_replacements)
return root_sections
except Exception as e:
logger.error(f"Error extracting sections: {str(e)}")
raise
def _find_all_sections(self, content: str) -> List[dict]:
"""查找所有section命令及其位置。"""
sections = []
for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE):
section_type = match.group('type').lower()
if section_type not in self.section_levels:
continue
section = {
'type': section_type,
'level': self.section_levels[section_type],
'title': self._clean_title(match.group('title')),
'start': match.start(),
'command_end': match.end(),
}
sections.append(section)
return sorted(sections, key=lambda x: x['start'])
def _process_sections(self, content: str, sections: List[dict]) -> List[Section]:
"""处理sections以构建层次结构和提取内容。"""
# 计算content范围
self._calculate_content_ranges(content, sections)
# 构建层次结构
root_sections = []
section_stack = []
for section_info in sections:
new_section = Section(
level=section_info['level'],
title=section_info['title'],
content=self._extract_clean_content(content, section_info),
subsections=[]
)
# 调整堆栈以找到正确的父section
while section_stack and section_stack[-1].level.value >= new_section.level.value:
section_stack.pop()
if section_stack:
section_stack[-1].subsections.append(new_section)
else:
root_sections.append(new_section)
section_stack.append(new_section)
return root_sections
def _calculate_content_ranges(self, content: str, sections: List[dict]):
for i, current in enumerate(sections):
content_start = current['command_end']
# 找到下一个section无论什么级别
content_end = len(content)
for next_section in sections[i + 1:]:
content_end = next_section['start']
break
current['content_range'] = (content_start, content_end)
def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]):
"""为每个section计算内容范围。"""
for i, current in enumerate(sections):
content_start = current['command_end']
# 找到下一个同级或更高级的section
content_end = len(content)
for next_section in sections[i + 1:]:
if next_section['level'] <= current['level']:
content_end = next_section['start']
break
current['content_range'] = (content_start, content_end)
def _extract_clean_content(self, content: str, section_info: dict) -> str:
"""提取并清理section内容。"""
start, end = section_info['content_range']
raw_content = content[start:end]
# 清理内容
clean_content = self._clean_content(raw_content)
return clean_content
def _clean_content(self, content: str) -> str:
"""清理LaTeX内容同时保留重要信息。"""
# 移除注释
content = re.sub(r'(?<!\\)%.*?\n', '\n', content)
# LaTeX命令处理规则
replacements = [
# 保留引用
(r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'),
# 保留脚注
(r'\\footnote{(.*?)}', r'[footnote:\1]'),
# 处理引用
(r'\\ref{(.*?)}', r'[ref:\1]'),
# 保留URL
(r'\\url{(.*?)}', r'[url:\1]'),
# 保留超链接
(r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'),
# 处理文本格式命令
(r'\\(?:textbf|textit|emph){(.*?)}', r'\1'),
# 保留特殊字符
(r'\\([&%$#_{}])', r'\1'),
]
# 应用所有替换规则
for pattern, replacement in replacements:
content = re.sub(pattern, replacement, content, flags=re.DOTALL)
# 清理多余的空白
content = re.sub(r'\n\s*\n', '\n\n', content)
return content.strip()
def _clean_title(self, title: str) -> str:
"""清理section标题。"""
# 处理嵌套的花括号
while '{' in title:
title = re.sub(r'{([^{}]*)}', r'\1', title)
# 处理LaTeX命令
title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title)
title = re.sub(r'\\([&%$#_{}])', r'\1', title)
return title.strip()
def _traverse_sections(self, sections: List[Section]) -> List[Section]:
"""遍历所有sections包括子sections"""
result = []
for section in sections:
result.append(section)
result.extend(self._traverse_sections(section.subsections))
return result
def test_enhanced_extractor():
"""使用复杂的测试用例测试提取器。"""
test_content = r"""
\section{Complex Examples}
Here's a complex section with various environments.
\begin{equation}
E = mc^2
\end{equation}
\subsection{Nested Environments}
This subsection has nested environments.
\begin{figure}
\begin{equation*}
f(x) = \int_0^x g(t) dt
\end{equation*}
\caption{A nested equation in a figure}
\end{figure}
"""
extractor = EnhancedSectionExtractor()
sections = extractor.extract(test_content)
def print_section(section, level=0):
print("\n" + " " * level + f"[{section.level.name}] {section.title}")
if section.content:
content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content
print(" " * (level + 1) + f"Content: {content_preview}")
for subsection in section.subsections:
print_section(subsection, level + 1)
print("\nExtracted Section Structure:")
for section in sections:
print_section(section)
if __name__ == "__main__":
test_enhanced_extractor()

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
@dataclass
class SectionFragment:
"""Arxiv论文片段数据类"""
title: str # 文件路径
abstract: str # 论文摘要
section_tree: str # 文章各章节的目录结构
arxiv_id: str = "" # 添加 arxiv_id 属性
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
content: str = '' #当前片段的内容
bibliography: str = '' #当前片段的参考文献

View File

@@ -0,0 +1,271 @@
import re
import os
import logging
from pathlib import Path
from typing import List, Tuple, Dict, Set, Optional, Callable
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
class TexUtils:
"""TeX文档处理器类"""
def __init__(self, ):
"""
初始化TeX处理器
Args:
char_range: 字符数范围(最小值, 最大值)
"""
self.logger = logging.getLogger(__name__)
# 初始化LaTeX环境和命令模式
self._init_patterns()
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
def _init_patterns(self):
"""初始化LaTeX模式匹配规则"""
# 特殊环境模式
self.special_envs = LaTeXPatterns.special_envs
# 章节模式
self.section_patterns = LaTeXPatterns.section_patterns
# 包含模式
self.include_patterns = LaTeXPatterns.include_patterns
# 元数据模式
self.metadata_patterns = LaTeXPatterns.metadata_patterns
def read_file(self, file_path: str) -> Optional[str]:
"""
读取TeX文件内容支持多种编码
Args:
file_path: 文件路径
Returns:
Optional[str]: 文件内容或None
"""
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
self.logger.warning(f"Failed to read {file_path} with all encodings")
return None
def find_main_tex_file(self, directory: str) -> Optional[str]:
"""
查找主TeX文件
Args:
directory: 目录路径
Returns:
Optional[str]: 主文件路径或None
"""
tex_files = list(Path(directory).rglob("*.tex"))
if not tex_files:
return None
# 按优先级查找
for tex_file in tex_files:
content = self.read_file(str(tex_file))
if content:
if r'\documentclass' in content:
return str(tex_file)
if tex_file.name.lower() == 'main.tex':
return str(tex_file)
# 返回最大的tex文件
return str(max(tex_files, key=lambda x: x.stat().st_size))
def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
"""
解析TeX文件中的include引用
Args:
tex_file: TeX文件路径
processed: 已处理的文件集合
Returns:
List[str]: 相关文件路径列表
"""
if processed is None:
processed = set()
if tex_file in processed:
return []
processed.add(tex_file)
result = [tex_file]
content = self.read_file(tex_file)
if not content:
return result
base_dir = Path(tex_file).parent
for pattern in self.include_patterns:
for match in re.finditer(pattern, content):
included_file = match.group(2)
if not included_file.endswith('.tex'):
included_file += '.tex'
full_path = str(base_dir / included_file)
if os.path.exists(full_path) and full_path not in processed:
result.extend(self.resolve_includes(full_path, processed))
return result
def resolve_references(self, tex_file: str, path_dir: str = None) -> str:
"""
解析TeX文件中的参考文献引用返回所有引用文献的内容只保留title、author和journal字段。
如果在tex_file目录下没找到bib文件会在path_dir中查找。
Args:
tex_file: TeX文件路径
path_dir: 额外的参考文献搜索路径
Returns:
str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔
"""
all_references = [] # 存储所有参考文献内容
content = self.read_file(tex_file)
if not content:
return ""
# 扩展参考文献引用的模式
bib_patterns = [
r'\\bibliography\{([^}]+)\}',
r'\\addbibresource\{([^}]+)\}',
r'\\bibliographyfile\{([^}]+)\}',
r'\\begin\{thebibliography\}',
r'\\bibinput\{([^}]+)\}',
r'\\newrefsection\{([^}]+)\}'
]
base_dir = Path(tex_file).parent
found_in_tex_dir = False
# 首先在tex文件目录下查找显式引用的bib文件
for pattern in bib_patterns:
for match in re.finditer(pattern, content):
if not match.groups():
continue
bib_files = match.group(1).split(',')
for bib_file in bib_files:
bib_file = bib_file.strip()
if not bib_file.endswith('.bib'):
bib_file += '.bib'
full_path = str(base_dir / bib_file)
if os.path.exists(full_path):
found_in_tex_dir = True
bib_content = self.read_file(full_path)
if bib_content:
processed_refs = self._process_bib_content(bib_content)
all_references.extend(processed_refs)
# 如果在tex文件目录下没找到bib文件且提供了额外搜索路径
if not found_in_tex_dir and path_dir:
search_dir = Path(path_dir)
try:
for bib_path in search_dir.glob('**/*.bib'):
bib_content = self.read_file(str(bib_path))
if bib_content:
processed_refs = self._process_bib_content(bib_content)
all_references.extend(processed_refs)
except Exception as e:
print(f"Error searching in path_dir: {e}")
# 合并所有参考文献内容,用空行分隔
return "\n\n".join(all_references)
def _process_bib_content(self, content: str) -> List[str]:
"""
处理bib文件内容提取每个参考文献的特定字段
Args:
content: bib文件内容
Returns:
List[str]: 处理后的参考文献列表
"""
processed_refs = []
# 匹配完整的参考文献条目
ref_pattern = r'@\w+\{[^@]*\}'
# 匹配参考文献类型和键值
entry_start_pattern = r'@(\w+)\{([^,]*?),'
# 匹配字段
field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}'
# 查找所有参考文献条目
for ref_match in re.finditer(ref_pattern, content, re.DOTALL):
ref_content = ref_match.group(0)
# 获取参考文献类型和键值
entry_match = re.match(entry_start_pattern, ref_content)
if not entry_match:
continue
entry_type, cite_key = entry_match.groups()
# 提取需要的字段
needed_fields = {'title': None, 'author': None, 'journal': None}
for field_match in re.finditer(field_pattern, ref_content):
field_name, field_value = field_match.groups()
field_name = field_name.lower()
if field_name in needed_fields:
needed_fields[field_name] = field_value.strip()
# 构建新的参考文献条目
if any(needed_fields.values()): # 如果至少有一个需要的字段
ref_lines = [f"@{entry_type}{{{cite_key},"]
for field_name, field_value in needed_fields.items():
if field_value:
ref_lines.append(f" {field_name}={{{field_value}}},")
ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号
ref_lines.append("}")
processed_refs.append("\n".join(ref_lines))
return processed_refs
def _extract_inline_references(self, content: str) -> str:
"""
从tex文件内容中提取直接写在文件中的参考文献
Args:
content: tex文件内容
Returns:
str: 提取的参考文献内容,如果没有找到则返回空字符串
"""
# 查找参考文献环境
bib_start = r'\\begin\{thebibliography\}'
bib_end = r'\\end\{thebibliography\}'
start_match = re.search(bib_start, content)
end_match = re.search(bib_end, content)
if start_match and end_match:
return content[start_match.start():end_match.end()]
return ""
def _preprocess_content(self, content: str) -> str:
"""预处理TeX内容"""
# 移除注释
content = re.sub(r'(?m)%.*$', '', content)
# 规范化空白字符
# content = re.sub(r'\s+', ' ', content)
content = re.sub(r'\n\s*\n', '\n\n', content)
return content.strip()