up
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
|
||||||
import tarfile
|
import tarfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class ArxivDownloader:
|
class ArxivDownloader:
|
||||||
"""用于下载arXiv论文源码的下载器"""
|
"""用于下载arXiv论文源码的下载器"""
|
||||||
|
|
||||||
@@ -80,6 +82,7 @@ class ArxivDownloader:
|
|||||||
"""
|
"""
|
||||||
return self._download_and_extract(arxiv_id)
|
return self._download_and_extract(arxiv_id)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""测试下载功能"""
|
"""测试下载功能"""
|
||||||
# 配置代理(如果需要)
|
# 配置代理(如果需要)
|
||||||
@@ -107,5 +110,6 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error downloading paper: {e}")
|
print(f"Error downloading paper: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,21 +1,19 @@
|
|||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
import tarfile
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
import re
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict, Set
|
||||||
|
|
||||||
from typing import Generator, List, Tuple, Optional, Dict, Set
|
import aiohttp
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||||
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||||
|
|
||||||
|
|
||||||
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
||||||
@@ -31,7 +29,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
|
||||||
|
|
||||||
# Create output directory
|
# Create output directory
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
@@ -103,9 +100,6 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
|
|
||||||
# Content
|
# Content
|
||||||
f.write("\n**Content:**\n")
|
f.write("\n**Content:**\n")
|
||||||
# f.write("```tex\n")
|
|
||||||
# f.write(fragment.content)
|
|
||||||
# f.write("\n```\n")
|
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
f.write(fragment.content)
|
f.write(fragment.content)
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
@@ -140,6 +134,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
print(f"Fragments saved to: {file_path}")
|
print(f"Fragments saved to: {file_path}")
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
# 定义各种引用命令的模式
|
# 定义各种引用命令的模式
|
||||||
CITATION_PATTERNS = [
|
CITATION_PATTERNS = [
|
||||||
# 基本的 \cite{} 格式
|
# 基本的 \cite{} 格式
|
||||||
@@ -199,8 +194,6 @@ class ArxivSplitter:
|
|||||||
# 配置日志
|
# 配置日志
|
||||||
self._setup_logging()
|
self._setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging(self):
|
def _setup_logging(self):
|
||||||
"""配置日志"""
|
"""配置日志"""
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -221,7 +214,6 @@ class ArxivSplitter:
|
|||||||
return arxiv_id.split('v')[0].strip()
|
return arxiv_id.split('v')[0].strip()
|
||||||
return input_str.split('v')[0].strip()
|
return input_str.split('v')[0].strip()
|
||||||
|
|
||||||
|
|
||||||
def _check_cache(self, paper_dir: Path) -> bool:
|
def _check_cache(self, paper_dir: Path) -> bool:
|
||||||
"""
|
"""
|
||||||
检查缓存是否有效,包括文件完整性检查
|
检查缓存是否有效,包括文件完整性检查
|
||||||
@@ -545,6 +537,7 @@ class ArxivSplitter:
|
|||||||
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
self.logger.error(f"Error finding citation contexts: {str(e)}")
|
||||||
|
|
||||||
return contexts
|
return contexts
|
||||||
|
|
||||||
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
|
||||||
"""
|
"""
|
||||||
Process ArXiv paper and convert to list of SectionFragments.
|
Process ArXiv paper and convert to list of SectionFragments.
|
||||||
@@ -573,8 +566,6 @@ class ArxivSplitter:
|
|||||||
# 读取主 TeX 文件内容
|
# 读取主 TeX 文件内容
|
||||||
main_tex_content = read_tex_file(main_tex)
|
main_tex_content = read_tex_file(main_tex)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get all related TeX files and references
|
# Get all related TeX files and references
|
||||||
tex_files = self.tex_processor.resolve_includes(main_tex)
|
tex_files = self.tex_processor.resolve_includes(main_tex)
|
||||||
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
|
||||||
@@ -742,7 +733,6 @@ class ArxivSplitter:
|
|||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def test_arxiv_splitter():
|
async def test_arxiv_splitter():
|
||||||
"""测试ArXiv分割器的功能"""
|
"""测试ArXiv分割器的功能"""
|
||||||
|
|
||||||
@@ -765,7 +755,6 @@ async def test_arxiv_splitter():
|
|||||||
root_dir="test_cache"
|
root_dir="test_cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||||
try:
|
try:
|
||||||
|
|||||||
177
crazy_functions/rag_fns/arxiv_fns/author_extractor.py
Normal file
177
crazy_functions/rag_fns/arxiv_fns/author_extractor.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LatexAuthorExtractor:
|
||||||
|
def __init__(self):
|
||||||
|
# Patterns for matching author blocks with balanced braces
|
||||||
|
self.author_block_patterns = [
|
||||||
|
# Standard LaTeX patterns with optional arguments
|
||||||
|
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Conference and journal specific patterns
|
||||||
|
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
# Academic publisher specific patterns
|
||||||
|
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cleaning patterns for LaTeX commands and formatting
|
||||||
|
self.cleaning_patterns = [
|
||||||
|
# Text formatting commands - preserve content
|
||||||
|
(r'\\textbf\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textit\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\emph\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\texttt\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\textrm\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\text\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Affiliation and footnote markers
|
||||||
|
(r'\$\^{[^}]+}\$', ''),
|
||||||
|
(r'\^{[^}]+}', ''),
|
||||||
|
(r'\\thanks\{[^}]+\}', ''),
|
||||||
|
(r'\\footnote\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Email and contact formatting
|
||||||
|
(r'\\email\{([^}]+)\}', r'\1'),
|
||||||
|
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
|
||||||
|
|
||||||
|
# Institution formatting
|
||||||
|
(r'\\inst\{[^}]+\}', ''),
|
||||||
|
(r'\\affil\{[^}]+\}', ''),
|
||||||
|
|
||||||
|
# Special characters and symbols
|
||||||
|
(r'\\&', '&'),
|
||||||
|
(r'\\\\\s*', ' '),
|
||||||
|
(r'\\,', ' '),
|
||||||
|
(r'\\;', ' '),
|
||||||
|
(r'\\quad', ' '),
|
||||||
|
(r'\\qquad', ' '),
|
||||||
|
|
||||||
|
# Math mode content
|
||||||
|
(r'\$[^$]+\$', ''),
|
||||||
|
|
||||||
|
# Common symbols
|
||||||
|
(r'\\dagger', '†'),
|
||||||
|
(r'\\ddagger', '‡'),
|
||||||
|
(r'\\ast', '*'),
|
||||||
|
(r'\\star', '★'),
|
||||||
|
|
||||||
|
# Remove remaining LaTeX commands
|
||||||
|
(r'\\[a-zA-Z]+', ''),
|
||||||
|
|
||||||
|
# Clean up remaining special characters
|
||||||
|
(r'[\\{}]', '')
|
||||||
|
]
|
||||||
|
|
||||||
|
def extract_author_block(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the complete author block from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Extracted author block or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for pattern in self.author_block_patterns:
|
||||||
|
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (AttributeError, IndexError) as e:
|
||||||
|
print(f"Error extracting author block: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean_tex_commands(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove LaTeX commands and formatting from text while preserving content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text containing LaTeX commands
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned text with commands removed
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
cleaned_text = text
|
||||||
|
|
||||||
|
# Apply cleaning patterns
|
||||||
|
for pattern, replacement in self.cleaning_patterns:
|
||||||
|
cleaned_text = re.sub(pattern, replacement, cleaned_text)
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
|
||||||
|
cleaned_text = cleaned_text.strip()
|
||||||
|
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def extract_authors(self, text: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract and clean author information from LaTeX text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input LaTeX text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Cleaned author information or None if extraction fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract author block
|
||||||
|
author_block = self.extract_author_block(text)
|
||||||
|
if not author_block:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean LaTeX commands
|
||||||
|
cleaned_authors = self.clean_tex_commands(author_block)
|
||||||
|
return cleaned_authors or None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing text: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_author_extractor():
|
||||||
|
"""Test the LatexAuthorExtractor with sample inputs."""
|
||||||
|
test_cases = [
|
||||||
|
# Basic test case
|
||||||
|
(r"\author{John Doe}", "John Doe"),
|
||||||
|
|
||||||
|
# Test with multiple authors
|
||||||
|
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
|
||||||
|
|
||||||
|
# Test with affiliations
|
||||||
|
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
extractor = LatexAuthorExtractor()
|
||||||
|
|
||||||
|
for i, (input_tex, expected) in enumerate(test_cases, 1):
|
||||||
|
result = extractor.extract_authors(input_tex)
|
||||||
|
print(f"\nTest case {i}:")
|
||||||
|
print(f"Input: {input_tex[:50]}...")
|
||||||
|
print(f"Expected: {expected[:50]}...")
|
||||||
|
print(f"Got: {result[:50]}...")
|
||||||
|
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_author_extractor()
|
||||||
@@ -5,14 +5,15 @@ This module provides functionality for parsing and extracting structured informa
|
|||||||
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
including metadata, document structure, and content. It uses modular design and clean architecture principles.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, SectionLevel, EnhancedSectionExtractor
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -28,6 +29,7 @@ def read_tex_file(file_path):
|
|||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DocumentStructure:
|
class DocumentStructure:
|
||||||
title: str = ''
|
title: str = ''
|
||||||
@@ -149,6 +151,8 @@ class DocumentStructure:
|
|||||||
result.extend(_format_section(section, 0) for section in self.toc)
|
result.extend(_format_section(section, 0) for section in self.toc)
|
||||||
|
|
||||||
return "".join(result)
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
class BaseExtractor(ABC):
|
class BaseExtractor(ABC):
|
||||||
"""Base class for LaTeX content extractors."""
|
"""Base class for LaTeX content extractors."""
|
||||||
|
|
||||||
@@ -157,6 +161,7 @@ class BaseExtractor(ABC):
|
|||||||
"""Extract specific content from LaTeX document."""
|
"""Extract specific content from LaTeX document."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TitleExtractor(BaseExtractor):
|
class TitleExtractor(BaseExtractor):
|
||||||
"""Extracts title from LaTeX document."""
|
"""Extracts title from LaTeX document."""
|
||||||
|
|
||||||
@@ -180,6 +185,7 @@ class TitleExtractor(BaseExtractor):
|
|||||||
return clean_latex_commands(title)
|
return clean_latex_commands(title)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
class AbstractExtractor(BaseExtractor):
|
class AbstractExtractor(BaseExtractor):
|
||||||
"""Extracts abstract from LaTeX document."""
|
"""Extracts abstract from LaTeX document."""
|
||||||
|
|
||||||
@@ -203,6 +209,7 @@ class AbstractExtractor(BaseExtractor):
|
|||||||
return clean_latex_commands(abstract)
|
return clean_latex_commands(abstract)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
class EssayStructureParser:
|
class EssayStructureParser:
|
||||||
"""Main class for parsing LaTeX documents."""
|
"""Main class for parsing LaTeX documents."""
|
||||||
|
|
||||||
@@ -231,6 +238,7 @@ class EssayStructureParser:
|
|||||||
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
|
||||||
"""Print document structure in a readable format."""
|
"""Print document structure in a readable format."""
|
||||||
print(f"Title: {doc.title}\n")
|
print(f"Title: {doc.title}\n")
|
||||||
@@ -250,10 +258,10 @@ def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100
|
|||||||
for section in doc.toc:
|
for section in doc.toc:
|
||||||
print_section(section)
|
print_section(section)
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# Test with a file
|
# Test with a file
|
||||||
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
file_path = 'test_cache/2411.03663/neurips_2024.tex'
|
||||||
main_tex = read_tex_file(file_path)
|
main_tex = read_tex_file(file_path)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
|
||||||
import re
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import Set, Dict, Pattern, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
class EnvType(Enum):
|
class EnvType(Enum):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LaTeXPatterns:
|
class LaTeXPatterns:
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
import re
|
|
||||||
from typing import List, Dict, Tuple, Optional, Set
|
|
||||||
from enum import Enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SectionLevel(Enum):
|
class SectionLevel(Enum):
|
||||||
CHAPTER = 0
|
CHAPTER = 0
|
||||||
@@ -39,6 +39,7 @@ class SectionLevel(Enum):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.value >= other.value
|
return self.value >= other.value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Section:
|
class Section:
|
||||||
level: SectionLevel
|
level: SectionLevel
|
||||||
@@ -46,6 +47,7 @@ class Section:
|
|||||||
content: str = ''
|
content: str = ''
|
||||||
bibliography: str = ''
|
bibliography: str = ''
|
||||||
subsections: List['Section'] = field(default_factory=list)
|
subsections: List['Section'] = field(default_factory=list)
|
||||||
|
|
||||||
def merge(self, other: 'Section') -> 'Section':
|
def merge(self, other: 'Section') -> 'Section':
|
||||||
"""Merge this section with another section."""
|
"""Merge this section with another section."""
|
||||||
if self.title != other.title or self.level != other.level:
|
if self.title != other.title or self.level != other.level:
|
||||||
@@ -78,6 +80,8 @@ class Section:
|
|||||||
return content1
|
return content1
|
||||||
# Combine non-empty contents with a separator
|
# Combine non-empty contents with a separator
|
||||||
return f"{content1}\n\n{content2}"
|
return f"{content1}\n\n{content2}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LatexEnvironment:
|
class LatexEnvironment:
|
||||||
"""表示LaTeX环境的数据类"""
|
"""表示LaTeX环境的数据类"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SectionFragment:
|
class SectionFragment:
|
||||||
"""Arxiv论文片段数据类"""
|
"""Arxiv论文片段数据类"""
|
||||||
@@ -11,9 +12,3 @@ class SectionFragment:
|
|||||||
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
|
||||||
content: str = '' # 当前片段的内容
|
content: str = '' # 当前片段的内容
|
||||||
bibliography: str = '' # 当前片段的参考文献
|
bibliography: str = '' # 当前片段的参考文献
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import re
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Dict, Set, Optional, Callable
|
from typing import List, Set, Optional
|
||||||
|
|
||||||
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
|
||||||
|
|
||||||
|
|
||||||
class TexUtils:
|
class TexUtils:
|
||||||
"""TeX文档处理器类"""
|
"""TeX文档处理器类"""
|
||||||
|
|
||||||
@@ -21,9 +23,6 @@ class TexUtils:
|
|||||||
self._init_patterns()
|
self._init_patterns()
|
||||||
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_patterns(self):
|
def _init_patterns(self):
|
||||||
"""初始化LaTeX模式匹配规则"""
|
"""初始化LaTeX模式匹配规则"""
|
||||||
# 特殊环境模式
|
# 特殊环境模式
|
||||||
@@ -234,6 +233,7 @@ class TexUtils:
|
|||||||
processed_refs.append("\n".join(ref_lines))
|
processed_refs.append("\n".join(ref_lines))
|
||||||
|
|
||||||
return processed_refs
|
return processed_refs
|
||||||
|
|
||||||
def _extract_inline_references(self, content: str) -> str:
|
def _extract_inline_references(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
从tex文件内容中提取直接写在文件中的参考文献
|
从tex文件内容中提取直接写在文件中的参考文献
|
||||||
@@ -255,6 +255,7 @@ class TexUtils:
|
|||||||
return content[start_match.start():end_match.end()]
|
return content[start_match.start():end_match.end()]
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _preprocess_content(self, content: str) -> str:
|
def _preprocess_content(self, content: str) -> str:
|
||||||
"""预处理TeX内容"""
|
"""预处理TeX内容"""
|
||||||
# 移除注释
|
# 移除注释
|
||||||
@@ -263,9 +264,3 @@ class TexUtils:
|
|||||||
# content = re.sub(r'\s+', ' ', content)
|
# content = re.sub(r'\s+', ' ', content)
|
||||||
content = re.sub(r'\n\s*\n', '\n\n', content)
|
content = re.sub(r'\n\s*\n', '\n\n', content)
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,13 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
from loguru import logger
|
import os
|
||||||
from typing import List
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core import PromptTemplate
|
from llama_index.core.schema import TextNode
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
from loguru import logger
|
||||||
|
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -127,7 +123,6 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
logger.error(f"Error saving checkpoint: {str(e)}")
|
logger.error(f"Error saving checkpoint: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,14 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from loguru import logger
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
|
||||||
from llama_index.core import PromptTemplate
|
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
from llama_index.core import StorageContext
|
from llama_index.core import StorageContext
|
||||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -70,12 +64,14 @@ class MilvusSaveLoad():
|
|||||||
overwrite=overwrite
|
overwrite=overwrite
|
||||||
)
|
)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
|
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
|
||||||
|
embed_model=self.embed_model)
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def purge(self):
|
def purge(self):
|
||||||
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
||||||
|
|
||||||
|
|
||||||
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
||||||
|
|
||||||
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import os
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||||
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
|
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
|
||||||
|
|
||||||
|
|
||||||
def read_docx_doc(file_path):
|
def read_docx_doc(file_path):
|
||||||
if file_path.split(".")[-1] == "docx":
|
if file_path.split(".")[-1] == "docx":
|
||||||
from docx import Document
|
from docx import Document
|
||||||
@@ -25,9 +25,11 @@ def read_docx_doc(file_path):
|
|||||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||||
return file_content
|
return file_content
|
||||||
|
|
||||||
|
|
||||||
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def extract_text(file_path):
|
def extract_text(file_path):
|
||||||
_, ext = os.path.splitext(file_path.lower())
|
_, ext = os.path.splitext(file_path.lower())
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from llama_index.core import VectorStoreIndex
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
from llama_index.core.callbacks.base import CallbackManager
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
from llama_index.core.schema import TransformComponent
|
from llama_index.core.schema import TransformComponent
|
||||||
from llama_index.core.service_context import ServiceContext
|
from llama_index.core.service_context import ServiceContext
|
||||||
@@ -44,7 +44,6 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with callback_manager.as_trace("index_construction"):
|
with callback_manager.as_trace("index_construction"):
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
nodes=[],
|
nodes=[],
|
||||||
storage_context=storage_context,
|
storage_context=storage_context,
|
||||||
@@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex):
|
|||||||
embed_model=embed_model,
|
embed_model=embed_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user