Files
gpt_academic/crazy_functions/review_fns/data_sources/semantic_source.py
binary-husky 8042750d41 Master 4.0 (#2210)
* stage academic conversation

* stage document conversation

* fix buggy gradio version

* file dynamic load

* merge more academic plugins

* accelerate nltk

* feat: 为predict函数添加文件和URL读取功能
- 添加URL检测和网页内容提取功能,支持自动提取网页文本
- 添加文件路径识别和文件内容读取功能,支持private_upload路径格式
- 集成WebTextExtractor处理网页内容提取
- 集成TextContentLoader处理本地文件读取
- 支持文件路径与问题组合的智能处理

* back

* block unstable

---------

Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
2025-08-23 15:59:22 +08:00

480 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class SemanticScholarSource(DataSource):
"""Semantic Scholar API实现,使用官方Python包"""
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Semantic Scholar API密钥(可选)
"""
self.api_key = api_key
self._initialize() # 调用初始化方法
def _initialize(self) -> None:
"""初始化API客户端"""
if not self.api_key:
# 默认API密钥列表
default_api_keys = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
self.api_key = random.choice(default_api_keys)
self.client = None # 延迟初始化
self.fields = [
"title",
"authors",
"abstract",
"year",
"externalIds",
"citationCount",
"venue",
"openAccessPdf",
"publicationVenue"
]
async def _ensure_client(self):
"""确保客户端已初始化"""
if self.client is None:
from semanticscholar import AsyncSemanticScholar
self.client = AsyncSemanticScholar(api_key=self.api_key)
async def search(
self,
query: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
await self._ensure_client()
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} year>={start_year}"
# 直接使用 search_paper 的结果
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
return []
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
"""获取指定DOI的论文详情"""
try:
await self._ensure_client()
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
return self._parse_paper_data(paper)
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
async def get_citations(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
params,
self.client.auth_header
)
citations = response.get('data', [])
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
except Exception as e:
print(f"获取引用列表时发生错误: {str(e)}")
return []
async def get_references(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
params,
self.client.auth_header
)
references = response.get('data', [])
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
except Exception as e:
print(f"获取参考文献列表时发生错误: {str(e)}")
return []
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取论文推荐
根据一篇论文获取相关的推荐论文
Args:
doi: 论文的DOI
limit: 返回结果数量限制最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
papers = await self.client.get_recommended_papers(
f"DOI:{doi}",
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐时发生错误: {str(e)}")
return []
async def get_recommended_papers_from_lists(
self,
positive_dois: List[str],
negative_dois: List[str] = None,
limit: int = 100
) -> List[PaperMetadata]:
"""基于正负例论文列表获取推荐
Args:
positive_dois: 正例论文DOI列表想要获取类似的论文
negative_dois: 负例论文DOI列表不想要类似的论文
limit: 返回结果数量限制最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
papers = await self.client.get_recommended_papers_from_lists(
positive_paper_ids=positive_ids,
negative_paper_ids=negative_ids,
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐列表时发生错误: {str(e)}")
return []
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
"""搜索作者"""
try:
await self._ensure_client()
# 直接使用 API 请求而不是 search_author 方法
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
self.client.auth_header
)
authors = response.get('data', [])
return [
{
'author_id': author.get('authorId'),
'name': author.get('name'),
'paper_count': author.get('paperCount'),
'citation_count': author.get('citationCount'),
}
for author in authors
]
except Exception as e:
print(f"搜索作者时发生错误: {str(e)}")
return []
async def get_author_details(self, author_id: str) -> Optional[dict]:
"""获取作者详细信息"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
"fields=name,paperCount,citationCount,hIndex",
self.client.auth_header
)
return {
'author_id': response.get('authorId'),
'name': response.get('name'),
'paper_count': response.get('paperCount'),
'citation_count': response.get('citationCount'),
'h_index': response.get('hIndex'),
}
except Exception as e:
print(f"获取作者详情时发生错误: {str(e)}")
return None
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
"""获取作者的论文列表"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取作者论文列表时发生错误: {str(e)}")
return []
async def get_paper_autocomplete(self, query: str) -> List[dict]:
"""论文标题自动补全"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
f"query={query}",
self.client.auth_header
)
suggestions = response.get('matches', [])
return [
{
'title': suggestion.get('title'),
'paper_id': suggestion.get('paperId'),
'year': suggestion.get('year'),
'venue': suggestion.get('venue'),
}
for suggestion in suggestions
]
except Exception as e:
print(f"获取标题自动补全时发生错误: {str(e)}")
return []
def _parse_paper_data(self, paper) -> PaperMetadata:
"""解析论文数据"""
# 获取DOI
doi = None
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
if external_ids:
if isinstance(external_ids, dict):
doi = external_ids.get('DOI')
if not doi and 'ArXiv' in external_ids:
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
else:
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
if not doi and hasattr(external_ids, 'ArXiv'):
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
# 获取PDF URL
pdf_url = None
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
if pdf_info:
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
# 获取发表场所详细信息
venue_type = None
venue_name = None
venue_info = {}
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
if venue:
if isinstance(venue, dict):
venue_name = venue.get('name')
venue_type = venue.get('type')
# 提取更多venue信息
venue_info = {
'issn': venue.get('issn'),
'publisher': venue.get('publisher'),
'url': venue.get('url'),
'alternate_names': venue.get('alternate_names', [])
}
else:
venue_name = venue.name if hasattr(venue, 'name') else None
venue_type = venue.type if hasattr(venue, 'type') else None
venue_info = {
'issn': getattr(venue, 'issn', None),
'publisher': getattr(venue, 'publisher', None),
'url': getattr(venue, 'url', None),
'alternate_names': getattr(venue, 'alternate_names', [])
}
# 获取标题
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
# 获取作者
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
author_names = []
for author in authors:
if isinstance(author, dict):
author_names.append(author.get('name', ''))
else:
author_names.append(author.name if hasattr(author, 'name') else str(author))
# 获取摘要
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
# 获取年份
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
# 获取引用次数
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
return PaperMetadata(
title=title,
authors=author_names,
abstract=abstract,
year=year,
doi=doi,
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
citations=citations,
venue=venue_name,
institutions=[],
venue_type=venue_type,
venue_name=venue_name,
venue_info=venue_info,
source='semantic' # 添加来源标记
)
async def example_usage():
"""SemanticScholarSource使用示例"""
semantic = SemanticScholarSource()
try:
# 示例1使用DOI直接获取论文
print("\n=== 示例1通过DOI获取论文 ===")
doi = "10.18653/v1/N19-1423" # BERT论文
print(f"获取DOI为 {doi} 的论文信息...")
paper = await semantic.get_paper_details(doi)
if paper:
print("\n--- 论文信息 ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\n引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
# 示例2搜索论文
print("\n=== 示例2搜索论文 ===")
query = "BERT pre-training"
print(f"搜索关键词 '{query}' 相关的论文...")
papers = await semantic.search(query=query, limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 搜索结果 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\nDOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例3获取论文推荐
print("\n=== 示例3获取论文推荐 ===")
print(f"获取与论文 {doi} 相关的推荐论文...")
recommendations = await semantic.get_recommended_papers(doi, limit=3)
for i, paper in enumerate(recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 示例4基于多篇论文的推荐
print("\n=== 示例4基于多篇论文的推荐 ===")
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
multi_recommendations = await semantic.get_recommended_papers_from_lists(
positive_dois=positive_dois,
limit=3
)
for i, paper in enumerate(multi_recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
# 示例5搜索作者
print("\n=== 示例5搜索作者 ===")
author_query = "Yann LeCun"
print(f"搜索作者: '{author_query}'")
authors = await semantic.search_author(author_query, limit=3)
for i, author in enumerate(authors, 1):
print(f"\n--- 作者 {i} ---")
print(f"姓名: {author['name']}")
print(f"论文数量: {author['paper_count']}")
print(f"总引用次数: {author['citation_count']}")
# 示例6获取作者详情
print("\n=== 示例6获取作者详情 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者ID {author_id} 的详细信息...")
author_details = await semantic.get_author_details(author_id)
if author_details:
print(f"姓名: {author_details['name']}")
print(f"H指数: {author_details['h_index']}")
print(f"总引用次数: {author_details['citation_count']}")
print(f"发表论文数: {author_details['paper_count']}")
# 示例7获取作者论文
print("\n=== 示例7获取作者论文 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者 {authors[0]['name']} 的论文列表...")
author_papers = await semantic.get_author_papers(author_id, limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
print(f"引用次数: {paper.citations}")
# 示例8论文标题自动补全
print("\n=== 示例8论文标题自动补全 ===")
title_query = "Attention is all"
print(f"搜索标题: '{title_query}'")
suggestions = await semantic.get_paper_autocomplete(title_query)
for i, suggestion in enumerate(suggestions[:3], 1):
print(f"\n--- 建议 {i} ---")
print(f"标题: {suggestion['title']}")
print(f"发表年份: {suggestion['year']}")
print(f"发表venue: {suggestion['venue']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())