From 724940a9d8f8bf14ebdb11242382c20c0ec29dfc Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sat, 23 Nov 2024 17:59:17 +0800 Subject: [PATCH] up --- crazy_functions/Arxiv_论文对话.py | 17 +- .../rag_fns/arxiv_fns/arxiv_splitter.py | 449 ------- .../rag_fns/arxiv_fns/essay_structure.py | 266 ++-- .../rag_fns/arxiv_fns/latex_cleaner.py | 36 +- .../rag_fns/arxiv_fns/tex_processor.py | 1099 ----------------- 5 files changed, 134 insertions(+), 1733 deletions(-) delete mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py delete mode 100644 crazy_functions/rag_fns/arxiv_fns/tex_processor.py diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 2c29503f..631da5ab 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -11,7 +11,7 @@ import aiohttp from shared_utils.fastapi_server import validate_path_safety from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file -from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment as Fragment +from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker from crazy_functions.crazy_utils import input_clipping @@ -124,12 +124,14 @@ class ArxivRagWorker: "title": fragments[0].title, "abstract": fragments[0].abstract, "arxiv_id": fragments[0].arxiv_id, + "section_tree": fragments[0].section_tree, } overview_text = ( f"Paper Title: {overview['title']}\n" f"ArXiv ID: {overview['arxiv_id']}\n" f"Abstract: {overview['abstract']}\n" + f"Section Tree:{overview['section_tree']}\n" f"Type: OVERVIEW" ) @@ -161,10 +163,12 @@ class ArxivRagWorker: try: text = ( f"Paper Title: {fragment.title}\n" + f"Abstract: {fragment.abstract}\n" f"ArXiv ID: {fragment.arxiv_id}\n" - f"Section: {fragment.section}\n" - f"Fragment Index: {index}\n" + f"Section: {fragment.current_section}\n" + f"Section Tree: {fragment.section_tree}\n" f"Content: {fragment.content}\n" + f"Bibliography: {fragment.bibliography}\n" f"Type: FRAGMENT" ) @@ -179,13 +183,14 @@ class ArxivRagWorker: try: text = ( f"Paper Title: {fragment.title}\n" + f"Abstract: {fragment.abstract}\n" f"ArXiv ID: {fragment.arxiv_id}\n" - f"Section: {fragment.section}\n" - f"Fragment Index: {index}\n" + f"Section: {fragment.current_section}\n" + f"Section Tree: {fragment.section_tree}\n" f"Content: {fragment.content}\n" + f"Bibliography: {fragment.bibliography}\n" f"Type: FRAGMENT" ) - logger.info(f"Processing fragment {index} for paper {fragment.arxiv_id}") self.rag_worker.add_text_to_vector_store(text) logger.info(f"Successfully added fragment {index} to vector store") diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py deleted file mode 100644 index 901bb064..00000000 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py +++ /dev/null @@ -1,449 +0,0 @@ -import os -import re -import time -import aiohttp -import asyncio -import requests -import tarfile -import logging -from pathlib import Path -from typing import Generator, List, Tuple, Optional, Dict, Set -from concurrent.futures import ThreadPoolExecutor, as_completed -from crazy_functions.rag_fns.arxiv_fns.tex_processor import TexProcessor -from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment - - - -def save_fragments_to_file(fragments, output_dir: str = "fragment_outputs"): - """ - 将所有fragments保存为单个结构化markdown文件 - - Args: - fragments: fragment列表 - output_dir: 输出目录 - """ - from datetime import datetime - from pathlib import Path - import re - - # 创建输出目录 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - - # 生成文件名 - filename = f"fragments_{timestamp}.md" - file_path = output_path / filename - - current_section = "" - section_count = {} # 用于跟踪每个章节的片段数量 - - with open(file_path, "w", encoding="utf-8") as f: - # 写入文档头部 - f.write("# Document Fragments Analysis\n\n") - f.write("## Overview\n") - f.write(f"- Total Fragments: {len(fragments)}\n") - f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - - # 如果有标题和摘要,添加到开头 - if fragments and (fragments[0].title or fragments[0].abstract): - f.write("\n## Paper Information\n") - if fragments[0].title: - f.write(f"### Title\n{fragments[0].title}\n") - if fragments[0].abstract: - f.write(f"\n### Abstract\n{fragments[0].abstract}\n") - - # 生成目录 - f.write("\n## Table of Contents\n") - - # 首先收集所有章节信息 - sections = {} - for fragment in fragments: - section = fragment.section or "Uncategorized" - if section not in sections: - sections[section] = [] - sections[section].append(fragment) - - # 写入目录 - for section, section_fragments in sections.items(): - clean_section = section.strip() - if not clean_section: - clean_section = "Uncategorized" - f.write( - f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) ({len(section_fragments)} fragments)\n") - - # 写入正文内容 - f.write("\n## Content\n") - - # 按章节组织内容 - for section, section_fragments in sections.items(): - clean_section = section.strip() or "Uncategorized" - f.write(f"\n### {clean_section}\n") - - # 写入每个fragment - for i, fragment in enumerate(section_fragments, 1): - f.write(f"\n#### Fragment {i} ({fragment.segment_type})\n") - - # 元数据 - f.write("**Metadata:**\n") - f.write(f"- Type: {fragment.segment_type}\n") - f.write(f"- Length: {len(fragment.content)} chars\n") - f.write(f"- Importance: {fragment.importance:.2f}\n") - f.write(f"- Is Appendix: {fragment.is_appendix}\n") - f.write(f"- File: {fragment.rel_path}\n") - - # 内容 - f.write("\n**Content:**\n") - f.write("```tex\n") - f.write(fragment.content) - f.write("\n```\n") - - # 添加分隔线 - if i < len(section_fragments): - f.write("\n---\n") - - # 添加统计信息 - f.write("\n## Statistics\n") - f.write("\n### Fragment Type Distribution\n") - type_stats = {} - for fragment in fragments: - type_stats[fragment.segment_type] = type_stats.get(fragment.segment_type, 0) + 1 - - for ftype, count in type_stats.items(): - percentage = (count / len(fragments)) * 100 - f.write(f"- {ftype}: {count} ({percentage:.1f}%)\n") - - # 长度分布 - f.write("\n### Length Distribution\n") - lengths = [len(f.content) for f in fragments] - f.write(f"- Minimum: {min(lengths)} chars\n") - f.write(f"- Maximum: {max(lengths)} chars\n") - f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n") - - print(f"Fragments saved to: {file_path}") - return file_path - - - -class ArxivSplitter: - """Arxiv论文智能分割器""" - - def __init__(self, - char_range: Tuple[int, int], - root_dir: str = "gpt_log/arxiv_cache", - proxies: Optional[Dict[str, str]] = None, - cache_ttl: int = 7 * 24 * 60 * 60): - """ - 初始化分割器 - - Args: - char_range: 字符数范围(最小值, 最大值) - root_dir: 缓存根目录 - proxies: 代理设置 - cache_ttl: 缓存过期时间(秒) - """ - self.min_chars, self.max_chars = char_range - self.root_dir = Path(root_dir) - self.root_dir.mkdir(parents=True, exist_ok=True) - self.proxies = proxies or {} - self.cache_ttl = cache_ttl - - # 动态计算最优线程数 - import multiprocessing - cpu_count = multiprocessing.cpu_count() - # 根据CPU核心数动态设置,但设置上限防止过度并发 - self.max_workers = min(32, cpu_count * 2) - - # 初始化TeX处理器 - self.tex_processor = TexProcessor(char_range) - - # 配置日志 - self._setup_logging() - - - - def _setup_logging(self): - """配置日志""" - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' - ) - self.logger = logging.getLogger(__name__) - - def _normalize_arxiv_id(self, input_str: str) -> str: - """规范化ArXiv ID""" - if 'arxiv.org/' in input_str.lower(): - # 处理URL格式 - if '/pdf/' in input_str: - arxiv_id = input_str.split('/pdf/')[-1] - else: - arxiv_id = input_str.split('/abs/')[-1] - # 移除版本号和其他后缀 - return arxiv_id.split('v')[0].strip() - return input_str.split('v')[0].strip() - - - def _check_cache(self, paper_dir: Path) -> bool: - """ - 检查缓存是否有效,包括文件完整性检查 - - Args: - paper_dir: 论文目录路径 - - Returns: - bool: 如果缓存有效返回True,否则返回False - """ - if not paper_dir.exists(): - return False - - # 检查目录中是否存在必要文件 - has_tex_files = False - has_main_tex = False - - for file_path in paper_dir.rglob("*"): - if file_path.suffix == '.tex': - has_tex_files = True - content = self.tex_processor.read_file(str(file_path)) - if content and r'\documentclass' in content: - has_main_tex = True - break - - if not (has_tex_files and has_main_tex): - return False - - # 检查缓存时间 - cache_time = paper_dir.stat().st_mtime - if (time.time() - cache_time) < self.cache_ttl: - self.logger.info(f"Using valid cache for {paper_dir.name}") - return True - - return False - - async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool: - """ - 异步下载论文,包含重试机制和临时文件处理 - - Args: - arxiv_id: ArXiv论文ID - paper_dir: 目标目录路径 - - Returns: - bool: 下载成功返回True,否则返回False - """ - from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader - temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz" - final_tar_path = paper_dir / f"{arxiv_id}.tar.gz" - - # 确保目录存在 - paper_dir.mkdir(parents=True, exist_ok=True) - - # 尝试使用 ArxivDownloader 下载 - try: - downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies) - downloaded_dir = downloader.download_paper(arxiv_id) - if downloaded_dir: - self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}") - return True - except Exception as e: - self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.") - - # 如果 ArxivDownloader 失败,使用原有的下载方式作为备选 - urls = [ - f"https://arxiv.org/src/{arxiv_id}", - f"https://arxiv.org/e-print/{arxiv_id}" - ] - - max_retries = 3 - retry_delay = 1 # 初始重试延迟(秒) - - for url in urls: - for attempt in range(max_retries): - try: - self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})") - async with aiohttp.ClientSession() as session: - async with session.get(url, proxy=self.proxies.get('http')) as response: - if response.status == 200: - content = await response.read() - - # 写入临时文件 - temp_tar_path.write_bytes(content) - - try: - # 验证tar文件完整性并解压 - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir) - - # 下载成功后移动临时文件到最终位置 - temp_tar_path.rename(final_tar_path) - return True - - except Exception as e: - self.logger.warning(f"Invalid tar file: {str(e)}") - if temp_tar_path.exists(): - temp_tar_path.unlink() - - except Exception as e: - self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}") - await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 - continue - - return False - - def _process_tar_file(self, tar_path: Path, extract_path: Path): - """处理tar文件的同步操作""" - with tarfile.open(tar_path, 'r:gz') as tar: - tar.testall() # 验证文件完整性 - tar.extractall(path=extract_path) # 解压文件 - - def _process_single_tex(self, file_path: str) -> List[ArxivFragment]: - """处理单个TeX文件""" - try: - content = self.tex_processor.read_file(file_path) - if not content: - return [] - - # 提取元数据 - is_main = r'\documentclass' in content - title, abstract = "", "" - if is_main: - title, abstract = self.tex_processor.extract_metadata(content) - - # 分割内容 - segments = self.tex_processor.split_content(content) - fragments = [] - - for i, (segment_content, section, is_appendix) in enumerate(segments): - if not segment_content.strip(): - continue - - segment_type = self.tex_processor.detect_segment_type(segment_content) - importance = self.tex_processor.calculate_importance( - segment_content, segment_type, is_main - ) - fragments.append(ArxivFragment( - file_path=file_path, - content=segment_content, - segment_index=i, - total_segments=len(segments), - rel_path=str(Path(file_path).relative_to(self.root_dir)), - segment_type=segment_type, - title=title, - abstract=abstract, - section=section, - is_appendix=is_appendix, - importance=importance - )) - - return fragments - - except Exception as e: - self.logger.error(f"Error processing {file_path}: {str(e)}") - return [] - - async def process(self, arxiv_id_or_url: str) -> List[ArxivFragment]: - """处理ArXiv论文""" - try: - arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) - paper_dir = self.root_dir / arxiv_id - - # 检查缓存 - if not self._check_cache(paper_dir): - paper_dir.mkdir(exist_ok=True) - if not await self.download_paper(arxiv_id, paper_dir): - raise RuntimeError(f"Failed to download paper {arxiv_id}") - - # 查找主TeX文件 - main_tex = self.tex_processor.find_main_tex_file(str(paper_dir)) - if not main_tex: - raise RuntimeError(f"No main TeX file found in {paper_dir}") - - # 获取所有相关TeX文件 - tex_files = self.tex_processor.resolve_includes(main_tex) - if not tex_files: - raise RuntimeError(f"No valid TeX files found for {arxiv_id}") - - # 并行处理所有TeX文件 - fragments = [] - chunk_size = max(1, len(tex_files) // self.max_workers) # 计算每个线程处理的文件数 - loop = asyncio.get_event_loop() - - async def process_chunk(chunk_files): - chunk_fragments = [] - for file_path in chunk_files: - try: - result = await loop.run_in_executor(None, self._process_single_tex, file_path) - chunk_fragments.extend(result) - except Exception as e: - self.logger.error(f"Error processing {file_path}: {str(e)}") - return chunk_fragments - - # 将文件分成多个块 - file_chunks = [tex_files[i:i + chunk_size] for i in range(0, len(tex_files), chunk_size)] - # 异步处理每个块 - chunk_results = await asyncio.gather(*[process_chunk(chunk) for chunk in file_chunks]) - for result in chunk_results: - fragments.extend(result) - # 重新计算片段索引并排序 - fragments.sort(key=lambda x: (x.rel_path, x.segment_index)) - total_fragments = len(fragments) - - for i, fragment in enumerate(fragments): - fragment.segment_index = i - fragment.total_segments = total_fragments - # 在返回之前添加过滤 - fragments = self.tex_processor.filter_fragments(fragments) - return fragments - - except Exception as e: - self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}") - raise - - -async def test_arxiv_splitter(): - """测试ArXiv分割器的功能""" - - # 测试配置 - test_cases = [ - { - "arxiv_id": "2411.03663", - "expected_title": "Large Language Models and Simple Scripts", - "min_fragments": 10, - }, - # { - # "arxiv_id": "1805.10988", - # "expected_title": "RAG vs Fine-tuning", - # "min_fragments": 15, - # } - ] - - # 创建分割器实例 - splitter = ArxivSplitter( - char_range=(800, 1800), - root_dir="test_cache" - ) - - - for case in test_cases: - print(f"\nTesting paper: {case['arxiv_id']}") - try: - fragments = await splitter.process(case['arxiv_id']) - - # 保存fragments - output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") - print(f"Output saved to: {output_dir}") - # 内容检查 - for fragment in fragments: - # 长度检查 - - print((fragment.content)) - print(len(fragment.content)) - # 类型检查 - - - except Exception as e: - print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}") - raise - - -if __name__ == "__main__": - asyncio.run(test_arxiv_splitter()) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py index 91ea0474..053eb568 100644 --- a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py +++ b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py @@ -5,75 +5,28 @@ This module provides functionality for parsing and extracting structured informa including metadata, document structure, and content. It uses modular design and clean architecture principles. """ - import re from abc import ABC, abstractmethod -from enum import Enum import logging +from dataclasses import dataclass, field +from typing import List, Optional, Dict +from copy import deepcopy +from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands +from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from dataclasses import dataclass, field -from typing import List, Optional, Dict -from enum import Enum -import logging -from copy import deepcopy -from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands - -logger = logging.getLogger(__name__) - - -class SectionLevel(Enum): - CHAPTER = 0 - SECTION = 1 - SUBSECTION = 2 - SUBSUBSECTION = 3 - PARAGRAPH = 4 - SUBPARAGRAPH = 5 - - -@dataclass -class Section: - level: SectionLevel - title: str - content: str = '' - subsections: List['Section'] = field(default_factory=list) - - def merge(self, other: 'Section') -> 'Section': - """Merge this section with another section.""" - if self.title != other.title or self.level != other.level: - raise ValueError("Can only merge sections with same title and level") - - merged = deepcopy(self) - merged.content = self._merge_content(self.content, other.content) - - # Create subsections lookup for efficient merging - subsections_map = {s.title: s for s in merged.subsections} - - for other_subsection in other.subsections: - if other_subsection.title in subsections_map: - # Merge existing subsection - idx = next(i for i, s in enumerate(merged.subsections) - if s.title == other_subsection.title) - merged.subsections[idx] = merged.subsections[idx].merge(other_subsection) - else: - # Add new subsection - merged.subsections.append(deepcopy(other_subsection)) - - return merged - - @staticmethod - def _merge_content(content1: str, content2: str) -> str: - """Merge content strings intelligently.""" - if not content1: - return content2 - if not content2: - return content1 - # Combine non-empty contents with a separator - return f"{content1}\n\n{content2}" +def read_tex_file(file_path): + encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] + for encoding in encodings: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue @dataclass class DocumentStructure: @@ -114,7 +67,7 @@ class DocumentStructure: if other_section.title in sections_map: # Merge existing section idx = next(i for i, s in enumerate(merged.toc) - if s.title == other_section.title) + if s.title == other_section.title) merged.toc[idx] = merged.toc[idx].merge(other_section) else: # Add new section @@ -132,11 +85,69 @@ class DocumentStructure: # Combine non-empty abstracts with a separator return f"{abstract1}\n\n{abstract2}" + def generate_toc_tree(self, indent_char: str = " ", abstract_preview_length: int = 0) -> str: + """ + Generate a tree-like string representation of the table of contents including abstract. + Args: + indent_char: Character(s) used for indentation. Default is two spaces. + abstract_preview_length: Maximum length of abstract preview. Default is 200 characters. + Returns: + str: A formatted string showing the hierarchical document structure with abstract + """ + def _format_section(section: Section, level: int = 0) -> str: + # Create the current section line with proper indentation + current_line = f"{indent_char * level}{'•' if level > 0 else '○'} {section.title}\n" + # Recursively process subsections + subsections = "" + if section.subsections: + subsections = "".join(_format_section(subsec, level + 1) + for subsec in section.subsections) + return current_line + subsections + + result = [] + + # Add document title if it exists + if self.title: + result.append(f"《{self.title}》\n") + + # Add abstract if it exists + if self.abstract: + result.append("\n□ Abstract:") + # Format abstract content with word wrap + abstract_preview = self.abstract[:abstract_preview_length] + if len(self.abstract) > abstract_preview_length: + abstract_preview += "..." + + # Split abstract into lines and indent them + wrapped_lines = [] + current_line = "" + for word in abstract_preview.split(): + if len(current_line) + len(word) + 1 <= 80: # 80 characters per line + current_line = (current_line + " " + word).strip() + else: + wrapped_lines.append(current_line) + current_line = word + if current_line: + wrapped_lines.append(current_line) + + # Add formatted abstract lines + for line in wrapped_lines: + result.append(f"\n{indent_char}{line}") + result.append("\n") # Add extra newline after abstract + + # Add table of contents header if there are sections + if self.toc: + result.append("\n◈ Table of Contents:\n") + + # Add all top-level sections and their subsections + result.extend(_format_section(section, 0) for section in self.toc) + + return "".join(result) class BaseExtractor(ABC): """Base class for LaTeX content extractors.""" @@ -145,7 +156,6 @@ class BaseExtractor(ABC): """Extract specific content from LaTeX document.""" pass - class TitleExtractor(BaseExtractor): """Extracts title from LaTeX document.""" @@ -169,7 +179,6 @@ class TitleExtractor(BaseExtractor): return clean_latex_commands(title) return '' - class AbstractExtractor(BaseExtractor): """Extracts abstract from LaTeX document.""" @@ -193,70 +202,13 @@ class AbstractExtractor(BaseExtractor): return clean_latex_commands(abstract) return '' - -class SectionExtractor: - """Extracts document structure including sections and their content.""" - - def __init__(self): - self.section_pattern = self._compile_section_pattern() - - def _compile_section_pattern(self) -> str: - """Create pattern for matching section commands.""" - section_types = '|'.join(level.name.lower() for level in SectionLevel) - return fr'\\({section_types})\*?(?:\[.*?\])?\{{(.*?)\}}' - - def extract(self, content: str) -> List[Section]: - """Extract sections and build document hierarchy.""" - sections = [] - section_stack = [] - matches = list(re.finditer(self.section_pattern, content, re.IGNORECASE)) - - for i, match in enumerate(matches): - cmd_type = match.group(1).lower() - section_title = match.group(2) - level = SectionLevel[cmd_type.upper()] - - content = self._extract_section_content(content, match, - matches[i + 1] if i < len(matches) - 1 else None) - - new_section = Section( - level=level, - title=clean_latex_commands(section_title), - content=clean_latex_commands(content) - ) - - self._update_section_hierarchy(sections, section_stack, new_section) - - return sections - - def _extract_section_content(self, content: str, current_match: re.Match, - next_match: Optional[re.Match]) -> str: - """Extract content between current section and next section.""" - start_pos = current_match.end() - end_pos = next_match.start() if next_match else len(content) - return content[start_pos:end_pos].strip() - - def _update_section_hierarchy(self, sections: List[Section], - stack: List[Section], new_section: Section): - """Update section hierarchy based on section levels.""" - while stack and stack[-1].level.value >= new_section.level.value: - stack.pop() - - if stack: - stack[-1].subsections.append(new_section) - else: - sections.append(new_section) - - stack.append(new_section) - - class EssayStructureParser: """Main class for parsing LaTeX documents.""" def __init__(self): self.title_extractor = TitleExtractor() self.abstract_extractor = AbstractExtractor() - self.section_extractor = SectionExtractor() + self.section_extractor = EnhancedSectionExtractor() # Using the enhanced extractor def parse(self, content: str) -> DocumentStructure: """Parse LaTeX document and extract structured information.""" @@ -276,17 +228,8 @@ class EssayStructureParser: """Preprocess LaTeX content for parsing.""" # Remove comments content = re.sub(r'(? str: """Preserve inline math content.""" @@ -168,7 +171,7 @@ class LatexCleaner: def clean_text(self, text: str) -> str: """Clean LaTeX text while preserving meaningful content.""" if not text: - raise ValueError("Input text cannot be empty") + return "" try: # Remove comments not inside environments @@ -206,15 +209,32 @@ if __name__ == "__main__": \begin{equation} F = ma \end{equation} + \label{sec:intro} """) print(text) # Custom configuration config = LatexConfig( - preserve_envs={'equation', 'theorem'}, + preserve_envs={}, preserve_commands={'textbf', 'emph'}, latex_chars={'~': ' ', '\\&': '&'} ) + + + def read_tex_file(file_path): + try: + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + except FileNotFoundError: + return "文件未找到,请检查路径是否正确。" + except Exception as e: + return f"读取文件时发生错误: {e}" + + + # 使用函数 + file_path = 'test_cache/2411.03663/neurips_2024.tex' + content = read_tex_file(file_path) cleaner = LatexCleaner(config) - text = cleaner.clean_text(r"\textbf{Custom} cleaning") + text = cleaner.clean_text(content) print(text) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/tex_processor.py b/crazy_functions/rag_fns/arxiv_fns/tex_processor.py deleted file mode 100644 index b5fe9c97..00000000 --- a/crazy_functions/rag_fns/arxiv_fns/tex_processor.py +++ /dev/null @@ -1,1099 +0,0 @@ -import re -import os -import logging -from pathlib import Path -from typing import List, Tuple, Dict, Set, Optional, Callable -from crazy_functions.rag_fns.arxiv_fns.arxiv_fragment import ArxivFragment -from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns - -class TexProcessor: - """TeX文档处理器类""" - - def __init__(self, char_range: Tuple[int, int]): - """ - 初始化TeX处理器 - - Args: - char_range: 字符数范围(最小值, 最大值) - """ - self.min_chars, self.max_chars = char_range - self.logger = logging.getLogger(__name__) - - # 初始化LaTeX环境和命令模式 - self._init_patterns() - self.latex_only_patterns = LaTeXPatterns.latex_only_patterns - # 初始化合并规则列表,每个规则是(priority, rule_func)元组 - self.merge_rules = [] - # 注册默认规则 - self.register_merge_rule(self._merge_short_segments, priority=90) - self.register_merge_rule(self._merge_clauses, priority=100) - - def is_latex_commands_only(self, content: str) -> bool: - """ - 检查内容是否仅包含LaTeX命令 - - Args: - content: 要检查的内容 - - Returns: - bool: 如果内容仅包含LaTeX命令返回True,否则返回False - """ - # 预处理:移除空白字符 - content = content.strip() - if not content: - return True - - # 移除注释 - content = re.sub(r'(?m)%.*$', '', content) - content = content.strip() - - # 移除所有已知的LaTeX命令模式 - for pattern in self.latex_only_patterns: - content = re.sub(pattern, '', content) - - # 移除常见的LaTeX控制序列 - content = re.sub(r'\\[a-zA-Z]+(\[.*?\])?(\{.*?\})?', '', content) - - # 移除剩余的空白字符 - content = re.sub(r'\s+', '', content) - - # 检查是否还有实质性内容 - # 如果长度为0或者只包含花括号、方括号等LaTeX标记,则认为是纯LaTeX命令 - remaining_chars = re.sub(r'[\{\}\[\]\(\)\,\\\s]', '', content) - return len(remaining_chars) == 0 - - def has_meaningful_content(self, content: str, min_text_ratio: float = 0.1) -> bool: - """ - 检查内容是否包含足够的有意义文本 - - Args: - content: 要检查的内容 - min_text_ratio: 最小文本比例(默认0.1,表示至少10%是文本) - - Returns: - bool: 如果内容包含足够的有意义文本返回True,否则返回False - """ - # 移除注释和空白字符 - content = re.sub(r'(?m)%.*$', '', content) - content = content.strip() - - # 计算总长度 - total_length = len(content) - if total_length == 0: - return False - - # 移除所有LaTeX命令和环境 - for pattern in self.latex_only_patterns: - content = re.sub(pattern, '', content) - content = re.sub(r'\\[a-zA-Z]+(\[.*?\])?(\{.*?\})?', '', content) - - # 计算剩余文本长度(移除剩余的LaTeX标记) - remaining_text = re.sub(r'[\{\}\[\]\(\)\,\\\s]', '', content) - text_ratio = len(remaining_text) / total_length - - return text_ratio >= min_text_ratio - - def filter_fragments(self, fragments, - min_text_ratio: float = 0.1): - """ - 过滤fragment列表,移除仅包含LaTeX命令的片段,并合并相邻的片段 - - Args: - fragments: ArxivFragment列表 - min_text_ratio: 最小文本比例 - - Returns: - List[ArxivFragment]: 过滤后的fragment列表 - """ - filtered_fragments = [] - total_count = len(fragments) - filtered_count = 0 - - for fragment in fragments: - if self.has_meaningful_content(fragment.content, min_text_ratio): - filtered_fragments.append(fragment) - else: - filtered_count += 1 - self.logger.debug(f"Filtered out latex-only fragment: {fragment.content[:100]}...") - - # 记录过滤统计 - if filtered_count > 0: - self.logger.info(f"Filtered out {filtered_count}/{total_count} latex-only fragments") - - # 重新计算索引 - for i, fragment in enumerate(filtered_fragments): - fragment.segment_index = i - fragment.total_segments = len(filtered_fragments) - - - filtered_fragments = self.merge_segments(filtered_fragments) - - # 重新计算索引 - for i, fragment in enumerate(filtered_fragments): - fragment.segment_index = i - fragment.total_segments = len(filtered_fragments) - - return filtered_fragments - def _is_special_environment(self, content: str) -> bool: - """ - 检查内容是否属于特殊环境 - - Args: - content: 要检查的内容 - - Returns: - bool: 如果内容属于特殊环境返回True,否则返回False - """ - for env_patterns in self.special_envs.values(): - for pattern in env_patterns: - if re.search(pattern, content, re.DOTALL): - return True - return False - - def _init_patterns(self): - """初始化LaTeX模式匹配规则""" - # 特殊环境模式 - self.special_envs = LaTeXPatterns.special_envs - # 章节模式 - self.section_patterns = LaTeXPatterns.section_patterns - # 包含模式 - self.include_patterns = LaTeXPatterns.include_patterns - # 元数据模式 - self.metadata_patterns = LaTeXPatterns.metadata_patterns - - def read_file(self, file_path: str) -> Optional[str]: - """ - 读取TeX文件内容,支持多种编码 - - Args: - file_path: 文件路径 - - Returns: - Optional[str]: 文件内容或None - """ - encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] - for encoding in encodings: - try: - with open(file_path, 'r', encoding=encoding) as f: - return f.read() - except UnicodeDecodeError: - continue - - self.logger.warning(f"Failed to read {file_path} with all encodings") - return None - - def find_main_tex_file(self, directory: str) -> Optional[str]: - """ - 查找主TeX文件 - - Args: - directory: 目录路径 - - Returns: - Optional[str]: 主文件路径或None - """ - tex_files = list(Path(directory).rglob("*.tex")) - if not tex_files: - return None - - # 按优先级查找 - for tex_file in tex_files: - content = self.read_file(str(tex_file)) - if content: - if r'\documentclass' in content: - return str(tex_file) - if tex_file.name.lower() == 'main.tex': - return str(tex_file) - - # 返回最大的tex文件 - return str(max(tex_files, key=lambda x: x.stat().st_size)) - - def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]: - """ - 解析TeX文件中的include引用 - - Args: - tex_file: TeX文件路径 - processed: 已处理的文件集合 - - Returns: - List[str]: 相关文件路径列表 - """ - if processed is None: - processed = set() - - if tex_file in processed: - return [] - - processed.add(tex_file) - result = [tex_file] - content = self.read_file(tex_file) - - if not content: - return result - - base_dir = Path(tex_file).parent - for pattern in self.include_patterns: - for match in re.finditer(pattern, content): - included_file = match.group(2) - if not included_file.endswith('.tex'): - included_file += '.tex' - - full_path = str(base_dir / included_file) - if os.path.exists(full_path) and full_path not in processed: - result.extend(self.resolve_includes(full_path, processed)) - - return result - - - def _preprocess_content(self, content: str) -> str: - """预处理TeX内容""" - # 移除注释 - content = re.sub(r'(?m)%.*$', '', content) - # 规范化空白字符 - # content = re.sub(r'\s+', ' ', content) - content = re.sub(r'\n\s*\n', '\n\n', content) - return content.strip() - - def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: - """保护特殊环境内容""" - for env_patterns in self.special_envs.values(): - for pattern in env_patterns: - content = re.sub( - pattern, - lambda m: self._store_protected_block(m.group(0), protected_blocks), - content, - flags=re.DOTALL - ) - return content - - def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str: - """存储保护块""" - placeholder = f"__PROTECTED_{len(protected_blocks)}__" - protected_blocks[placeholder] = content - return placeholder - - def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: - """恢复特殊环境内容""" - for placeholder, original in protected_blocks.items(): - content = content.replace(placeholder, original) - return content - - def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, bool]]: - """获取章节信息""" - # 检查是否是附录 - is_appendix = bool(re.search(r'\\appendix', content)) - - # 提取章节标题 - for pattern in self.section_patterns: - match = re.search(pattern, para) - if match: - section_title = match.group(1) - # 清理LaTeX命令 - section_title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', section_title) - return section_title, is_appendix - - return None - - def _split_long_paragraph(self, paragraph: str) -> List[str]: - """分割长段落""" - parts = [] - current_part = [] - current_length = 0 - - sentences = re.split(r'(?<=[.!?。!?])\s+', paragraph) - for sentence in sentences: - sent_length = len(sentence) - - if current_length + sent_length <= self.max_chars: - current_part.append(sentence) - current_length += sent_length - else: - if current_part: - parts.append(' '.join(current_part)) - current_part = [sentence] - current_length = sent_length - - if current_part: - parts.append(' '.join(current_part)) - - return parts - - def extract_metadata(self, content: str) -> Tuple[str, str]: - """ - 提取文档元数据 - - Args: - content: TeX内容 - - Returns: - Tuple[str, str]: (标题, 摘要) - """ - title = "" - abstract = "" - - # 提取标题 - for pattern in self.metadata_patterns['title']: - match = re.search(pattern, content) - if match: - title = match.group(1) - # 清理LaTeX命令 - title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', title) - break - - # 提取摘要 - for pattern in self.metadata_patterns['abstract']: - match = re.search(pattern, content, re.DOTALL) - if match: - abstract = match.group(1) - # 清理LaTeX命令 - abstract = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.+?)}', r'\1', abstract) - break - - return title.strip(), abstract.strip() - - def detect_segment_type(self, content: str) -> str: - """ - 检测片段类型 - - Args: - content: 内容片段 - - Returns: - str: 片段类型 - """ - for env_type, patterns in self.special_envs.items(): - for pattern in patterns: - if re.search(pattern, content, re.DOTALL): - return env_type - return 'text' - - def calculate_importance(self, content: str, segment_type: str, is_main: bool) -> float: - """ - 计算内容重要性得分 - - Args: - content: 内容片段 - segment_type: 片段类型 - is_main: 是否在主文件中 - - Returns: - float: 重要性得分 (0-1) - """ - score = 0.5 # 基础分 - - # 根据片段类型调整得分 - type_weights = { - 'text': 0.5, - 'math': 0.7, - 'table': 0.8, - 'figure': 0.6, - 'algorithm': 0.8 - } - score += type_weights.get(segment_type, 0) - - # 根据位置调整得分 - if is_main: - score += 0.2 - - # 根据内容特征调整得分 - if re.search(r'\\label{', content): - score += 0.1 - if re.search(r'\\cite{', content): - score += 0.1 - if re.search(r'\\ref{', content): - score += 0.1 - - # 规范化得分到0-1范围 - return min(1.0, max(0.0, score)) - - - - def split_content(self, content: str) -> List[Tuple[str, str, bool]]: - """ - 按段落分割TeX内容,对超长段落按换行符分割 - - Args: - content: TeX文档内容 - - Returns: - List[Tuple[str, str, bool]]: [(段落内容, 章节名, 是否附录)] - """ - content = self._preprocess_content(content) - segments = [] - current_section = "未命名章节" - is_appendix = False - - # 保护特殊环境 - protected_blocks = {} - # content = self._protect_special_environments(content, protected_blocks) - - # 按段落分割 - paragraphs = re.split(r'\n\s*\n', content) - - for para in paragraphs: - para = para.strip() - if not para: - continue - - # 恢复特殊环境 - para = self._restore_special_environments(para, protected_blocks) - - # 检查章节变化 - section_info = self._get_section_info(para, content) - if section_info: - current_section, is_appendix = section_info - continue - - # 处理特殊环境 - if self._is_special_environment(para): - # 特殊环境超长时分割 - if len(para) > self.max_chars: - split_parts = self._split_special_environment(para) - segments.extend((part, current_section, is_appendix) for part in split_parts) - else: - segments.append((para, current_section, is_appendix)) - continue - - # 处理普通段落 - if len(para) > self.max_chars: - # 按换行符分割超长段落 - split_parts = [p.strip() for p in para.split('\n') if p.strip()] - segments.extend((part, current_section, is_appendix) for part in split_parts) - else: - segments.append((para, current_section, is_appendix)) - - return segments - - def _is_complete_env(self, content: str) -> bool: - """ - 检查是否是完整的LaTeX环境 - - Args: - content: 要检查的内容 - - Returns: - bool: 是否是完整环境 - """ - try: - # 检查基本数学环境配对 - env_pairs = [ - (r'\\begin{(equation\*?)}', r'\\end{equation\*?}'), - (r'\\begin{(align\*?)}', r'\\end{align\*?}'), - (r'\\begin{(gather\*?)}', r'\\end{gather\*?}'), - (r'\\begin{(multline\*?)}', r'\\end{multline\*?}'), - (r'\$\$', r'\$\$'), # 行间数学 - (r'\$', r'\$'), # 行内数学 - (r'\\[', r'\\]'), # 显示数学 - (r'\\(', r'\\)'), # 行内数学 - (r'\\begin{', r'\\end{') # 通用环境 - ] - - # 检查所有环境配对 - for begin_pattern, end_pattern in env_pairs: - if isinstance(begin_pattern, tuple): - begin_pattern, end_pattern = begin_pattern - begin_count = len(re.findall(begin_pattern, content)) - end_count = len(re.findall(end_pattern, content)) - if begin_count != end_count: - return False - - # 检查括号配对 - brackets = {'{': '}', '[': ']', '(': ')'} - bracket_count = {k: 0 for k in brackets.keys() | brackets.values()} - - for char in content: - if char in bracket_count: - bracket_count[char] += 1 - - for open_bracket, close_bracket in brackets.items(): - if bracket_count[open_bracket] != bracket_count[close_bracket]: - return False - - return True - - except Exception as e: - self.logger.warning(f"Error checking environment completeness: {str(e)}") - return False - def _split_special_environment(self, content: str) -> List[str]: - """ - 分割特殊环境内容,确保环境的完整性 - - Args: - content: 特殊环境内容 - - Returns: - List[str]: 分割后的内容列表 - """ - env_type = self.detect_segment_type(content) - - # 如果内容已经在允许的长度范围内,且是完整的环境,直接返回 - try: - if len(content) <= self.max_chars: - if self._is_complete_env(content): - return [content] - except Exception as e: - self.logger.warning(f"Error checking environment in split_special_environment: {str(e)}") - - # 根据不同环境类型选择不同的分割策略 - if env_type == 'math': - return self._split_math_content(content) - elif env_type == 'table': - return self._split_table_content(content) - else: - # 对于其他类型的环境 - parts = [] - current_part = "" - - # 按行分割并尝试保持环境完整性 - lines = content.split('\n') - for line in lines: - line_with_newline = line + '\n' - - # 检查是否添加当前行会超出长度限制 - if len(current_part) + len(line_with_newline) <= self.max_chars: - current_part += line_with_newline - else: - # 如果当前部分不为空,进行处理 - if current_part: - try: - # 尝试找到一个完整的环境结束点 - if self._is_complete_env(current_part): - parts.append(current_part) - current_part = line_with_newline - else: - # 如果当前部分不是完整环境,继续添加 - if len(current_part) + len(line_with_newline) <= self.max_chars * 1.5: # 允许一定程度的超出 - current_part += line_with_newline - else: - # 如果实在太长,强制分割 - parts.append(current_part) - current_part = line_with_newline - except Exception as e: - self.logger.warning(f"Error processing environment part: {str(e)}") - parts.append(current_part) - current_part = line_with_newline - else: - # 如果当前行本身就超过长度限制 - parts.append(line_with_newline) - - # 处理最后剩余的部分 - if current_part: - parts.append(current_part) - - # 清理并返回非空片段 - return [p.strip() for p in parts if p.strip()] - def _split_math_content(self, content: str) -> List[str]: - """ - 分割数学公式内容,确保公式环境的完整性 - - Args: - content: 数学公式内容 - - Returns: - List[str]: 分割后的公式列表 - """ - # 首先识别完整的数学环境 - math_envs = LaTeXPatterns.math_envs - - # 提取所有完整的数学环境 - parts = [] - last_end = 0 - math_blocks = [] - - for pattern, env_type in math_envs: - for match in re.finditer(pattern, content, re.DOTALL): - math_blocks.append((match.start(), match.end(), match.group(0))) - - # 按照位置排序 - math_blocks.sort(key=lambda x: x[0]) - - # 保持数学环境的完整性 - if not math_blocks: - # 如果没有识别到完整的数学环境,作为单个块处理 - return [content] if len(content) <= self.max_chars else self._basic_content_split(content) - - current_part = "" - for start, end, block in math_blocks: - # 添加数学环境之前的文本 - if start > last_end: - text_before = content[last_end:start] - if text_before.strip(): - current_part += text_before - - # 处理数学环境 - if len(block) > self.max_chars: - # 如果当前部分已经有内容,先保存 - if current_part: - parts.append(current_part) - current_part = "" - # 将过长的数学环境作为独立部分 - parts.append(block) - else: - # 如果添加当前数学环境会导致超出长度限制 - if current_part and len(current_part) + len(block) > self.max_chars: - parts.append(current_part) - current_part = block - else: - current_part += block - - last_end = end - - # 处理最后的文本部分 - if last_end < len(content): - remaining = content[last_end:] - if remaining.strip(): - if current_part and len(current_part) + len(remaining) > self.max_chars: - parts.append(current_part) - current_part = remaining - else: - current_part += remaining - - if current_part: - parts.append(current_part) - - return parts - - - def _split_table_content(self, content: str) -> List[str]: - """ - 分割表格内容 - - Args: - content: 表格内容 - - Returns: - List[str]: 分割后的表格部分列表 - """ - # 在表格行之间分割 - rows = re.split(r'(\\\\|\\hline)', content) - result = [] - current_part = "" - header = self._extract_table_header(content) - - for row in rows: - if len(current_part + row) <= self.max_chars: - current_part += row - else: - if current_part: - # 确保每个部分都是完整的表格结构 - result.append(self._wrap_table_content(current_part, header)) - current_part = header + row if header else row - - if current_part: - result.append(self._wrap_table_content(current_part, header)) - - return result - - def _extract_table_header(self, content: str) -> str: - """ - 提取表格头部 - - Args: - content: 表格内容 - - Returns: - str: 表格头部 - """ - # 提取表格环境声明和列格式 - header_match = re.match(r'(\\begin{(?:table|tabular|longtable)\*?}.*?\\hline)', content, re.DOTALL) - return header_match.group(1) if header_match else "" - - def _wrap_table_content(self, content: str, header: str) -> str: - """ - 包装表格内容为完整结构 - - Args: - content: 表格内容 - header: 表格头部 - - Returns: - str: 完整的表格结构 - """ - # 确保表格有正确的开始和结束标签 - env_match = re.search(r'\\begin{(table|tabular|longtable)\*?}', header or content) - if env_match: - env_type = env_match.group(1) - if not content.startswith('\\begin'): - content = f"{header}\n{content}" if header else content - if not content.endswith(f'\\end{{{env_type}}}'): - content = f"{content}\n\\end{{{env_type}}}" - return content - - def _basic_content_split(self, content: str) -> List[str]: - """ - 基本的内容分割策略 - - Args: - content: 要分割的内容 - - Returns: - List[str]: 分割后的内容列表 - """ - parts = [] - while content: - if len(content) <= self.max_chars: - parts.append(content) - break - - # 尝试在最后一个完整行处分割 - split_pos = content[:self.max_chars].rfind('\n') - if split_pos == -1: # 如果找不到换行符,则在最后一个空格处分割 - split_pos = content[:self.max_chars].rfind(' ') - if split_pos == -1: # 如果仍然找不到分割点,则强制分割 - split_pos = self.max_chars - - parts.append(content[:split_pos]) - content = content[split_pos:].strip() - - return parts - - def _ensure_segment_lengths(self, segments: List[Tuple[str, str, bool]]) -> List[Tuple[str, str, bool]]: - """ - 确保所有片段都在指定的长度范围内 - - Args: - segments: 原始片段列表 - - Returns: - List[Tuple[str, str, bool]]: 处理后的片段列表 - """ - result = [] - for content, section, is_appendix in segments: - if len(content) <= self.max_chars: - result.append((content, section, is_appendix)) - else: - # 根据内容类型选择合适的分割方法 - if self._is_special_environment(content): - split_parts = self._split_special_environment(content) - else: - split_parts = self._split_long_paragraph(content) - - result.extend((part, section, is_appendix) for part in split_parts) - - return result - - def register_merge_rule(self, rule_func: Callable[[List['ArxivFragment']], List['ArxivFragment']], - priority: int = 0) -> None: - """ - 注册新的合并规则 - - Args: - rule_func: 合并规则函数,接收fragment列表返回处理后的列表 - priority: 规则优先级,数字越大优先级越高 - """ - self.merge_rules.append((priority, rule_func)) - # 按优先级排序,保证高优先级规则先执行 - self.merge_rules.sort(reverse=True, key=lambda x: x[0]) - - - def _merge_segments(self, seg1: 'ArxivFragment', seg2: 'ArxivFragment') -> 'ArxivFragment': - """ - 合并两个片段的通用方法 - - Args: - seg1: 第一个片段 - seg2: 第二个片段 - - Returns: - ArxivFragment: 合并后的片段 - """ - return ArxivFragment( - file_path=seg1.file_path, - content=f"{seg1.content}\n{seg2.content}", - segment_index=seg1.segment_index, - total_segments=seg1.total_segments - 1, - rel_path=seg1.rel_path, - segment_type=self._merge_segment_type(seg1.segment_type, seg2.segment_type), - title=seg1.title, - abstract=seg1.abstract, - section=seg1.section, - is_appendix=seg1.is_appendix, - importance=max(seg1.importance, seg2.importance) - ) - - def _merge_segment_type(self, type1: str, type2: str) -> str: - """ - 确定合并后片段的类型 - - Args: - type1: 第一个片段的类型 - type2: 第二个片段的类型 - - Returns: - str: 合并后的类型 - """ - # 如果类型相同,保持不变 - if type1 == type2: - return type1 - # 如果其中之一是文本,返回非文本的类型 - if type1 == 'text': - return type2 - if type2 == 'text': - return type1 - # 如果是不同的特殊类型,返回 mixed - return 'mixed' - - def _merge_short_segments(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: - """ - 合并短片段规则 - - Args: - fragments: 片段列表 - - Returns: - List[ArxivFragment]: 处理后的片段列表 - """ - if not fragments: - return fragments - - # 持续合并直到没有可以合并的片段 - need_merge = True - current_fragments = fragments - max_iterations = len(fragments) * 2 # 设置最大迭代次数防止意外情况 - iteration_count = 0 - - while need_merge and iteration_count < max_iterations: - need_merge = False - iteration_count += 1 - result = [] - i = 0 - - while i < len(current_fragments): - current = current_fragments[i] - current_len = len(current.content) - - # 如果当前片段长度足够或是最后一个片段 - if current_len >= self.min_chars or i == len(current_fragments) - 1: - result.append(current) - i += 1 - continue - - # 查找最适合合并的相邻片段 - best_target_idx = -1 - min_combined_length = float('inf') - - # 检查前后片段,选择合并后总长度最小的 - for idx in [i - 1, i + 1]: - if 0 <= idx < len(current_fragments): - target = current_fragments[idx] - target_len = len(target.content) - combined_len = current_len + target_len - - # 更新最佳合并目标 - if combined_len < min_combined_length and ( - target_len < self.min_chars or # 目标也是短片段 - current_len < target_len # 或当前片段更短 - ): - min_combined_length = combined_len - best_target_idx = idx - - # 执行合并 - if best_target_idx != -1: - if best_target_idx < i: # 与前一个片段合并 - result.pop() # 移除之前添加的片段 - merged = self._merge_segments(current_fragments[best_target_idx], current) - result.append(merged) - else: # 与后一个片段合并 - merged = self._merge_segments(current, current_fragments[best_target_idx]) - result.append(merged) - i += 1 # 跳过下一个片段 - need_merge = True # 标记发生了合并,需要继续检查 - i += 1 - else: - # 如果没找到合适的合并目标,保留当前片段 - result.append(current) - i += 1 - - # 更新当前片段列表 - current_fragments = result - - # 检查是否还需要继续合并 - if not need_merge: - # 最后检查一遍是否还有短片段 - has_short = any(len(f.content) < self.min_chars for f in result) - need_merge = has_short and len(result) > 1 - - # 如果达到最大迭代次数,记录警告 - if iteration_count >= max_iterations: - self.logger.warning(f"Reached maximum iterations ({max_iterations}) in merge_short_segments") - - return current_fragments - - def _merge_where_clauses(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: - """ - 合并 where 子句规则 - - Args: - fragments: 片段列表 - - Returns: - List[ArxivFragment]: 处理后的片段列表 - """ - if not fragments: - return fragments - - result = [] - i = 0 - while i < len(fragments): - current = fragments[i] - - # 检查是否是 where 子句 - if current.content.strip().lower().startswith('where'): - if result: # 确保有前一个片段可以合并 - merged = self._merge_segments(result.pop(), current) - result.append(merged) - else: - result.append(current) - else: - result.append(current) - i += 1 - - return result - - def _merge_clauses(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: - """ - 合并从句和连接词规则,确保句子的完整性 - - 处理以下情况: - 1. where/which/that等从句 - 2. 连接词(such that, so that等) - 3. 条件句(if, when等) - 4. 其他常见的数学论文连接词 - - Args: - fragments: 片段列表 - - Returns: - List[ArxivFragment]: 处理后的片段列表 - """ - if not fragments: - return fragments - - # 需要合并的从句和连接词模式 - clause_patterns = [ - # 从句引导词 - r'^(?:where|which|that|whose|when)\b', - # 数学中的连接词 - r'^(?:such\s+that|so\s+that|in\s+which|for\s+which)\b', - # 条件引导词 - r'^(?:if|unless|provided|assuming)\b', - # 其他常见数学连接词 - r'^(?:therefore|thus|hence|consequently|furthermore|moreover)\b', - # 并列连接词 - r'^(?:and|or|but|while|whereas)\b', - # 因果关系词 - r'^(?:because|since|as)\b', - # 时序关系词 - r'^(?:after|before|until|whenever)\b', - # 让步关系词 - r'^(?:although|though|even\s+if|even\s+though)\b', - # 比较关系词 - r'^(?:than|as\s+[.\w]+\s+as)\b', - # 目的关系词 - r'^(?:in\s+order\s+to|so\s+as\s+to)\b', - # 条件关系词组 - r'^(?:on\s+condition\s+that|given\s+that|suppose\s+that)\b', - # 常见数学术语 - r'^(?:denoted\s+by|defined\s+as|written\s+as|expressed\s+as)\b' - ] - # 编译正则表达式模式 - clause_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in clause_patterns] - - def is_clause_start(text: str) -> bool: - """检查文本是否以从句或连接词开始""" - text = text.strip() - return any(pattern.search(text) for pattern in clause_patterns) - - def is_sentence_complete(text: str) -> bool: - """检查句子是否完整(基于简单的标点符号检查)""" - # 检查常见的句子结束符号 - end_markers = ['.', '。', '!', '?', '!', '?'] - # 排除可能的小数点和缩写 - text = text.strip() - if not text: - return False - last_char = text[-1] - if last_char in end_markers: - # 确保不是小数点 - if last_char == '.' and re.search(r'\d\.$', text): - return False - return True - return False - - def should_merge(prev: ArxivFragment, curr: ArxivFragment) -> bool: - """判断两个片段是否应该合并""" - # 检查当前片段是否以从句开始 - if is_clause_start(curr.content): - return True - - # 检查前一个片段是否句子完整 - if not is_sentence_complete(prev.content): - # 如果前一个片段以数学公式结束,检查当前片段是否是其补充说明 - if re.search(r'[\$\)]\\?$', prev.content.strip()): - return True - - # 检查是否存在被截断的括号对 - brackets = { - '(': ')', '[': ']', '{': '}', - r'\{': r'\}', r'\[': r'\]', r'\(': r'\)' - } - for open_b, close_b in brackets.items(): - open_count = prev.content.count(open_b) - close_count = prev.content.count(close_b) - if open_count > close_count: - return True - - return False - - result = [] - i = 0 - while i < len(fragments): - current = fragments[i] - if "which means that the graph convolution adds up all atom features" in current.content: - print("find here") - if not result: - result.append(current) - i += 1 - continue - - prev = result[-1] - if should_merge(prev, current): - # 合并片段,确保不超过最大长度限制 - merged_content = f"{prev.content}\n{current.content}" - if len(current.content) <= self.min_chars: - merged = self._merge_segments(prev, current) - result.pop() # 移除前一个片段 - result.append(merged) # 添加合并后的片段 - else: - # 如果合并后超过长度限制,保持分开 - result.append(current) - else: - result.append(current) - i += 1 - - return result - - # 在TexProcessor类中更新merge_segments方法 - def merge_segments(self, fragments: List['ArxivFragment']) -> List['ArxivFragment']: - """ - 按注册的规则合并片段 - - Args: - fragments: 要合并的片段列表 - - Returns: - List[ArxivFragment]: 合并后的片段列表 - """ - result = fragments - - # 首先处理从句和连接词 - result = self._merge_clauses(result) - - # 然后执行其他合并规则 - for _, rule_func in self.merge_rules: - if rule_func != self._merge_where_clauses: # 跳过旧的where从句处理 - result = rule_func(result) - - return result -