up
This commit is contained in:
@@ -1,12 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
import tarfile
|
import tarfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class ArxivDownloader:
|
class ArxivDownloader:
|
||||||
"""用于下载arXiv论文源码的下载器"""
|
"""用于下载arXiv论文源码的下载器"""
|
||||||
|
|
||||||
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
||||||
"""
|
"""
|
||||||
初始化下载器
|
初始化下载器
|
||||||
@@ -18,13 +20,13 @@ class ArxivDownloader:
|
|||||||
self.root_dir = Path(root_dir)
|
self.root_dir = Path(root_dir)
|
||||||
self.root_dir.mkdir(exist_ok=True)
|
self.root_dir.mkdir(exist_ok=True)
|
||||||
self.proxies = proxies
|
self.proxies = proxies
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_extract(self, arxiv_id: str) -> str:
|
def _download_and_extract(self, arxiv_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
下载并解压arxiv论文源码
|
下载并解压arxiv论文源码
|
||||||
@@ -40,19 +42,19 @@ class ArxivDownloader:
|
|||||||
"""
|
"""
|
||||||
paper_dir = self.root_dir / arxiv_id
|
paper_dir = self.root_dir / arxiv_id
|
||||||
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if paper_dir.exists() and any(paper_dir.iterdir()):
|
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||||||
logging.info(f"Using cached version for {arxiv_id}")
|
logging.info(f"Using cached version for {arxiv_id}")
|
||||||
return str(paper_dir)
|
return str(paper_dir)
|
||||||
|
|
||||||
paper_dir.mkdir(exist_ok=True)
|
paper_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
f"https://arxiv.org/src/{arxiv_id}",
|
f"https://arxiv.org/src/{arxiv_id}",
|
||||||
f"https://arxiv.org/e-print/{arxiv_id}"
|
f"https://arxiv.org/e-print/{arxiv_id}"
|
||||||
]
|
]
|
||||||
|
|
||||||
for url in urls:
|
for url in urls:
|
||||||
try:
|
try:
|
||||||
logging.info(f"Downloading from {url}")
|
logging.info(f"Downloading from {url}")
|
||||||
@@ -65,9 +67,9 @@ class ArxivDownloader:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Download failed for {url}: {e}")
|
logging.warning(f"Download failed for {url}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||||||
|
|
||||||
def download_paper(self, arxiv_id: str) -> str:
|
def download_paper(self, arxiv_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
下载指定的arXiv论文
|
下载指定的arXiv论文
|
||||||
@@ -80,6 +82,7 @@ class ArxivDownloader:
|
|||||||
"""
|
"""
|
||||||
return self._download_and_extract(arxiv_id)
|
return self._download_and_extract(arxiv_id)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""测试下载功能"""
|
"""测试下载功能"""
|
||||||
# 配置代理(如果需要)
|
# 配置代理(如果需要)
|
||||||
@@ -87,16 +90,16 @@ def main():
|
|||||||
"http": "http://your-proxy:port",
|
"http": "http://your-proxy:port",
|
||||||
"https": "https://your-proxy:port"
|
"https": "https://your-proxy:port"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
||||||
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
||||||
|
|
||||||
# 测试下载一篇论文(这里使用一个示例ID)
|
# 测试下载一篇论文(这里使用一个示例ID)
|
||||||
try:
|
try:
|
||||||
paper_id = "2103.00020" # 这是一个示例ID
|
paper_id = "2103.00020" # 这是一个示例ID
|
||||||
paper_dir = downloader.download_paper(paper_id)
|
paper_dir = downloader.download_paper(paper_id)
|
||||||
print(f"Successfully downloaded paper to: {paper_dir}")
|
print(f"Successfully downloaded paper to: {paper_dir}")
|
||||||
|
|
||||||
# 检查下载的文件
|
# 检查下载的文件
|
||||||
paper_path = Path(paper_dir)
|
paper_path = Path(paper_dir)
|
||||||
if paper_path.exists():
|
if paper_path.exists():
|
||||||
@@ -107,5 +110,6 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error downloading paper: {e}")
|
print(f"Error downloading paper: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,21 +1,19 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
import tarfile
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
import re
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict, Set
|
||||||
|
|
||||||
from typing import Generator, List, Tuple, Optional, Dict, Set
|
import aiohttp
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||||
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||||
|
|
||||||
|
|
||||||
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
||||||
@@ -31,7 +29,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
|
||||||
|
|
||||||
# Create output directory
|
# Create output directory
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
@@ -103,9 +100,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
|
|
||||||
# Content
|
# Content
|
||||||
f.write("\n**Content:**\n")
|
f.write("\n**Content:**\n")
|
||||||
# f.write("```tex\n")
|
|
||||||
# f.write(fragment.content)
|
|
||||||
# f.write("\n```\n")
|
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
f.write(fragment.content)
|
f.write(fragment.content)
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
@@ -140,6 +134,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
print(f"Fragments saved to: {file_path}")
|
print(f"Fragments saved to: {file_path}")
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
# 定义各种引用命令的模式
|
# 定义各种引用命令的模式
|
||||||
CITATION_PATTERNS = [
|
CITATION_PATTERNS = [
|
||||||
# 基本的 \cite{} 格式
|
# 基本的 \cite{} 格式
|
||||||
@@ -199,8 +194,6 @@ class ArxivSplitter:
|
|||||||
# 配置日志
|
# 配置日志
|
||||||
self._setup_logging()
|
self._setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging(self):
|
def _setup_logging(self):
|
||||||
"""配置日志"""
|
"""配置日志"""
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -221,7 +214,6 @@ class ArxivSplitter:
|
|||||||
return arxiv_id.split('v')[0].strip()
|
return arxiv_id.split('v')[0].strip()
|
||||||
return input_str.split('v')[0].strip()
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
|
||||||
def _check_cache(self, paper_dir: Path) -> bool:
|
def _check_cache(self, paper_dir: Path) -> bool:
|
||||||
"""
|
"""
|
||||||
检查缓存是否有效,包括文件完整性检查
|
检查缓存是否有效,包括文件完整性检查
|
||||||
@@ -545,6 +537,7 @@ class ArxivSplitter:
|
|||||||
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
||||||
|
|
||||||
return contexts
|
return contexts
|
||||||
|
|
||||||
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
||||||
"""
|
"""
|
||||||
Process ArXiv paper and convert to list of SectionFragments.
|
Process ArXiv paper and convert to list of SectionFragments.
|
||||||
@@ -573,8 +566,6 @@ class ArxivSplitter:
|
|||||||
# 读取主 TeX 文件内容
|
# 读取主 TeX 文件内容
|
||||||
main_tex_content = read_tex_file(main_tex)
|
main_tex_content = read_tex_file(main_tex)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get all related TeX files and references
|
# Get all related TeX files and references
|
||||||
tex_files = self.tex_processor.resolve_includes(main_tex)
|
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||||
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
||||||
@@ -742,7 +733,6 @@ class ArxivSplitter:
|
|||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def test_arxiv_splitter():
|
async def test_arxiv_splitter():
|
||||||
"""测试ArXiv分割器的功能"""
|
"""测试ArXiv分割器的功能"""
|
||||||
|
|
||||||
@@ -765,14 +755,13 @@ async def test_arxiv_splitter():
|
|||||||
root_dir="test_cache"
|
root_dir="test_cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||||
try:
|
try:
|
||||||
fragments = await splitter.process(case['arxiv_id'])
|
fragments = await splitter.process(case['arxiv_id'])
|
||||||
|
|
||||||
# 保存fragments
|
# 保存fragments
|
||||||
output_dir = save_fragments_to_file(fragments,output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
||||||
print(f"Output saved to: {output_dir}")
|
print(f"Output saved to: {output_dir}")
|
||||||
# # 内容检查
|
# # 内容检查
|
||||||
# for fragment in fragments:
|
# for fragment in fragments:
|
||||||
@@ -780,7 +769,7 @@ async def test_arxiv_splitter():
|
|||||||
#
|
#
|
||||||
# print((fragment.content))
|
# print((fragment.content))
|
||||||
# print(len(fragment.content))
|
# print(len(fragment.content))
|
||||||
# 类型检查
|
# 类型检查
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -789,4 +778,4 @@ async def test_arxiv_splitter():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(test_arxiv_splitter())
|
asyncio.run(test_arxiv_splitter())
|
||||||
|
|||||||
177
crazy_functions/rag_fns/arxiv_fns/author_extractor.py
Normal file
177
crazy_functions/rag_fns/arxiv_fns/author_extractor.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LatexAuthorExtractor:
|
||||||
|
def __init__(self):
|
||||||
|
# Patterns for matching author blocks with balanced braces
|
||||||
|
self.author_block_patterns = [
|
||||||
|
# Standard LaTeX patterns with optional arguments
|
||||||
|
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Conference and journal specific patterns
|
||||||
|
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Academic publisher specific patterns
|
||||||
|
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cleaning patterns for LaTeX commands and formatting
|
||||||
|
self.cleaning_patterns = [
|
||||||
|
# Text formatting commands - preserve content
|
||||||
|
(r'\\textbf\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textit\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\emph\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\texttt\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textrm\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\text\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Affiliation and footnote markers
|
||||||
|
(r'\$\^{[^}]+}\$', ''),
|
||||||
|
(r'\^{[^}]+}', ''),
|
||||||
|
(r'\\thanks\{[^}]+\}', ''),
|
||||||
|
(r'\\footnote\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Email and contact formatting
|
||||||
|
(r'\\email\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Institution formatting
|
||||||
|
(r'\\inst\{[^}]+\}', ''),
|
||||||
|
(r'\\affil\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Special characters and symbols
|
||||||
|
(r'\\&', '&'),
|
||||||
|
(r'\\\\\s*', ' '),
|
||||||
|
(r'\\,', ' '),
|
||||||
|
(r'\\;', ' '),
|
||||||
|
(r'\\quad', ' '),
|
||||||
|
(r'\\qquad', ' '),
|
||||||
|
|
||||||
|
# Math mode content
|
||||||
|
(r'\$[^$]+\$', ''),
|
||||||
|
|
||||||
|
# Common symbols
|
||||||
|
(r'\\dagger', '†'),
|
||||||
|
(r'\\ddagger', '‡'),
|
||||||
|
(r'\\ast', '*'),
|
||||||
|
(r'\\star', '★'),
|
||||||
|
|
||||||
|
# Remove remaining LaTeX commands
|
||||||
|
(r'\\[a-zA-Z]+', ''),
|
||||||
|
|
||||||
|
# Clean up remaining special characters
|
||||||
|
(r'[\\{}]', '')
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_author_block(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the complete author block from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Extracted author block or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for pattern in self.author_block_patterns:
|
||||||
|
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (AttributeError, IndexError) as e:
|
||||||
|
print(f"Error extracting author block: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tex_commands(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove LaTeX commands and formatting from text while preserving content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text containing LaTeX commands
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned text with commands removed
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
cleaned_text = text
|
||||||
|
|
||||||
|
# Apply cleaning patterns
|
||||||
|
for pattern, replacement in self.cleaning_patterns:
|
||||||
|
cleaned_text = re.sub(pattern, replacement, cleaned_text)
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
|
||||||
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def extract_authors(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract and clean author information from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Cleaned author information or None if extraction fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract author block
|
||||||
|
author_block = self.extract_author_block(text)
|
||||||
|
if not author_block:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean LaTeX commands
|
||||||
|
cleaned_authors = self.clean_tex_commands(author_block)
|
||||||
|
return cleaned_authors or None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing text: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_author_extractor():
|
||||||
|
"""Test the LatexAuthorExtractor with sample inputs."""
|
||||||
|
test_cases = [
|
||||||
|
# Basic test case
|
||||||
|
(r"\author{John Doe}", "John Doe"),
|
||||||
|
|
||||||
|
# Test with multiple authors
|
||||||
|
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
|
||||||
|
|
||||||
|
# Test with affiliations
|
||||||
|
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
extractor = LatexAuthorExtractor()
|
||||||
|
|
||||||
|
for i, (input_tex, expected) in enumerate(test_cases, 1):
|
||||||
|
result = extractor.extract_authors(input_tex)
|
||||||
|
print(f"\nTest case {i}:")
|
||||||
|
print(f"Input: {input_tex[:50]}...")
|
||||||
|
print(f"Expected: {expected[:50]}...")
|
||||||
|
print(f"Got: {result[:50]}...")
|
||||||
|
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_author_extractor()
|
||||||
@@ -5,14 +5,15 @@ This module provides functionality for parsing and extracting structured informa
|
|||||||
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -28,6 +29,7 @@ def read_tex_file(file_path):
|
|||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentStructure:
|
class DocumentStructure:
|
||||||
title: str = ''
|
title: str = ''
|
||||||
@@ -68,7 +70,7 @@ class DocumentStructure:
|
|||||||
if other_section.title in sections_map:
|
if other_section.title in sections_map:
|
||||||
# Merge existing section
|
# Merge existing section
|
||||||
idx = next(i for i, s in enumerate(merged.toc)
|
idx = next(i for i, s in enumerate(merged.toc)
|
||||||
if s.title == other_section.title)
|
if s.title == other_section.title)
|
||||||
merged.toc[idx] = merged.toc[idx].merge(other_section)
|
merged.toc[idx] = merged.toc[idx].merge(other_section)
|
||||||
else:
|
else:
|
||||||
# Add new section
|
# Add new section
|
||||||
@@ -149,6 +151,8 @@ class DocumentStructure:
|
|||||||
result.extend(_format_section(section, 0) for section in self.toc)
|
result.extend(_format_section(section, 0) for section in self.toc)
|
||||||
|
|
||||||
return "".join(result)
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
class BaseExtractor(ABC):
|
class BaseExtractor(ABC):
|
||||||
"""Base class for LaTeX content extractors."""
|
"""Base class for LaTeX content extractors."""
|
||||||
|
|
||||||
@@ -157,6 +161,7 @@ class BaseExtractor(ABC):
|
|||||||
"""Extract specific content from LaTeX document."""
|
"""Extract specific content from LaTeX document."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TitleExtractor(BaseExtractor):
|
class TitleExtractor(BaseExtractor):
|
||||||
"""Extracts title from LaTeX document."""
|
"""Extracts title from LaTeX document."""
|
||||||
|
|
||||||
@@ -180,6 +185,7 @@ class TitleExtractor(BaseExtractor):
|
|||||||
return clean_latex_commands(title)
|
return clean_latex_commands(title)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
class AbstractExtractor(BaseExtractor):
|
class AbstractExtractor(BaseExtractor):
|
||||||
"""Extracts abstract from LaTeX document."""
|
"""Extracts abstract from LaTeX document."""
|
||||||
|
|
||||||
@@ -203,6 +209,7 @@ class AbstractExtractor(BaseExtractor):
|
|||||||
return clean_latex_commands(abstract)
|
return clean_latex_commands(abstract)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
class EssayStructureParser:
|
class EssayStructureParser:
|
||||||
"""Main class for parsing LaTeX documents."""
|
"""Main class for parsing LaTeX documents."""
|
||||||
|
|
||||||
@@ -231,6 +238,7 @@ class EssayStructureParser:
|
|||||||
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
||||||
"""Print document structure in a readable format."""
|
"""Print document structure in a readable format."""
|
||||||
print(f"Title: {doc.title}\n")
|
print(f"Title: {doc.title}\n")
|
||||||
@@ -250,10 +258,10 @@ def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100
|
|||||||
for section in doc.toc:
|
for section in doc.toc:
|
||||||
print_section(section)
|
print_section(section)
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# Test with a file
|
# Test with a file
|
||||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||||
main_tex = read_tex_file(file_path)
|
main_tex = read_tex_file(file_path)
|
||||||
@@ -278,5 +286,5 @@ if __name__ == "__main__":
|
|||||||
additional_doc = parser.parse(tex_content)
|
additional_doc = parser.parse(tex_content)
|
||||||
main_doc = main_doc.merge(additional_doc)
|
main_doc = main_doc.merge(additional_doc)
|
||||||
|
|
||||||
tree= main_doc.generate_toc_tree()
|
tree = main_doc.generate_toc_tree()
|
||||||
pretty_print_structure(main_doc)
|
pretty_print_structure(main_doc)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
|
||||||
import re
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
@@ -326,4 +326,4 @@ if __name__ == "__main__":
|
|||||||
content = read_tex_file(file_path)
|
content = read_tex_file(file_path)
|
||||||
cleaner = LatexCleaner(config)
|
cleaner = LatexCleaner(config)
|
||||||
text = cleaner.clean_text(text)
|
text = cleaner.clean_text(text)
|
||||||
print(text)
|
print(text)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LaTeXPatterns:
|
class LaTeXPatterns:
|
||||||
@@ -142,124 +143,124 @@ class LaTeXPatterns:
|
|||||||
]
|
]
|
||||||
|
|
||||||
metadata_patterns = {
|
metadata_patterns = {
|
||||||
# 标题相关
|
# 标题相关
|
||||||
'title': [
|
'title': [
|
||||||
r'\\title\{([^}]+)\}',
|
r'\\title\{([^}]+)\}',
|
||||||
r'\\Title\{([^}]+)\}',
|
r'\\Title\{([^}]+)\}',
|
||||||
r'\\doctitle\{([^}]+)\}',
|
r'\\doctitle\{([^}]+)\}',
|
||||||
r'\\subtitle\{([^}]+)\}',
|
r'\\subtitle\{([^}]+)\}',
|
||||||
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
|
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
|
||||||
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
|
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
|
||||||
],
|
],
|
||||||
|
|
||||||
# 摘要相关
|
# 摘要相关
|
||||||
'abstract': [
|
'abstract': [
|
||||||
r'\\begin{abstract}(.*?)\\end{abstract}',
|
r'\\begin{abstract}(.*?)\\end{abstract}',
|
||||||
r'\\abstract\{([^}]+)\}',
|
r'\\abstract\{([^}]+)\}',
|
||||||
r'\\begin{摘要}(.*?)\\end{摘要}',
|
r'\\begin{摘要}(.*?)\\end{摘要}',
|
||||||
r'\\begin{Summary}(.*?)\\end{Summary}',
|
r'\\begin{Summary}(.*?)\\end{Summary}',
|
||||||
r'\\begin{synopsis}(.*?)\\end{synopsis}',
|
r'\\begin{synopsis}(.*?)\\end{synopsis}',
|
||||||
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
|
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
|
||||||
],
|
],
|
||||||
|
|
||||||
# 作者信息
|
# 作者信息
|
||||||
'author': [
|
'author': [
|
||||||
r'\\author\{([^}]+)\}',
|
r'\\author\{([^}]+)\}',
|
||||||
r'\\Author\{([^}]+)\}',
|
r'\\Author\{([^}]+)\}',
|
||||||
r'\\authorinfo\{([^}]+)\}',
|
r'\\authorinfo\{([^}]+)\}',
|
||||||
r'\\authors\{([^}]+)\}',
|
r'\\authors\{([^}]+)\}',
|
||||||
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
|
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
|
||||||
r'\\begin{authors}(.*?)\\end{authors}'
|
r'\\begin{authors}(.*?)\\end{authors}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 日期相关
|
# 日期相关
|
||||||
'date': [
|
'date': [
|
||||||
r'\\date\{([^}]+)\}',
|
r'\\date\{([^}]+)\}',
|
||||||
r'\\Date\{([^}]+)\}',
|
r'\\Date\{([^}]+)\}',
|
||||||
r'\\submitdate\{([^}]+)\}',
|
r'\\submitdate\{([^}]+)\}',
|
||||||
r'\\publishdate\{([^}]+)\}',
|
r'\\publishdate\{([^}]+)\}',
|
||||||
r'\\revisiondate\{([^}]+)\}'
|
r'\\revisiondate\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 关键词
|
# 关键词
|
||||||
'keywords': [
|
'keywords': [
|
||||||
r'\\keywords\{([^}]+)\}',
|
r'\\keywords\{([^}]+)\}',
|
||||||
r'\\Keywords\{([^}]+)\}',
|
r'\\Keywords\{([^}]+)\}',
|
||||||
r'\\begin{keywords}(.*?)\\end{keywords}',
|
r'\\begin{keywords}(.*?)\\end{keywords}',
|
||||||
r'\\key\{([^}]+)\}',
|
r'\\key\{([^}]+)\}',
|
||||||
r'\\begin{关键词}(.*?)\\end{关键词}'
|
r'\\begin{关键词}(.*?)\\end{关键词}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 机构/单位
|
# 机构/单位
|
||||||
'institution': [
|
'institution': [
|
||||||
r'\\institute\{([^}]+)\}',
|
r'\\institute\{([^}]+)\}',
|
||||||
r'\\institution\{([^}]+)\}',
|
r'\\institution\{([^}]+)\}',
|
||||||
r'\\affiliation\{([^}]+)\}',
|
r'\\affiliation\{([^}]+)\}',
|
||||||
r'\\organization\{([^}]+)\}',
|
r'\\organization\{([^}]+)\}',
|
||||||
r'\\department\{([^}]+)\}'
|
r'\\department\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 学科/主题
|
# 学科/主题
|
||||||
'subject': [
|
'subject': [
|
||||||
r'\\subject\{([^}]+)\}',
|
r'\\subject\{([^}]+)\}',
|
||||||
r'\\Subject\{([^}]+)\}',
|
r'\\Subject\{([^}]+)\}',
|
||||||
r'\\field\{([^}]+)\}',
|
r'\\field\{([^}]+)\}',
|
||||||
r'\\discipline\{([^}]+)\}'
|
r'\\discipline\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 版本信息
|
# 版本信息
|
||||||
'version': [
|
'version': [
|
||||||
r'\\version\{([^}]+)\}',
|
r'\\version\{([^}]+)\}',
|
||||||
r'\\revision\{([^}]+)\}',
|
r'\\revision\{([^}]+)\}',
|
||||||
r'\\release\{([^}]+)\}'
|
r'\\release\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 许可证/版权
|
# 许可证/版权
|
||||||
'license': [
|
'license': [
|
||||||
r'\\license\{([^}]+)\}',
|
r'\\license\{([^}]+)\}',
|
||||||
r'\\copyright\{([^}]+)\}',
|
r'\\copyright\{([^}]+)\}',
|
||||||
r'\\begin{license}(.*?)\\end{license}'
|
r'\\begin{license}(.*?)\\end{license}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 联系方式
|
# 联系方式
|
||||||
'contact': [
|
'contact': [
|
||||||
r'\\email\{([^}]+)\}',
|
r'\\email\{([^}]+)\}',
|
||||||
r'\\phone\{([^}]+)\}',
|
r'\\phone\{([^}]+)\}',
|
||||||
r'\\address\{([^}]+)\}',
|
r'\\address\{([^}]+)\}',
|
||||||
r'\\contact\{([^}]+)\}'
|
r'\\contact\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 致谢
|
# 致谢
|
||||||
'acknowledgments': [
|
'acknowledgments': [
|
||||||
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
|
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
|
||||||
r'\\acknowledgments\{([^}]+)\}',
|
r'\\acknowledgments\{([^}]+)\}',
|
||||||
r'\\thanks\{([^}]+)\}',
|
r'\\thanks\{([^}]+)\}',
|
||||||
r'\\begin{致谢}(.*?)\\end{致谢}'
|
r'\\begin{致谢}(.*?)\\end{致谢}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 项目/基金
|
# 项目/基金
|
||||||
'funding': [
|
'funding': [
|
||||||
r'\\funding\{([^}]+)\}',
|
r'\\funding\{([^}]+)\}',
|
||||||
r'\\grant\{([^}]+)\}',
|
r'\\grant\{([^}]+)\}',
|
||||||
r'\\project\{([^}]+)\}',
|
r'\\project\{([^}]+)\}',
|
||||||
r'\\support\{([^}]+)\}'
|
r'\\support\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 分类号/编号
|
# 分类号/编号
|
||||||
'classification': [
|
'classification': [
|
||||||
r'\\classification\{([^}]+)\}',
|
r'\\classification\{([^}]+)\}',
|
||||||
r'\\serialnumber\{([^}]+)\}',
|
r'\\serialnumber\{([^}]+)\}',
|
||||||
r'\\id\{([^}]+)\}',
|
r'\\id\{([^}]+)\}',
|
||||||
r'\\doi\{([^}]+)\}'
|
r'\\doi\{([^}]+)\}'
|
||||||
],
|
],
|
||||||
|
|
||||||
# 语言
|
# 语言
|
||||||
'language': [
|
'language': [
|
||||||
r'\\documentlanguage\{([^}]+)\}',
|
r'\\documentlanguage\{([^}]+)\}',
|
||||||
r'\\lang\{([^}]+)\}',
|
r'\\lang\{([^}]+)\}',
|
||||||
r'\\language\{([^}]+)\}'
|
r'\\language\{([^}]+)\}'
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
latex_only_patterns = {
|
latex_only_patterns = {
|
||||||
# 文档类和包引入
|
# 文档类和包引入
|
||||||
r'\\documentclass(\[.*?\])?\{.*?\}',
|
r'\\documentclass(\[.*?\])?\{.*?\}',
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
import re
|
|
||||||
from typing import List, Dict, Tuple, Optional, Set
|
|
||||||
from enum import Enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SectionLevel(Enum):
|
class SectionLevel(Enum):
|
||||||
CHAPTER = 0
|
CHAPTER = 0
|
||||||
@@ -39,6 +39,7 @@ class SectionLevel(Enum):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.value >= other.value
|
return self.value >= other.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Section:
|
class Section:
|
||||||
level: SectionLevel
|
level: SectionLevel
|
||||||
@@ -46,6 +47,7 @@ class Section:
|
|||||||
content: str = ''
|
content: str = ''
|
||||||
bibliography: str = ''
|
bibliography: str = ''
|
||||||
subsections: List['Section'] = field(default_factory=list)
|
subsections: List['Section'] = field(default_factory=list)
|
||||||
|
|
||||||
def merge(self, other: 'Section') -> 'Section':
|
def merge(self, other: 'Section') -> 'Section':
|
||||||
"""Merge this section with another section."""
|
"""Merge this section with another section."""
|
||||||
if self.title != other.title or self.level != other.level:
|
if self.title != other.title or self.level != other.level:
|
||||||
@@ -78,6 +80,8 @@ class Section:
|
|||||||
return content1
|
return content1
|
||||||
# Combine non-empty contents with a separator
|
# Combine non-empty contents with a separator
|
||||||
return f"{content1}\n\n{content2}"
|
return f"{content1}\n\n{content2}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LatexEnvironment:
|
class LatexEnvironment:
|
||||||
"""表示LaTeX环境的数据类"""
|
"""表示LaTeX环境的数据类"""
|
||||||
@@ -409,4 +413,4 @@ f(x) = \int_0^x g(t) dt
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_enhanced_extractor()
|
test_enhanced_extractor()
|
||||||
|
|||||||
@@ -1,19 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SectionFragment:
|
class SectionFragment:
|
||||||
"""Arxiv论文片段数据类"""
|
"""Arxiv论文片段数据类"""
|
||||||
title: str # 论文标题
|
title: str # 论文标题
|
||||||
authors: str
|
authors: str
|
||||||
abstract: str # 论文摘要
|
abstract: str # 论文摘要
|
||||||
catalogs: str # 文章各章节的目录结构
|
catalogs: str # 文章各章节的目录结构
|
||||||
arxiv_id: str = "" # 添加 arxiv_id 属性
|
arxiv_id: str = "" # 添加 arxiv_id 属性
|
||||||
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
||||||
content: str = '' #当前片段的内容
|
content: str = '' # 当前片段的内容
|
||||||
bibliography: str = '' #当前片段的参考文献
|
bibliography: str = '' # 当前片段的参考文献
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import re
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Dict, Set, Optional, Callable
|
from typing import List, Set, Optional
|
||||||
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
||||||
|
|
||||||
|
|
||||||
class TexUtils:
|
class TexUtils:
|
||||||
"""TeX文档处理器类"""
|
"""TeX文档处理器类"""
|
||||||
|
|
||||||
@@ -21,9 +23,6 @@ class TexUtils:
|
|||||||
self._init_patterns()
|
self._init_patterns()
|
||||||
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_patterns(self):
|
def _init_patterns(self):
|
||||||
"""初始化LaTeX模式匹配规则"""
|
"""初始化LaTeX模式匹配规则"""
|
||||||
# 特殊环境模式
|
# 特殊环境模式
|
||||||
@@ -234,6 +233,7 @@ class TexUtils:
|
|||||||
processed_refs.append("\n".join(ref_lines))
|
processed_refs.append("\n".join(ref_lines))
|
||||||
|
|
||||||
return processed_refs
|
return processed_refs
|
||||||
|
|
||||||
def _extract_inline_references(self, content: str) -> str:
|
def _extract_inline_references(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
从tex文件内容中提取直接写在文件中的参考文献
|
从tex文件内容中提取直接写在文件中的参考文献
|
||||||
@@ -255,6 +255,7 @@ class TexUtils:
|
|||||||
return content[start_match.start():end_match.end()]
|
return content[start_match.start():end_match.end()]
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _preprocess_content(self, content: str) -> str:
|
def _preprocess_content(self, content: str) -> str:
|
||||||
"""预处理TeX内容"""
|
"""预处理TeX内容"""
|
||||||
# 移除注释
|
# 移除注释
|
||||||
@@ -263,9 +264,3 @@ class TexUtils:
|
|||||||
# content = re.sub(r'\s+', ' ', content)
|
# content = re.sub(r'\s+', ' ', content)
|
||||||
content = re.sub(r'\n\s*\n', '\n\n', content)
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,13 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
from loguru import logger
|
import os
|
||||||
from typing import List
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core import PromptTemplate
|
from llama_index.core.schema import TextNode
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
from loguru import logger
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -127,7 +123,6 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
logger.error(f"Error saving checkpoint: {str(e)}")
|
logger.error(f"Error saving checkpoint: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,14 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from loguru import logger
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
|
||||||
from llama_index.core import PromptTemplate
|
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
from llama_index.core import StorageContext
|
from llama_index.core import StorageContext
|
||||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -65,17 +59,19 @@ class MilvusSaveLoad():
|
|||||||
|
|
||||||
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
||||||
vector_store = MilvusVectorStore(
|
vector_store = MilvusVectorStore(
|
||||||
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
||||||
dim=self.embed_model.embedding_dimension(),
|
dim=self.embed_model.embedding_dimension(),
|
||||||
overwrite=overwrite
|
overwrite=overwrite
|
||||||
)
|
)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
|
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
|
||||||
|
embed_model=self.embed_model)
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def purge(self):
|
def purge(self):
|
||||||
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
||||||
|
|
||||||
|
|
||||||
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
||||||
|
|
||||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
@@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
|||||||
docstore = self.vs_index.storage_context.docstore.docs
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
if not docstore.items():
|
if not docstore.items():
|
||||||
raise ValueError("cannot inspect")
|
raise ValueError("cannot inspect")
|
||||||
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
|
||||||
except:
|
except:
|
||||||
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
||||||
vector_store_preview = "\n".join(
|
vector_store_preview = "\n".join(
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||||
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
|
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
|
||||||
|
|
||||||
|
|
||||||
def read_docx_doc(file_path):
|
def read_docx_doc(file_path):
|
||||||
if file_path.split(".")[-1] == "docx":
|
if file_path.split(".")[-1] == "docx":
|
||||||
@@ -25,9 +25,11 @@ def read_docx_doc(file_path):
|
|||||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||||
return file_content
|
return file_content
|
||||||
|
|
||||||
|
|
||||||
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def extract_text(file_path):
|
def extract_text(file_path):
|
||||||
_, ext = os.path.splitext(file_path.lower())
|
_, ext = os.path.splitext(file_path.lower())
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from llama_index.core import VectorStoreIndex
|
from typing import Any, List, Optional
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.callbacks.base import CallbackManager
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
from llama_index.core.schema import TransformComponent
|
from llama_index.core.schema import TransformComponent
|
||||||
from llama_index.core.service_context import ServiceContext
|
from llama_index.core.service_context import ServiceContext
|
||||||
@@ -13,18 +13,18 @@ from llama_index.core.storage.storage_context import StorageContext
|
|||||||
|
|
||||||
|
|
||||||
class GptacVectorStoreIndex(VectorStoreIndex):
|
class GptacVectorStoreIndex(VectorStoreIndex):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_vector_store(
|
def default_vector_store(
|
||||||
cls,
|
cls,
|
||||||
storage_context: Optional[StorageContext] = None,
|
storage_context: Optional[StorageContext] = None,
|
||||||
show_progress: bool = False,
|
show_progress: bool = False,
|
||||||
callback_manager: Optional[CallbackManager] = None,
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
transformations: Optional[List[TransformComponent]] = None,
|
transformations: Optional[List[TransformComponent]] = None,
|
||||||
# deprecated
|
# deprecated
|
||||||
service_context: Optional[ServiceContext] = None,
|
service_context: Optional[ServiceContext] = None,
|
||||||
embed_model = None,
|
embed_model=None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Create index from documents.
|
"""Create index from documents.
|
||||||
|
|
||||||
@@ -36,15 +36,14 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
storage_context = storage_context or StorageContext.from_defaults()
|
storage_context = storage_context or StorageContext.from_defaults()
|
||||||
docstore = storage_context.docstore
|
docstore = storage_context.docstore
|
||||||
callback_manager = (
|
callback_manager = (
|
||||||
callback_manager
|
callback_manager
|
||||||
or callback_manager_from_settings_or_context(Settings, service_context)
|
or callback_manager_from_settings_or_context(Settings, service_context)
|
||||||
)
|
)
|
||||||
transformations = transformations or transformations_from_settings_or_context(
|
transformations = transformations or transformations_from_settings_or_context(
|
||||||
Settings, service_context
|
Settings, service_context
|
||||||
)
|
)
|
||||||
|
|
||||||
with callback_manager.as_trace("index_construction"):
|
with callback_manager.as_trace("index_construction"):
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
nodes=[],
|
nodes=[],
|
||||||
storage_context=storage_context,
|
storage_context=storage_context,
|
||||||
@@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
embed_model=embed_model,
|
embed_model=embed_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user