diff --git a/crazy_functions/rag_essay_fns/arxiv_splitter.py b/crazy_functions/rag_essay_fns/arxiv_splitter.py new file mode 100644 index 00000000..e978b25d --- /dev/null +++ b/crazy_functions/rag_essay_fns/arxiv_splitter.py @@ -0,0 +1,534 @@ +import os +import re +import requests +import tarfile +import logging +from dataclasses import dataclass +from typing import Generator, List, Tuple, Optional, Dict, Set +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + + +@dataclass +class ArxivFragment: + """Arxiv论文片段数据类""" + file_path: str + content: str + segment_index: int + total_segments: int + rel_path: str + segment_type: str + title: str + abstract: str + section: str # 保存完整的section层级路径,如 "Introduction" 或 "Methods-Data Processing" + section_type: str # 新增:标识片段类型,如 "abstract", "section", "subsection" 等 + section_level: int # 新增:section的层级深度,abstract为0,main section为1,subsection为2,等等 + is_appendix: bool + + +class SmartArxivSplitter: + def __init__(self, + char_range: Tuple[int, int], + root_dir: str = "gpt_log/arxiv_cache", + proxies: Optional[Dict[str, str]] = None, + max_workers: int = 4): + + self.min_chars, self.max_chars = char_range + self.root_dir = Path(root_dir) + self.root_dir.mkdir(parents=True, exist_ok=True) + self.proxies = proxies or {} + self.max_workers = max_workers + + # 定义特殊环境模式 + self._init_patterns() + + # 配置日志 + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + def _init_patterns(self): + """初始化LaTeX环境和命令模式""" + self.special_envs = { + 'math': [r'\\begin{(equation|align|gather|eqnarray)\*?}.*?\\end{\1\*?}', + r'\$\$.*?\$\$', r'\$[^$]+\$'], + 'table': [r'\\begin{(table|tabular)\*?}.*?\\end{\1\*?}'], + 'figure': [r'\\begin{figure\*?}.*?\\end{figure\*?}'], + 'algorithm': [r'\\begin{(algorithm|algorithmic)}.*?\\end{\1}'] + } + + self.section_patterns = [ + r'\\(sub)*section\{([^}]+)\}', + r'\\chapter\{([^}]+)\}' + ] + + self.include_patterns = [ + r'\\(input|include|subfile)\{([^}]+)\}' + ] + + def _find_main_tex_file(self, directory: str) -> Optional[str]: + """查找主TEX文件""" + tex_files = list(Path(directory).rglob("*.tex")) + if not tex_files: + return None + + # 按以下优先级查找: + # 1. 包含documentclass的文件 + # 2. 文件名为main.tex + # 3. 最大的tex文件 + for tex_file in tex_files: + try: + content = self._read_file(str(tex_file)) + if content and r'\documentclass' in content: + return str(tex_file) + if tex_file.name.lower() == 'main.tex': + return str(tex_file) + except Exception: + continue + + return str(max(tex_files, key=lambda x: x.stat().st_size)) + + def _resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]: + """递归解析tex文件中的include/input命令""" + if processed is None: + processed = set() + + if tex_file in processed: + return [] + + processed.add(tex_file) + result = [tex_file] + content = self._read_file(tex_file) + + if not content: + return result + + base_dir = Path(tex_file).parent + for pattern in self.include_patterns: + for match in re.finditer(pattern, content): + included_file = match.group(2) + if not included_file.endswith('.tex'): + included_file += '.tex' + + # 构建完整路径 + full_path = str(base_dir / included_file) + if os.path.exists(full_path) and full_path not in processed: + result.extend(self._resolve_includes(full_path, processed)) + + return result + + def _split_into_sentences(self, text: str) -> List[str]: + """将文本分割成句子""" + return re.split(r'(?<=[.!?。!?])\s+', text) + + def _split_long_sentence(self, sentence: str) -> List[str]: + """智能分割过长的句子""" + if len(sentence) <= self.max_chars: + return [sentence] + + result = [] + while sentence: + # 在最大长度位置寻找合适的分割点 + split_pos = self._find_split_position(sentence[:self.max_chars]) + if split_pos <= 0: + split_pos = self.max_chars + + result.append(sentence[:split_pos]) + sentence = sentence[split_pos:].strip() + + return result + + def _find_split_position(self, text: str) -> int: + """找到合适的句子分割位置""" + # 优先在标点符号处分割 + punctuation_match = re.search(r'[,,;;]\s*', text[::-1]) + if punctuation_match: + return len(text) - punctuation_match.end() + + # 其次在空白字符处分割 + space_match = re.search(r'\s+', text[::-1]) + if space_match: + return len(text) - space_match.end() + + return -1 + + def _protect_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: + """保护特殊环境内容""" + for env_type, patterns in self.special_envs.items(): + for pattern in patterns: + content = re.sub( + pattern, + lambda m: self._store_protected_block(m.group(0), protected_blocks), + content, + flags=re.DOTALL + ) + return content + + def _store_protected_block(self, content: str, protected_blocks: Dict[str, str]) -> str: + """存储受保护的内容块""" + placeholder = f"PROTECTED_{len(protected_blocks)}" + protected_blocks[placeholder] = content + return placeholder + + def _restore_special_environments(self, content: str, protected_blocks: Dict[str, str]) -> str: + """恢复特殊环境内容""" + for placeholder, original in protected_blocks.items(): + content = content.replace(placeholder, original) + return content + + def _is_special_environment(self, text: str) -> bool: + """判断是否是特殊环境""" + for patterns in self.special_envs.values(): + for pattern in patterns: + if re.search(pattern, text, re.DOTALL): + return True + return False + + def _preprocess_content(self, content: str) -> str: + """预处理TEX内容""" + # 移除注释 + content = re.sub(r'(?m)%.*$', '', content) + # 规范化空白字符 + content = re.sub(r'\s+', ' ', content) + content = re.sub(r'\n\s*\n', '\n\n', content) + # 移除不必要的命令 + content = re.sub(r'\\(label|ref|cite)\{[^}]*\}', '', content) + return content.strip() + + def process_paper(self, arxiv_id_or_url: str) -> Generator[ArxivFragment, None, None]: + """处理单篇arxiv论文""" + try: + arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url) + paper_dir = self._download_and_extract(arxiv_id) + + # 查找主tex文件 + main_tex = self._find_main_tex_file(paper_dir) + if not main_tex: + raise RuntimeError(f"No main tex file found in {paper_dir}") + + # 获取所有相关tex文件 + tex_files = self._resolve_includes(main_tex) + + # 处理所有tex文件 + fragments = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + future_to_file = { + executor.submit(self._process_single_tex, file_path): file_path + for file_path in tex_files + } + + for future in as_completed(future_to_file): + try: + fragments.extend(future.result()) + except Exception as e: + logging.error(f"Error processing file: {e}") + + # 重新计算片段索引 + fragments.sort(key=lambda x: (x.rel_path, x.segment_index)) + total_fragments = len(fragments) + + for i, fragment in enumerate(fragments): + fragment.segment_index = i + fragment.total_segments = total_fragments + yield fragment + + except Exception as e: + logging.error(f"Error processing paper {arxiv_id_or_url}: {e}") + raise RuntimeError(f"Failed to process paper: {str(e)}") + + def _normalize_arxiv_id(self, input_str: str) -> str: + """规范化arxiv ID""" + if input_str.startswith('https://arxiv.org/'): + if '/pdf/' in input_str: + return input_str.split('/pdf/')[-1].split('v')[0] + return input_str.split('/abs/')[-1].split('v')[0] + return input_str.split('v')[0] + + def _download_and_extract(self, arxiv_id: str) -> str: + """下载并解压arxiv论文源码""" + paper_dir = self.root_dir / arxiv_id + tar_path = paper_dir / f"{arxiv_id}.tar.gz" + + # 检查缓存 + if paper_dir.exists() and any(paper_dir.iterdir()): + logging.info(f"Using cached version for {arxiv_id}") + return str(paper_dir) + + paper_dir.mkdir(exist_ok=True) + + urls = [ + f"https://arxiv.org/src/{arxiv_id}", + f"https://arxiv.org/e-print/{arxiv_id}" + ] + + for url in urls: + try: + logging.info(f"Downloading from {url}") + response = requests.get(url, proxies=self.proxies) + if response.status_code == 200: + tar_path.write_bytes(response.content) + with tarfile.open(tar_path, 'r:gz') as tar: + tar.extractall(path=paper_dir) + return str(paper_dir) + except Exception as e: + logging.warning(f"Download failed for {url}: {e}") + continue + + raise RuntimeError(f"Failed to download paper {arxiv_id}") + + def _read_file(self, file_path: str) -> Optional[str]: + """使用多种编码尝试读取文件""" + encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii'] + for encoding in encodings: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue + logging.warning(f"Failed to read file {file_path} with all encodings") + return None + + def _extract_metadata(self, content: str) -> Tuple[str, str]: + """提取论文标题和摘要""" + title = "" + abstract = "" + + # 提取标题 + title_patterns = [ + r'\\title{([^}]*)}', + r'\\Title{([^}]*)}' + ] + for pattern in title_patterns: + match = re.search(pattern, content) + if match: + title = match.group(1) + title = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', title) + break + + # 提取摘要 + abstract_patterns = [ + r'\\begin{abstract}(.*?)\\end{abstract}', + r'\\abstract{([^}]*)}' + ] + for pattern in abstract_patterns: + match = re.search(pattern, content, re.DOTALL) + if match: + abstract = match.group(1).strip() + abstract = re.sub(r'\\[a-zA-Z]+{([^}]*)}', r'\1', abstract) + break + + return title.strip(), abstract.strip() + + def _get_section_info(self, para: str, content: str) -> Optional[Tuple[str, str, int, bool]]: + """获取段落所属的章节信息,返回(section_path, section_type, level, is_appendix)""" + current_path = [] + section_type = "content" + level = 0 + is_appendix = False + + # 定义section层级的正则模式 + section_patterns = { + r'\\chapter\{([^}]+)\}': 1, + r'\\section\{([^}]+)\}': 1, + r'\\subsection\{([^}]+)\}': 2, + r'\\subsubsection\{([^}]+)\}': 3 + } + + # 查找所有章节标记 + all_sections = [] + for pattern, sec_level in section_patterns.items(): + for match in re.finditer(pattern, content): + all_sections.append((match.start(), match.group(1), sec_level)) + + # 检查是否是摘要 + abstract_match = re.search(r'\\begin{abstract}.*?' + re.escape(para), content, re.DOTALL) + if abstract_match: + return "Abstract", "abstract", 0, False + + # 查找appendix标记 + appendix_pos = content.find(r'\appendix') + + # 确定当前章节 + para_pos = content.find(para) + if para_pos >= 0: + is_appendix = appendix_pos >= 0 and para_pos > appendix_pos + current_sections = [] + current_level = 0 + + # 按位置排序所有section标记 + for sec_pos, sec_title, sec_level in sorted(all_sections): + if sec_pos > para_pos: + break + # 如果遇到更高层级的section,清除所有更低层级的section + if sec_level <= current_level: + current_sections = [s for s in current_sections if s[1] < sec_level] + current_sections.append((sec_title, sec_level)) + current_level = sec_level + + # 构建section路径 + if current_sections: + current_path = [s[0] for s in sorted(current_sections, key=lambda x: x[1])] + section_path = "-".join(current_path) + level = max(s[1] for s in current_sections) + section_type = "section" if level == 1 else "subsection" + return section_path, section_type, level, is_appendix + + return "Unknown Section", "content", 0, is_appendix + + def _smart_split(self, content: str) -> List[Tuple[str, str, str, int, bool]]: + """智能分割TEX内容,确保在字符范围内并保持语义完整性""" + content = self._preprocess_content(content) + segments = [] + current_buffer = [] + current_length = 0 + current_section_info = ("Unknown Section", "content", 0, False) + + # 保护特殊环境 + protected_blocks = {} + content = self._protect_special_environments(content, protected_blocks) + + # 按段落分割 + paragraphs = re.split(r'\n\s*\n', content) + + for para in paragraphs: + para = para.strip() + if not para: + continue + + # 恢复特殊环境 + para = self._restore_special_environments(para, protected_blocks) + + # 更新章节信息 + section_info = self._get_section_info(para, content) + if section_info: + current_section_info = section_info + + # 判断是否是特殊环境 + if self._is_special_environment(para): + # 处理当前缓冲区 + if current_buffer: + segments.append(( + '\n'.join(current_buffer), + *current_section_info + )) + current_buffer = [] + current_length = 0 + + # 添加特殊环境作为独立片段 + segments.append((para, *current_section_info)) + continue + + # 处理普通段落 + sentences = self._split_into_sentences(para) + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + sent_length = len(sentence) + new_length = current_length + sent_length + (1 if current_buffer else 0) + + if new_length <= self.max_chars: + current_buffer.append(sentence) + current_length = new_length + else: + # 如果当前缓冲区达到最小长度要求 + if current_length >= self.min_chars: + segments.append(( + '\n'.join(current_buffer), + *current_section_info + )) + current_buffer = [sentence] + current_length = sent_length + else: + # 尝试将过长的句子分割 + split_sentences = self._split_long_sentence(sentence) + for split_sent in split_sentences: + if current_length + len(split_sent) <= self.max_chars: + current_buffer.append(split_sent) + current_length += len(split_sent) + 1 + else: + segments.append(( + '\n'.join(current_buffer), + *current_section_info + )) + current_buffer = [split_sent] + current_length = len(split_sent) + + # 处理剩余的缓冲区 + if current_buffer: + segments.append(( + '\n'.join(current_buffer), + *current_section_info + )) + + return segments + + def _process_single_tex(self, file_path: str) -> List[ArxivFragment]: + """处理单个TEX文件""" + try: + content = self._read_file(file_path) + if not content: + return [] + + # 提取元数据 + is_main = r'\documentclass' in content + title = "" + abstract = "" + if is_main: + title, abstract = self._extract_metadata(content) + + # 智能分割内容 + segments = self._smart_split(content) + fragments = [] + + for i, (segment_content, section_path, section_type, level, is_appendix) in enumerate(segments): + if segment_content.strip(): + segment_type = 'text' + for env_type, patterns in self.special_envs.items(): + if any(re.search(pattern, segment_content, re.DOTALL) + for pattern in patterns): + segment_type = env_type + break + + fragments.append(ArxivFragment( + file_path=file_path, + content=segment_content, + segment_index=i, + total_segments=len(segments), + rel_path=os.path.relpath(file_path, str(self.root_dir)), + segment_type=segment_type, + title=title, + abstract=abstract, + section=section_path, + section_type=section_type, + section_level=level, + is_appendix=is_appendix + )) + + return fragments + + except Exception as e: + logging.error(f"Error processing file {file_path}: {e}") + return [] + + +def main(): + """使用示例""" + # 创建分割器实例 + splitter = SmartArxivSplitter( + char_range=(1000, 1200), + root_dir="gpt_log/arxiv_cache" + ) + + # 处理论文 + for fragment in splitter.process_paper("2411.03663"): + print(f"Segment {fragment.segment_index + 1}/{fragment.total_segments}") + print(f"Length: {len(fragment.content)}") + print(f"Section: {fragment.section}") + + print(fragment.content) + print("-" * 80) + + +if __name__ == "__main__": + main()