* stage academic conversation * stage document conversation * fix buggy gradio version * file dynamic load * merge more academic plugins * accelerate nltk * feat: 为predict函数添加文件和URL读取功能 - 添加URL检测和网页内容提取功能,支持自动提取网页文本 - 添加文件路径识别和文件内容读取功能,支持private_upload路径格式 - 集成WebTextExtractor处理网页内容提取 - 集成TextContentLoader处理本地文件读取 - 支持文件路径与问题组合的智能处理 * back * block unstable --------- Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
412 lines
17 KiB
Python
412 lines
17 KiB
Python
import asyncio
|
||
from datetime import datetime
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Dict, Any
|
||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
|
||
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
|
||
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
|
||
from request_llms.bridge_all import model_info
|
||
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
|
||
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
|
||
from toolbox import get_conf
|
||
|
||
|
||
class BaseHandler(ABC):
|
||
"""处理器基类"""
|
||
|
||
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
|
||
self.arxiv = arxiv
|
||
self.semantic = semantic
|
||
self.pubmed = PubMedSource()
|
||
self.crossref = CrossrefSource() # 添加 Crossref 实例
|
||
self.adsabs = AdsabsSource() # 添加 ADS 实例
|
||
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
|
||
self.ranked_papers = [] # 存储排序后的论文列表
|
||
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
|
||
|
||
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
|
||
"""获取搜索参数"""
|
||
return {
|
||
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
|
||
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
|
||
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
|
||
}
|
||
|
||
@abstractmethod
|
||
async def handle(
|
||
self,
|
||
criteria: SearchCriteria,
|
||
chatbot: List[List[str]],
|
||
history: List[List[str]],
|
||
system_prompt: str,
|
||
llm_kwargs: Dict[str, Any],
|
||
plugin_kwargs: Dict[str, Any],
|
||
) -> List[List[str]]:
|
||
"""处理查询"""
|
||
pass
|
||
|
||
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||
"""使用arXiv专用参数搜索"""
|
||
try:
|
||
original_limit = params.get("limit", 20)
|
||
params["limit"] = original_limit * limit_multiplier
|
||
papers = []
|
||
|
||
# 首先尝试基础搜索
|
||
query = params.get("query", "")
|
||
if query:
|
||
papers = await self.arxiv.search(
|
||
query,
|
||
limit=params["limit"],
|
||
sort_by=params.get("sort_by", "relevance"),
|
||
sort_order=params.get("sort_order", "descending"),
|
||
start_year=min_year
|
||
)
|
||
|
||
# 如果基础搜索没有结果,尝试分类搜索
|
||
if not papers:
|
||
categories = params.get("categories", [])
|
||
for category in categories:
|
||
category_papers = await self.arxiv.search_by_category(
|
||
category,
|
||
limit=params["limit"],
|
||
sort_by=params.get("sort_by", "relevance"),
|
||
sort_order=params.get("sort_order", "descending"),
|
||
)
|
||
if category_papers:
|
||
papers.extend(category_papers)
|
||
|
||
return papers or []
|
||
|
||
except Exception as e:
|
||
print(f"arXiv搜索出错: {str(e)}")
|
||
return []
|
||
|
||
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||
"""使用Semantic Scholar专用参数搜索"""
|
||
try:
|
||
original_limit = params.get("limit", 20)
|
||
params["limit"] = original_limit * limit_multiplier
|
||
|
||
# 只使用基本的搜索参数
|
||
papers = await self.semantic.search(
|
||
query=params.get("query", ""),
|
||
limit=params["limit"]
|
||
)
|
||
|
||
# 在内存中进行过滤
|
||
if papers and min_year:
|
||
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
|
||
|
||
return papers or []
|
||
|
||
except Exception as e:
|
||
print(f"Semantic Scholar搜索出错: {str(e)}")
|
||
return []
|
||
|
||
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||
"""使用PubMed专用参数搜索"""
|
||
try:
|
||
# 如果不需要PubMed搜索,直接返回空列表
|
||
if params.get("search_type") == "none":
|
||
return []
|
||
|
||
original_limit = params.get("limit", 20)
|
||
params["limit"] = original_limit * limit_multiplier
|
||
papers = []
|
||
|
||
# 根据搜索类型选择搜索方法
|
||
if params.get("search_type") == "basic":
|
||
papers = await self.pubmed.search(
|
||
query=params.get("query", ""),
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
elif params.get("search_type") == "author":
|
||
papers = await self.pubmed.search_by_author(
|
||
author=params.get("query", ""),
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
elif params.get("search_type") == "journal":
|
||
papers = await self.pubmed.search_by_journal(
|
||
journal=params.get("query", ""),
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
|
||
return papers or []
|
||
|
||
except Exception as e:
|
||
print(f"PubMed搜索出错: {str(e)}")
|
||
return []
|
||
|
||
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||
"""使用Crossref专用参数搜索"""
|
||
try:
|
||
original_limit = params.get("limit", 20)
|
||
params["limit"] = original_limit * limit_multiplier
|
||
papers = []
|
||
|
||
# 根据搜索类型选择搜索方法
|
||
if params.get("search_type") == "basic":
|
||
papers = await self.crossref.search(
|
||
query=params.get("query", ""),
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
elif params.get("search_type") == "author":
|
||
papers = await self.crossref.search_by_authors(
|
||
authors=[params.get("query", "")],
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
elif params.get("search_type") == "journal":
|
||
# 实现期刊搜索逻辑
|
||
pass
|
||
|
||
return papers or []
|
||
|
||
except Exception as e:
|
||
print(f"Crossref搜索出错: {str(e)}")
|
||
return []
|
||
|
||
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||
"""使用ADS专用参数搜索"""
|
||
try:
|
||
original_limit = params.get("limit", 20)
|
||
params["limit"] = original_limit * limit_multiplier
|
||
papers = []
|
||
|
||
# 执行搜索
|
||
if params.get("search_type") == "basic":
|
||
papers = await self.adsabs.search(
|
||
query=params.get("query", ""),
|
||
limit=params["limit"],
|
||
start_year=min_year
|
||
)
|
||
|
||
return papers or []
|
||
|
||
except Exception as e:
|
||
print(f"ADS搜索出错: {str(e)}")
|
||
return []
|
||
|
||
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||
"""从所有数据源搜索论文"""
|
||
search_tasks = []
|
||
|
||
# # 检查是否需要执行PubMed搜索
|
||
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
|
||
is_using_pubmed = False # 开源版本不再搜索pubmed
|
||
|
||
# 如果使用PubMed,则只执行PubMed和Semantic Scholar搜索
|
||
if is_using_pubmed:
|
||
search_tasks.append(
|
||
self._search_pubmed(
|
||
criteria.pubmed_params,
|
||
limit_multiplier=search_params['search_multiplier'],
|
||
min_year=criteria.start_year
|
||
)
|
||
)
|
||
|
||
# Semantic Scholar总是执行搜索
|
||
search_tasks.append(
|
||
self._search_semantic(
|
||
criteria.semantic_params,
|
||
limit_multiplier=search_params['search_multiplier'],
|
||
min_year=criteria.start_year
|
||
)
|
||
)
|
||
else:
|
||
|
||
# 如果不使用ADS,则执行Crossref搜索
|
||
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
|
||
search_tasks.append(
|
||
self._search_crossref(
|
||
criteria.crossref_params,
|
||
limit_multiplier=search_params['search_multiplier'],
|
||
min_year=criteria.start_year
|
||
)
|
||
)
|
||
|
||
search_tasks.append(
|
||
self._search_arxiv(
|
||
criteria.arxiv_params,
|
||
limit_multiplier=search_params['search_multiplier'],
|
||
min_year=criteria.start_year
|
||
)
|
||
)
|
||
if get_conf("SEMANTIC_SCHOLAR_KEY"):
|
||
search_tasks.append(
|
||
self._search_semantic(
|
||
criteria.semantic_params,
|
||
limit_multiplier=search_params['search_multiplier'],
|
||
min_year=criteria.start_year
|
||
)
|
||
)
|
||
|
||
# 执行所有需要的搜索任务
|
||
papers = await asyncio.gather(*search_tasks)
|
||
|
||
# 合并所有来源的论文并统计各来源的数量
|
||
all_papers = []
|
||
source_counts = {
|
||
'arxiv': 0,
|
||
'semantic': 0,
|
||
'pubmed': 0,
|
||
'crossref': 0,
|
||
'adsabs': 0
|
||
}
|
||
|
||
for source_papers in papers:
|
||
if source_papers:
|
||
for paper in source_papers:
|
||
source = getattr(paper, 'source', 'unknown')
|
||
if source in source_counts:
|
||
source_counts[source] += 1
|
||
all_papers.extend(source_papers)
|
||
|
||
# 打印各来源的论文数量
|
||
print("\n=== 各数据源找到的论文数量 ===")
|
||
for source, count in source_counts.items():
|
||
if count > 0: # 只打印有论文的来源
|
||
print(f"{source.capitalize()}: {count} 篇")
|
||
print(f"总计: {len(all_papers)} 篇")
|
||
print("===========================\n")
|
||
|
||
return all_papers
|
||
|
||
def _format_paper_time(self, paper) -> str:
|
||
"""格式化论文时间信息"""
|
||
year = getattr(paper, 'year', None)
|
||
if not year:
|
||
return ""
|
||
|
||
# 如果有具体的发表日期,使用具体日期
|
||
if hasattr(paper, 'published') and paper.published:
|
||
return f"(发表于 {paper.published.strftime('%Y-%m')})"
|
||
# 如果只有年份,只显示年份
|
||
return f"({year})"
|
||
|
||
def _format_papers(self, papers: List) -> str:
|
||
"""格式化论文列表,使用token限制控制长度"""
|
||
formatted = []
|
||
|
||
for i, paper in enumerate(papers, 1):
|
||
# 只保留前三个作者
|
||
authors = paper.authors[:3]
|
||
if len(paper.authors) > 3:
|
||
authors.append("et al.")
|
||
|
||
# 构建所有可能的下载链接
|
||
download_links = []
|
||
|
||
# 添加arXiv链接
|
||
if hasattr(paper, 'doi') and paper.doi:
|
||
if paper.doi.startswith("10.48550/arXiv."):
|
||
# 从DOI中提取完整的arXiv ID
|
||
arxiv_id = paper.doi.split("arXiv.")[-1]
|
||
# 移除多余的点号并确保格式正确
|
||
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
|
||
if arxiv_id.startswith("."): # 移除开头的点号
|
||
arxiv_id = arxiv_id[1:]
|
||
if arxiv_id.endswith("."): # 移除结尾的点号
|
||
arxiv_id = arxiv_id[:-1]
|
||
|
||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||
elif "arxiv.org/abs/" in paper.doi:
|
||
# 直接从URL中提取arXiv ID
|
||
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
|
||
if "v" in arxiv_id: # 移除版本号
|
||
arxiv_id = arxiv_id.split("v")[0]
|
||
|
||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||
else:
|
||
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
|
||
|
||
# 添加直接URL链接(如果存在且不同于前面的链接)
|
||
if hasattr(paper, 'url') and paper.url:
|
||
if not any(paper.url in link for link in download_links):
|
||
download_links.append(f"[Source]({paper.url})")
|
||
|
||
# 构建下载链接字符串
|
||
download_section = " | ".join(download_links) if download_links else "No direct download link available"
|
||
|
||
# 构建来源信息
|
||
source_info = []
|
||
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
|
||
source_info.append(f"Type: {paper.venue_type}")
|
||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||
source_info.append(f"Venue: {paper.venue_name}")
|
||
|
||
# 添加IF指数和分区信息
|
||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||
source_info.append(f"IF: {paper.if_factor}")
|
||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||
source_info.append(f"中科院分区: {paper.cas_division}")
|
||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||
source_info.append(f"JCR分区: {paper.jcr_division}")
|
||
|
||
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||
if paper.venue_info.get('journal_ref'):
|
||
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
|
||
if paper.venue_info.get('publisher'):
|
||
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
|
||
|
||
# 构建当前论文的格式化文本
|
||
paper_text = (
|
||
f"{i}. **{paper.title}**\n" +
|
||
f" Authors: {', '.join(authors)}\n" +
|
||
f" Year: {paper.year}\n" +
|
||
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
|
||
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
|
||
# 添加PubMed特有信息
|
||
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
|
||
'mesh_terms') and paper.mesh_terms else "") +
|
||
f" 📥 PDF Downloads: {download_section}\n" +
|
||
f" Abstract: {paper.abstract}\n"
|
||
)
|
||
|
||
formatted.append(paper_text)
|
||
|
||
full_text = "\n".join(formatted)
|
||
|
||
# 根据不同模型设置不同的token限制
|
||
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
|
||
|
||
token_limit = model_info[model_name]['max_token'] * 3 // 4
|
||
# 使用token限制控制长度
|
||
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
|
||
|
||
def _get_current_time(self) -> str:
|
||
"""获取当前时间信息"""
|
||
now = datetime.now()
|
||
return now.strftime("%Y年%m月%d日")
|
||
|
||
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
|
||
"""生成道歉提示"""
|
||
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
|
||
|
||
可能的原因:
|
||
1. 搜索词过于具体或专业
|
||
2. 时间范围限制过严
|
||
|
||
建议解决方案:
|
||
1. 尝试使用更通用的关键词
|
||
2. 扩大搜索时间范围
|
||
3. 使用同义词或相关术语
|
||
请根据以上建议调整后重试。"""
|
||
|
||
def get_ranked_papers(self) -> str:
|
||
"""获取排序后的论文列表的格式化字符串"""
|
||
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
|
||
|
||
def _is_pubmed_paper(self, paper) -> bool:
|
||
"""判断是否为PubMed论文"""
|
||
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url) |