up
This commit is contained in:
@@ -11,7 +11,7 @@ import aiohttp
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment as Fragment
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||
|
||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||
from crazy_functions.crazy_utils import input_clipping
|
||||
@@ -124,12 +124,14 @@ class ArxivRagWorker:
|
||||
"title": fragments[0].title,
|
||||
"abstract": fragments[0].abstract,
|
||||
"arxiv_id": fragments[0].arxiv_id,
|
||||
"section_tree": fragments[0].section_tree,
|
||||
}
|
||||
|
||||
overview_text = (
|
||||
f"Paper Title: {overview['title']}\n"
|
||||
f"ArXiv ID: {overview['arxiv_id']}\n"
|
||||
f"Abstract: {overview['abstract']}\n"
|
||||
f"Section Tree:{overview['section_tree']}\n"
|
||||
f"Type: OVERVIEW"
|
||||
)
|
||||
|
||||
@@ -161,10 +163,12 @@ class ArxivRagWorker:
|
||||
try:
|
||||
text = (
|
||||
f"Paper Title: {fragment.title}\n"
|
||||
f"Abstract: {fragment.abstract}\n"
|
||||
f"ArXiv ID: {fragment.arxiv_id}\n"
|
||||
f"Section: {fragment.section}\n"
|
||||
f"Fragment Index: {index}\n"
|
||||
f"Section: {fragment.current_section}\n"
|
||||
f"Section Tree: {fragment.section_tree}\n"
|
||||
f"Content: {fragment.content}\n"
|
||||
f"Bibliography: {fragment.bibliography}\n"
|
||||
f"Type: FRAGMENT"
|
||||
)
|
||||
|
||||
@@ -179,13 +183,14 @@ class ArxivRagWorker:
|
||||
try:
|
||||
text = (
|
||||
f"Paper Title: {fragment.title}\n"
|
||||
f"Abstract: {fragment.abstract}\n"
|
||||
f"ArXiv ID: {fragment.arxiv_id}\n"
|
||||
f"Section: {fragment.section}\n"
|
||||
f"Fragment Index: {index}\n"
|
||||
f"Section: {fragment.current_section}\n"
|
||||
f"Section Tree: {fragment.section_tree}\n"
|
||||
f"Content: {fragment.content}\n"
|
||||
f"Bibliography: {fragment.bibliography}\n"
|
||||
f"Type: FRAGMENT"
|
||||
)
|
||||
|
||||
logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}")
|
||||
self.rag_worker.add_text_to_vector_store(text)
|
||||
logger.info(f"Successfully added fragment {index} to vector store")
|
||||
|
||||
@@ -1,449 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import requests
|
||||
import tarfile
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Tuple, Optional, Dict, Set
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment
|
||||
|
||||
|
||||
|
||||
def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"):
|
||||
"""
|
||||
将所有fragments保存为单个结构化markdown文件
|
||||
|
||||
Args:
|
||||
fragments: fragment列表
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
# 创建输出目录
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
filename = f"fragments_{timestamp}.md"
|
||||
file_path = output_path / filename
|
||||
|
||||
current_section = ""
|
||||
section_count = {} # 用于跟踪每个章节的片段数量
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
# 写入文档头部
|
||||
f.write("# Document Fragments Analysis\n\n")
|
||||
f.write("## Overview\n")
|
||||
f.write(f"- Total Fragments: {len(fragments)}\n")
|
||||
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
|
||||
# 如果有标题和摘要,添加到开头
|
||||
if fragments and (fragments[0].title or fragments[0].abstract):
|
||||
f.write("\n## Paper Information\n")
|
||||
if fragments[0].title:
|
||||
f.write(f"### Title\n{fragments[0].title}\n")
|
||||
if fragments[0].abstract:
|
||||
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
|
||||
|
||||
# 生成目录
|
||||
f.write("\n## Table of Contents\n")
|
||||
|
||||
# 首先收集所有章节信息
|
||||
sections = {}
|
||||
for fragment in fragments:
|
||||
section = fragment.section or "Uncategorized"
|
||||
if section not in sections:
|
||||
sections[section] = []
|
||||
sections[section].append(fragment)
|
||||
|
||||
# 写入目录
|
||||
for section, section_fragments in sections.items():
|
||||
clean_section = section.strip()
|
||||
if not clean_section:
|
||||
clean_section = "Uncategorized"
|
||||
f.write(
|
||||
f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n")
|
||||
|
||||
# 写入正文内容
|
||||
f.write("\n## Content\n")
|
||||
|
||||
# 按章节组织内容
|
||||
for section, section_fragments in sections.items():
|
||||
clean_section = section.strip() or "Uncategorized"
|
||||
f.write(f"\n### {clean_section}\n")
|
||||
|
||||
# 写入每个fragment
|
||||
for i, fragment in enumerate(section_fragments, 1):
|
||||
f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n")
|
||||
|
||||
# 元数据
|
||||
f.write("**Metadata:**\n")
|
||||
f.write(f"- Type: {fragment.segment_type}\n")
|
||||
f.write(f"- Length: {len(fragment.content)} chars\n")
|
||||
f.write(f"- Importance: {fragment.importance:.2f}\n")
|
||||
f.write(f"- Is Appendix: {fragment.is_appendix}\n")
|
||||
f.write(f"- File: {fragment.rel_path}\n")
|
||||
|
||||
# 内容
|
||||
f.write("\n**Content:**\n")
|
||||
f.write("```tex\n")
|
||||
f.write(fragment.content)
|
||||
f.write("\n```\n")
|
||||
|
||||
# 添加分隔线
|
||||
if i < len(section_fragments):
|
||||
f.write("\n---\n")
|
||||
|
||||
# 添加统计信息
|
||||
f.write("\n## Statistics\n")
|
||||
f.write("\n### Fragment Type Distribution\n")
|
||||
type_stats = {}
|
||||
for fragment in fragments:
|
||||
type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1
|
||||
|
||||
for ftype, count in type_stats.items():
|
||||
percentage = (count / len(fragments)) * 100
|
||||
f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n")
|
||||
|
||||
# 长度分布
|
||||
f.write("\n### Length Distribution\n")
|
||||
lengths = [len(f.content) for f in fragments]
|
||||
f.write(f"- Minimum: {min(lengths)} chars\n")
|
||||
f.write(f"- Maximum: {max(lengths)} chars\n")
|
||||
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
|
||||
|
||||
print(f"Fragments saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
|
||||
class ArxivSplitter:
|
||||
"""Arxiv论文智能分割器"""
|
||||
|
||||
def __init__(self,
|
||||
char_range: Tuple[int, int],
|
||||
root_dir: str = "gpt_log/arxiv_cache",
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
cache_ttl: int = 7 * 24 * 60 * 60):
|
||||
"""
|
||||
初始化分割器
|
||||
|
||||
Args:
|
||||
char_range: 字符数范围(最小值, 最大值)
|
||||
root_dir: 缓存根目录
|
||||
proxies: 代理设置
|
||||
cache_ttl: 缓存过期时间(秒)
|
||||
"""
|
||||
self.min_chars, self.max_chars = char_range
|
||||
self.root_dir = Path(root_dir)
|
||||
self.root_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.proxies = proxies or {}
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
# 动态计算最优线程数
|
||||
import multiprocessing
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
# 根据CPU核心数动态设置,但设置上限防止过度并发
|
||||
self.max_workers = min(32, cpu_count * 2)
|
||||
|
||||
# 初始化TeX处理器
|
||||
self.tex_processor = TexProcessor(char_range)
|
||||
|
||||
# 配置日志
|
||||
self._setup_logging()
|
||||
|
||||
|
||||
|
||||
def _setup_logging(self):
|
||||
"""配置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _normalize_arxiv_id(self, input_str: str) -> str:
|
||||
"""规范化ArXiv ID"""
|
||||
if 'arxiv.org/' in input_str.lower():
|
||||
# 处理URL格式
|
||||
if '/pdf/' in input_str:
|
||||
arxiv_id = input_str.split('/pdf/')[-1]
|
||||
else:
|
||||
arxiv_id = input_str.split('/abs/')[-1]
|
||||
# 移除版本号和其他后缀
|
||||
return arxiv_id.split('v')[0].strip()
|
||||
return input_str.split('v')[0].strip()
|
||||
|
||||
|
||||
def _check_cache(self, paper_dir: Path) -> bool:
|
||||
"""
|
||||
检查缓存是否有效,包括文件完整性检查
|
||||
|
||||
Args:
|
||||
paper_dir: 论文目录路径
|
||||
|
||||
Returns:
|
||||
bool: 如果缓存有效返回True,否则返回False
|
||||
"""
|
||||
if not paper_dir.exists():
|
||||
return False
|
||||
|
||||
# 检查目录中是否存在必要文件
|
||||
has_tex_files = False
|
||||
has_main_tex = False
|
||||
|
||||
for file_path in paper_dir.rglob("*"):
|
||||
if file_path.suffix == '.tex':
|
||||
has_tex_files = True
|
||||
content = self.tex_processor.read_file(str(file_path))
|
||||
if content and r'\documentclass' in content:
|
||||
has_main_tex = True
|
||||
break
|
||||
|
||||
if not (has_tex_files and has_main_tex):
|
||||
return False
|
||||
|
||||
# 检查缓存时间
|
||||
cache_time = paper_dir.stat().st_mtime
|
||||
if (time.time() - cache_time) < self.cache_ttl:
|
||||
self.logger.info(f"Using valid cache for {paper_dir.name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
|
||||
"""
|
||||
异步下载论文,包含重试机制和临时文件处理
|
||||
|
||||
Args:
|
||||
arxiv_id: ArXiv论文ID
|
||||
paper_dir: 目标目录路径
|
||||
|
||||
Returns:
|
||||
bool: 下载成功返回True,否则返回False
|
||||
"""
|
||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
|
||||
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
|
||||
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||
|
||||
# 确保目录存在
|
||||
paper_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 尝试使用 ArxivDownloader 下载
|
||||
try:
|
||||
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
|
||||
downloaded_dir = downloader.download_paper(arxiv_id)
|
||||
if downloaded_dir:
|
||||
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
|
||||
|
||||
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
|
||||
urls = [
|
||||
f"https://arxiv.org/src/{arxiv_id}",
|
||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||
]
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 初始重试延迟(秒)
|
||||
|
||||
for url in urls:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, proxy=self.proxies.get('http')) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# 写入临时文件
|
||||
temp_tar_path.write_bytes(content)
|
||||
|
||||
try:
|
||||
# 验证tar文件完整性并解压
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
|
||||
|
||||
# 下载成功后移动临时文件到最终位置
|
||||
temp_tar_path.rename(final_tar_path)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Invalid tar file: {str(e)}")
|
||||
if temp_tar_path.exists():
|
||||
temp_tar_path.unlink()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _process_tar_file(self, tar_path: Path, extract_path: Path):
|
||||
"""处理tar文件的同步操作"""
|
||||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||||
tar.testall() # 验证文件完整性
|
||||
tar.extractall(path=extract_path) # 解压文件
|
||||
|
||||
def _process_single_tex(self, file_path: str) -> List[ArxivFragment]:
|
||||
"""处理单个TeX文件"""
|
||||
try:
|
||||
content = self.tex_processor.read_file(file_path)
|
||||
if not content:
|
||||
return []
|
||||
|
||||
# 提取元数据
|
||||
is_main = r'\documentclass' in content
|
||||
title, abstract = "", ""
|
||||
if is_main:
|
||||
title, abstract = self.tex_processor.extract_metadata(content)
|
||||
|
||||
# 分割内容
|
||||
segments = self.tex_processor.split_content(content)
|
||||
fragments = []
|
||||
|
||||
for i, (segment_content, section, is_appendix) in enumerate(segments):
|
||||
if not segment_content.strip():
|
||||
continue
|
||||
|
||||
segment_type = self.tex_processor.detect_segment_type(segment_content)
|
||||
importance = self.tex_processor.calculate_importance(
|
||||
segment_content, segment_type, is_main
|
||||
)
|
||||
fragments.append(ArxivFragment(
|
||||
file_path=file_path,
|
||||
content=segment_content,
|
||||
segment_index=i,
|
||||
total_segments=len(segments),
|
||||
rel_path=str(Path(file_path).relative_to(self.root_dir)),
|
||||
segment_type=segment_type,
|
||||
title=title,
|
||||
abstract=abstract,
|
||||
section=section,
|
||||
is_appendix=is_appendix,
|
||||
importance=importance
|
||||
))
|
||||
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file_path}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]:
|
||||
"""处理ArXiv论文"""
|
||||
try:
|
||||
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
|
||||
paper_dir = self.root_dir / arxiv_id
|
||||
|
||||
# 检查缓存
|
||||
if not self._check_cache(paper_dir):
|
||||
paper_dir.mkdir(exist_ok=True)
|
||||
if not await self.download_paper(arxiv_id, paper_dir):
|
||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||
|
||||
# 查找主TeX文件
|
||||
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
|
||||
if not main_tex:
|
||||
raise RuntimeError(f"No main TeX file found in {paper_dir}")
|
||||
|
||||
# 获取所有相关TeX文件
|
||||
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||
if not tex_files:
|
||||
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
|
||||
|
||||
# 并行处理所有TeX文件
|
||||
fragments = []
|
||||
chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def process_chunk(chunk_files):
|
||||
chunk_fragments = []
|
||||
for file_path in chunk_files:
|
||||
try:
|
||||
result = await loop.run_in_executor(None, self._process_single_tex, file_path)
|
||||
chunk_fragments.extend(result)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file_path}: {str(e)}")
|
||||
return chunk_fragments
|
||||
|
||||
# 将文件分成多个块
|
||||
file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)]
|
||||
# 异步处理每个块
|
||||
chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks])
|
||||
for result in chunk_results:
|
||||
fragments.extend(result)
|
||||
# 重新计算片段索引并排序
|
||||
fragments.sort(key=lambda x: (x.rel_path, x.segment_index))
|
||||
total_fragments = len(fragments)
|
||||
|
||||
for i, fragment in enumerate(fragments):
|
||||
fragment.segment_index = i
|
||||
fragment.total_segments = total_fragments
|
||||
# 在返回之前添加过滤
|
||||
fragments = self.tex_processor.filter_fragments(fragments)
|
||||
return fragments
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def test_arxiv_splitter():
|
||||
"""测试ArXiv分割器的功能"""
|
||||
|
||||
# 测试配置
|
||||
test_cases = [
|
||||
{
|
||||
"arxiv_id": "2411.03663",
|
||||
"expected_title": "Large Language Models and Simple Scripts",
|
||||
"min_fragments": 10,
|
||||
},
|
||||
# {
|
||||
# "arxiv_id": "1805.10988",
|
||||
# "expected_title": "RAG vs Fine-tuning",
|
||||
# "min_fragments": 15,
|
||||
# }
|
||||
]
|
||||
|
||||
# 创建分割器实例
|
||||
splitter = ArxivSplitter(
|
||||
char_range=(800, 1800),
|
||||
root_dir="test_cache"
|
||||
)
|
||||
|
||||
|
||||
for case in test_cases:
|
||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||
try:
|
||||
fragments = await splitter.process(case['arxiv_id'])
|
||||
|
||||
# 保存fragments
|
||||
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
||||
print(f"Output saved to: {output_dir}")
|
||||
# 内容检查
|
||||
for fragment in fragments:
|
||||
# 长度检查
|
||||
|
||||
print((fragment.content))
|
||||
print(len(fragment.content))
|
||||
# 类型检查
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_arxiv_splitter())
|
||||
@@ -5,75 +5,28 @@ This module provides functionality for parsing and extracting structured informa
|
||||
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
||||
"""
|
||||
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict
|
||||
from copy import deepcopy
|
||||
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict
|
||||
from enum import Enum
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SectionLevel(Enum):
|
||||
CHAPTER = 0
|
||||
SECTION = 1
|
||||
SUBSECTION = 2
|
||||
SUBSUBSECTION = 3
|
||||
PARAGRAPH = 4
|
||||
SUBPARAGRAPH = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class Section:
|
||||
level: SectionLevel
|
||||
title: str
|
||||
content: str = ''
|
||||
subsections: List['Section'] = field(default_factory=list)
|
||||
|
||||
def merge(self, other: 'Section') -> 'Section':
|
||||
"""Merge this section with another section."""
|
||||
if self.title != other.title or self.level != other.level:
|
||||
raise ValueError("Can only merge sections with same title and level")
|
||||
|
||||
merged = deepcopy(self)
|
||||
merged.content = self._merge_content(self.content, other.content)
|
||||
|
||||
# Create subsections lookup for efficient merging
|
||||
subsections_map = {s.title: s for s in merged.subsections}
|
||||
|
||||
for other_subsection in other.subsections:
|
||||
if other_subsection.title in subsections_map:
|
||||
# Merge existing subsection
|
||||
idx = next(i for i, s in enumerate(merged.subsections)
|
||||
if s.title == other_subsection.title)
|
||||
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
|
||||
else:
|
||||
# Add new subsection
|
||||
merged.subsections.append(deepcopy(other_subsection))
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _merge_content(content1: str, content2: str) -> str:
|
||||
"""Merge content strings intelligently."""
|
||||
if not content1:
|
||||
return content2
|
||||
if not content2:
|
||||
return content1
|
||||
# Combine non-empty contents with a separator
|
||||
return f"{content1}\n\n{content2}"
|
||||
|
||||
def read_tex_file(file_path):
|
||||
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return f.read()
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
@dataclass
|
||||
class DocumentStructure:
|
||||
@@ -114,7 +67,7 @@ class DocumentStructure:
|
||||
if other_section.title in sections_map:
|
||||
# Merge existing section
|
||||
idx = next(i for i, s in enumerate(merged.toc)
|
||||
if s.title == other_section.title)
|
||||
if s.title == other_section.title)
|
||||
merged.toc[idx] = merged.toc[idx].merge(other_section)
|
||||
else:
|
||||
# Add new section
|
||||
@@ -132,11 +85,69 @@ class DocumentStructure:
|
||||
# Combine non-empty abstracts with a separator
|
||||
return f"{abstract1}\n\n{abstract2}"
|
||||
|
||||
def generate_toc_tree(self, indent_char: str = " ", abstract_preview_length: int = 0) -> str:
|
||||
"""
|
||||
Generate a tree-like string representation of the table of contents including abstract.
|
||||
|
||||
Args:
|
||||
indent_char: Character(s) used for indentation. Default is two spaces.
|
||||
abstract_preview_length: Maximum length of abstract preview. Default is 200 characters.
|
||||
|
||||
Returns:
|
||||
str: A formatted string showing the hierarchical document structure with abstract
|
||||
"""
|
||||
|
||||
def _format_section(section: Section, level: int = 0) -> str:
|
||||
# Create the current section line with proper indentation
|
||||
current_line = f"{indent_char * level}{'•' if level > 0 else '○'} {section.title}\n"
|
||||
|
||||
# Recursively process subsections
|
||||
subsections = ""
|
||||
if section.subsections:
|
||||
subsections = "".join(_format_section(subsec, level + 1)
|
||||
for subsec in section.subsections)
|
||||
|
||||
return current_line + subsections
|
||||
|
||||
result = []
|
||||
|
||||
# Add document title if it exists
|
||||
if self.title:
|
||||
result.append(f"《{self.title}》\n")
|
||||
|
||||
# Add abstract if it exists
|
||||
if self.abstract:
|
||||
result.append("\n□ Abstract:")
|
||||
# Format abstract content with word wrap
|
||||
abstract_preview = self.abstract[:abstract_preview_length]
|
||||
if len(self.abstract) > abstract_preview_length:
|
||||
abstract_preview += "..."
|
||||
|
||||
# Split abstract into lines and indent them
|
||||
wrapped_lines = []
|
||||
current_line = ""
|
||||
for word in abstract_preview.split():
|
||||
if len(current_line) + len(word) + 1 <= 80: # 80 characters per line
|
||||
current_line = (current_line + " " + word).strip()
|
||||
else:
|
||||
wrapped_lines.append(current_line)
|
||||
current_line = word
|
||||
if current_line:
|
||||
wrapped_lines.append(current_line)
|
||||
|
||||
# Add formatted abstract lines
|
||||
for line in wrapped_lines:
|
||||
result.append(f"\n{indent_char}{line}")
|
||||
result.append("\n") # Add extra newline after abstract
|
||||
|
||||
# Add table of contents header if there are sections
|
||||
if self.toc:
|
||||
result.append("\n◈ Table of Contents:\n")
|
||||
|
||||
# Add all top-level sections and their subsections
|
||||
result.extend(_format_section(section, 0) for section in self.toc)
|
||||
|
||||
return "".join(result)
|
||||
class BaseExtractor(ABC):
|
||||
"""Base class for LaTeX content extractors."""
|
||||
|
||||
@@ -145,7 +156,6 @@ class BaseExtractor(ABC):
|
||||
"""Extract specific content from LaTeX document."""
|
||||
pass
|
||||
|
||||
|
||||
class TitleExtractor(BaseExtractor):
|
||||
"""Extracts title from LaTeX document."""
|
||||
|
||||
@@ -169,7 +179,6 @@ class TitleExtractor(BaseExtractor):
|
||||
return clean_latex_commands(title)
|
||||
return ''
|
||||
|
||||
|
||||
class AbstractExtractor(BaseExtractor):
|
||||
"""Extracts abstract from LaTeX document."""
|
||||
|
||||
@@ -193,70 +202,13 @@ class AbstractExtractor(BaseExtractor):
|
||||
return clean_latex_commands(abstract)
|
||||
return ''
|
||||
|
||||
|
||||
class SectionExtractor:
|
||||
"""Extracts document structure including sections and their content."""
|
||||
|
||||
def __init__(self):
|
||||
self.section_pattern = self._compile_section_pattern()
|
||||
|
||||
def _compile_section_pattern(self) -> str:
|
||||
"""Create pattern for matching section commands."""
|
||||
section_types = '|'.join(level.name.lower() for level in SectionLevel)
|
||||
return fr'\\({section_types})\*?(?:\[.*?\])?\{{(.*?)\}}'
|
||||
|
||||
def extract(self, content: str) -> List[Section]:
|
||||
"""Extract sections and build document hierarchy."""
|
||||
sections = []
|
||||
section_stack = []
|
||||
matches = list(re.finditer(self.section_pattern, content, re.IGNORECASE))
|
||||
|
||||
for i, match in enumerate(matches):
|
||||
cmd_type = match.group(1).lower()
|
||||
section_title = match.group(2)
|
||||
level = SectionLevel[cmd_type.upper()]
|
||||
|
||||
content = self._extract_section_content(content, match,
|
||||
matches[i + 1] if i < len(matches) - 1 else None)
|
||||
|
||||
new_section = Section(
|
||||
level=level,
|
||||
title=clean_latex_commands(section_title),
|
||||
content=clean_latex_commands(content)
|
||||
)
|
||||
|
||||
self._update_section_hierarchy(sections, section_stack, new_section)
|
||||
|
||||
return sections
|
||||
|
||||
def _extract_section_content(self, content: str, current_match: re.Match,
|
||||
next_match: Optional[re.Match]) -> str:
|
||||
"""Extract content between current section and next section."""
|
||||
start_pos = current_match.end()
|
||||
end_pos = next_match.start() if next_match else len(content)
|
||||
return content[start_pos:end_pos].strip()
|
||||
|
||||
def _update_section_hierarchy(self, sections: List[Section],
|
||||
stack: List[Section], new_section: Section):
|
||||
"""Update section hierarchy based on section levels."""
|
||||
while stack and stack[-1].level.value >= new_section.level.value:
|
||||
stack.pop()
|
||||
|
||||
if stack:
|
||||
stack[-1].subsections.append(new_section)
|
||||
else:
|
||||
sections.append(new_section)
|
||||
|
||||
stack.append(new_section)
|
||||
|
||||
|
||||
class EssayStructureParser:
|
||||
"""Main class for parsing LaTeX documents."""
|
||||
|
||||
def __init__(self):
|
||||
self.title_extractor = TitleExtractor()
|
||||
self.abstract_extractor = AbstractExtractor()
|
||||
self.section_extractor = SectionExtractor()
|
||||
self.section_extractor = EnhancedSectionExtractor() # Using the enhanced extractor
|
||||
|
||||
def parse(self, content: str) -> DocumentStructure:
|
||||
"""Parse LaTeX document and extract structured information."""
|
||||
@@ -276,17 +228,8 @@ class EssayStructureParser:
|
||||
"""Preprocess LaTeX content for parsing."""
|
||||
# Remove comments
|
||||
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
||||
|
||||
# # Handle input/include commands
|
||||
# content = re.sub(r'\\(?:input|include){.*?}', '', content)
|
||||
#
|
||||
# # Normalize newlines and whitespace
|
||||
# content = re.sub(r'\r\n?', '\n', content)
|
||||
# content = re.sub(r'\n\s*\n', '\n', content)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
||||
"""Print document structure in a readable format."""
|
||||
print(f"Title: {doc.title}\n")
|
||||
@@ -306,51 +249,32 @@ def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100
|
||||
for section in doc.toc:
|
||||
print_section(section)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Sample main.tex
|
||||
main_tex = r"""
|
||||
\documentclass{article}
|
||||
\title{Research Paper}
|
||||
\begin{document}
|
||||
\begin{abstract}
|
||||
Main abstract introducing the research.
|
||||
\end{abstract}
|
||||
\section{Introduction}
|
||||
Overview of the topic...
|
||||
\section{Background}
|
||||
Part 1 of background...
|
||||
\end{document}
|
||||
"""
|
||||
|
||||
# Sample background.tex
|
||||
background_tex = r"""
|
||||
\section{Background}
|
||||
Part 2 of background...
|
||||
\subsection{Related Work}
|
||||
Discussion of related work...
|
||||
\section{Methodology}
|
||||
Research methods...
|
||||
"""
|
||||
|
||||
# Parse both files
|
||||
parser = EssayStructureParser() # Assuming LaTeXParser class from previous code
|
||||
# Test with a file
|
||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||
main_tex = read_tex_file(file_path)
|
||||
|
||||
# Parse main file
|
||||
parser = EssayStructureParser()
|
||||
main_doc = parser.parse(main_tex)
|
||||
background_doc = parser.parse(background_tex)
|
||||
|
||||
# Merge documents using smart strategy
|
||||
merged_doc = main_doc.merge(background_doc)
|
||||
# Merge other documents
|
||||
file_path_list = [
|
||||
"test_cache/2411.03663/1_intro.tex",
|
||||
"test_cache/2411.03663/0_abstract.tex",
|
||||
"test_cache/2411.03663/2_pre.tex",
|
||||
"test_cache/2411.03663/3_method.tex",
|
||||
"test_cache/2411.03663/4_experiment.tex",
|
||||
"test_cache/2411.03663/5_related_work.tex",
|
||||
"test_cache/2411.03663/6_conclu.tex"
|
||||
]
|
||||
for file_path in file_path_list:
|
||||
tex_content = read_tex_file(file_path)
|
||||
additional_doc = parser.parse(tex_content)
|
||||
main_doc = main_doc.merge(additional_doc)
|
||||
|
||||
# Example of how sections are merged:
|
||||
print("Original Background section content:",
|
||||
[s for s in main_doc.toc if s.title == "Background"][0].content)
|
||||
print("\nMerged Background section content:",
|
||||
[s for s in merged_doc.toc if s.title == "Background"][0].content)
|
||||
print("\nMerged structure:")
|
||||
pretty_print_structure(merged_doc) # Assuming pretty_print_structure from previous code
|
||||
|
||||
# Example of appending sections
|
||||
appended_doc = main_doc.merge(background_doc, strategy='append')
|
||||
print("\nAppended structure (may have duplicate sections):")
|
||||
pretty_print_structure(appended_doc)
|
||||
tree= main_doc.generate_toc_tree()
|
||||
pretty_print_structure(main_doc)
|
||||
@@ -44,6 +44,9 @@ class LatexConfig:
|
||||
# Document setup
|
||||
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
|
||||
'bibliography', 'bibliographystyle', 'frontmatter', 'mainmatter',
|
||||
'newtheorem', 'theoremstyle', 'proof', 'proofname', 'qed',
|
||||
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
|
||||
'newenvironment',
|
||||
# Layout and spacing
|
||||
'pagestyle', 'thispagestyle', 'vspace', 'hspace', 'vfill', 'hfill',
|
||||
'newpage', 'clearpage', 'pagebreak', 'linebreak', 'newline',
|
||||
@@ -126,12 +129,12 @@ class LatexCleaner:
|
||||
cmd = match.group(1).rstrip('*') # Handle starred versions
|
||||
content = match.group(2)
|
||||
|
||||
# Keep math content intact
|
||||
# For these delimiters, return the original math content
|
||||
if cmd in {'[', ']', '(', ')', '$'} or cmd in self.config.inline_math_delimiters:
|
||||
return content
|
||||
|
||||
return content if cmd in self.config.preserve_commands else ' '
|
||||
return match.group(0)
|
||||
|
||||
# For preserved commands return content, otherwise return space
|
||||
return match.group(0) if cmd in self.config.preserve_commands else ' '
|
||||
# Handle commands with arguments
|
||||
text = re.sub(r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}', handle_command, text)
|
||||
|
||||
@@ -139,7 +142,7 @@ class LatexCleaner:
|
||||
text = self._preserve_inline_math(text)
|
||||
|
||||
# Remove remaining standalone commands
|
||||
return re.sub(r'\\[a-zA-Z]+\*?(?:\[\])?', '', text)
|
||||
return text
|
||||
|
||||
def _preserve_inline_math(self, text: str) -> str:
|
||||
"""Preserve inline math content."""
|
||||
@@ -168,7 +171,7 @@ class LatexCleaner:
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""Clean LaTeX text while preserving meaningful content."""
|
||||
if not text:
|
||||
raise ValueError("Input text cannot be empty")
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Remove comments not inside environments
|
||||
@@ -206,15 +209,32 @@ if __name__ == "__main__":
|
||||
\begin{equation}
|
||||
F = ma
|
||||
\end{equation}
|
||||
\label{sec:intro}
|
||||
""")
|
||||
print(text)
|
||||
|
||||
# Custom configuration
|
||||
config = LatexConfig(
|
||||
preserve_envs={'equation', 'theorem'},
|
||||
preserve_envs={},
|
||||
preserve_commands={'textbf', 'emph'},
|
||||
latex_chars={'~': ' ', '\\&': '&'}
|
||||
)
|
||||
|
||||
|
||||
def read_tex_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
return "文件未找到,请检查路径是否正确。"
|
||||
except Exception as e:
|
||||
return f"读取文件时发生错误: {e}"
|
||||
|
||||
|
||||
# 使用函数
|
||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||
content = read_tex_file(file_path)
|
||||
cleaner = LatexCleaner(config)
|
||||
text = cleaner.clean_text(r"\textbf{Custom} cleaning")
|
||||
text = cleaner.clean_text(content)
|
||||
print(text)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user