From 12be7c16e97748c7effeaa4df5a218950ea2a545 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sat, 23 Nov 2024 19:00:02 +0800 Subject: [PATCH] up --- crazy_functions/Arxiv_论文对话.py | 30 +- .../rag_fns/arxiv_fns/arxiv_splitter.py | 772 ++++++++++++++++++ .../rag_fns/arxiv_fns/latex_cleaner.py | 281 ++++--- .../rag_fns/arxiv_fns/section_extractor.py | 412 ++++++++++ .../rag_fns/arxiv_fns/section_fragment.py | 17 + .../rag_fns/arxiv_fns/tex_utils.py | 271 ++++++ 6 files changed, 1682 insertions(+), 101 deletions(-) create mode 100644 crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/section_extractor.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/section_fragment.py create mode 100644 crazy_functions/rag_fns/arxiv_fns/tex_utils.py diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 631da5ab..9e6ea8a7 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -140,11 +140,31 @@ class ArxivRagWorker: self.rag_worker.add_text_to_vector_store(overview_text) logger.info(f"Added paper overview for {overview['arxiv_id']}") - # 并行处理其余片段 - tasks = [] - for i, fragment in enumerate(fragments): - tasks.append(self._process_single_fragment(fragment, i)) - await asyncio.gather(*tasks) + # 创建线程池 + with ThreadPoolExecutor(max_workers=10) as executor: + # 使用 asyncio.gather 收集所有任务 + loop = asyncio.get_event_loop() + tasks = [ + loop.run_in_executor( + executor, + self._process_single_fragment, + fragment, + i + ) + for i, fragment in enumerate(fragments) + ] + + # 等待所有任务完成 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果和异常 + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Error processing fragment {i}: {result}") + else: + # 处理成功的结果 + pass + logger.info(f"Processed {len(fragments)} fragments successfully") diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py new file mode 100644 index 00000000..9ddb9733 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py @@ -0,0 +1,772 @@ +import os +import re +import time +import aiohttp +import asyncio +import requests +import tarfile +import logging +from pathlib import Path +from copy import deepcopy + +from typing import Generator, List, Tuple, Optional, Dict, Set +from concurrent.futures import ThreadPoolExecutor, as_completed +from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils +from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment +from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file +from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section + + +def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path: + """ + Save all fragments to a single structured markdown file. + + Args: + fragments: List of SectionFragment objects + output_dir: Output directory path + + Returns: + Path: Path to the generated markdown file + """ + from datetime import datetime + from pathlib import Path + import re + + # Create output directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Generate filename + filename = f"fragments_{timestamp}.md" + file_path = output_path / filename + + # Group fragments by section + sections = {} + for fragment in fragments: + section = fragment.current_section or "Uncategorized" + if section not in sections: + sections[section] = [] + sections[section].append(fragment) + + with open(file_path, "w", encoding="utf-8") as f: + # Write document header + f.write("# Document Fragments Analysis\n\n") + f.write("## Overview\n") + f.write(f"- Total Fragments: {len(fragments)}\n") + f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + + # Add paper information if available + if fragments and (fragments[0].title or fragments[0].abstract): + f.write("\n## Paper Information\n") + if fragments[0].title: + f.write(f"### Title\n{fragments[0].title}\n") + if fragments[0].abstract: + f.write(f"\n### Abstract\n{fragments[0].abstract}\n") + + # Write section tree if available + if fragments and fragments[0].section_tree: + f.write("\n## Section Tree\n") + f.write(fragments[0].section_tree) + + # Generate table of contents + f.write("\n## Table of Contents\n") + for section in sections: + clean_section = section.strip() or "Uncategorized" + fragment_count = len(sections[section]) + f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) " + f"({fragment_count} fragments)\n") + + # Write content sections + f.write("\n## Content\n") + for section, section_fragments in sections.items(): + clean_section = section.strip() or "Uncategorized" + f.write(f"\n### {clean_section}\n") + + # Write each fragment + for i, fragment in enumerate(section_fragments, 1): + f.write(f"\n#### Fragment {i}\n") + + # Metadata + f.write("**Metadata:**\n") + metadata = [ + f"- Section: {fragment.current_section}", + f"- Length: {len(fragment.content)} chars", + f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None + ] + f.write("\n".join(filter(None, metadata)) + "\n") + + # Content + f.write("\n**Content:**\n") + f.write("```tex\n") + f.write(fragment.content) + f.write("\n```\n") + + # Bibliography if exists + if fragment.bibliography: + f.write("\n**Bibliography:**\n") + f.write("```bibtex\n") + f.write(fragment.bibliography) + f.write("\n```\n") + + # Add separator + if i < len(section_fragments): + f.write("\n---\n") + + # Add statistics + f.write("\n## Statistics\n") + + # Length distribution + lengths = [len(f.content) for f in fragments] + f.write("\n### Length Distribution\n") + f.write(f"- Minimum: {min(lengths)} chars\n") + f.write(f"- Maximum: {max(lengths)} chars\n") + f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n") + + # Section distribution + f.write("\n### Section Distribution\n") + for section, section_fragments in sections.items(): + percentage = (len(section_fragments) / len(fragments)) * 100 + f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n") + + print(f"Fragments saved to: {file_path}") + return file_path + +# 定义各种引用命令的模式 +CITATION_PATTERNS = [ + # 基本的 \cite{} 格式 + r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + # natbib 格式 + r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + # biblatex 格式 + r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}', + # 自定义 [cite:...] 格式 + r'\[cite:([^\]]+)\]', +] + +# 编译所有模式 +COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS] + + +class ArxivSplitter: + """Arxiv论文智能分割器""" + + def __init__(self, + root_dir: str = "gpt_log/arxiv_cache", + proxies: Optional[Dict[str, str]] = None, + cache_ttl: int = 7 * 24 * 60 * 60): + """ + 初始化分割器 + + Args: + char_range: 字符数范围(最小值, 最大值) + root_dir: 缓存根目录 + proxies: 代理设置 + cache_ttl: 缓存过期时间(秒) + """ + self.root_dir = Path(root_dir) + self.root_dir.mkdir(parents=True, exist_ok=True) + self.proxies = proxies or {} + self.cache_ttl = cache_ttl + + # 动态计算最优线程数 + import multiprocessing + cpu_count = multiprocessing.cpu_count() + # 根据CPU核心数动态设置,但设置上限防止过度并发 + self.document_structure = DocumentStructure() + self.document_parser = EssayStructureParser() + + self.max_workers = min(32, cpu_count * 2) + + # 初始化TeX处理器 + self.tex_processor = TexUtils() + + # 配置日志 + self._setup_logging() + + + + def _setup_logging(self): + """配置日志""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' + ) + self.logger = logging.getLogger(__name__) + + def _normalize_arxiv_id(self, input_str: str) -> str: + """规范化ArXiv ID""" + if 'arxiv.org/' in input_str.lower(): + # 处理URL格式 + if '/pdf/' in input_str: + arxiv_id = input_str.split('/pdf/')[-1] + else: + arxiv_id = input_str.split('/abs/')[-1] + # 移除版本号和其他后缀 + return arxiv_id.split('v')[0].strip() + return input_str.split('v')[0].strip() + + + def _check_cache(self, paper_dir: Path) -> bool: + """ + 检查缓存是否有效,包括文件完整性检查 + + Args: + paper_dir: 论文目录路径 + + Returns: + bool: 如果缓存有效返回True,否则返回False + """ + if not paper_dir.exists(): + return False + + # 检查目录中是否存在必要文件 + has_tex_files = False + has_main_tex = False + + for file_path in paper_dir.rglob("*"): + if file_path.suffix == '.tex': + has_tex_files = True + content = self.tex_processor.read_file(str(file_path)) + if content and r'\documentclass' in content: + has_main_tex = True + break + + if not (has_tex_files and has_main_tex): + return False + + # 检查缓存时间 + cache_time = paper_dir.stat().st_mtime + if (time.time() - cache_time) < self.cache_ttl: + self.logger.info(f"Using valid cache for {paper_dir.name}") + return True + + return False + + async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool: + """ + 异步下载论文,包含重试机制和临时文件处理 + + Args: + arxiv_id: ArXiv论文ID + paper_dir: 目标目录路径 + + Returns: + bool: 下载成功返回True,否则返回False + """ + from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader + temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz" + final_tar_path = paper_dir / f"{arxiv_id}.tar.gz" + + # 确保目录存在 + paper_dir.mkdir(parents=True, exist_ok=True) + + # 尝试使用 ArxivDownloader 下载 + try: + downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies) + downloaded_dir = downloader.download_paper(arxiv_id) + if downloaded_dir: + self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}") + return True + except Exception as e: + self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.") + + # 如果 ArxivDownloader 失败,使用原有的下载方式作为备选 + urls = [ + f"https://arxiv.org/src/{arxiv_id}", + f"https://arxiv.org/e-print/{arxiv_id}" + ] + + max_retries = 3 + retry_delay = 1 # 初始重试延迟(秒) + + for url in urls: + for attempt in range(max_retries): + try: + self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})") + async with aiohttp.ClientSession() as session: + async with session.get(url, proxy=self.proxies.get('http')) as response: + if response.status == 200: + content = await response.read() + + # 写入临时文件 + temp_tar_path.write_bytes(content) + + try: + # 验证tar文件完整性并解压 + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir) + + # 下载成功后移动临时文件到最终位置 + temp_tar_path.rename(final_tar_path) + return True + + except Exception as e: + self.logger.warning(f"Invalid tar file: {str(e)}") + if temp_tar_path.exists(): + temp_tar_path.unlink() + + except Exception as e: + self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}") + await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避 + continue + + return False + + def _process_tar_file(self, tar_path: Path, extract_path: Path): + """处理tar文件的同步操作""" + with tarfile.open(tar_path, 'r:gz') as tar: + tar.testall() # 验证文件完整性 + tar.extractall(path=extract_path) # 解压文件 + + def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure: + """ + Process citations in document structure and add referenced literature for each section + + Args: + doc_structure: DocumentStructure object + ref_bib: String containing references separated by newlines + + Returns: + Updated DocumentStructure object + """ + try: + # Create a copy to avoid modifying the original + doc = deepcopy(doc_structure) + + # Parse references into a mapping + ref_map = self._parse_references(ref_bib) + if not ref_map: + self.logger.warning("No valid references found in ref_bib") + return doc + + # Process all sections recursively + self._process_section_references(doc.toc, ref_map) + + return doc + + except Exception as e: + self.logger.error(f"Error processing references: {str(e)}") + return doc_structure # Return original if processing fails + + def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None: + """ + Recursively process sections to add references + + Args: + sections: List of Section objects + ref_map: Mapping of citation keys to full references + """ + for section in sections: + if section.content: + # Find citations in current section + cited_refs = self.find_citations(section.content) + + if cited_refs: + # Get full references for citations + full_refs = [] + for ref_key in cited_refs: + ref_text = ref_map.get(ref_key) + if ref_text: + full_refs.append(ref_text) + else: + self.logger.warning(f"Reference not found for citation key: {ref_key}") + + # Add references to section content + if full_refs: + section.bibliography = "\n\n".join(full_refs) + + # Process subsections recursively + if section.subsections: + self._process_section_references(section.subsections, ref_map) + + def _parse_references(self, ref_bib: str) -> Dict[str, str]: + """ + Parse reference string into a mapping of citation keys to full references + + Args: + ref_bib: Reference string with references separated by newlines + + Returns: + Dict mapping citation keys to full reference text + """ + ref_map = {} + current_ref = [] + current_key = None + + try: + for line in ref_bib.split('\n'): + line = line.strip() + if not line: + continue + + # New reference entry + if line.startswith('@'): + # Save previous reference if exists + if current_key and current_ref: + ref_map[current_key] = '\n'.join(current_ref) + current_ref = [] + + # Extract key from new reference + key_match = re.search(r'{(.*?),', line) + if key_match: + current_key = key_match.group(1) + current_ref.append(line) + else: + if current_ref is not None: + current_ref.append(line) + + # Save last reference + if current_key and current_ref: + ref_map[current_key] = '\n'.join(current_ref) + + except Exception as e: + self.logger.error(f"Error parsing references: {str(e)}") + + return ref_map + + # 编译一次正则表达式以提高效率 + + @staticmethod + def _clean_citation_key(key: str) -> str: + """Clean individual citation key.""" + return key.strip().strip(',').strip() + + def _extract_keys_from_group(self, keys_str: str) -> Set[str]: + """Extract and clean individual citation keys from a group.""" + try: + # 分割多个引用键(支持逗号和分号分隔) + separators = '[,;]' + keys = re.split(separators, keys_str) + # 清理并过滤空键 + return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)} + except Exception as e: + self.logger.warning(f"Error processing citation group '{keys_str}': {e}") + return set() + + def find_citations(self, content: str) -> Set[str]: + """ + Find citation keys in text content in various formats. + + Args: + content: Text content to search for citations + + Returns: + Set of unique citation keys + + Examples: + Supported formats include: + - \cite{key1,key2} + - \cite[p. 1]{key} + - \citep{key} + - \citet{key} + - [cite:key1, key2] + - And many other variants + """ + citations = set() + + if not content: + return citations + + try: + # 对每个编译好的模式进行搜索 + for pattern in COMPILED_PATTERNS: + matches = pattern.finditer(content) + for match in matches: + # 获取捕获组中的引用键 + keys_str = match.group(1) + if keys_str: + # 提取并添加所有引用键 + new_keys = self._extract_keys_from_group(keys_str) + citations.update(new_keys) + + except Exception as e: + self.logger.error(f"Error finding citations: {str(e)}") + + # 移除明显无效的键 + citations = {key for key in citations + if key and not key.startswith(('\\', '{', '}', '[', ']'))} + + return citations + + def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict: + """ + Find citations and their surrounding context. + + Args: + content: Text content to search for citations + context_chars: Number of characters of context to include before/after + + Returns: + Dict mapping citation keys to lists of context strings + """ + contexts = {} + + if not content: + return contexts + + try: + for pattern in COMPILED_PATTERNS: + matches = pattern.finditer(content) + for match in matches: + # 获取匹配的位置 + start = max(0, match.start() - context_chars) + end = min(len(content), match.end() + context_chars) + + # 获取上下文 + context = content[start:end] + + # 获取并处理引用键 + keys_str = match.group(1) + keys = self._extract_keys_from_group(keys_str) + + # 为每个键添加上下文 + for key in keys: + if key not in contexts: + contexts[key] = [] + contexts[key].append(context) + + except Exception as e: + self.logger.error(f"Error finding citation contexts: {str(e)}") + + return contexts + async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]: + """ + Process ArXiv paper and convert to list of SectionFragments. + Each fragment represents the smallest section unit. + + Args: + arxiv_id_or_url: ArXiv paper ID or URL + + Returns: + List[SectionFragment]: List of processed paper fragments + """ + try: + arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) + paper_dir = self.root_dir / arxiv_id + + # Check if paper directory exists, if not, try to download + if not paper_dir.exists(): + self.logger.info(f"Downloading paper {arxiv_id}") + await self.download_paper(arxiv_id, paper_dir) + + # Find main TeX file + main_tex = self.tex_processor.find_main_tex_file(str(paper_dir)) + if not main_tex: + raise RuntimeError(f"No main TeX file found in {paper_dir}") + + # Get all related TeX files and references + tex_files = self.tex_processor.resolve_includes(main_tex) + ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir) + + if not tex_files: + raise RuntimeError(f"No valid TeX files found for {arxiv_id}") + + # Reset document structure for new processing + self.document_structure = DocumentStructure() + + # Process each TeX file + for file_path in tex_files: + self.logger.info(f"Processing TeX file: {file_path}") + tex_content = read_tex_file(file_path) + if tex_content: + additional_doc = self.document_parser.parse(tex_content) + self.document_structure = self.document_structure.merge(additional_doc) + + # Process references if available + if ref_bib: + self.document_structure = self.process_references(self.document_structure, ref_bib) + self.logger.info("Successfully processed references") + else: + self.logger.info("No references found to process") + + # Generate table of contents once + section_tree = self.document_structure.generate_toc_tree() + + # Convert DocumentStructure to SectionFragments + fragments = self._convert_to_fragments( + doc_structure=self.document_structure, + arxiv_id=arxiv_id, + section_tree=section_tree + ) + + return fragments + + except Exception as e: + self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}") + raise + + def _convert_to_fragments(self, + doc_structure: DocumentStructure, + arxiv_id: str, + section_tree: str) -> List[SectionFragment]: + """ + Convert DocumentStructure to list of SectionFragments. + Creates a fragment for each leaf section in the document hierarchy. + + Args: + doc_structure: Source DocumentStructure + arxiv_id: ArXiv paper ID + section_tree: Pre-generated table of contents tree + + Returns: + List[SectionFragment]: List of paper fragments + """ + fragments = [] + + # Create a base template for all fragments to avoid repetitive assignments + base_fragment_template = { + 'title': doc_structure.title, + 'abstract': doc_structure.abstract, + 'section_tree': section_tree, + 'arxiv_id': arxiv_id + } + + def get_leaf_sections(section: Section, path: List[str] = None) -> None: + """ + Recursively find all leaf sections and create fragments. + A leaf section is one that has content but no subsections, or has neither. + + Args: + section: Current section being processed + path: List of section titles forming the path to current section + """ + if path is None: + path = [] + + current_path = path + [section.title] + + if not section.subsections: + # This is a leaf section, create a fragment if it has content + if section.content or section.bibliography: + fragment = SectionFragment( + **base_fragment_template, + current_section="/".join(current_path), + content=self._clean_content(section.content), + bibliography=section.bibliography + ) + if self._validate_fragment(fragment): + fragments.append(fragment) + else: + # Process each subsection + for subsection in section.subsections: + get_leaf_sections(subsection, current_path) + + # Process all top-level sections + for section in doc_structure.toc: + get_leaf_sections(section) + + # Add a fragment for the abstract if it exists + if doc_structure.abstract: + abstract_fragment = SectionFragment( + **base_fragment_template, + current_section="Abstract", + content=self._clean_content(doc_structure.abstract) + ) + if self._validate_fragment(abstract_fragment): + fragments.insert(0, abstract_fragment) + + self.logger.info(f"Created {len(fragments)} fragments") + return fragments + + def _validate_fragment(self, fragment: SectionFragment) -> bool: + """ + Validate if the fragment has all required fields with meaningful content. + + Args: + fragment: SectionFragment to validate + + Returns: + bool: True if fragment is valid, False otherwise + """ + try: + return all([ + fragment.title.strip(), + fragment.section_tree.strip(), + fragment.current_section.strip(), + fragment.content.strip() or fragment.bibliography.strip() + ]) + except AttributeError: + return False + + def _clean_content(self, content: str) -> str: + """ + Clean and normalize content text. + + Args: + content: Raw content text + + Returns: + str: Cleaned content text + """ + if not content: + return "" + + # Remove excessive whitespace + content = re.sub(r'\s+', ' ', content) + + # Remove remaining LaTeX artifacts + content = re.sub(r'\\item\s*', '• ', content) # Convert \item to bullet points + content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands + + # Clean special characters + content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines + content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines + + return content.strip() + + +async def test_arxiv_splitter(): + """测试ArXiv分割器的功能""" + + # 测试配置 + test_cases = [ + { + "arxiv_id": "2411.03663", + "expected_title": "Large Language Models and Simple Scripts", + "min_fragments": 10, + }, + # { + # "arxiv_id": "1805.10988", + # "expected_title": "RAG vs Fine-tuning", + # "min_fragments": 15, + # } + ] + + # 创建分割器实例 + splitter = ArxivSplitter( + root_dir="test_cache" + ) + + + for case in test_cases: + print(f"\nTesting paper: {case['arxiv_id']}") + try: + fragments = await splitter.process(case['arxiv_id']) + + # 保存fragments + output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") + print(f"Output saved to: {output_dir}") + # 内容检查 + for fragment in fragments: + # 长度检查 + + print((fragment.content)) + print(len(fragment.content)) + # 类型检查 + + + except Exception as e: + print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}") + raise + + +if __name__ == "__main__": + asyncio.run(test_arxiv_splitter()) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/latex_cleaner.py b/crazy_functions/rag_fns/arxiv_fns/latex_cleaner.py index bcaa9793..4574e8a8 100644 --- a/crazy_functions/rag_fns/arxiv_fns/latex_cleaner.py +++ b/crazy_functions/rag_fns/arxiv_fns/latex_cleaner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Set, Dict, Pattern, Optional +from typing import Set, Dict, Pattern, Optional, List, Tuple import re from enum import Enum import logging @@ -8,179 +8,259 @@ from functools import lru_cache class EnvType(Enum): """Environment classification types.""" - PRESERVE = "preserve" - REMOVE = "remove" - EXTRACT = "extract" + PRESERVE = "preserve" # Preserve complete environment including commands + REMOVE = "remove" # Remove environment completely + EXTRACT = "extract" # Extract and clean content @dataclass class LatexConfig: """Configuration for LaTeX processing.""" preserve_envs: Set[str] = field(default_factory=lambda: { - # Math environments + # Math environments - preserve complete content 'equation', 'equation*', 'align', 'align*', 'displaymath', - 'math', 'eqnarray', 'gather', 'gather*', 'multline', 'multline*', - # Tables and figures + 'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*', + 'multline', 'multline*', 'flalign', 'flalign*', + 'alignat', 'alignat*', 'cases', 'split', 'aligned', + # Tables and figures - preserve structure and content 'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix', - 'figure', 'figure*', 'subfigure', + 'figure', 'figure*', 'subfigure', 'wrapfigure', + 'minipage', 'tabbing', 'verbatim', 'longtable', + 'sidewaystable', 'sidewaysfigure', 'floatrow', + # Arrays and matrices + 'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix', + 'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*', # Algorithms and code - 'algorithm', 'algorithmic', 'lstlisting', + 'algorithm', 'algorithmic', 'lstlisting', 'verbatim', + 'minted', 'listing', 'algorithmic*', 'algorithm2e', # Theorems and proofs 'theorem', 'proof', 'definition', 'lemma', 'corollary', - 'proposition', 'example', 'remark' + 'proposition', 'example', 'remark', 'note', 'claim', + 'axiom', 'property', 'assumption', 'conjecture', 'observation', + # Bibliography + 'thebibliography', 'bibliography', 'references' + }) + + # 引用类命令的特殊处理配置 + citation_commands: Set[str] = field(default_factory=lambda: { + # Basic citations + 'cite', 'citep', 'citet', 'citeyear', 'citeauthor', + 'citeyearpar', 'citetext', 'citenum', + # Natbib citations + 'citefullauthor', 'citealp', 'citealt', 'citename', + 'citepalias', 'citetalias', 'citetext', + # Cross-references + 'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref', + 'Cref', 'vref', 'Vref', 'fref', 'pref', + # Hyperref + 'hyperref', 'href', 'url', + # Labels + 'label', 'tag' }) preserve_commands: Set[str] = field(default_factory=lambda: { - # Citations and references - 'caption', 'label', 'ref', 'cite', 'citep', 'citet', 'eqref', # Text formatting 'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote', - 'section', 'subsection', 'subsubsection', 'paragraph', - # Math operators - 'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf' + 'section', 'subsection', 'subsubsection', 'paragraph', 'part', + 'chapter', 'title', 'author', 'date', 'thanks', + # Math operators and symbols + 'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf', + 'partial', 'nabla', 'implies', 'iff', 'therefore', + 'exists', 'forall', 'in', 'subset', 'subseteq', + # Greek letters and math symbols + 'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', + 'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu', + 'nu', 'xi', 'pi', 'rho', 'sigma', 'tau', + 'upsilon', 'phi', 'chi', 'psi', 'omega', + 'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi', + 'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega', + # Math commands + 'left', 'right', 'big', 'Big', 'bigg', 'Bigg', + 'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb', + 'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop', + 'operatorname', 'overline', 'underline', 'overbrace', + 'underbrace', 'overset', 'underset', 'stackrel', + # Spacing and alignment + 'quad', 'qquad', 'hspace', 'vspace', 'medskip', + 'bigskip', 'smallskip', 'hfill', 'vfill', 'centering', + 'raggedright', 'raggedleft' }) remove_commands: Set[str] = field(default_factory=lambda: { # Document setup 'documentclass', 'usepackage', 'input', 'include', 'includeonly', - 'bibliography', 'bibliographystyle', 'frontmatter', 'mainmatter', - 'newtheorem', 'theoremstyle', 'proof', 'proofname', 'qed', + 'bibliographystyle', 'frontmatter', 'mainmatter', + 'newtheorem', 'theoremstyle', 'proofname', 'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator', 'newenvironment', # Layout and spacing - 'pagestyle', 'thispagestyle', 'vspace', 'hspace', 'vfill', 'hfill', - 'newpage', 'clearpage', 'pagebreak', 'linebreak', 'newline', - 'setlength', 'setcounter', 'addtocounter', 'renewcommand', - 'newcommand', 'makeatletter', 'makeatother', 'pagenumbering', - # Margins and columns - 'marginpar', 'marginparsep', 'columnsep', 'columnseprule', - 'twocolumn', 'onecolumn', 'minipage', 'parbox' + 'pagestyle', 'thispagestyle', 'newpage', 'clearpage', + 'pagebreak', 'linebreak', 'newline', 'setlength', + 'setcounter', 'addtocounter', 'makeatletter', + 'makeatother', 'pagenumbering' }) latex_chars: Dict[str, str] = field(default_factory=lambda: { '~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$', '\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"', '\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...', - '\\textasciitilde': '~', '\\textasciicircum': '^', - '\\quad': ' ', '\\qquad': ' ', '\\,': '', '\\;': '', '\\:': '', - '\\!': '', '\\space': ' ', '\\noindent': '' + '\\textasciitilde': '~', '\\textasciicircum': '^' }) - inline_math_delimiters: Set[str] = field(default_factory=lambda: { - '$', '\\(', '\\)', '\\[', '\\]' - }) + # 保留原始格式的特殊命令模式 + special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [ + (r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'), + (r'\\ref\*?{([^}]+)}', r'\\ref{\1}'), + (r'\\label{([^}]+)}', r'\\label{\1}'), + (r'\\eqref{([^}]+)}', r'\\eqref{\1}'), + (r'\\autoref{([^}]+)}', r'\\autoref{\1}'), + (r'\\url{([^}]+)}', r'\\url{\1}'), + (r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}') + ]) class LatexCleaner: - """Efficient and modular LaTeX text cleaner.""" + """Enhanced LaTeX text cleaner that preserves mathematical content and citations.""" def __init__(self, config: Optional[LatexConfig] = None): self.config = config or LatexConfig() self.logger = logging.getLogger(__name__) + # 初始化正则表达式缓存 + self._regex_cache = {} @lru_cache(maxsize=128) def _get_env_pattern(self, env_name: str) -> Pattern: + """Get cached regex pattern for environment matching.""" return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL) def _get_env_type(self, env_name: str) -> EnvType: """Determine environment processing type.""" if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}: return EnvType.PRESERVE - elif env_name in {'verbatim', 'comment'}: + elif env_name in {'comment'}: return EnvType.REMOVE return EnvType.EXTRACT + def _preserve_special_commands(self, text: str) -> str: + """Preserve special commands like citations and references with their complete structure.""" + for pattern, replacement in self.config.special_command_patterns: + if pattern not in self._regex_cache: + self._regex_cache[pattern] = re.compile(pattern) + + def replace_func(match): + # 保持原始命令格式 + return match.group(0) + + text = self._regex_cache[pattern].sub(replace_func, text) + return text + def _process_environment(self, match: re.Match) -> str: + """Process LaTeX environments while preserving complete content for special environments.""" try: env_name = match.group(1) content = match.group(2) env_type = self._get_env_type(env_name) if env_type == EnvType.PRESERVE: - # Preserve math content without markers for inline math - if env_name in {'math', 'displaymath'}: - return f" {content} " - return f" [BEGIN_{env_name}] {content} [END_{env_name}] " + # 完整保留环境内容 + complete_env = match.group(0) + return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n" elif env_type == EnvType.REMOVE: return ' ' - # Process nested environments recursively - return self._clean_nested_environments(content) + else: + # 处理嵌套环境 + return self._clean_nested_environments(content) except Exception as e: - self.logger.error(f"Error processing environment {env_name}: {e}") - return content + self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}") + return match.group(0) + + def _preserve_inline_math(self, text: str) -> str: + """Preserve complete inline math content.""" + + def preserve_math(match): + return f" {match.group(0)} " + + patterns = [ + (r'\$[^$]+\$', preserve_math), + (r'\\[\(\[].*?\\[\)\]]', preserve_math), + (r'\\begin{math}.*?\\end{math}', preserve_math) + ] + + for pattern, handler in patterns: + if pattern not in self._regex_cache: + self._regex_cache[pattern] = re.compile(pattern, re.DOTALL) + text = self._regex_cache[pattern].sub(handler, text) + + return text def _clean_nested_environments(self, text: str) -> str: """Process nested environments recursively.""" - return re.sub( - r'\\begin{(\w+)}(.*?)\\end{\1}', - self._process_environment, - text, - flags=re.DOTALL - ) + pattern = r'\\begin{(\w+)}(.*?)\\end{\1}' + if pattern not in self._regex_cache: + self._regex_cache[pattern] = re.compile(pattern, re.DOTALL) + + return self._regex_cache[pattern].sub(self._process_environment, text) def _clean_commands(self, text: str) -> str: - """Clean LaTeX commands while preserving specified content.""" - # Remove complete commands - for cmd in self.config.remove_commands: - text = re.sub(fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*', '', text) + """Clean LaTeX commands while preserving important content.""" + # 首先处理特殊命令 + text = self._preserve_special_commands(text) - # Process commands with content - def handle_command(match: re.Match) -> str: - cmd = match.group(1).rstrip('*') # Handle starred versions - content = match.group(2) - - # For these delimiters, return the original math content - if cmd in {'[', ']', '(', ')', '$'} or cmd in self.config.inline_math_delimiters: - return match.group(0) - - # For preserved commands return content, otherwise return space - return match.group(0) if cmd in self.config.preserve_commands else ' ' - # Handle commands with arguments - text = re.sub(r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}', handle_command, text) - - # Handle inline math + # 保留内联数学 text = self._preserve_inline_math(text) - # Remove remaining standalone commands - return text + # 移除指定的命令 + for cmd in self.config.remove_commands: + if cmd not in self._regex_cache: + self._regex_cache[cmd] = re.compile( + fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*' + ) + text = self._regex_cache[cmd].sub('', text) - def _preserve_inline_math(self, text: str) -> str: - """Preserve inline math content.""" - # Handle $...$ math - text = re.sub(r'\$(.+?)\$', r' \1 ', text) - # Handle \(...\) math - text = re.sub(r'\\[\(\[](.+?)\\[\)\]]', r' \1 ', text) + # 处理带内容的命令 + def handle_command(match: re.Match) -> str: + cmd = match.group(1).rstrip('*') + if cmd in self.config.preserve_commands or cmd in self.config.citation_commands: + return match.group(0) # 完整保留命令和内容 + return ' ' + + if 'command_pattern' not in self._regex_cache: + self._regex_cache['command_pattern'] = re.compile( + r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}' + ) + + text = self._regex_cache['command_pattern'].sub(handle_command, text) return text def _normalize_text(self, text: str) -> str: - """Normalize special characters and whitespace.""" - # Replace special characters + """Normalize text while preserving special content markers.""" + # 替换特殊字符 for char, replacement in self.config.latex_chars.items(): text = text.replace(char, replacement) - # Clean up whitespace + # 清理空白字符,同时保留环境标记 text = re.sub(r'\s+', ' ', text) - text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r' [BEGIN_\1] ', text) - text = re.sub(r'\s*\[END_(\w+)\]\s*', r' [END_\1] ', text) + text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text) + text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text) - # Remove empty brackets and braces - text = re.sub(r'{\s*}|\[\s*\]|\(\s*\)', '', text) + # 保持块级环境之间的分隔 + text = re.sub(r'\n{3,}', '\n\n', text) return text.strip() def clean_text(self, text: str) -> str: - """Clean LaTeX text while preserving meaningful content.""" + """Clean LaTeX text while preserving mathematical content, citations, and special environments.""" if not text: return "" try: - # Remove comments not inside environments + # 移除注释 text = re.sub(r'(? str: """Convenience function for quick text cleaning with default config.""" - config = LatexConfig( - preserve_envs={'equation', 'theorem'}, - preserve_commands={'textbf', 'emph', "label"}, - latex_chars={'~': ' ', '\\&': '&'} - ) - return LatexCleaner(config).clean_text(text) + cleaner = LatexCleaner() + return cleaner.clean_text(text) # Example usage: if __name__ == "__main__": - # Basic usage with inline math - text = clean_latex_commands(r""" + text = r""" + \documentclass{article} + \begin{document} + + \section{Introduction} + This is a reference to \cite{smith2020} and equation \eqref{eq:main}. + + \begin{equation}\label{eq:main} + E = mc^2 \times \sum_{i=1}^{n} x_i + \end{equation} + + See Figure \ref{fig:example} for details. + + \begin{figure} + \includegraphics{image.png} + \caption{Example figure\label \textbf{Important} result: $E=mc^2$ and \begin{equation} F = ma \end{equation} \label{sec:intro} - """) - print(text) + """ # Custom configuration config = LatexConfig( @@ -236,5 +325,5 @@ if __name__ == "__main__": file_path = 'test_cache/2411.03663/neurips_2024.tex' content = read_tex_file(file_path) cleaner = LatexCleaner(config) - text = cleaner.clean_text(content) + text = cleaner.clean_text(text) print(text) \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/section_extractor.py b/crazy_functions/rag_fns/arxiv_fns/section_extractor.py new file mode 100644 index 00000000..f8fe6787 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/section_extractor.py @@ -0,0 +1,412 @@ +import re +from typing import List, Dict, Tuple, Optional, Set +from enum import Enum +from dataclasses import dataclass, field +from copy import deepcopy + +import logging + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@dataclass +class SectionLevel(Enum): + CHAPTER = 0 + SECTION = 1 + SUBSECTION = 2 + SUBSUBSECTION = 3 + PARAGRAPH = 4 + SUBPARAGRAPH = 5 + + def __lt__(self, other): + if not isinstance(other, SectionLevel): + return NotImplemented + return self.value < other.value + + def __le__(self, other): + if not isinstance(other, SectionLevel): + return NotImplemented + return self.value <= other.value + + def __gt__(self, other): + if not isinstance(other, SectionLevel): + return NotImplemented + return self.value > other.value + + def __ge__(self, other): + if not isinstance(other, SectionLevel): + return NotImplemented + return self.value >= other.value + +@dataclass +class Section: + level: SectionLevel + title: str + content: str = '' + bibliography: str = '' + subsections: List['Section'] = field(default_factory=list) + def merge(self, other: 'Section') -> 'Section': + """Merge this section with another section.""" + if self.title != other.title or self.level != other.level: + raise ValueError("Can only merge sections with same title and level") + + merged = deepcopy(self) + merged.content = self._merge_content(self.content, other.content) + + # Create subsections lookup for efficient merging + subsections_map = {s.title: s for s in merged.subsections} + + for other_subsection in other.subsections: + if other_subsection.title in subsections_map: + # Merge existing subsection + idx = next(i for i, s in enumerate(merged.subsections) + if s.title == other_subsection.title) + merged.subsections[idx] = merged.subsections[idx].merge(other_subsection) + else: + # Add new subsection + merged.subsections.append(deepcopy(other_subsection)) + + return merged + + @staticmethod + def _merge_content(content1: str, content2: str) -> str: + """Merge content strings intelligently.""" + if not content1: + return content2 + if not content2: + return content1 + # Combine non-empty contents with a separator + return f"{content1}\n\n{content2}" +@dataclass +class LatexEnvironment: + """表示LaTeX环境的数据类""" + name: str + start: int + end: int + content: str + raw: str + + +class EnhancedSectionExtractor: + """Enhanced section extractor with comprehensive content handling and hierarchy management.""" + + def __init__(self, preserve_environments: bool = True): + """ + 初始化Section提取器 + + Args: + preserve_environments: 是否保留特定环境(如equation, figure等)的原始LaTeX代码 + """ + self.preserve_environments = preserve_environments + + # Section级别定义 + self.section_levels = { + 'chapter': SectionLevel.CHAPTER, + 'section': SectionLevel.SECTION, + 'subsection': SectionLevel.SUBSECTION, + 'subsubsection': SectionLevel.SUBSUBSECTION, + 'paragraph': SectionLevel.PARAGRAPH, + 'subparagraph': SectionLevel.SUBPARAGRAPH + } + + # 需要保留的环境类型 + self.important_environments = { + 'equation', 'equation*', 'align', 'align*', + 'figure', 'table', 'algorithm', 'algorithmic', + 'definition', 'theorem', 'lemma', 'proof', + 'itemize', 'enumerate', 'description' + } + + # 改进的section pattern + self.section_pattern = ( + r'\\(?Pchapter|section|subsection|subsubsection|paragraph|subparagraph)' + r'\*?' # Optional star + r'(?:\[(?P.*?)\])?' # Optional short title + r'{(?P(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support + ) + + # 环境匹配模式 + self.environment_pattern = ( + r'\\begin{(?P<env_name>[^}]+)}' + r'(?P<env_content>.*?)' + r'\\end{(?P=env_name)}' + ) + + def _find_environments(self, content: str) -> List[LatexEnvironment]: + """ + 查找文档中的所有LaTeX环境。 + 支持嵌套环境的处理。 + """ + environments = [] + stack = [] + + # 使用正则表达式查找所有begin和end标记 + begin_pattern = r'\\begin{([^}]+)}' + end_pattern = r'\\end{([^}]+)}' + + # 组合模式来同时匹配begin和end + tokens = [] + for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content): + if match.group(1): # begin标记 + tokens.append(('begin', match.group(1), match.start())) + else: # end标记 + tokens.append(('end', match.group(2), match.start())) + + # 处理环境嵌套 + for token_type, env_name, pos in tokens: + if token_type == 'begin': + stack.append((env_name, pos)) + elif token_type == 'end' and stack: + if stack[-1][0] == env_name: + start_env_name, start_pos = stack.pop() + env_content = content[start_pos:pos] + raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')] + + if start_env_name in self.important_environments: + environments.append(LatexEnvironment( + name=start_env_name, + start=start_pos, + end=pos + len('\\end{' + env_name + '}'), + content=env_content, + raw=raw_content + )) + + return sorted(environments, key=lambda x: x.start) + + def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]: + """ + 保护重要的LaTeX环境,用占位符替换它们。 + 返回处理后的内容和恢复映射。 + """ + environments = self._find_environments(content) + replacements = {} + + # 从后向前替换,避免位置改变的问题 + for env in reversed(environments): + if env.name in self.important_environments: + placeholder = f'__ENV_{len(replacements)}__' + replacements[placeholder] = env.raw + content = content[:env.start] + placeholder + content[env.end:] + + return content, replacements + + def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str: + """ + 恢复之前保护的环境。 + """ + for placeholder, original in replacements.items(): + content = content.replace(placeholder, original) + return content + + def extract(self, content: str) -> List[Section]: + """ + 从LaTeX文档中提取sections及其内容。 + + Args: + content: LaTeX文档内容 + + Returns: + List[Section]: 提取的section列表,包含层次结构 + """ + try: + # 预处理:保护重要环境 + if self.preserve_environments: + content, env_replacements = self._protect_environments(content) + + # 查找所有sections + sections = self._find_all_sections(content) + if not sections: + return [] + + # 处理sections + root_sections = self._process_sections(content, sections) + + # 如果需要,恢复环境 + if self.preserve_environments: + for section in self._traverse_sections(root_sections): + section.content = self._restore_environments(section.content, env_replacements) + + return root_sections + + except Exception as e: + logger.error(f"Error extracting sections: {str(e)}") + raise + + def _find_all_sections(self, content: str) -> List[dict]: + """查找所有section命令及其位置。""" + sections = [] + + for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE): + section_type = match.group('type').lower() + if section_type not in self.section_levels: + continue + + section = { + 'type': section_type, + 'level': self.section_levels[section_type], + 'title': self._clean_title(match.group('title')), + 'start': match.start(), + 'command_end': match.end(), + } + sections.append(section) + + return sorted(sections, key=lambda x: x['start']) + + def _process_sections(self, content: str, sections: List[dict]) -> List[Section]: + """处理sections以构建层次结构和提取内容。""" + # 计算content范围 + self._calculate_content_ranges(content, sections) + + # 构建层次结构 + root_sections = [] + section_stack = [] + + for section_info in sections: + new_section = Section( + level=section_info['level'], + title=section_info['title'], + content=self._extract_clean_content(content, section_info), + subsections=[] + ) + + # 调整堆栈以找到正确的父section + while section_stack and section_stack[-1].level.value >= new_section.level.value: + section_stack.pop() + + if section_stack: + section_stack[-1].subsections.append(new_section) + else: + root_sections.append(new_section) + + section_stack.append(new_section) + + return root_sections + + def _calculate_content_ranges(self, content: str, sections: List[dict]): + for i, current in enumerate(sections): + content_start = current['command_end'] + + # 找到下一个section(无论什么级别) + content_end = len(content) + for next_section in sections[i + 1:]: + content_end = next_section['start'] + break + + current['content_range'] = (content_start, content_end) + + def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]): + """为每个section计算内容范围。""" + for i, current in enumerate(sections): + content_start = current['command_end'] + + # 找到下一个同级或更高级的section + content_end = len(content) + for next_section in sections[i + 1:]: + if next_section['level'] <= current['level']: + content_end = next_section['start'] + break + + current['content_range'] = (content_start, content_end) + + def _extract_clean_content(self, content: str, section_info: dict) -> str: + """提取并清理section内容。""" + start, end = section_info['content_range'] + raw_content = content[start:end] + + # 清理内容 + clean_content = self._clean_content(raw_content) + return clean_content + + def _clean_content(self, content: str) -> str: + """清理LaTeX内容同时保留重要信息。""" + # 移除注释 + content = re.sub(r'(?<!\\)%.*?\n', '\n', content) + + # LaTeX命令处理规则 + replacements = [ + # 保留引用 + (r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'), + # 保留脚注 + (r'\\footnote{(.*?)}', r'[footnote:\1]'), + # 处理引用 + (r'\\ref{(.*?)}', r'[ref:\1]'), + # 保留URL + (r'\\url{(.*?)}', r'[url:\1]'), + # 保留超链接 + (r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'), + # 处理文本格式命令 + (r'\\(?:textbf|textit|emph){(.*?)}', r'\1'), + # 保留特殊字符 + (r'\\([&%$#_{}])', r'\1'), + ] + + # 应用所有替换规则 + for pattern, replacement in replacements: + content = re.sub(pattern, replacement, content, flags=re.DOTALL) + + # 清理多余的空白 + content = re.sub(r'\n\s*\n', '\n\n', content) + return content.strip() + + def _clean_title(self, title: str) -> str: + """清理section标题。""" + # 处理嵌套的花括号 + while '{' in title: + title = re.sub(r'{([^{}]*)}', r'\1', title) + + # 处理LaTeX命令 + title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title) + title = re.sub(r'\\([&%$#_{}])', r'\1', title) + + return title.strip() + + def _traverse_sections(self, sections: List[Section]) -> List[Section]: + """遍历所有sections(包括子sections)。""" + result = [] + for section in sections: + result.append(section) + result.extend(self._traverse_sections(section.subsections)) + return result + + +def test_enhanced_extractor(): + """使用复杂的测试用例测试提取器。""" + test_content = r""" +\section{Complex Examples} +Here's a complex section with various environments. + +\begin{equation} +E = mc^2 +\end{equation} + +\subsection{Nested Environments} +This subsection has nested environments. + +\begin{figure} +\begin{equation*} +f(x) = \int_0^x g(t) dt +\end{equation*} +\caption{A nested equation in a figure} +\end{figure} + + """ + + extractor = EnhancedSectionExtractor() + sections = extractor.extract(test_content) + + def print_section(section, level=0): + print("\n" + " " * level + f"[{section.level.name}] {section.title}") + if section.content: + content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content + print(" " * (level + 1) + f"Content: {content_preview}") + for subsection in section.subsections: + print_section(subsection, level + 1) + + print("\nExtracted Section Structure:") + for section in sections: + print_section(section) + + +if __name__ == "__main__": + test_enhanced_extractor() \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/section_fragment.py b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py new file mode 100644 index 00000000..21f79e6a --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +@dataclass +class SectionFragment: + """Arxiv论文片段数据类""" + title: str # 文件路径 + abstract: str # 论文摘要 + section_tree: str # 文章各章节的目录结构 + arxiv_id: str = "" # 添加 arxiv_id 属性 + current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字 + content: str = '' #当前片段的内容 + bibliography: str = '' #当前片段的参考文献 + + + + + diff --git a/crazy_functions/rag_fns/arxiv_fns/tex_utils.py b/crazy_functions/rag_fns/arxiv_fns/tex_utils.py new file mode 100644 index 00000000..1fba7953 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/tex_utils.py @@ -0,0 +1,271 @@ +import re +import os +import logging +from pathlib import Path +from typing import List, Tuple, Dict, Set, Optional, Callable +from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns + +class TexUtils: + """TeX文档处理器类""" + + def __init__(self, ): + """ + 初始化TeX处理器 + + Args: + char_range: 字符数范围(最小值, 最大值) + """ + self.logger = logging.getLogger(__name__) + + # 初始化LaTeX环境和命令模式 + self._init_patterns() + self.latex_only_patterns = LaTeXPatterns.latex_only_patterns + + + + + def _init_patterns(self): + """初始化LaTeX模式匹配规则""" + # 特殊环境模式 + self.special_envs = LaTeXPatterns.special_envs + # 章节模式 + self.section_patterns = LaTeXPatterns.section_patterns + # 包含模式 + self.include_patterns = LaTeXPatterns.include_patterns + # 元数据模式 + self.metadata_patterns = LaTeXPatterns.metadata_patterns + + def read_file(self, file_path: str) -> Optional[str]: + """ + 读取TeX文件内容,支持多种编码 + + Args: + file_path: 文件路径 + + Returns: + Optional[str]: 文件内容或None + """ + encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] + for encoding in encodings: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue + + self.logger.warning(f"Failed to read {file_path} with all encodings") + return None + + def find_main_tex_file(self, directory: str) -> Optional[str]: + """ + 查找主TeX文件 + + Args: + directory: 目录路径 + + Returns: + Optional[str]: 主文件路径或None + """ + tex_files = list(Path(directory).rglob("*.tex")) + if not tex_files: + return None + + # 按优先级查找 + for tex_file in tex_files: + content = self.read_file(str(tex_file)) + if content: + if r'\documentclass' in content: + return str(tex_file) + if tex_file.name.lower() == 'main.tex': + return str(tex_file) + + # 返回最大的tex文件 + return str(max(tex_files, key=lambda x: x.stat().st_size)) + + def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]: + """ + 解析TeX文件中的include引用 + + Args: + tex_file: TeX文件路径 + processed: 已处理的文件集合 + + Returns: + List[str]: 相关文件路径列表 + """ + if processed is None: + processed = set() + + if tex_file in processed: + return [] + + processed.add(tex_file) + result = [tex_file] + content = self.read_file(tex_file) + + if not content: + return result + + base_dir = Path(tex_file).parent + for pattern in self.include_patterns: + for match in re.finditer(pattern, content): + included_file = match.group(2) + if not included_file.endswith('.tex'): + included_file += '.tex' + + full_path = str(base_dir / included_file) + if os.path.exists(full_path) and full_path not in processed: + result.extend(self.resolve_includes(full_path, processed)) + + return result + + def resolve_references(self, tex_file: str, path_dir: str = None) -> str: + """ + 解析TeX文件中的参考文献引用,返回所有引用文献的内容,只保留title、author和journal字段。 + 如果在tex_file目录下没找到bib文件,会在path_dir中查找。 + + Args: + tex_file: TeX文件路径 + path_dir: 额外的参考文献搜索路径 + + Returns: + str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔 + """ + all_references = [] # 存储所有参考文献内容 + content = self.read_file(tex_file) + + if not content: + return "" + + # 扩展参考文献引用的模式 + bib_patterns = [ + r'\\bibliography\{([^}]+)\}', + r'\\addbibresource\{([^}]+)\}', + r'\\bibliographyfile\{([^}]+)\}', + r'\\begin\{thebibliography\}', + r'\\bibinput\{([^}]+)\}', + r'\\newrefsection\{([^}]+)\}' + ] + + base_dir = Path(tex_file).parent + found_in_tex_dir = False + + # 首先在tex文件目录下查找显式引用的bib文件 + for pattern in bib_patterns: + for match in re.finditer(pattern, content): + if not match.groups(): + continue + + bib_files = match.group(1).split(',') + for bib_file in bib_files: + bib_file = bib_file.strip() + if not bib_file.endswith('.bib'): + bib_file += '.bib' + + full_path = str(base_dir / bib_file) + if os.path.exists(full_path): + found_in_tex_dir = True + bib_content = self.read_file(full_path) + if bib_content: + processed_refs = self._process_bib_content(bib_content) + all_references.extend(processed_refs) + + # 如果在tex文件目录下没找到bib文件,且提供了额外搜索路径 + if not found_in_tex_dir and path_dir: + search_dir = Path(path_dir) + try: + for bib_path in search_dir.glob('**/*.bib'): + bib_content = self.read_file(str(bib_path)) + if bib_content: + processed_refs = self._process_bib_content(bib_content) + all_references.extend(processed_refs) + except Exception as e: + print(f"Error searching in path_dir: {e}") + + # 合并所有参考文献内容,用空行分隔 + return "\n\n".join(all_references) + + def _process_bib_content(self, content: str) -> List[str]: + """ + 处理bib文件内容,提取每个参考文献的特定字段 + + Args: + content: bib文件内容 + + Returns: + List[str]: 处理后的参考文献列表 + """ + processed_refs = [] + # 匹配完整的参考文献条目 + ref_pattern = r'@\w+\{[^@]*\}' + # 匹配参考文献类型和键值 + entry_start_pattern = r'@(\w+)\{([^,]*?),' + # 匹配字段 + field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}' + + # 查找所有参考文献条目 + for ref_match in re.finditer(ref_pattern, content, re.DOTALL): + ref_content = ref_match.group(0) + + # 获取参考文献类型和键值 + entry_match = re.match(entry_start_pattern, ref_content) + if not entry_match: + continue + + entry_type, cite_key = entry_match.groups() + + # 提取需要的字段 + needed_fields = {'title': None, 'author': None, 'journal': None} + for field_match in re.finditer(field_pattern, ref_content): + field_name, field_value = field_match.groups() + field_name = field_name.lower() + if field_name in needed_fields: + needed_fields[field_name] = field_value.strip() + + # 构建新的参考文献条目 + if any(needed_fields.values()): # 如果至少有一个需要的字段 + ref_lines = [f"@{entry_type}{{{cite_key},"] + for field_name, field_value in needed_fields.items(): + if field_value: + ref_lines.append(f" {field_name}={{{field_value}}},") + ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号 + ref_lines.append("}") + + processed_refs.append("\n".join(ref_lines)) + + return processed_refs + def _extract_inline_references(self, content: str) -> str: + """ + 从tex文件内容中提取直接写在文件中的参考文献 + + Args: + content: tex文件内容 + + Returns: + str: 提取的参考文献内容,如果没有找到则返回空字符串 + """ + # 查找参考文献环境 + bib_start = r'\\begin\{thebibliography\}' + bib_end = r'\\end\{thebibliography\}' + + start_match = re.search(bib_start, content) + end_match = re.search(bib_end, content) + + if start_match and end_match: + return content[start_match.start():end_match.end()] + + return "" + def _preprocess_content(self, content: str) -> str: + """预处理TeX内容""" + # 移除注释 + content = re.sub(r'(?m)%.*$', '', content) + # 规范化空白字符 + # content = re.sub(r'\s+', ' ', content) + content = re.sub(r'\n\s*\n', '\n\n', content) + return content.strip() + + + + + +