From b3aef6b3936d4c5f6312f2bd8b064f85e77be247 Mon Sep 17 00:00:00 2001 From: lbykkkk Date: Sun, 1 Dec 2024 17:35:57 +0800 Subject: [PATCH] up --- .../rag_fns/arxiv_fns/arxiv_downloader.py | 32 +-- .../rag_fns/arxiv_fns/arxiv_splitter.py | 41 ++-- .../rag_fns/arxiv_fns/author_extractor.py | 177 +++++++++++++++ .../rag_fns/arxiv_fns/essay_structure.py | 24 +- .../rag_fns/arxiv_fns/latex_cleaner.py | 10 +- .../rag_fns/arxiv_fns/latex_patterns.py | 213 +++++++++--------- .../rag_fns/arxiv_fns/section_extractor.py | 18 +- .../rag_fns/arxiv_fns/section_fragment.py | 15 +- .../rag_fns/arxiv_fns/tex_utils.py | 19 +- crazy_functions/rag_fns/llama_index_worker.py | 19 +- crazy_functions/rag_fns/milvus_worker.py | 26 +-- crazy_functions/rag_fns/rag_file_support.py | 8 +- crazy_functions/rag_fns/vector_store_index.py | 30 ++- 13 files changed, 398 insertions(+), 234 deletions(-) create mode 100644 crazy_functions/rag_fns/arxiv_fns/author_extractor.py diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py index 6d0a8646..c7ccfd49 100644 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_downloader.py @@ -1,12 +1,14 @@ import logging -import requests import tarfile from pathlib import Path from typing import Optional, Dict +import requests + + class ArxivDownloader: """用于下载arXiv论文源码的下载器""" - + def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None): """ 初始化下载器 @@ -18,13 +20,13 @@ class ArxivDownloader: self.root_dir = Path(root_dir) self.root_dir.mkdir(exist_ok=True) self.proxies = proxies - + # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) - + def _download_and_extract(self, arxiv_id: str) -> str: """ 下载并解压arxiv论文源码 @@ -40,19 +42,19 @@ class ArxivDownloader: """ paper_dir = self.root_dir / arxiv_id tar_path = paper_dir / f"{arxiv_id}.tar.gz" - + # 检查缓存 if paper_dir.exists() and any(paper_dir.iterdir()): logging.info(f"Using cached version for {arxiv_id}") return str(paper_dir) - + paper_dir.mkdir(exist_ok=True) - + urls = [ f"https://arxiv.org/src/{arxiv_id}", f"https://arxiv.org/e-print/{arxiv_id}" ] - + for url in urls: try: logging.info(f"Downloading from {url}") @@ -65,9 +67,9 @@ class ArxivDownloader: except Exception as e: logging.warning(f"Download failed for {url}: {e}") continue - + raise RuntimeError(f"Failed to download paper {arxiv_id}") - + def download_paper(self, arxiv_id: str) -> str: """ 下载指定的arXiv论文 @@ -80,6 +82,7 @@ class ArxivDownloader: """ return self._download_and_extract(arxiv_id) + def main(): """测试下载功能""" # 配置代理(如果需要) @@ -87,16 +90,16 @@ def main(): "http": "http://your-proxy:port", "https": "https://your-proxy:port" } - + # 创建下载器实例(如果不需要代理,可以不传入proxies参数) downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None) - + # 测试下载一篇论文(这里使用一个示例ID) try: paper_id = "2103.00020" # 这是一个示例ID paper_dir = downloader.download_paper(paper_id) print(f"Successfully downloaded paper to: {paper_dir}") - + # 检查下载的文件 paper_path = Path(paper_dir) if paper_path.exists(): @@ -107,5 +110,6 @@ def main(): except Exception as e: print(f"Error downloading paper: {e}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py index 049fa4a4..c78f5ccc 100644 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py @@ -1,21 +1,19 @@ -import os -import re -import time -import aiohttp import asyncio -import requests -import tarfile import logging -from pathlib import Path +import re +import tarfile +import time from copy import deepcopy +from pathlib import Path +from typing import List, Optional, Dict, Set -from typing import Generator, List, Tuple, Optional, Dict, Set -from concurrent.futures import ThreadPoolExecutor, as_completed -from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils -from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment +import aiohttp + +from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section -from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor +from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment +from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path: @@ -31,7 +29,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " """ from datetime import datetime from pathlib import Path - import re # Create output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -103,9 +100,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " # Content f.write("\n**Content:**\n") - # f.write("```tex\n") - # f.write(fragment.content) - # f.write("\n```\n") f.write("\n") f.write(fragment.content) f.write("\n") @@ -140,6 +134,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " print(f"Fragments saved to: {file_path}") return file_path + # 定义各种引用命令的模式 CITATION_PATTERNS = [ # 基本的 \cite{} 格式 @@ -199,8 +194,6 @@ class ArxivSplitter: # 配置日志 self._setup_logging() - - def _setup_logging(self): """配置日志""" logging.basicConfig( @@ -221,7 +214,6 @@ class ArxivSplitter: return arxiv_id.split('v')[0].strip() return input_str.split('v')[0].strip() - def _check_cache(self, paper_dir: Path) -> bool: """ 检查缓存是否有效,包括文件完整性检查 @@ -545,6 +537,7 @@ class ArxivSplitter: self.logger.error(f"Error finding citation contexts: {str(e)}") return contexts + async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]: """ Process ArXiv paper and convert to list of SectionFragments. @@ -573,8 +566,6 @@ class ArxivSplitter: # 读取主 TeX 文件内容 main_tex_content = read_tex_file(main_tex) - - # Get all related TeX files and references tex_files = self.tex_processor.resolve_includes(main_tex) ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir) @@ -742,7 +733,6 @@ class ArxivSplitter: return content.strip() - async def test_arxiv_splitter(): """测试ArXiv分割器的功能""" @@ -765,14 +755,13 @@ async def test_arxiv_splitter(): root_dir="test_cache" ) - for case in test_cases: print(f"\nTesting paper: {case['arxiv_id']}") try: fragments = await splitter.process(case['arxiv_id']) # 保存fragments - output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") + output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") print(f"Output saved to: {output_dir}") # # 内容检查 # for fragment in fragments: @@ -780,7 +769,7 @@ async def test_arxiv_splitter(): # # print((fragment.content)) # print(len(fragment.content)) - # 类型检查 + # 类型检查 except Exception as e: @@ -789,4 +778,4 @@ async def test_arxiv_splitter(): if __name__ == "__main__": - asyncio.run(test_arxiv_splitter()) \ No newline at end of file + asyncio.run(test_arxiv_splitter()) diff --git a/crazy_functions/rag_fns/arxiv_fns/author_extractor.py b/crazy_functions/rag_fns/arxiv_fns/author_extractor.py new file mode 100644 index 00000000..05b958f5 --- /dev/null +++ b/crazy_functions/rag_fns/arxiv_fns/author_extractor.py @@ -0,0 +1,177 @@ +import re +from typing import Optional + + +class LatexAuthorExtractor: + def __init__(self): + # Patterns for matching author blocks with balanced braces + self.author_block_patterns = [ + # Standard LaTeX patterns with optional arguments + r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + # Conference and journal specific patterns + r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + # Academic publisher specific patterns + r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}', + ] + + # Cleaning patterns for LaTeX commands and formatting + self.cleaning_patterns = [ + # Text formatting commands - preserve content + (r'\\textbf\{([^}]+)\}', r'\1'), + (r'\\textit\{([^}]+)\}', r'\1'), + (r'\\emph\{([^}]+)\}', r'\1'), + (r'\\texttt\{([^}]+)\}', r'\1'), + (r'\\textrm\{([^}]+)\}', r'\1'), + (r'\\text\{([^}]+)\}', r'\1'), + + # Affiliation and footnote markers + (r'\$\^{[^}]+}\$', ''), + (r'\^{[^}]+}', ''), + (r'\\thanks\{[^}]+\}', ''), + (r'\\footnote\{[^}]+\}', ''), + + # Email and contact formatting + (r'\\email\{([^}]+)\}', r'\1'), + (r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'), + + # Institution formatting + (r'\\inst\{[^}]+\}', ''), + (r'\\affil\{[^}]+\}', ''), + + # Special characters and symbols + (r'\\&', '&'), + (r'\\\\\s*', ' '), + (r'\\,', ' '), + (r'\\;', ' '), + (r'\\quad', ' '), + (r'\\qquad', ' '), + + # Math mode content + (r'\$[^$]+\$', ''), + + # Common symbols + (r'\\dagger', '†'), + (r'\\ddagger', '‡'), + (r'\\ast', '*'), + (r'\\star', '★'), + + # Remove remaining LaTeX commands + (r'\\[a-zA-Z]+', ''), + + # Clean up remaining special characters + (r'[\\{}]', '') + ] + + def extract_author_block(self, text: str) -> Optional[str]: + """ + Extract the complete author block from LaTeX text. + + Args: + text (str): Input LaTeX text + + Returns: + Optional[str]: Extracted author block or None if not found + """ + try: + if not text: + return None + + for pattern in self.author_block_patterns: + match = re.search(pattern, text, re.DOTALL | re.MULTILINE) + if match: + return match.group(1).strip() + return None + + except (AttributeError, IndexError) as e: + print(f"Error extracting author block: {e}") + return None + + def clean_tex_commands(self, text: str) -> str: + """ + Remove LaTeX commands and formatting from text while preserving content. + + Args: + text (str): Text containing LaTeX commands + + Returns: + str: Cleaned text with commands removed + """ + if not text: + return "" + + cleaned_text = text + + # Apply cleaning patterns + for pattern, replacement in self.cleaning_patterns: + cleaned_text = re.sub(pattern, replacement, cleaned_text) + + # Clean up whitespace + cleaned_text = re.sub(r'\s+', ' ', cleaned_text) + cleaned_text = cleaned_text.strip() + + return cleaned_text + + def extract_authors(self, text: str) -> Optional[str]: + """ + Extract and clean author information from LaTeX text. + + Args: + text (str): Input LaTeX text + + Returns: + Optional[str]: Cleaned author information or None if extraction fails + """ + try: + if not text: + return None + + # Extract author block + author_block = self.extract_author_block(text) + if not author_block: + return None + + # Clean LaTeX commands + cleaned_authors = self.clean_tex_commands(author_block) + return cleaned_authors or None + + except Exception as e: + print(f"Error processing text: {e}") + return None + + +def test_author_extractor(): + """Test the LatexAuthorExtractor with sample inputs.""" + test_cases = [ + # Basic test case + (r"\author{John Doe}", "John Doe"), + + # Test with multiple authors + (r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"), + + # Test with affiliations + (r"\author[1]{John Smith}\affil[1]{University}", "John Smith"), + + ] + + extractor = LatexAuthorExtractor() + + for i, (input_tex, expected) in enumerate(test_cases, 1): + result = extractor.extract_authors(input_tex) + print(f"\nTest case {i}:") + print(f"Input: {input_tex[:50]}...") + print(f"Expected: {expected[:50]}...") + print(f"Got: {result[:50]}...") + print(f"Pass: {bool(result and result.strip() == expected.strip())}") + + +if __name__ == "__main__": + test_author_extractor() diff --git a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py index cc1f1391..09fffcfc 100644 --- a/crazy_functions/rag_fns/arxiv_fns/essay_structure.py +++ b/crazy_functions/rag_fns/arxiv_fns/essay_structure.py @@ -5,14 +5,15 @@ This module provides functionality for parsing and extracting structured informa including metadata, document structure, and content. It uses modular design and clean architecture principles. """ +import logging import re from abc import ABC, abstractmethod -import logging -from dataclasses import dataclass, field -from typing import List, Optional, Dict from copy import deepcopy +from dataclasses import dataclass, field +from typing import List, Dict + from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands -from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor +from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor # Configure logging logging.basicConfig(level=logging.INFO) @@ -28,6 +29,7 @@ def read_tex_file(file_path): except UnicodeDecodeError: continue + @dataclass class DocumentStructure: title: str = '' @@ -68,7 +70,7 @@ class DocumentStructure: if other_section.title in sections_map: # Merge existing section idx = next(i for i, s in enumerate(merged.toc) - if s.title == other_section.title) + if s.title == other_section.title) merged.toc[idx] = merged.toc[idx].merge(other_section) else: # Add new section @@ -149,6 +151,8 @@ class DocumentStructure: result.extend(_format_section(section, 0) for section in self.toc) return "".join(result) + + class BaseExtractor(ABC): """Base class for LaTeX content extractors.""" @@ -157,6 +161,7 @@ class BaseExtractor(ABC): """Extract specific content from LaTeX document.""" pass + class TitleExtractor(BaseExtractor): """Extracts title from LaTeX document.""" @@ -180,6 +185,7 @@ class TitleExtractor(BaseExtractor): return clean_latex_commands(title) return '' + class AbstractExtractor(BaseExtractor): """Extracts abstract from LaTeX document.""" @@ -203,6 +209,7 @@ class AbstractExtractor(BaseExtractor): return clean_latex_commands(abstract) return '' + class EssayStructureParser: """Main class for parsing LaTeX documents.""" @@ -231,6 +238,7 @@ class EssayStructureParser: content = re.sub(r'(?= other.value + @dataclass class Section: level: SectionLevel @@ -46,6 +47,7 @@ class Section: content: str = '' bibliography: str = '' subsections: List['Section'] = field(default_factory=list) + def merge(self, other: 'Section') -> 'Section': """Merge this section with another section.""" if self.title != other.title or self.level != other.level: @@ -78,6 +80,8 @@ class Section: return content1 # Combine non-empty contents with a separator return f"{content1}\n\n{content2}" + + @dataclass class LatexEnvironment: """表示LaTeX环境的数据类""" @@ -409,4 +413,4 @@ f(x) = \int_0^x g(t) dt if __name__ == "__main__": - test_enhanced_extractor() \ No newline at end of file + test_enhanced_extractor() diff --git a/crazy_functions/rag_fns/arxiv_fns/section_fragment.py b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py index f933837d..7ea33998 100644 --- a/crazy_functions/rag_fns/arxiv_fns/section_fragment.py +++ b/crazy_functions/rag_fns/arxiv_fns/section_fragment.py @@ -1,19 +1,14 @@ from dataclasses import dataclass + @dataclass class SectionFragment: """Arxiv论文片段数据类""" title: str # 论文标题 authors: str abstract: str # 论文摘要 - catalogs: str # 文章各章节的目录结构 + catalogs: str # 文章各章节的目录结构 arxiv_id: str = "" # 添加 arxiv_id 属性 - current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字 - content: str = '' #当前片段的内容 - bibliography: str = '' #当前片段的参考文献 - - - - - - + current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字 + content: str = '' # 当前片段的内容 + bibliography: str = '' # 当前片段的参考文献 diff --git a/crazy_functions/rag_fns/arxiv_fns/tex_utils.py b/crazy_functions/rag_fns/arxiv_fns/tex_utils.py index 1fba7953..1a281d3a 100644 --- a/crazy_functions/rag_fns/arxiv_fns/tex_utils.py +++ b/crazy_functions/rag_fns/arxiv_fns/tex_utils.py @@ -1,10 +1,12 @@ -import re -import os import logging +import os +import re from pathlib import Path -from typing import List, Tuple, Dict, Set, Optional, Callable +from typing import List, Set, Optional + from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns + class TexUtils: """TeX文档处理器类""" @@ -21,9 +23,6 @@ class TexUtils: self._init_patterns() self.latex_only_patterns = LaTeXPatterns.latex_only_patterns - - - def _init_patterns(self): """初始化LaTeX模式匹配规则""" # 特殊环境模式 @@ -234,6 +233,7 @@ class TexUtils: processed_refs.append("\n".join(ref_lines)) return processed_refs + def _extract_inline_references(self, content: str) -> str: """ 从tex文件内容中提取直接写在文件中的参考文献 @@ -255,6 +255,7 @@ class TexUtils: return content[start_match.start():end_match.end()] return "" + def _preprocess_content(self, content: str) -> str: """预处理TeX内容""" # 移除注释 @@ -263,9 +264,3 @@ class TexUtils: # content = re.sub(r'\s+', ' ', content) content = re.sub(r'\n\s*\n', '\n\n', content) return content.strip() - - - - - - diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index 50a23b03..08e3d50f 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,17 +1,13 @@ -import llama_index -import os import atexit -from loguru import logger -from typing import List +import os + from llama_index.core import Document -from llama_index.core.schema import TextNode -from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel -from shared_utils.connect_void_terminal import get_chat_default_kwargs -from llama_index.core import VectorStoreIndex, SimpleDirectoryReader -from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex from llama_index.core.ingestion import run_transformations -from llama_index.core import PromptTemplate -from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core.schema import TextNode +from loguru import logger + +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel DEFAULT_QUERY_GENERATION_PROMPT = """\ Now, you have context information as below: @@ -127,7 +123,6 @@ class LlamaIndexRagWorker(SaveLoad): logger.error(f"Error saving checkpoint: {str(e)}") raise - def assign_embedding_model(self): pass diff --git a/crazy_functions/rag_fns/milvus_worker.py b/crazy_functions/rag_fns/milvus_worker.py index 6eccb6a7..680784b8 100644 --- a/crazy_functions/rag_fns/milvus_worker.py +++ b/crazy_functions/rag_fns/milvus_worker.py @@ -1,20 +1,14 @@ -import llama_index -import os import atexit +import os from typing import List -from loguru import logger -from llama_index.core import Document -from llama_index.core.schema import TextNode -from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel -from shared_utils.connect_void_terminal import get_chat_default_kwargs -from llama_index.core import VectorStoreIndex, SimpleDirectoryReader -from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex -from llama_index.core.ingestion import run_transformations -from llama_index.core import PromptTemplate -from llama_index.core.response_synthesizers import TreeSummarize + from llama_index.core import StorageContext from llama_index.vector_stores.milvus import MilvusVectorStore +from loguru import logger + from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel DEFAULT_QUERY_GENERATION_PROMPT = """\ Now, you have context information as below: @@ -65,17 +59,19 @@ class MilvusSaveLoad(): def create_new_vs(self, checkpoint_dir, overwrite=False): vector_store = MilvusVectorStore( - uri=os.path.join(checkpoint_dir, "milvus_demo.db"), + uri=os.path.join(checkpoint_dir, "milvus_demo.db"), dim=self.embed_model.embedding_dimension(), overwrite=overwrite ) storage_context = StorageContext.from_defaults(vector_store=vector_store) - index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model) + index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, + embed_model=self.embed_model) return index def purge(self): self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True) + class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker): def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: @@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker): docstore = self.vs_index.storage_context.docstore.docs if not docstore.items(): raise ValueError("cannot inspect") - vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) + vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()]) except: dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ') vector_store_preview = "\n".join( diff --git a/crazy_functions/rag_fns/rag_file_support.py b/crazy_functions/rag_fns/rag_file_support.py index f826fab1..05d141b1 100644 --- a/crazy_functions/rag_fns/rag_file_support.py +++ b/crazy_functions/rag_fns/rag_file_support.py @@ -1,8 +1,8 @@ -import os from llama_index.core import SimpleDirectoryReader -supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', - '.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m'] +supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', + '.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m'] + def read_docx_doc(file_path): if file_path.split(".")[-1] == "docx": @@ -25,9 +25,11 @@ def read_docx_doc(file_path): raise RuntimeError('请先将.doc文档转换为.docx文档。') return file_content + # 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑 import os + def extract_text(file_path): _, ext = os.path.splitext(file_path.lower()) diff --git a/crazy_functions/rag_fns/vector_store_index.py b/crazy_functions/rag_fns/vector_store_index.py index 74e8b09d..d14a4c62 100644 --- a/crazy_functions/rag_fns/vector_store_index.py +++ b/crazy_functions/rag_fns/vector_store_index.py @@ -1,6 +1,6 @@ -from llama_index.core import VectorStoreIndex -from typing import Any, List, Optional +from typing import Any, List, Optional +from llama_index.core import VectorStoreIndex from llama_index.core.callbacks.base import CallbackManager from llama_index.core.schema import TransformComponent from llama_index.core.service_context import ServiceContext @@ -13,18 +13,18 @@ from llama_index.core.storage.storage_context import StorageContext class GptacVectorStoreIndex(VectorStoreIndex): - + @classmethod def default_vector_store( - cls, - storage_context: Optional[StorageContext] = None, - show_progress: bool = False, - callback_manager: Optional[CallbackManager] = None, - transformations: Optional[List[TransformComponent]] = None, - # deprecated - service_context: Optional[ServiceContext] = None, - embed_model = None, - **kwargs: Any, + cls, + storage_context: Optional[StorageContext] = None, + show_progress: bool = False, + callback_manager: Optional[CallbackManager] = None, + transformations: Optional[List[TransformComponent]] = None, + # deprecated + service_context: Optional[ServiceContext] = None, + embed_model=None, + **kwargs: Any, ): """Create index from documents. @@ -36,15 +36,14 @@ class GptacVectorStoreIndex(VectorStoreIndex): storage_context = storage_context or StorageContext.from_defaults() docstore = storage_context.docstore callback_manager = ( - callback_manager - or callback_manager_from_settings_or_context(Settings, service_context) + callback_manager + or callback_manager_from_settings_or_context(Settings, service_context) ) transformations = transformations or transformations_from_settings_or_context( Settings, service_context ) with callback_manager.as_trace("index_construction"): - return cls( nodes=[], storage_context=storage_context, @@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex): embed_model=embed_model, **kwargs, ) -