* stage academic conversation * stage document conversation * fix buggy gradio version * file dynamic load * merge more academic plugins * accelerate nltk * feat: 为predict函数添加文件和URL读取功能 - 添加URL检测和网页内容提取功能,支持自动提取网页文本 - 添加文件路径识别和文件内容读取功能,支持private_upload路径格式 - 集成WebTextExtractor处理网页内容提取 - 集成TextContentLoader处理本地文件读取 - 支持文件路径与问题组合的智能处理 * back * block unstable --------- Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
480 lines
19 KiB
Python
480 lines
19 KiB
Python
from typing import List, Optional
|
||
from datetime import datetime
|
||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||
import random
|
||
|
||
class SemanticScholarSource(DataSource):
|
||
"""Semantic Scholar API实现,使用官方Python包"""
|
||
|
||
def __init__(self, api_key: str = None):
|
||
"""初始化
|
||
|
||
Args:
|
||
api_key: Semantic Scholar API密钥(可选)
|
||
"""
|
||
self.api_key = api_key
|
||
self._initialize() # 调用初始化方法
|
||
|
||
def _initialize(self) -> None:
|
||
"""初始化API客户端"""
|
||
if not self.api_key:
|
||
# 默认API密钥列表
|
||
default_api_keys = [
|
||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||
]
|
||
self.api_key = random.choice(default_api_keys)
|
||
|
||
self.client = None # 延迟初始化
|
||
self.fields = [
|
||
"title",
|
||
"authors",
|
||
"abstract",
|
||
"year",
|
||
"externalIds",
|
||
"citationCount",
|
||
"venue",
|
||
"openAccessPdf",
|
||
"publicationVenue"
|
||
]
|
||
|
||
async def _ensure_client(self):
|
||
"""确保客户端已初始化"""
|
||
if self.client is None:
|
||
from semanticscholar import AsyncSemanticScholar
|
||
self.client = AsyncSemanticScholar(api_key=self.api_key)
|
||
|
||
async def search(
|
||
self,
|
||
query: str,
|
||
limit: int = 100,
|
||
start_year: int = None
|
||
) -> List[PaperMetadata]:
|
||
"""搜索论文"""
|
||
try:
|
||
await self._ensure_client()
|
||
|
||
# 如果指定了起始年份,添加到查询中
|
||
if start_year:
|
||
query = f"{query} year>={start_year}"
|
||
|
||
# 直接使用 search_paper 的结果
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
|
||
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
|
||
self.client.auth_header
|
||
)
|
||
papers = response.get('data', [])
|
||
return [self._parse_paper_data(paper) for paper in papers]
|
||
except Exception as e:
|
||
print(f"搜索论文时发生错误: {str(e)}")
|
||
import traceback
|
||
print(traceback.format_exc())
|
||
return []
|
||
|
||
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
|
||
"""获取指定DOI的论文详情"""
|
||
try:
|
||
await self._ensure_client()
|
||
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
|
||
return self._parse_paper_data(paper)
|
||
except Exception as e:
|
||
print(f"获取论文详情时发生错误: {str(e)}")
|
||
return None
|
||
|
||
async def get_citations(
|
||
self,
|
||
doi: str,
|
||
limit: int = 100,
|
||
start_year: int = None
|
||
) -> List[PaperMetadata]:
|
||
"""获取引用指定DOI论文的文献列表"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 构建查询参数
|
||
fields_param = f"fields={','.join(self.fields)}"
|
||
limit_param = f"limit={limit}"
|
||
year_param = f"year>={start_year}" if start_year else ""
|
||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
|
||
params,
|
||
self.client.auth_header
|
||
)
|
||
citations = response.get('data', [])
|
||
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
|
||
except Exception as e:
|
||
print(f"获取引用列表时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def get_references(
|
||
self,
|
||
doi: str,
|
||
limit: int = 100,
|
||
start_year: int = None
|
||
) -> List[PaperMetadata]:
|
||
"""获取指定DOI论文的参考文献列表"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 构建查询参数
|
||
fields_param = f"fields={','.join(self.fields)}"
|
||
limit_param = f"limit={limit}"
|
||
year_param = f"year>={start_year}" if start_year else ""
|
||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
|
||
params,
|
||
self.client.auth_header
|
||
)
|
||
references = response.get('data', [])
|
||
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
|
||
except Exception as e:
|
||
print(f"获取参考文献列表时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||
"""获取论文推荐
|
||
|
||
根据一篇论文获取相关的推荐论文
|
||
|
||
Args:
|
||
doi: 论文的DOI
|
||
limit: 返回结果数量限制,最大500
|
||
|
||
Returns:
|
||
推荐论文列表
|
||
"""
|
||
try:
|
||
await self._ensure_client()
|
||
papers = await self.client.get_recommended_papers(
|
||
f"DOI:{doi}",
|
||
fields=self.fields,
|
||
limit=min(limit, 500)
|
||
)
|
||
return [self._parse_paper_data(paper) for paper in papers]
|
||
except Exception as e:
|
||
print(f"获取论文推荐时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def get_recommended_papers_from_lists(
|
||
self,
|
||
positive_dois: List[str],
|
||
negative_dois: List[str] = None,
|
||
limit: int = 100
|
||
) -> List[PaperMetadata]:
|
||
"""基于正负例论文列表获取推荐
|
||
|
||
Args:
|
||
positive_dois: 正例论文DOI列表(想要获取类似的论文)
|
||
negative_dois: 负例论文DOI列表(不想要类似的论文)
|
||
limit: 返回结果数量限制,最大500
|
||
|
||
Returns:
|
||
推荐论文列表
|
||
"""
|
||
try:
|
||
await self._ensure_client()
|
||
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
|
||
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
|
||
|
||
papers = await self.client.get_recommended_papers_from_lists(
|
||
positive_paper_ids=positive_ids,
|
||
negative_paper_ids=negative_ids,
|
||
fields=self.fields,
|
||
limit=min(limit, 500)
|
||
)
|
||
return [self._parse_paper_data(paper) for paper in papers]
|
||
except Exception as e:
|
||
print(f"获取论文推荐列表时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
|
||
"""搜索作者"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 直接使用 API 请求而不是 search_author 方法
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
|
||
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
|
||
self.client.auth_header
|
||
)
|
||
authors = response.get('data', [])
|
||
return [
|
||
{
|
||
'author_id': author.get('authorId'),
|
||
'name': author.get('name'),
|
||
'paper_count': author.get('paperCount'),
|
||
'citation_count': author.get('citationCount'),
|
||
}
|
||
for author in authors
|
||
]
|
||
except Exception as e:
|
||
print(f"搜索作者时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def get_author_details(self, author_id: str) -> Optional[dict]:
|
||
"""获取作者详细信息"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 直接使用 API 请求
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
|
||
"fields=name,paperCount,citationCount,hIndex",
|
||
self.client.auth_header
|
||
)
|
||
return {
|
||
'author_id': response.get('authorId'),
|
||
'name': response.get('name'),
|
||
'paper_count': response.get('paperCount'),
|
||
'citation_count': response.get('citationCount'),
|
||
'h_index': response.get('hIndex'),
|
||
}
|
||
except Exception as e:
|
||
print(f"获取作者详情时发生错误: {str(e)}")
|
||
return None
|
||
|
||
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
|
||
"""获取作者的论文列表"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 直接使用 API 请求
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
|
||
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
|
||
self.client.auth_header
|
||
)
|
||
papers = response.get('data', [])
|
||
return [self._parse_paper_data(paper) for paper in papers]
|
||
except Exception as e:
|
||
print(f"获取作者论文列表时发生错误: {str(e)}")
|
||
return []
|
||
|
||
async def get_paper_autocomplete(self, query: str) -> List[dict]:
|
||
"""论文标题自动补全"""
|
||
try:
|
||
await self._ensure_client()
|
||
# 直接使用 API 请求
|
||
response = await self.client._requester.get_data_async(
|
||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
|
||
f"query={query}",
|
||
self.client.auth_header
|
||
)
|
||
suggestions = response.get('matches', [])
|
||
return [
|
||
{
|
||
'title': suggestion.get('title'),
|
||
'paper_id': suggestion.get('paperId'),
|
||
'year': suggestion.get('year'),
|
||
'venue': suggestion.get('venue'),
|
||
}
|
||
for suggestion in suggestions
|
||
]
|
||
except Exception as e:
|
||
print(f"获取标题自动补全时发生错误: {str(e)}")
|
||
return []
|
||
|
||
def _parse_paper_data(self, paper) -> PaperMetadata:
|
||
"""解析论文数据"""
|
||
# 获取DOI
|
||
doi = None
|
||
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
|
||
if external_ids:
|
||
if isinstance(external_ids, dict):
|
||
doi = external_ids.get('DOI')
|
||
if not doi and 'ArXiv' in external_ids:
|
||
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
|
||
else:
|
||
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
|
||
if not doi and hasattr(external_ids, 'ArXiv'):
|
||
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
|
||
|
||
# 获取PDF URL
|
||
pdf_url = None
|
||
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
|
||
if pdf_info:
|
||
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
|
||
|
||
# 获取发表场所详细信息
|
||
venue_type = None
|
||
venue_name = None
|
||
venue_info = {}
|
||
|
||
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
|
||
if venue:
|
||
if isinstance(venue, dict):
|
||
venue_name = venue.get('name')
|
||
venue_type = venue.get('type')
|
||
# 提取更多venue信息
|
||
venue_info = {
|
||
'issn': venue.get('issn'),
|
||
'publisher': venue.get('publisher'),
|
||
'url': venue.get('url'),
|
||
'alternate_names': venue.get('alternate_names', [])
|
||
}
|
||
else:
|
||
venue_name = venue.name if hasattr(venue, 'name') else None
|
||
venue_type = venue.type if hasattr(venue, 'type') else None
|
||
venue_info = {
|
||
'issn': getattr(venue, 'issn', None),
|
||
'publisher': getattr(venue, 'publisher', None),
|
||
'url': getattr(venue, 'url', None),
|
||
'alternate_names': getattr(venue, 'alternate_names', [])
|
||
}
|
||
|
||
# 获取标题
|
||
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
|
||
|
||
# 获取作者
|
||
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
|
||
author_names = []
|
||
for author in authors:
|
||
if isinstance(author, dict):
|
||
author_names.append(author.get('name', ''))
|
||
else:
|
||
author_names.append(author.name if hasattr(author, 'name') else str(author))
|
||
|
||
# 获取摘要
|
||
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
|
||
|
||
# 获取年份
|
||
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
|
||
|
||
# 获取引用次数
|
||
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
|
||
|
||
return PaperMetadata(
|
||
title=title,
|
||
authors=author_names,
|
||
abstract=abstract,
|
||
year=year,
|
||
doi=doi,
|
||
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
|
||
citations=citations,
|
||
venue=venue_name,
|
||
institutions=[],
|
||
venue_type=venue_type,
|
||
venue_name=venue_name,
|
||
venue_info=venue_info,
|
||
source='semantic' # 添加来源标记
|
||
)
|
||
|
||
async def example_usage():
|
||
"""SemanticScholarSource使用示例"""
|
||
semantic = SemanticScholarSource()
|
||
|
||
try:
|
||
# 示例1:使用DOI直接获取论文
|
||
print("\n=== 示例1:通过DOI获取论文 ===")
|
||
doi = "10.18653/v1/N19-1423" # BERT论文
|
||
print(f"获取DOI为 {doi} 的论文信息...")
|
||
|
||
paper = await semantic.get_paper_details(doi)
|
||
if paper:
|
||
print("\n--- 论文信息 ---")
|
||
print(f"标题: {paper.title}")
|
||
print(f"作者: {', '.join(paper.authors)}")
|
||
print(f"发表年份: {paper.year}")
|
||
print(f"DOI: {paper.doi}")
|
||
print(f"URL: {paper.url}")
|
||
if paper.abstract:
|
||
print(f"\n摘要:")
|
||
print(paper.abstract)
|
||
print(f"\n引用次数: {paper.citations}")
|
||
print(f"发表venue: {paper.venue}")
|
||
|
||
# 示例2:搜索论文
|
||
print("\n=== 示例2:搜索论文 ===")
|
||
query = "BERT pre-training"
|
||
print(f"搜索关键词 '{query}' 相关的论文...")
|
||
papers = await semantic.search(query=query, limit=3)
|
||
|
||
for i, paper in enumerate(papers, 1):
|
||
print(f"\n--- 搜索结果 {i} ---")
|
||
print(f"标题: {paper.title}")
|
||
print(f"作者: {', '.join(paper.authors)}")
|
||
print(f"发表年份: {paper.year}")
|
||
if paper.abstract:
|
||
print(f"\n摘要:")
|
||
print(paper.abstract)
|
||
print(f"\nDOI: {paper.doi}")
|
||
print(f"引用次数: {paper.citations}")
|
||
|
||
# 示例3:获取论文推荐
|
||
print("\n=== 示例3:获取论文推荐 ===")
|
||
print(f"获取与论文 {doi} 相关的推荐论文...")
|
||
recommendations = await semantic.get_recommended_papers(doi, limit=3)
|
||
for i, paper in enumerate(recommendations, 1):
|
||
print(f"\n--- 推荐论文 {i} ---")
|
||
print(f"标题: {paper.title}")
|
||
print(f"作者: {', '.join(paper.authors)}")
|
||
print(f"发表年份: {paper.year}")
|
||
|
||
# 示例4:基于多篇论文的推荐
|
||
print("\n=== 示例4:基于多篇论文的推荐 ===")
|
||
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
|
||
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
|
||
multi_recommendations = await semantic.get_recommended_papers_from_lists(
|
||
positive_dois=positive_dois,
|
||
limit=3
|
||
)
|
||
for i, paper in enumerate(multi_recommendations, 1):
|
||
print(f"\n--- 推荐论文 {i} ---")
|
||
print(f"标题: {paper.title}")
|
||
print(f"作者: {', '.join(paper.authors)}")
|
||
|
||
# 示例5:搜索作者
|
||
print("\n=== 示例5:搜索作者 ===")
|
||
author_query = "Yann LeCun"
|
||
print(f"搜索作者: '{author_query}'")
|
||
authors = await semantic.search_author(author_query, limit=3)
|
||
for i, author in enumerate(authors, 1):
|
||
print(f"\n--- 作者 {i} ---")
|
||
print(f"姓名: {author['name']}")
|
||
print(f"论文数量: {author['paper_count']}")
|
||
print(f"总引用次数: {author['citation_count']}")
|
||
|
||
# 示例6:获取作者详情
|
||
print("\n=== 示例6:获取作者详情 ===")
|
||
if authors: # 使用第一个搜索结果的作者ID
|
||
author_id = authors[0]['author_id']
|
||
print(f"获取作者ID {author_id} 的详细信息...")
|
||
author_details = await semantic.get_author_details(author_id)
|
||
if author_details:
|
||
print(f"姓名: {author_details['name']}")
|
||
print(f"H指数: {author_details['h_index']}")
|
||
print(f"总引用次数: {author_details['citation_count']}")
|
||
print(f"发表论文数: {author_details['paper_count']}")
|
||
|
||
# 示例7:获取作者论文
|
||
print("\n=== 示例7:获取作者论文 ===")
|
||
if authors: # 使用第一个搜索结果的作者ID
|
||
author_id = authors[0]['author_id']
|
||
print(f"获取作者 {authors[0]['name']} 的论文列表...")
|
||
author_papers = await semantic.get_author_papers(author_id, limit=3)
|
||
for i, paper in enumerate(author_papers, 1):
|
||
print(f"\n--- 论文 {i} ---")
|
||
print(f"标题: {paper.title}")
|
||
print(f"发表年份: {paper.year}")
|
||
print(f"引用次数: {paper.citations}")
|
||
|
||
# 示例8:论文标题自动补全
|
||
print("\n=== 示例8:论文标题自动补全 ===")
|
||
title_query = "Attention is all"
|
||
print(f"搜索标题: '{title_query}'")
|
||
suggestions = await semantic.get_paper_autocomplete(title_query)
|
||
for i, suggestion in enumerate(suggestions[:3], 1):
|
||
print(f"\n--- 建议 {i} ---")
|
||
print(f"标题: {suggestion['title']}")
|
||
print(f"发表年份: {suggestion['year']}")
|
||
print(f"发表venue: {suggestion['venue']}")
|
||
|
||
except Exception as e:
|
||
print(f"发生错误: {str(e)}")
|
||
import traceback
|
||
print(traceback.format_exc())
|
||
|
||
if __name__ == "__main__":
|
||
import asyncio
|
||
asyncio.run(example_usage()) |