Files
gpt_academic/crazy_functions/review_fns/handlers/base_handler.py
binary-husky 8042750d41 Master 4.0 (#2210)
* stage academic conversation

* stage document conversation

* fix buggy gradio version

* file dynamic load

* merge more academic plugins

* accelerate nltk

* feat: 为predict函数添加文件和URL读取功能
- 添加URL检测和网页内容提取功能,支持自动提取网页文本
- 添加文件路径识别和文件内容读取功能,支持private_upload路径格式
- 集成WebTextExtractor处理网页内容提取
- 集成TextContentLoader处理本地文件读取
- 支持文件路径与问题组合的智能处理

* back

* block unstable

---------

Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
2025-08-23 15:59:22 +08:00

412 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from datetime import datetime
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from crazy_functions.review_fns.query_analyzer import SearchCriteria
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
from request_llms.bridge_all import model_info
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
from toolbox import get_conf
class BaseHandler(ABC):
"""处理器基类"""
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
self.arxiv = arxiv
self.semantic = semantic
self.pubmed = PubMedSource()
self.crossref = CrossrefSource() # 添加 Crossref 实例
self.adsabs = AdsabsSource() # 添加 ADS 实例
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
self.ranked_papers = [] # 存储排序后的论文列表
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
"""获取搜索参数"""
return {
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
}
@abstractmethod
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> List[List[str]]:
"""处理查询"""
pass
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用arXiv专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 首先尝试基础搜索
query = params.get("query", "")
if query:
papers = await self.arxiv.search(
query,
limit=params["limit"],
sort_by=params.get("sort_by", "relevance"),
sort_order=params.get("sort_order", "descending"),
start_year=min_year
)
# 如果基础搜索没有结果,尝试分类搜索
if not papers:
categories = params.get("categories", [])
for category in categories:
category_papers = await self.arxiv.search_by_category(
category,
limit=params["limit"],
sort_by=params.get("sort_by", "relevance"),
sort_order=params.get("sort_order", "descending"),
)
if category_papers:
papers.extend(category_papers)
return papers or []
except Exception as e:
print(f"arXiv搜索出错: {str(e)}")
return []
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用Semantic Scholar专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
# 只使用基本的搜索参数
papers = await self.semantic.search(
query=params.get("query", ""),
limit=params["limit"]
)
# 在内存中进行过滤
if papers and min_year:
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
return papers or []
except Exception as e:
print(f"Semantic Scholar搜索出错: {str(e)}")
return []
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用PubMed专用参数搜索"""
try:
# 如果不需要PubMed搜索直接返回空列表
if params.get("search_type") == "none":
return []
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 根据搜索类型选择搜索方法
if params.get("search_type") == "basic":
papers = await self.pubmed.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "author":
papers = await self.pubmed.search_by_author(
author=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "journal":
papers = await self.pubmed.search_by_journal(
journal=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
return papers or []
except Exception as e:
print(f"PubMed搜索出错: {str(e)}")
return []
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用Crossref专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 根据搜索类型选择搜索方法
if params.get("search_type") == "basic":
papers = await self.crossref.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "author":
papers = await self.crossref.search_by_authors(
authors=[params.get("query", "")],
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "journal":
# 实现期刊搜索逻辑
pass
return papers or []
except Exception as e:
print(f"Crossref搜索出错: {str(e)}")
return []
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用ADS专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 执行搜索
if params.get("search_type") == "basic":
papers = await self.adsabs.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
return papers or []
except Exception as e:
print(f"ADS搜索出错: {str(e)}")
return []
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
"""从所有数据源搜索论文"""
search_tasks = []
# # 检查是否需要执行PubMed搜索
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
is_using_pubmed = False # 开源版本不再搜索pubmed
# 如果使用PubMed则只执行PubMed和Semantic Scholar搜索
if is_using_pubmed:
search_tasks.append(
self._search_pubmed(
criteria.pubmed_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
# Semantic Scholar总是执行搜索
search_tasks.append(
self._search_semantic(
criteria.semantic_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
else:
# 如果不使用ADS则执行Crossref搜索
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
search_tasks.append(
self._search_crossref(
criteria.crossref_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
search_tasks.append(
self._search_arxiv(
criteria.arxiv_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
if get_conf("SEMANTIC_SCHOLAR_KEY"):
search_tasks.append(
self._search_semantic(
criteria.semantic_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
# 执行所有需要的搜索任务
papers = await asyncio.gather(*search_tasks)
# 合并所有来源的论文并统计各来源的数量
all_papers = []
source_counts = {
'arxiv': 0,
'semantic': 0,
'pubmed': 0,
'crossref': 0,
'adsabs': 0
}
for source_papers in papers:
if source_papers:
for paper in source_papers:
source = getattr(paper, 'source', 'unknown')
if source in source_counts:
source_counts[source] += 1
all_papers.extend(source_papers)
# 打印各来源的论文数量
print("\n=== 各数据源找到的论文数量 ===")
for source, count in source_counts.items():
if count > 0: # 只打印有论文的来源
print(f"{source.capitalize()}: {count}")
print(f"总计: {len(all_papers)}")
print("===========================\n")
return all_papers
def _format_paper_time(self, paper) -> str:
"""格式化论文时间信息"""
year = getattr(paper, 'year', None)
if not year:
return ""
# 如果有具体的发表日期,使用具体日期
if hasattr(paper, 'published') and paper.published:
return f"(发表于 {paper.published.strftime('%Y-%m')})"
# 如果只有年份,只显示年份
return f"({year})"
def _format_papers(self, papers: List) -> str:
"""格式化论文列表使用token限制控制长度"""
formatted = []
for i, paper in enumerate(papers, 1):
# 只保留前三个作者
authors = paper.authors[:3]
if len(paper.authors) > 3:
authors.append("et al.")
# 构建所有可能的下载链接
download_links = []
# 添加arXiv链接
if hasattr(paper, 'doi') and paper.doi:
if paper.doi.startswith("10.48550/arXiv."):
# 从DOI中提取完整的arXiv ID
arxiv_id = paper.doi.split("arXiv.")[-1]
# 移除多余的点号并确保格式正确
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
if arxiv_id.startswith("."): # 移除开头的点号
arxiv_id = arxiv_id[1:]
if arxiv_id.endswith("."): # 移除结尾的点号
arxiv_id = arxiv_id[:-1]
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
elif "arxiv.org/abs/" in paper.doi:
# 直接从URL中提取arXiv ID
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
if "v" in arxiv_id: # 移除版本号
arxiv_id = arxiv_id.split("v")[0]
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
else:
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
# 添加直接URL链接如果存在且不同于前面的链接
if hasattr(paper, 'url') and paper.url:
if not any(paper.url in link for link in download_links):
download_links.append(f"[Source]({paper.url})")
# 构建下载链接字符串
download_section = " | ".join(download_links) if download_links else "No direct download link available"
# 构建来源信息
source_info = []
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
source_info.append(f"Type: {paper.venue_type}")
if hasattr(paper, 'venue_name') and paper.venue_name:
source_info.append(f"Venue: {paper.venue_name}")
# 添加IF指数和分区信息
if hasattr(paper, 'if_factor') and paper.if_factor:
source_info.append(f"IF: {paper.if_factor}")
if hasattr(paper, 'cas_division') and paper.cas_division:
source_info.append(f"中科院分区: {paper.cas_division}")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
source_info.append(f"JCR分区: {paper.jcr_division}")
if hasattr(paper, 'venue_info') and paper.venue_info:
if paper.venue_info.get('journal_ref'):
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
if paper.venue_info.get('publisher'):
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
# 构建当前论文的格式化文本
paper_text = (
f"{i}. **{paper.title}**\n" +
f" Authors: {', '.join(authors)}\n" +
f" Year: {paper.year}\n" +
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
# 添加PubMed特有信息
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
'mesh_terms') and paper.mesh_terms else "") +
f" 📥 PDF Downloads: {download_section}\n" +
f" Abstract: {paper.abstract}\n"
)
formatted.append(paper_text)
full_text = "\n".join(formatted)
# 根据不同模型设置不同的token限制
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
token_limit = model_info[model_name]['max_token'] * 3 // 4
# 使用token限制控制长度
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
def _get_current_time(self) -> str:
"""获取当前时间信息"""
now = datetime.now()
return now.strftime("%Y年%m月%d")
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
"""生成道歉提示"""
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
可能的原因:
1. 搜索词过于具体或专业
2. 时间范围限制过严
建议解决方案:
1. 尝试使用更通用的关键词
2. 扩大搜索时间范围
3. 使用同义词或相关术语
请根据以上建议调整后重试。"""
def get_ranked_papers(self) -> str:
"""获取排序后的论文列表的格式化字符串"""
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
def _is_pubmed_paper(self, paper) -> bool:
"""判断是否为PubMed论文"""
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url)