Master 4.0 (#2210)

* stage academic conversation

* stage document conversation

* fix buggy gradio version

* file dynamic load

* merge more academic plugins

* accelerate nltk

* feat: 为predict函数添加文件和URL读取功能
- 添加URL检测和网页内容提取功能,支持自动提取网页文本
- 添加文件路径识别和文件内容读取功能,支持private_upload路径格式
- 集成WebTextExtractor处理网页内容提取
- 集成TextContentLoader处理本地文件读取
- 支持文件路径与问题组合的智能处理

* back

* block unstable

---------

Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
This commit is contained in:
binary-husky
2025-08-23 15:59:22 +08:00
committed by GitHub
parent 65a4cf59c2
commit 8042750d41
79 changed files with 20850 additions and 57 deletions

View File

@@ -0,0 +1,279 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class AdsabsSource(DataSource):
"""ADS (Astrophysics Data System) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: ADS API密钥如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.adsabs.harvard.edu/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
"""发送HTTP请求
Args:
url: 请求URL
method: HTTP方法
data: POST请求数据
Returns:
响应内容
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
if method == "GET":
async with session.get(url) as response:
if response.status == 200:
return await response.json()
elif method == "POST":
async with session.post(url, json=data) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
def _parse_paper(self, doc: dict) -> PaperMetadata:
"""解析ADS文献数据
Args:
doc: ADS文献数据
Returns:
解析后的论文数据
"""
try:
return PaperMetadata(
title=doc.get('title', [''])[0] if doc.get('title') else '',
authors=doc.get('author', []),
abstract=doc.get('abstract', ''),
year=doc.get('year'),
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
citations=doc.get('citation_count'),
venue=doc.get('pub', ''),
institutions=doc.get('aff', []),
venue_type="journal",
venue_name=doc.get('pub', ''),
venue_info={
'volume': doc.get('volume'),
'issue': doc.get('issue'),
'pub_date': doc.get('pubdate', '')
},
source='adsabs'
)
except Exception as e:
print(f"解析文章时发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 构建查询
if start_year:
query = f"{query} year:{start_year}-"
# 设置排序
sort_mapping = {
'relevance': 'score desc',
'date': 'date desc',
'citations': 'citation_count desc'
}
sort = sort_mapping.get(sort_by, 'score desc')
# 构建搜索请求
search_url = f"{self.base_url}/search/query"
params = {
"q": query,
"rows": limit,
"sort": sort,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
# 解析结果
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _build_query_string(self, params: dict) -> str:
"""构建查询字符串"""
return "&".join([f"{k}={v}" for k, v in params.items()])
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
"""获取指定bibcode的论文详情"""
search_url = f"{self.base_url}/search/query"
params = {
"q": f"identifier:{bibcode}",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if response and 'response' in response and response['response']['docs']:
return self._parse_paper(response['response']['docs'][0])
return None
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
"""获取相关论文"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
"rows": limit,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"author:\"{author}\""
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"pub:\"{journal}\""
return await self.search(query, limit=limit, start_year=start_year)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文"""
query = f"entdate:[NOW-{days}DAYS TO NOW]"
return await self.search(query, limit=limit, sort_by="date")
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"references(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def example_usage():
"""AdsabsSource使用示例"""
ads = AdsabsSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索黑洞相关论文 ===")
papers = await ads.search("black hole", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 其他示例...
except Exception as e:
print(f"发生错误: {str(e)}")
if __name__ == "__main__":
# python -m crazy_functions.review_fns.data_sources.adsabs_source
asyncio.run(example_usage())

View File

@@ -0,0 +1,636 @@
import arxiv
from typing import List, Optional, Union, Literal, Dict
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.request import urlretrieve
import feedparser
from tqdm import tqdm
class ArxivSource(DataSource):
"""arXiv API实现"""
CATEGORIES = {
# 物理学
"Physics": {
"astro-ph": "天体物理学",
"cond-mat": "凝聚态物理",
"gr-qc": "广义相对论与量子宇宙学",
"hep-ex": "高能物理实验",
"hep-lat": "格点场论",
"hep-ph": "高能物理理论",
"hep-th": "高能物理理论",
"math-ph": "数学物理",
"nlin": "非线性科学",
"nucl-ex": "核实验",
"nucl-th": "核理论",
"physics": "物理学",
"quant-ph": "量子物理",
},
# 数学
"Mathematics": {
"math.AG": "代数几何",
"math.AT": "代数拓扑",
"math.AP": "分析与偏微分方程",
"math.CT": "范畴论",
"math.CA": "复分析",
"math.CO": "组合数学",
"math.AC": "交换代数",
"math.CV": "复变函数",
"math.DG": "微分几何",
"math.DS": "动力系统",
"math.FA": "泛函分析",
"math.GM": "一般数学",
"math.GN": "一般拓扑",
"math.GT": "几何拓扑",
"math.GR": "群论",
"math.HO": "数学史与数学概述",
"math.IT": "信息论",
"math.KT": "K理论与同调",
"math.LO": "逻辑",
"math.MP": "数学物理",
"math.MG": "度量几何",
"math.NT": "数论",
"math.NA": "数值分析",
"math.OA": "算子代数",
"math.OC": "最优化与控制",
"math.PR": "概率论",
"math.QA": "量子代数",
"math.RT": "表示论",
"math.RA": "环与代数",
"math.SP": "谱理论",
"math.ST": "统计理论",
"math.SG": "辛几何",
},
# 计算机科学
"Computer Science": {
"cs.AI": "人工智能",
"cs.CL": "计算语言学",
"cs.CC": "计算复杂性",
"cs.CE": "计算工程",
"cs.CG": "计算几何",
"cs.GT": "计算机博弈论",
"cs.CV": "计算机视觉",
"cs.CY": "计算机与社会",
"cs.CR": "密码学与安全",
"cs.DS": "数据结构与算法",
"cs.DB": "数据库",
"cs.DL": "数字图书馆",
"cs.DM": "离散数学",
"cs.DC": "分布式计算",
"cs.ET": "新兴技术",
"cs.FL": "形式语言与自动机理论",
"cs.GL": "一般文献",
"cs.GR": "图形学",
"cs.AR": "硬件架构",
"cs.HC": "人机交互",
"cs.IR": "信息检索",
"cs.IT": "信息论",
"cs.LG": "机器学习",
"cs.LO": "逻辑与计算机",
"cs.MS": "数学软件",
"cs.MA": "多智能体系统",
"cs.MM": "多媒体",
"cs.NI": "网络与互联网架构",
"cs.NE": "神经与进化计算",
"cs.NA": "数值分析",
"cs.OS": "操作系统",
"cs.OH": "其他计算机科学",
"cs.PF": "性能评估",
"cs.PL": "编程语言",
"cs.RO": "机器人学",
"cs.SI": "社会与信息网络",
"cs.SE": "软件工程",
"cs.SD": "声音",
"cs.SC": "符号计算",
"cs.SY": "系统与控制",
},
# 定量生物学
"Quantitative Biology": {
"q-bio.BM": "生物分子",
"q-bio.CB": "细胞行为",
"q-bio.GN": "基因组学",
"q-bio.MN": "分子网络",
"q-bio.NC": "神经计算",
"q-bio.OT": "其他",
"q-bio.PE": "群体与进化",
"q-bio.QM": "定量方法",
"q-bio.SC": "亚细胞过程",
"q-bio.TO": "组织与器官",
},
# 定量金融
"Quantitative Finance": {
"q-fin.CP": "计算金融",
"q-fin.EC": "经济学",
"q-fin.GN": "一般金融",
"q-fin.MF": "数学金融",
"q-fin.PM": "投资组合管理",
"q-fin.PR": "定价理论",
"q-fin.RM": "风险管理",
"q-fin.ST": "统计金融",
"q-fin.TR": "交易与市场微观结构",
},
# 统计学
"Statistics": {
"stat.AP": "应用统计",
"stat.CO": "计算统计",
"stat.ML": "机器学习",
"stat.ME": "方法论",
"stat.OT": "其他统计",
"stat.TH": "统计理论",
},
# 电气工程与系统科学
"Electrical Engineering and Systems Science": {
"eess.AS": "音频与语音处理",
"eess.IV": "图像与视频处理",
"eess.SP": "信号处理",
"eess.SY": "系统与控制",
},
# 经济学
"Economics": {
"econ.EM": "计量经济学",
"econ.GN": "一般经济学",
"econ.TH": "理论经济学",
}
}
def __init__(self):
"""初始化"""
self._initialize() # 调用初始化方法
# 修改排序选项映射
self.sort_options = {
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
}
self.sort_order_options = {
'ascending': arxiv.SortOrder.Ascending,
'descending': arxiv.SortOrder.Descending
}
self.default_sort = 'lastUpdatedDate'
self.default_order = 'descending'
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.client = arxiv.Client()
async def search(
self,
query: str,
limit: int = 10,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[Dict]:
"""搜索论文"""
try:
# 使用默认排序如果提供的排序选项无效
if not sort_by or sort_by not in self.sort_options:
sort_by = self.default_sort
# 使用默认排序顺序如果提供的顺序无效
if not sort_order or sort_order not in self.sort_order_options:
sort_order = self.default_order
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
search = arxiv.Search(
query=query,
max_results=limit,
sort_by=self.sort_options[sort_by],
sort_order=self.sort_order_options[sort_order]
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
"""按ID搜索论文
Args:
paper_id: 单个arXiv ID或ID列表例如'2005.14165' 或 ['2005.14165', '2103.14030']
"""
if isinstance(paper_id, str):
paper_id = [paper_id]
search = arxiv.Search(
id_list=paper_id,
max_results=len(paper_id)
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
async def search_by_category(
self,
category: str,
limit: int = 100,
sort_by: str = 'relevance',
sort_order: str = 'descending',
start_year: int = None
) -> List[PaperMetadata]:
"""按类别搜索论文"""
query = f"cat:{category}"
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = 'relevance',
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " AND ".join([f"au:\"{author}\"" for author in authors])
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
sort_order: Literal['ascending', 'descending'] = 'descending'
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
return await self.search(
query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文PDF
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
# 清理标题中的非法字符
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.pdf"
filepath = os.path.join(dirpath, filename)
urlretrieve(paper.url, filepath)
return filepath
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文源文件通常是LaTeX源码
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.tar.gz"
filepath = os.path.join(dirpath, filename)
source_url = paper.url.replace("/pdf/", "/src/")
urlretrieve(source_url, filepath)
return filepath
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详情
Args:
paper_id: arXiv ID 或 DOI
Returns:
论文详细信息,如果未找到返回 None
"""
try:
# 如果是完整的 arXiv URL提取 ID
if "arxiv.org" in paper_id:
paper_id = paper_id.split("/")[-1]
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
elif paper_id.startswith("10.48550/arXiv."):
paper_id = paper_id.split(".")[-1]
papers = await self.search_by_id(paper_id)
return papers[0] if papers else None
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
"""解析arXiv API返回的数据"""
# 解析主要类别和次要类别
primary_category = result.primary_category
categories = result.categories
# 构建venue信息
venue_info = {
'primary_category': primary_category,
'categories': categories,
'comments': getattr(result, 'comment', None),
'journal_ref': getattr(result, 'journal_ref', None)
}
return PaperMetadata(
title=result.title,
authors=[author.name for author in result.authors],
abstract=result.summary,
year=result.published.year,
doi=result.entry_id,
url=result.pdf_url,
citations=None,
venue=f"arXiv:{primary_category}",
institutions=[],
venue_type='preprint', # arXiv论文都是预印本
venue_name='arXiv',
venue_info=venue_info,
source='arxiv' # 添加来源标记
)
async def get_latest_papers(
self,
category: str,
debug: bool = False,
batch_size: int = 50
) -> List[PaperMetadata]:
"""获取指定类别的最新论文
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
Args:
category: arXiv类别例如
- 整个领域: 'cs'
- 具体方向: 'cs.AI'
- 多个类别: 'cs.AI+q-bio.NC'
debug: 是否为调试模式如果为True则只返回5篇最新论文
batch_size: 批量获取论文的数量默认50
Returns:
论文列表
Raises:
ValueError: 如果类别无效
"""
try:
# 处理类别格式
# 1. 转换为小写
# 2. 确保多个类别之间使用+连接
category = category.lower().replace(' ', '+')
# 构建RSS feed URL
feed_url = f"https://rss.arxiv.org/rss/{category}"
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
feed = feedparser.parse(feed_url)
# 检查feed是否有效
if hasattr(feed, 'status') and feed.status != 200:
raise ValueError(f"获取RSS feed失败状态码: {feed.status}")
if not feed.entries:
print(f"警告未在feed中找到任何条目") # 添加调试信息
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
if debug:
# 调试模式只获取5篇最新论文
search = arxiv.Search(
query=f'cat:{category}',
sort_by=arxiv.SortCriterion.SubmittedDate,
sort_order=arxiv.SortOrder.Descending,
max_results=5
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
# 正常模式:获取所有新论文
# 从RSS条目中提取arXiv ID
paper_ids = []
for entry in feed.entries:
try:
# RSS链接格式可能是以下几种
# - http://arxiv.org/abs/2403.xxxxx
# - http://arxiv.org/pdf/2403.xxxxx
# - https://arxiv.org/abs/2403.xxxxx
link = entry.link or entry.id
arxiv_id = link.split('/')[-1].replace('.pdf', '')
if arxiv_id:
paper_ids.append(arxiv_id)
except Exception as e:
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
continue
if not paper_ids:
print("未能从feed中提取到任何论文ID") # 添加调试信息
return []
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
# 批量获取论文详情
papers = []
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
for i in range(0, len(paper_ids), batch_size):
batch_ids = paper_ids[i:i + batch_size]
search = arxiv.Search(
id_list=batch_ids,
max_results=len(batch_ids)
)
batch_results = list(self.client.results(search))
papers.extend([self._parse_paper_data(result) for result in batch_results])
pbar.update(len(batch_results))
return papers
except Exception as e:
print(f"获取最新论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc()) # 添加完整的错误追踪
return []
async def example_usage():
"""ArxivSource使用示例"""
arxiv_source = ArxivSource()
try:
# 示例1基本搜索使用不同的排序方式
# print("\n=== 示例1搜索最新的机器学习论文按提交时间排序===")
# papers = await arxiv_source.search(
# "ti:\"machine learning\"",
# limit=3,
# sort_by='submitted',
# sort_order='descending'
# )
# print(f"找到 {len(papers)} 篇论文")
# for i, paper in enumerate(papers, 1):
# print(f"\n--- 论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# print(f"arXiv ID: {paper.doi}")
# print(f"PDF URL: {paper.url}")
# if paper.abstract:
# print(f"\n摘要:")
# print(paper.abstract)
# print(f"发表venue: {paper.venue}")
# # 示例2按ID搜索
# print("\n=== 示例2按ID搜索论文 ===")
# paper_id = "2005.14165" # GPT-3论文
# papers = await arxiv_source.search_by_id(paper_id)
# if papers:
# paper = papers[0]
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例3按类别搜索
# print("\n=== 示例3搜索人工智能领域最新论文 ===")
# ai_papers = await arxiv_source.search_by_category(
# "cs.AI",
# limit=2,
# sort_by='updated',
# sort_order='descending'
# )
# for i, paper in enumerate(ai_papers, 1):
# print(f"\n--- AI论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例4按作者搜索
# print("\n=== 示例4搜索特定作者的论文 ===")
# author_papers = await arxiv_source.search_by_authors(
# ["Bengio"],
# limit=2,
# sort_by='relevance'
# )
# for i, paper in enumerate(author_papers, 1):
# print(f"\n--- Bengio的论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例5按日期范围搜索
# print("\n=== 示例5搜索特定日期范围的论文 ===")
# from datetime import datetime, timedelta
# end_date = datetime.now()
# start_date = end_date - timedelta(days=7) # 最近一周
# recent_papers = await arxiv_source.search_by_date_range(
# start_date,
# end_date,
# limit=2
# )
# for i, paper in enumerate(recent_papers, 1):
# print(f"\n--- 最近论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例6下载PDF
# print("\n=== 示例6下载论文PDF ===")
# if papers: # 使用之前搜索到的GPT-3论文
# pdf_path = await arxiv_source.download_pdf(paper_id)
# print(f"PDF已下载到: {pdf_path}")
# # 示例7下载源文件
# print("\n=== 示例7下载论文源文件 ===")
# if papers:
# source_path = await arxiv_source.download_source(paper_id)
# print(f"源文件已下载到: {source_path}")
# 示例6获取最新论文
print("\n=== 示例8获取最新论文 ===")
# 获取CS.AI领域的最新论文
print("\n--- 获取AI领域最新论文 ---")
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
for i, paper in enumerate(ai_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取整个计算机科学领域的最新论文
print("\n--- 获取整个CS领域最新论文 ---")
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
for i, paper in enumerate(cs_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取多个类别的最新论文
print("\n--- 获取AI和机器学习领域最新论文 ---")
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
for i, paper in enumerate(multi_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

View File

@@ -0,0 +1,102 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
from dataclasses import dataclass
class PaperMetadata:
"""论文元数据"""
def __init__(
self,
title: str,
authors: List[str],
abstract: str,
year: int,
doi: str = None,
url: str = None,
citations: int = None,
venue: str = None,
institutions: List[str] = None,
venue_type: str = None, # 来源类型(journal/conference/preprint等)
venue_name: str = None, # 具体的期刊/会议名称
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
source: str = None # 新增: 论文来源标记
):
self.title = title
self.authors = authors
self.abstract = abstract
self.year = year
self.doi = doi
self.url = url
self.citations = citations
self.venue = venue
self.institutions = institutions or []
self.venue_type = venue_type # 新增
self.venue_name = venue_name # 新增
self.venue_info = venue_info or {} # 新增
self.source = source # 新增: 存储论文来源
# 新增影响因子和分区信息初始化为None
self._if_factor = None
self._cas_division = None
self._jcr_division = None
@property
def if_factor(self) -> Optional[float]:
"""获取影响因子"""
return self._if_factor
@if_factor.setter
def if_factor(self, value: float):
"""设置影响因子"""
self._if_factor = value
@property
def cas_division(self) -> Optional[str]:
"""获取中科院分区"""
return self._cas_division
@cas_division.setter
def cas_division(self, value: str):
"""设置中科院分区"""
self._cas_division = value
@property
def jcr_division(self) -> Optional[str]:
"""获取JCR分区"""
return self._jcr_division
@jcr_division.setter
def jcr_division(self, value: str):
"""设置JCR分区"""
self._jcr_division = value
class DataSource(ABC):
"""数据源基类"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key
self._initialize()
@abstractmethod
def _initialize(self) -> None:
"""初始化数据源"""
pass
@abstractmethod
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
"""搜索论文"""
pass
@abstractmethod
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
"""获取论文详细信息"""
pass
@abstractmethod
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
pass
@abstractmethod
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
pass

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,400 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class CrossrefSource(DataSource):
"""Crossref API实现"""
CONTACT_EMAILS = [
"gpt_abc_academic@163.com",
"gpt_abc_newapi@163.com",
"gpt_abc_academic_pwd@163.com"
]
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.crossref.org"
# 随机选择一个邮箱
contact_email = random.choice(self.CONTACT_EMAILS)
self.headers = {
"Accept": "application/json",
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
}
if self.api_key:
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序字段
sort_order: 排序顺序
start_year: 起始年份
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
# 请求更多的结果以补偿可能被过滤掉的文章
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
params = {
"query": query,
"rows": adjusted_limit,
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
# 添加年份过滤
if start_year:
params["filter"] = f"from-pub-date:{start_year}"
# 添加排序
if sort_by:
params["sort"] = sort_by
if sort_order:
params["order"] = sort_order
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
data = await response.json()
items = data.get("message", {}).get("items", [])
if not items:
print(f"未找到相关论文")
return []
# 过滤掉没有摘要的文章
papers = []
filtered_count = 0
for work in items:
paper = self._parse_work(work)
if paper.abstract and paper.abstract.strip():
papers.append(paper)
if len(papers) >= limit: # 达到原始请求的限制后停止
break
else:
filtered_count += 1
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
print(f"返回 {len(papers)} 篇包含摘要的论文")
return papers
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
) as response:
if response.status != 200:
print(f"获取论文详情失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return None
try:
data = await response.json()
return self._parse_work(data.get("message", {}))
except Exception as e:
print(f"解析论文详情时发生错误: {str(e)}")
return None
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={"select": "reference"}
) as response:
if response.status != 200:
print(f"获取参考文献失败: HTTP {response.status}")
return []
try:
data = await response.json()
# 确保我们正确处理返回的数据结构
if not isinstance(data, dict):
print(f"API返回了意外的数据格式: {type(data)}")
return []
references = data.get("message", {}).get("reference", [])
if not references:
print(f"未找到参考文献")
return []
return [
PaperMetadata(
title=ref.get("article-title", ""),
authors=[ref.get("author", "")],
year=ref.get("year"),
doi=ref.get("DOI"),
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
abstract="",
citations=None,
venue=ref.get("journal-title", ""),
institutions=[]
)
for ref in references
]
except Exception as e:
print(f"解析参考文献数据时发生错误: {str(e)}")
return []
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works",
params={
"filter": f"reference.DOI:{doi}",
"select": "DOI,title,author,published-print,abstract"
}
) as response:
if response.status != 200:
print(f"获取引用信息失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
try:
data = await response.json()
# 检查返回的数据结构
if isinstance(data, dict):
items = data.get("message", {}).get("items", [])
return [self._parse_work(work) for work in items]
else:
print(f"API返回了意外的数据格式: {type(data)}")
return []
except Exception as e:
print(f"解析引用数据时发生错误: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析Crossref返回的数据"""
# 获取摘要 - 处理可能的不同格式
abstract = ""
if isinstance(work.get("abstract"), str):
abstract = work.get("abstract", "")
elif isinstance(work.get("abstract"), dict):
abstract = work.get("abstract", {}).get("value", "")
if not abstract:
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
# 获取机构信息
institutions = []
for author in work.get("author", []):
if "affiliation" in author:
for affiliation in author["affiliation"]:
if "name" in affiliation and affiliation["name"] not in institutions:
institutions.append(affiliation["name"])
# 获取venue信息
venue_name = work.get("container-title", [None])[0]
venue_type = work.get("type", "unknown") # 文献类型
venue_info = {
"publisher": work.get("publisher"),
"issn": work.get("ISSN", []),
"isbn": work.get("ISBN", []),
"issue": work.get("issue"),
"volume": work.get("volume"),
"page": work.get("page")
}
return PaperMetadata(
title=work.get("title", [None])[0] or "",
authors=[
author.get("given", "") + " " + author.get("family", "")
for author in work.get("author", [])
],
institutions=institutions, # 添加机构信息
abstract=abstract,
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
doi=work.get("DOI"),
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
citations=work.get("is-referenced-by-count"),
venue=venue_name,
venue_type=venue_type, # 添加venue类型
venue_name=venue_name, # 添加venue名称
venue_info=venue_info, # 添加venue详细信息
source='crossref' # 添加来源标记
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " ".join([f"author:\"{author}\"" for author in authors])
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
start_year=start_year
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: str = None,
sort_order: str = None
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def example_usage():
"""CrossrefSource使用示例"""
crossref = CrossrefSource(api_key=None)
try:
# 示例1基本搜索使用不同的排序方式
print("\n=== 示例1搜索最新的机器学习论文 ===")
papers = await crossref.search(
query="machine learning",
limit=3,
sort_by="published",
sort_order="desc",
start_year=2023
)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
print(f"venue类型: {paper.venue_type}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
# 示例2按DOI获取论文详情
print("\n=== 示例2获取特定论文详情 ===")
# 使用BERT论文的DOI
doi = "10.18653/v1/N19-1423"
paper = await crossref.get_paper_details(doi)
if paper:
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations}")
# 示例3按作者搜索
print("\n=== 示例3搜索特定作者的论文 ===")
author_papers = await crossref.search_by_authors(
authors=["Yoshua Bengio"],
limit=3,
sort_by="published",
start_year=2020
)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- {i}. {paper.title} ---")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例4按日期范围搜索
print("\n=== 示例4搜索特定日期范围的论文 ===")
from datetime import datetime, timedelta
end_date = datetime.now()
start_date = end_date - timedelta(days=30) # 最近一个月
recent_papers = await crossref.search_by_date_range(
start_date=start_date,
end_date=end_date,
limit=3,
sort_by="published",
sort_order="desc"
)
for i, paper in enumerate(recent_papers, 1):
print(f"\n--- 最近发表的论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 示例5获取论文引用信息
print("\n=== 示例5获取论文引用信息 ===")
if paper: # 使用之前获取的BERT论文
print("\n获取引用该论文的文献:")
citations = await crossref.get_citations(paper.doi)
for i, citing_paper in enumerate(citations[:3], 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {citing_paper.title}")
print(f"作者: {', '.join(citing_paper.authors)}")
print(f"发表年份: {citing_paper.year}")
print("\n获取该论文引用的参考文献:")
references = await crossref.get_references(paper.doi)
for i, ref_paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {ref_paper.title}")
print(f"作者: {', '.join(ref_paper.authors)}")
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
# 示例6展示venue信息的使用
print("\n=== 示例6展示期刊/会议详细信息 ===")
if papers:
paper = papers[0]
print(f"文献类型: {paper.venue_type}")
print(f"发表venue: {paper.venue_name}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

View File

@@ -0,0 +1,449 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class ElsevierSource(DataSource):
"""Elsevier (Scopus) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Elsevier API密钥如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS)
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.elsevier.com/content"
self.headers = {
"X-ELS-APIKey": self.api_key,
"Accept": "application/json",
"Content-Type": "application/json",
# 添加更多必要的头部信息
"X-ELS-Insttoken": "", # 如果有机构令牌
}
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
"""发送HTTP请求
Args:
url: 请求URL
params: 查询参数
Returns:
JSON响应
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
# 添加更详细的错误信息
error_text = await response.text()
print(f"请求失败: {response.status}")
print(f"错误详情: {error_text}")
if response.status == 401:
print(f"使用的API密钥: {self.api_key}")
# 尝试切换到另一个API密钥
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
print(f"尝试切换到新的API密钥: {new_key}")
self.api_key = new_key
self.headers["X-ELS-APIKey"] = new_key
# 重试请求
return await self._make_request(url, params)
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
params = {
"query": query,
"count": min(limit, 100),
"view": "STANDARD",
# 移除dc:description字段因为它在STANDARD视图中不可用
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
}
# 添加年份过滤
if start_year:
params["date"] = f"{start_year}-present"
# 添加排序
if sort_by == "date":
params["sort"] = "-coverDate"
elif sort_by == "cited":
params["sort"] = "-citedby-count"
# 发送搜索请求
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
# 解析搜索结果
entries = response["search-results"].get("entry", [])
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
# 尝试为每篇论文获取摘要
for paper in papers:
if paper.doi:
paper.abstract = await self.fetch_abstract(paper.doi) or ""
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
"""解析Scopus API返回的条目"""
try:
# 获取作者列表
authors = []
creator = entry.get("dc:creator")
if creator:
authors = [creator]
# 获取发表年份
year = None
if "prism:coverDate" in entry:
try:
year = int(entry["prism:coverDate"][:4])
except:
pass
# 简化venue信息
venue_info = {
'source_id': entry.get("source-id"),
'issn': entry.get("prism:issn")
}
return PaperMetadata(
title=entry.get("dc:title", ""),
authors=authors,
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
year=year,
doi=entry.get("prism:doi"),
url=entry.get("prism:url"),
citations=int(entry.get("citedby-count", 0)),
venue=entry.get("prism:publicationName"),
institutions=[], # 移除机构信息
venue_type="",
venue_name=entry.get("prism:publicationName"),
venue_info=venue_info
)
except Exception as e:
print(f"解析条目时发生错误: {str(e)}")
return None
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
try:
params = {
"query": f"REF({doi})",
"count": min(limit, 100),
"view": "STANDARD"
}
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
entries = response["search-results"].get("entry", [])
return [self._parse_entry(entry) for entry in entries]
except Exception as e:
print(f"获取引用文献时发生错误: {str(e)}")
return []
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
try:
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}/references",
params={"view": "STANDARD"}
)
if not response or "references" not in response:
return []
references = response["references"].get("reference", [])
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
return papers
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
"""解析参考文献数据"""
try:
authors = []
if "author-list" in ref:
author_list = ref["author-list"].get("author", [])
if isinstance(author_list, list):
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
for author in author_list]
else:
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
year = None
if "prism:coverDate" in ref:
try:
year = int(ref["prism:coverDate"][:4])
except:
pass
return PaperMetadata(
title=ref.get("ce:title", ""),
authors=authors,
abstract="", # 参考文献通常不包含摘要
year=year,
doi=ref.get("prism:doi"),
url=None,
citations=None,
venue=ref.get("prism:publicationName"),
institutions=[],
venue_type="unknown",
venue_name=ref.get("prism:publicationName"),
venue_info={}
)
except Exception as e:
print(f"解析参考文献时发生错误: {str(e)}")
return None
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"AUTHOR-NAME({author})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_affiliation(
self,
affiliation: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按机构搜索论文"""
query = f"AF-ID({affiliation})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_venue(
self,
venue: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊/会议搜索论文"""
query = f"SRCTITLE({venue})"
return await self.search(query, limit=limit, start_year=start_year)
async def test_api_access(self):
"""测试API访问权限"""
print(f"\n测试API密钥: {self.api_key}")
# 测试1: 基础搜索
basic_params = {
"query": "test",
"count": 1,
"view": "STANDARD"
}
print("\n1. 测试基础搜索...")
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=basic_params
)
if response:
print("基础搜索成功")
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
# 测试2: 测试单篇文章访问
print("\n2. 测试文章详情访问...")
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
response = await self._make_request(
f"{self.base_url}/abstract/doi/{test_doi}",
params={"view": "STANDARD"} # 改为STANDARD视图
)
if response:
print("文章详情访问成功")
else:
print("文章详情访问失败")
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详细信息
注意当前API权限不支持获取详细信息返回None
Args:
paper_id: 论文ID
Returns:
None因为当前API权限不支持此功能
"""
return None
async def fetch_abstract(self, doi: str) -> Optional[str]:
"""获取论文摘要
使用Scopus Abstract API获取论文摘要
Args:
doi: 论文的DOI
Returns:
摘要文本如果获取失败则返回None
"""
try:
# 使用Abstract API而不是Search API
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}",
params={
"view": "FULL" # 使用FULL视图
}
)
if response and "abstracts-retrieval-response" in response:
# 从coredata中获取摘要
coredata = response["abstracts-retrieval-response"].get("coredata", {})
return coredata.get("dc:description", "")
return None
except Exception as e:
print(f"获取摘要时发生错误: {str(e)}")
return None
async def example_usage():
"""ElsevierSource使用示例"""
elsevier = ElsevierSource()
try:
# 首先测试API访问权限
print("\n=== 测试API访问权限 ===")
await elsevier.test_api_access()
# 示例1基本搜索
print("\n=== 示例1搜索机器学习相关论文 ===")
papers = await elsevier.search("machine learning", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
print("期刊信息:")
for key, value in paper.venue_info.items():
if value: # 只打印非空值
print(f" - {key}: {value}")
# 示例2获取引用信息
if papers and papers[0].doi:
print("\n=== 示例2获取引用该论文的文献 ===")
citations = await elsevier.get_citations(papers[0].doi, limit=3)
for i, paper in enumerate(citations, 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例3获取参考文献
if papers and papers[0].doi:
print("\n=== 示例3获取论文的参考文献 ===")
references = await elsevier.get_references(papers[0].doi)
for i, paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"期刊/会议: {paper.venue}")
# 示例4按作者搜索
print("\n=== 示例4按作者搜索 ===")
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例5按机构搜索
print("\n=== 示例5按机构搜索 ===")
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
for i, paper in enumerate(affiliation_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例6获取论文摘要
print("\n=== 示例6获取论文摘要 ===")
test_doi = "10.1016/j.artint.2021.103535"
abstract = await elsevier.fetch_abstract(test_doi)
if abstract:
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
else:
print("无法获取摘要")
# 在搜索结果中显示摘要
print("\n=== 示例7搜索结果中的摘要 ===")
papers = await elsevier.search("machine learning", limit=1)
for paper in papers:
print(f"标题: {paper.title}")
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,698 @@
import aiohttp
import asyncio
import base64
import json
import random
from datetime import datetime
from typing import List, Dict, Optional, Union, Any
class GitHubSource:
"""GitHub API实现"""
# 默认API密钥列表 - 可以放置多个GitHub令牌
API_KEYS = [
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
]
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
"""初始化GitHub API客户端
Args:
api_key: GitHub个人访问令牌或令牌列表
"""
if api_key is None:
self.api_keys = self.API_KEYS
elif isinstance(api_key, str):
self.api_keys = [api_key]
else:
self.api_keys = api_key
self._initialize()
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.github.com"
self.headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
"User-Agent": "GitHub-API-Python-Client"
}
# 如果有可用的API密钥随机选择一个
if self.api_keys:
selected_key = random.choice(self.api_keys)
self.headers["Authorization"] = f"Bearer {selected_key}"
print(f"已随机选择API密钥进行认证")
else:
print("警告: 未提供API密钥将受到GitHub API请求限制")
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
"""发送API请求
Args:
method: HTTP方法 (GET, POST, PUT, DELETE等)
endpoint: API端点
params: URL参数
data: 请求体数据
Returns:
解析后的响应JSON
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
url = f"{self.base_url}{endpoint}"
# 为调试目的打印请求信息
print(f"请求: {method} {url}")
if params:
print(f"参数: {params}")
# 发送请求
request_kwargs = {}
if params:
request_kwargs["params"] = params
if data:
request_kwargs["json"] = data
async with session.request(method, url, **request_kwargs) as response:
response_text = await response.text()
# 检查HTTP状态码
if response.status >= 400:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {response_text}")
return None
# 解析JSON响应
try:
return json.loads(response_text)
except json.JSONDecodeError:
print(f"JSON解析错误: {response_text}")
return None
# ===== 用户相关方法 =====
async def get_user(self, username: Optional[str] = None) -> Dict:
"""获取用户信息
Args:
username: 指定用户名,不指定则获取当前授权用户
Returns:
用户信息字典
"""
endpoint = "/user" if username is None else f"/users/{username}"
return await self._request("GET", endpoint)
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户的仓库列表
Args:
username: 指定用户名,不指定则获取当前授权用户
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
params = {
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_user_starred(self, username: Optional[str] = None,
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户星标的仓库
Args:
username: 指定用户名,不指定则获取当前授权用户
per_page: 每页结果数量
page: 页码
Returns:
星标仓库列表
"""
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 仓库相关方法 =====
async def get_repo(self, owner: str, repo: str) -> Dict:
"""获取仓库信息
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
仓库信息
"""
endpoint = f"/repos/{owner}/{repo}"
return await self._request("GET", endpoint)
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的分支列表
Args:
owner: 仓库所有者
repo: 仓库名
per_page: 每页结果数量
page: 页码
Returns:
分支列表
"""
endpoint = f"/repos/{owner}/{repo}/branches"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的提交历史
Args:
owner: 仓库所有者
repo: 仓库名
sha: 特定提交SHA或分支名
path: 文件路径筛选
per_page: 每页结果数量
page: 页码
Returns:
提交列表
"""
endpoint = f"/repos/{owner}/{repo}/commits"
params = {
"per_page": per_page,
"page": page
}
if sha:
params["sha"] = sha
if path:
params["path"] = path
return await self._request("GET", endpoint, params=params)
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
"""获取特定提交的详情
Args:
owner: 仓库所有者
repo: 仓库名
commit_sha: 提交SHA
Returns:
提交详情
"""
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
return await self._request("GET", endpoint)
# ===== 内容相关方法 =====
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
"""获取文件内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 文件路径
ref: 分支名、标签名或提交SHA
Returns:
文件内容信息
"""
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
response = await self._request("GET", endpoint, params=params)
if response and isinstance(response, dict) and "content" in response:
try:
# 解码Base64编码的文件内容
content = base64.b64decode(response["content"].encode()).decode()
response["decoded_content"] = content
except Exception as e:
print(f"解码文件内容时出错: {str(e)}")
return response
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
"""获取目录内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 目录路径
ref: 分支名、标签名或提交SHA
Returns:
目录内容列表
"""
# 注意此方法与get_file_content使用相同的端点但对于目录会返回列表
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
return await self._request("GET", endpoint, params=params)
# ===== Issues相关方法 =====
async def get_issues(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Issues列表
Args:
owner: 仓库所有者
repo: 仓库名
state: Issue状态 (open, closed, all)
sort: 排序方式 (created, updated, comments)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Issues列表
"""
endpoint = f"/repos/{owner}/{repo}/issues"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
"""获取特定Issue的详情
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
Issue详情
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
return await self._request("GET", endpoint)
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
"""获取Issue的评论
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
评论列表
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
return await self._request("GET", endpoint)
# ===== Pull Requests相关方法 =====
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Pull Request列表
Args:
owner: 仓库所有者
repo: 仓库名
state: PR状态 (open, closed, all)
sort: 排序方式 (created, updated, popularity, long-running)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Pull Request列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
"""获取特定Pull Request的详情
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
Pull Request详情
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
return await self._request("GET", endpoint)
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
"""获取Pull Request中修改的文件
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
修改文件列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
return await self._request("GET", endpoint)
# ===== 搜索相关方法 =====
async def search_repositories(self, query: str, sort: str = "stars",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索仓库
Args:
query: 搜索关键词
sort: 排序方式 (stars, forks, updated)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/repositories"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_code(self, query: str, sort: str = "indexed",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索代码
Args:
query: 搜索关键词
sort: 排序方式 (indexed)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/code"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_issues(self, query: str, sort: str = "created",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索Issues和Pull Requests
Args:
query: 搜索关键词
sort: 排序方式 (created, updated, comments)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/issues"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_users(self, query: str, sort: str = "followers",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索用户
Args:
query: 搜索关键词
sort: 排序方式 (followers, repositories, joined)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/users"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 组织相关方法 =====
async def get_organization(self, org: str) -> Dict:
"""获取组织信息
Args:
org: 组织名称
Returns:
组织信息
"""
endpoint = f"/orgs/{org}"
return await self._request("GET", endpoint)
async def get_organization_repos(self, org: str, type: str = "all",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织的仓库列表
Args:
org: 组织名称
type: 仓库类型 (all, public, private, forks, sources, member, internal)
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = f"/orgs/{org}/repos"
params = {
"type": type,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织成员列表
Args:
org: 组织名称
per_page: 每页结果数量
page: 页码
Returns:
成员列表
"""
endpoint = f"/orgs/{org}/members"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 更复杂的操作 =====
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
"""获取仓库使用的编程语言及其比例
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
语言使用情况
"""
endpoint = f"/repos/{owner}/{repo}/languages"
return await self._request("GET", endpoint)
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的贡献者统计
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
贡献者统计信息
"""
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
return await self._request("GET", endpoint)
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的提交活动
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
提交活动统计
"""
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
return await self._request("GET", endpoint)
async def example_usage():
"""GitHubSource使用示例"""
# 创建客户端实例可选传入API令牌
# github = GitHubSource(api_key="your_github_token")
github = GitHubSource()
try:
# 示例1搜索热门Python仓库
print("\n=== 示例1搜索热门Python仓库 ===")
repos = await github.search_repositories(
query="language:python stars:>1000",
sort="stars",
order="desc",
per_page=5
)
if repos and "items" in repos:
for i, repo in enumerate(repos["items"], 1):
print(f"\n--- 仓库 {i} ---")
print(f"名称: {repo['full_name']}")
print(f"描述: {repo['description']}")
print(f"星标数: {repo['stargazers_count']}")
print(f"Fork数: {repo['forks_count']}")
print(f"最近更新: {repo['updated_at']}")
print(f"URL: {repo['html_url']}")
# 示例2获取特定仓库的详情
print("\n=== 示例2获取特定仓库的详情 ===")
repo_details = await github.get_repo("microsoft", "vscode")
if repo_details:
print(f"名称: {repo_details['full_name']}")
print(f"描述: {repo_details['description']}")
print(f"星标数: {repo_details['stargazers_count']}")
print(f"Fork数: {repo_details['forks_count']}")
print(f"默认分支: {repo_details['default_branch']}")
print(f"开源许可: {repo_details.get('license', {}).get('name', '')}")
print(f"语言: {repo_details['language']}")
print(f"Open Issues数: {repo_details['open_issues_count']}")
# 示例3获取仓库的提交历史
print("\n=== 示例3获取仓库的最近提交 ===")
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
if commits:
for i, commit in enumerate(commits, 1):
print(f"\n--- 提交 {i} ---")
print(f"SHA: {commit['sha'][:7]}")
print(f"作者: {commit['commit']['author']['name']}")
print(f"日期: {commit['commit']['author']['date']}")
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
# 示例4搜索代码
print("\n=== 示例4搜索代码 ===")
code_results = await github.search_code(
query="filename:README.md language:markdown pytorch in:file",
per_page=3
)
if code_results and "items" in code_results:
print(f"共找到: {code_results['total_count']} 个结果")
for i, item in enumerate(code_results["items"], 1):
print(f"\n--- 代码 {i} ---")
print(f"仓库: {item['repository']['full_name']}")
print(f"文件: {item['path']}")
print(f"URL: {item['html_url']}")
# 示例5获取文件内容
print("\n=== 示例5获取文件内容 ===")
file_content = await github.get_file_content("python", "cpython", "README.rst")
if file_content and "decoded_content" in file_content:
content = file_content["decoded_content"]
print(f"文件名: {file_content['name']}")
print(f"大小: {file_content['size']} 字节")
print(f"内容预览: {content[:200]}...")
# 示例6获取仓库使用的编程语言
print("\n=== 示例6获取仓库使用的编程语言 ===")
languages = await github.get_repository_languages("facebook", "react")
if languages:
print(f"React仓库使用的编程语言:")
for lang, bytes_of_code in languages.items():
print(f"- {lang}: {bytes_of_code} 字节")
# 示例7获取组织信息
print("\n=== 示例7获取组织信息 ===")
org_info = await github.get_organization("google")
if org_info:
print(f"名称: {org_info['name']}")
print(f"描述: {org_info.get('description', '')}")
print(f"位置: {org_info.get('location', '未指定')}")
print(f"公共仓库数: {org_info['public_repos']}")
print(f"成员数: {org_info.get('public_members', 0)}")
print(f"URL: {org_info['html_url']}")
# 示例8获取用户信息
print("\n=== 示例8获取用户信息 ===")
user_info = await github.get_user("torvalds")
if user_info:
print(f"名称: {user_info['name']}")
print(f"公司: {user_info.get('company', '')}")
print(f"博客: {user_info.get('blog', '')}")
print(f"位置: {user_info.get('location', '未指定')}")
print(f"公共仓库数: {user_info['public_repos']}")
print(f"关注者数: {user_info['followers']}")
print(f"URL: {user_info['html_url']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

View File

@@ -0,0 +1,142 @@
import json
import os
from typing import Dict, Optional
class JournalMetrics:
"""期刊指标管理类"""
def __init__(self):
self.journal_data: Dict = {} # 期刊名称到指标的映射
self.issn_map: Dict = {} # ISSN到指标的映射
self.name_map: Dict = {} # 标准化名称到指标的映射
self._load_journal_data()
def _normalize_journal_name(self, name: str) -> str:
"""标准化期刊名称
Args:
name: 原始期刊名称
Returns:
标准化后的期刊名称
"""
if not name:
return ""
# 转换为小写
name = name.lower()
# 移除常见的前缀和后缀
prefixes = ['the ', 'proceedings of ', 'journal of ']
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
for prefix in prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
for suffix in suffixes:
if name.endswith(suffix):
name = name[:-len(suffix)]
# 移除特殊字符,保留字母、数字和空格
name = ''.join(c for c in name if c.isalnum() or c.isspace())
# 移除多余的空格
name = ' '.join(name.split())
return name
def _convert_if_value(self, if_str: str) -> Optional[float]:
"""转换IF值为float处理特殊情况"""
try:
if if_str.startswith('<'):
# 对于<0.1这样的值返回0.1
return float(if_str.strip('<'))
return float(if_str)
except (ValueError, AttributeError):
return None
def _load_journal_data(self):
"""加载期刊数据"""
try:
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 建立期刊名称到指标的映射
for journal in data:
# 准备指标数据
metrics = {
'if_factor': self._convert_if_value(journal.get('IF')),
'jcr_division': journal.get('Q'),
'cas_division': journal.get('B')
}
# 存储期刊名称映射(使用标准化名称)
if journal.get('journal'):
normalized_name = self._normalize_journal_name(journal['journal'])
self.journal_data[normalized_name] = metrics
self.name_map[normalized_name] = metrics
# 存储期刊缩写映射
if journal.get('jabb'):
normalized_abbr = self._normalize_journal_name(journal['jabb'])
self.journal_data[normalized_abbr] = metrics
self.name_map[normalized_abbr] = metrics
# 存储ISSN映射
if journal.get('issn'):
self.issn_map[journal['issn']] = metrics
if journal.get('eissn'):
self.issn_map[journal['eissn']] = metrics
except Exception as e:
print(f"加载期刊数据时出错: {str(e)}")
self.journal_data = {}
self.issn_map = {}
self.name_map = {}
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
"""获取期刊指标
Args:
venue_name: 期刊名称
venue_info: 期刊详细信息
Returns:
包含期刊指标的字典
"""
try:
metrics = {}
# 1. 首先尝试通过ISSN匹配
if venue_info and 'issn' in venue_info:
issn_value = venue_info['issn']
# 处理ISSN可能是列表的情况
if isinstance(issn_value, list):
# 尝试每个ISSN
for issn in issn_value:
metrics = self.issn_map.get(issn, {})
if metrics: # 如果找到匹配的指标,就停止搜索
break
else: # ISSN是字符串的情况
metrics = self.issn_map.get(issn_value, {})
# 2. 如果ISSN匹配失败尝试通过期刊名称匹配
if not metrics and venue_name:
# 标准化期刊名称
normalized_name = self._normalize_journal_name(venue_name)
metrics = self.name_map.get(normalized_name, {})
# 如果完全匹配失败,尝试部分匹配
# if not metrics:
# for db_name, db_metrics in self.name_map.items():
# if normalized_name in db_name:
# metrics = db_metrics
# break
return metrics
except Exception as e:
print(f"获取期刊指标时出错: {str(e)}")
return {}

View File

@@ -0,0 +1,163 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.parse import quote
class OpenAlexSource(DataSource):
"""OpenAlex API实现"""
def _initialize(self) -> None:
self.base_url = "https://api.openalex.org"
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"title.search:{query}",
"per-page": limit
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
try:
response.raise_for_status()
data = await response.json()
results = data.get("results", [])
return [self._parse_work(work) for work in results]
except Exception as e:
print(f"搜索出错: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析OpenAlex返回的数据"""
# 获取作者信息
raw_author_names = [
authorship.get("raw_author_name", "")
for authorship in work.get("authorships", [])
if authorship
]
# 处理作者名字格式
authors = [
self._reformat_name(author)
for author in raw_author_names
]
# 获取机构信息
institutions = [
inst.get("display_name", "")
for authorship in work.get("authorships", [])
for inst in authorship.get("institutions", [])
if inst
]
# 获取主要发表位置信息
primary_location = work.get("primary_location") or {}
source = primary_location.get("source") or {}
venue = source.get("display_name")
# 获取发表日期
year = work.get("publication_year")
return PaperMetadata(
title=work.get("title", ""),
authors=authors,
institutions=institutions,
abstract=work.get("abstract", ""),
year=year,
doi=work.get("doi"),
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
citations=work.get("cited_by_count"),
venue=venue
)
def _reformat_name(self, name: str) -> str:
"""重新格式化作者名字"""
if "," not in name:
return name
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
return f"{given_names} {family}"
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
params=params
) as response:
data = await response.json()
return self._parse_work(data)
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"cites:doi:{doi}",
"per-page": 100
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def example_usage():
"""OpenAlexSource使用示例"""
# 初始化OpenAlexSource
openalex = OpenAlexSource()
try:
print("正在搜索论文...")
# 搜索与"artificial intelligence"相关的论文限制返回5篇
papers = await openalex.search(query="artificial intelligence", limit=5)
if not papers:
print("未获取到任何论文信息")
return
print(f"找到 {len(papers)} 篇论文")
# 打印搜索结果
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"发表年份: {paper.year if paper.year else '未知'}")
print(f"DOI: {paper.doi if paper.doi else '未知'}")
print(f"URL: {paper.url if paper.url else '未知'}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
# 如果直接运行此文件,执行示例代码
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

View File

@@ -0,0 +1,458 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import xml.etree.ElementTree as ET
from urllib.parse import quote
import json
from tqdm import tqdm
import random
class PubMedSource(DataSource):
"""PubMed API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: PubMed API密钥如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
self.headers = {
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
"Accept": "application/json"
}
async def _make_request(self, url: str) -> Optional[str]:
"""发送HTTP请求
Args:
url: 请求URL
Returns:
响应内容
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url) as response:
if response.status == 200:
return await response.text()
else:
print(f"请求失败: {response.status}")
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 添加年份过滤
if start_year:
query = f"{query} AND {start_year}:3000[dp]"
# 构建搜索URL
search_url = (
f"{self.base_url}/esearch.fcgi?"
f"db=pubmed&term={quote(query)}&retmax={limit}"
f"&usehistory=y&api_key={self.api_key}"
)
if sort_by == "date":
search_url += "&sort=date"
# 获取搜索结果
response = await self._make_request(search_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
id_list = root.findall(".//Id")
pmids = [id_elem.text for id_elem in id_list]
if not pmids:
return []
# 批量获取论文详情
papers = []
batch_size = 50
for i in range(0, len(pmids), batch_size):
batch = pmids[i:i + batch_size]
batch_papers = await self._fetch_papers_batch(batch)
papers.extend(batch_papers)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
"""批量获取论文详情
Args:
pmids: PubMed ID列表
Returns:
论文详情列表
"""
try:
# 构建批量获取URL
fetch_url = (
f"{self.base_url}/efetch.fcgi?"
f"db=pubmed&id={','.join(pmids)}"
f"&retmode=xml&api_key={self.api_key}"
)
response = await self._make_request(fetch_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
articles = root.findall(".//PubmedArticle")
return [self._parse_article(article) for article in articles]
except Exception as e:
print(f"获取论文批次时发生错误: {str(e)}")
return []
def _parse_article(self, article: ET.Element) -> PaperMetadata:
"""解析PubMed文章XML
Args:
article: XML元素
Returns:
解析后的论文数据
"""
try:
# 提取基本信息
pmid = article.find(".//PMID").text
article_meta = article.find(".//Article")
# 获取标题
title = article_meta.find(".//ArticleTitle")
title = title.text if title is not None else ""
# 获取作者列表
authors = []
author_list = article_meta.findall(".//Author")
for author in author_list:
last_name = author.find("LastName")
fore_name = author.find("ForeName")
if last_name is not None and fore_name is not None:
authors.append(f"{fore_name.text} {last_name.text}")
elif last_name is not None:
authors.append(last_name.text)
# 获取摘要
abstract = article_meta.find(".//Abstract/AbstractText")
abstract = abstract.text if abstract is not None else ""
# 获取发表年份
pub_date = article_meta.find(".//PubDate/Year")
year = int(pub_date.text) if pub_date is not None else None
# 获取DOI
doi = article.find(".//ELocationID[@EIdType='doi']")
doi = doi.text if doi is not None else None
# 获取期刊信息
journal = article_meta.find(".//Journal")
if journal is not None:
journal_title = journal.find(".//Title")
venue = journal_title.text if journal_title is not None else None
# 获取期刊详细信息
venue_info = {
'issn': journal.findtext(".//ISSN"),
'volume': journal.findtext(".//Volume"),
'issue': journal.findtext(".//Issue"),
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
}
else:
venue = None
venue_info = {}
# 获取机构信息
institutions = []
affiliations = article_meta.findall(".//Affiliation")
for affiliation in affiliations:
if affiliation is not None and affiliation.text:
institutions.append(affiliation.text)
return PaperMetadata(
title=title,
authors=authors,
abstract=abstract,
year=year,
doi=doi,
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
citations=None, # PubMed API不直接提供引用数据
venue=venue,
institutions=institutions,
venue_type="journal",
venue_name=venue,
venue_info=venue_info,
source='pubmed' # 添加来源标记
)
except Exception as e:
print(f"解析文章时发生错误: {str(e)}")
return None
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
"""获取指定PMID的论文详情"""
papers = await self._fetch_papers_batch([pmid])
return papers[0] if papers else None
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
"""获取相关论文
使用PubMed的相关文章功能
Args:
pmid: PubMed ID
limit: 返回结果数量限制
Returns:
相关论文列表
"""
try:
# 构建相关文章URL
link_url = (
f"{self.base_url}/elink.fcgi?"
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
)
response = await self._make_request(link_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
related_ids = root.findall(".//Link/Id")
pmids = [id_elem.text for id_elem in related_ids][:limit]
if not pmids:
return []
# 获取相关论文详情
return await self._fetch_papers_batch(pmids)
except Exception as e:
print(f"获取相关论文时发生错误: {str(e)}")
return []
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"{author}[Author]"
if start_year:
query += f" AND {start_year}:3000[dp]"
return await self.search(query, limit=limit)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"{journal}[Journal]"
if start_year:
query += f" AND {start_year}:3000[dp]"
return await self.search(query, limit=limit)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文
Args:
days: 最近几天的论文
limit: 返回结果数量限制
Returns:
最新论文列表
"""
query = f"last {days} days[dp]"
return await self.search(query, limit=limit, sort_by="date")
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献
注意PubMed API本身不提供引用数据此方法将返回空列表
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
Args:
paper_id: PubMed ID
Returns:
空列表因为PubMed不提供引用数据
"""
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献
从PubMed文章的参考文献列表获取引用的文献
Args:
paper_id: PubMed ID
Returns:
引用的文献列表
"""
try:
# 构建获取参考文献的URL
refs_url = (
f"{self.base_url}/elink.fcgi?"
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
f"&api_key={self.api_key}"
)
response = await self._make_request(refs_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
ref_ids = root.findall(".//Link/Id")
pmids = [id_elem.text for id_elem in ref_ids]
if not pmids:
return []
# 获取参考文献详情
return await self._fetch_papers_batch(pmids)
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
async def example_usage():
"""PubMedSource使用示例"""
pubmed = PubMedSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索COVID-19相关论文 ===")
papers = await pubmed.search("COVID-19", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
# 示例2获取论文详情
if papers:
print("\n=== 示例2获取论文详情 ===")
paper_id = papers[0].url.split("/")[-2]
paper = await pubmed.get_paper_details(paper_id)
if paper:
print(f"标题: {paper.title}")
print(f"期刊: {paper.venue}")
print(f"机构: {', '.join(paper.institutions)}")
# 示例3获取相关论文
if papers:
print("\n=== 示例3获取相关论文 ===")
related = await pubmed.get_related_papers(paper_id, limit=3)
for i, paper in enumerate(related, 1):
print(f"\n--- 相关论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
# 示例4按作者搜索
print("\n=== 示例4按作者搜索 ===")
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
# 示例5按期刊搜索
print("\n=== 示例5按期刊搜索 ===")
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
for i, paper in enumerate(journal_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
# 示例6获取最新论文
print("\n=== 示例6获取最新论文 ===")
latest = await pubmed.get_latest_papers(days=7, limit=3)
for i, paper in enumerate(latest, 1):
print(f"\n--- 最新论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表日期: {paper.venue_info.get('pub_date')}")
# 示例7获取论文的参考文献
if papers:
print("\n=== 示例7获取论文的参考文献 ===")
paper_id = papers[0].url.split("/")[-2]
references = await pubmed.get_references(paper_id)
for i, paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 示例8尝试获取引用信息将返回空列表
if papers:
print("\n=== 示例8获取论文的引用信息 ===")
paper_id = papers[0].url.split("/")[-2]
citations = await pubmed.get_citations(paper_id)
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(example_usage())

View File

@@ -0,0 +1,326 @@
from pathlib import Path
import requests
from bs4 import BeautifulSoup
import time
from loguru import logger
import PyPDF2
import io
class SciHub:
# 更新的镜像列表,包含更多可用的镜像
MIRRORS = [
'https://sci-hub.se/',
'https://sci-hub.st/',
'https://sci-hub.ru/',
'https://sci-hub.wf/',
'https://sci-hub.ee/',
'https://sci-hub.ren/',
'https://sci-hub.tf/',
'https://sci-hub.si/',
'https://sci-hub.do/',
'https://sci-hub.hkvisa.net/',
'https://sci-hub.mksa.top/',
'https://sci-hub.shop/',
'https://sci-hub.yncjkj.com/',
'https://sci-hub.41610.org/',
'https://sci-hub.automic.us/',
'https://sci-hub.et-fine.com/',
'https://sci-hub.pooh.mu/',
'https://sci-hub.bban.top/',
'https://sci-hub.usualwant.com/',
'https://sci-hub.unblockit.kim/'
]
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
self.timeout = timeout
self.path = path
self.doi = str(doi)
self.use_proxy = use_proxy
self.headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
}
self.payload = {
'sci-hub-plugin-check': '',
'request': self.doi
}
self.url = url if url else self.MIRRORS[0]
self.proxies = {
"http": "socks5h://localhost:10880",
"https": "socks5h://localhost:10880",
} if use_proxy else None
def _test_proxy_connection(self):
"""测试代理连接是否可用"""
if not self.use_proxy:
return True
try:
# 测试代理连接
test_response = requests.get(
'https://httpbin.org/ip',
proxies=self.proxies,
timeout=10
)
if test_response.status_code == 200:
logger.info("代理连接测试成功")
return True
except Exception as e:
logger.warning(f"代理连接测试失败: {str(e)}")
return False
return False
def _check_pdf_validity(self, content):
"""检查PDF文件是否有效"""
try:
# 使用PyPDF2检查PDF是否可以正常打开和读取
pdf = PyPDF2.PdfReader(io.BytesIO(content))
if len(pdf.pages) > 0:
return True
return False
except Exception as e:
logger.error(f"PDF文件无效: {str(e)}")
return False
def _send_request(self):
"""发送请求到Sci-Hub镜像站点"""
# 首先测试代理连接
if self.use_proxy and not self._test_proxy_connection():
logger.warning("代理连接不可用,切换到直连模式")
self.use_proxy = False
self.proxies = None
last_exception = None
working_mirrors = []
# 先测试哪些镜像可用
logger.info("正在测试镜像站点可用性...")
for mirror in self.MIRRORS:
try:
test_response = requests.get(
mirror,
headers=self.headers,
proxies=self.proxies,
timeout=10
)
if test_response.status_code == 200:
working_mirrors.append(mirror)
logger.info(f"镜像 {mirror} 可用")
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
break
except Exception as e:
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
continue
if not working_mirrors:
raise Exception("没有找到可用的镜像站点")
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
# 使用可用的镜像进行下载
for mirror in working_mirrors:
try:
res = requests.post(
mirror,
headers=self.headers,
data=self.payload,
proxies=self.proxies,
timeout=self.timeout
)
if res.ok:
logger.info(f"成功使用镜像站点: {mirror}")
self.url = mirror # 更新当前使用的镜像
time.sleep(1) # 降低等待时间以提高效率
return res
except Exception as e:
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
last_exception = e
continue
if last_exception:
raise last_exception
raise Exception("所有可用镜像站点均无法完成下载")
def _extract_url(self, response):
"""从响应中提取PDF下载链接"""
soup = BeautifulSoup(response.content, 'html.parser')
try:
# 尝试多种方式提取PDF链接
pdf_element = soup.find(id='pdf')
if pdf_element:
content_url = pdf_element.get('src')
else:
# 尝试其他可能的选择器
pdf_element = soup.find('iframe')
if pdf_element:
content_url = pdf_element.get('src')
else:
# 查找直接的PDF链接
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
if pdf_links:
content_url = pdf_links[0].get('href')
else:
raise AttributeError("未找到PDF链接")
if content_url:
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
raise AttributeError("找到的链接不是PDF文件")
except AttributeError:
logger.error(f"未找到论文 {self.doi}")
return None
current_mirror = self.url.rstrip('/')
if content_url.startswith('/'):
return current_mirror + content_url
elif content_url.startswith('http'):
return content_url
else:
return 'https:/' + content_url
def _download_pdf(self, pdf_url):
"""下载PDF文件并验证其完整性"""
try:
# 尝试不同的下载方式
download_methods = [
# 方法1直接下载
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
# 方法2添加 Referer 头
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
headers={**self.headers, 'Referer': self.url}),
# 方法3使用原始域名作为 Referer
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
]
for i, download_method in enumerate(download_methods):
try:
logger.info(f"尝试下载方式 {i+1}/3...")
response = download_method()
if response.status_code == 200:
content = response.content
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
logger.info(f"PDF下载成功文件大小: {len(content)} bytes")
return content
else:
logger.warning("下载的文件可能不是有效的PDF")
elif response.status_code == 403:
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
continue
else:
logger.warning(f"下载失败,状态码: {response.status_code}")
continue
except Exception as e:
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
continue
# 如果所有方法都失败尝试构造替代URL
try:
logger.info("尝试使用替代镜像下载...")
# 从原始URL提取关键信息
if '/downloads/' in pdf_url:
file_part = pdf_url.split('/downloads/')[-1]
alternative_mirrors = [
f"https://sci-hub.se/downloads/{file_part}",
f"https://sci-hub.st/downloads/{file_part}",
f"https://sci-hub.ru/downloads/{file_part}",
f"https://sci-hub.wf/downloads/{file_part}",
f"https://sci-hub.ee/downloads/{file_part}",
f"https://sci-hub.ren/downloads/{file_part}",
f"https://sci-hub.tf/downloads/{file_part}"
]
for alt_url in alternative_mirrors:
try:
response = requests.get(
alt_url,
proxies=self.proxies,
timeout=self.timeout,
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
)
if response.status_code == 200:
content = response.content
if len(content) > 1000 and self._check_pdf_validity(content):
logger.info(f"使用替代镜像成功下载: {alt_url}")
return content
except Exception as e:
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
continue
except Exception as e:
logger.error(f"所有下载方式都失败: {str(e)}")
return None
except Exception as e:
logger.error(f"下载PDF文件失败: {str(e)}")
return None
def fetch(self):
"""获取论文PDF包含重试和验证机制"""
for attempt in range(2): # 最多重试3次
try:
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
# 获取PDF下载链接
response = self._send_request()
pdf_url = self._extract_url(response)
if pdf_url is None:
logger.warning(f"{attempt + 1} 次尝试未找到PDF下载链接")
continue
logger.info(f"找到PDF下载链接: {pdf_url}")
# 下载并验证PDF
pdf_content = self._download_pdf(pdf_url)
if pdf_content is None:
logger.warning(f"{attempt + 1} 次尝试PDF下载失败")
continue
# 保存PDF文件
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
pdf_path = self.path.joinpath(pdf_name)
pdf_path.write_bytes(pdf_content)
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
return str(pdf_path)
except Exception as e:
logger.error(f"{attempt + 1} 次尝试失败: {str(e)}")
if attempt < 2: # 不是最后一次尝试
wait_time = (attempt + 1) * 3 # 递增等待时间
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
continue
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
# Usage Example
if __name__ == '__main__':
# 创建一个用于保存PDF的目录
save_path = Path('./downloaded_papers')
save_path.mkdir(exist_ok=True)
# DOI示例
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
try:
# 初始化SciHub下载器先尝试使用代理
logger.info("尝试使用代理模式...")
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
# 开始下载
result = downloader.fetch()
print(f"论文已保存到: {result}")
except Exception as e:
print(f"使用代理模式失败: {str(e)}")
try:
# 如果代理模式失败,尝试直连模式
logger.info("尝试直连模式...")
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
result = downloader.fetch()
print(f"论文已保存到: {result}")
except Exception as e2:
print(f"直连模式也失败: {str(e2)}")
print("建议检查网络连接或尝试其他DOI")

View File

@@ -0,0 +1,400 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import random
from .base_source import DataSource, PaperMetadata
from tqdm import tqdm
class ScopusSource(DataSource):
"""Scopus API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Scopus API密钥如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS)
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.elsevier.com/content"
self.headers = {
"X-ELS-APIKey": self.api_key,
"Accept": "application/json"
}
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
"""发送HTTP请求
Args:
url: 请求URL
params: 查询参数
Returns:
响应JSON数据
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
print(f"请求失败: {response.status}")
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
"""解析Scopus API返回的数据
Args:
data: Scopus API返回的论文数据
Returns:
解析后的论文元数据
"""
try:
# 提取基本信息
title = data.get("dc:title", "")
# 提取作者信息
authors = []
if "author" in data:
if isinstance(data["author"], list):
for author in data["author"]:
if "given-name" in author and "surname" in author:
authors.append(f"{author['given-name']} {author['surname']}")
elif "indexed-name" in author:
authors.append(author["indexed-name"])
elif isinstance(data["author"], dict):
if "given-name" in data["author"] and "surname" in data["author"]:
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
elif "indexed-name" in data["author"]:
authors.append(data["author"]["indexed-name"])
# 提取摘要
abstract = data.get("dc:description", "")
# 提取年份
year = None
if "prism:coverDate" in data:
try:
year = int(data["prism:coverDate"][:4])
except:
pass
# 提取DOI
doi = data.get("prism:doi")
# 提取引用次数
citations = data.get("citedby-count")
if citations:
try:
citations = int(citations)
except:
citations = None
# 提取期刊信息
venue = data.get("prism:publicationName")
# 提取机构信息
institutions = []
if "affiliation" in data:
if isinstance(data["affiliation"], list):
for aff in data["affiliation"]:
if "affilname" in aff:
institutions.append(aff["affilname"])
elif isinstance(data["affiliation"], dict):
if "affilname" in data["affiliation"]:
institutions.append(data["affiliation"]["affilname"])
# 构建venue信息
venue_info = {
"issn": data.get("prism:issn"),
"eissn": data.get("prism:eIssn"),
"volume": data.get("prism:volume"),
"issue": data.get("prism:issueIdentifier"),
"page_range": data.get("prism:pageRange"),
"article_number": data.get("article-number"),
"publication_date": data.get("prism:coverDate")
}
return PaperMetadata(
title=title,
authors=authors,
abstract=abstract,
year=year,
doi=doi,
url=data.get("link", [{}])[0].get("@href"),
citations=citations,
venue=venue,
institutions=institutions,
venue_type="journal",
venue_name=venue,
venue_info=venue_info
)
except Exception as e:
print(f"解析论文数据时发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 构建查询参数
params = {
"query": query,
"count": min(limit, 100), # Scopus API单次请求限制
"start": 0
}
# 添加年份过滤
if start_year:
params["date"] = f"{start_year}-present"
# 添加排序
if sort_by:
sort_map = {
"relevance": "-score",
"date": "-coverDate",
"citations": "-citedby-count"
}
if sort_by in sort_map:
params["sort"] = sort_map[sort_by]
# 发送请求
url = f"{self.base_url}/search/scopus"
response = await self._make_request(url, params)
if not response or "search-results" not in response:
return []
# 解析结果
results = response["search-results"].get("entry", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详情
Args:
paper_id: Scopus ID或DOI
Returns:
论文详情
"""
try:
# 判断是否为DOI
if "/" in paper_id:
url = f"{self.base_url}/article/doi/{paper_id}"
else:
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
response = await self._make_request(url)
if not response or "abstracts-retrieval-response" not in response:
return None
data = response["abstracts-retrieval-response"]
return self._parse_paper_data(data)
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献
Args:
paper_id: Scopus ID
Returns:
引用论文列表
"""
try:
url = f"{self.base_url}/abstract/citations/{paper_id}"
response = await self._make_request(url)
if not response or "citing-papers" not in response:
return []
results = response["citing-papers"].get("papers", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"获取引用信息时发生错误: {str(e)}")
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献
Args:
paper_id: Scopus ID
Returns:
参考文献列表
"""
try:
url = f"{self.base_url}/abstract/references/{paper_id}"
response = await self._make_request(url)
if not response or "references" not in response:
return []
results = response["references"].get("reference", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"AUTHOR-NAME({author})"
if start_year:
query += f" AND PUBYEAR > {start_year}"
return await self.search(query, limit=limit)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"SRCTITLE({journal})"
if start_year:
query += f" AND PUBYEAR > {start_year}"
return await self.search(query, limit=limit)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文"""
query = f"LOAD-DATE > NOW() - {days}d"
return await self.search(query, limit=limit, sort_by="date")
async def example_usage():
"""ScopusSource使用示例"""
scopus = ScopusSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索机器学习相关论文 ===")
papers = await scopus.search("machine learning", limit=3)
print(f"\n找到 {len(papers)} 篇相关论文:")
for i, paper in enumerate(papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
# 示例2按作者搜索
print("\n=== 示例2搜索特定作者的论文 ===")
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
for i, paper in enumerate(author_papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
# 示例3根据关键词搜索相关论文
print("\n=== 示例3搜索人工智能相关论文 ===")
keywords = "artificial intelligence AND deep learning"
papers = await scopus.search(
query=keywords,
limit=5,
sort_by="citations", # 按引用次数排序
start_year=2020 # 只搜索2020年之后的论文
)
print(f"\n找到 {len(papers)} 篇相关论文:")
for i, paper in enumerate(papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

View File

@@ -0,0 +1,480 @@
from typing import List, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class SemanticScholarSource(DataSource):
"""Semantic Scholar API实现,使用官方Python包"""
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Semantic Scholar API密钥(可选)
"""
self.api_key = api_key
self._initialize() # 调用初始化方法
def _initialize(self) -> None:
"""初始化API客户端"""
if not self.api_key:
# 默认API密钥列表
default_api_keys = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
self.api_key = random.choice(default_api_keys)
self.client = None # 延迟初始化
self.fields = [
"title",
"authors",
"abstract",
"year",
"externalIds",
"citationCount",
"venue",
"openAccessPdf",
"publicationVenue"
]
async def _ensure_client(self):
"""确保客户端已初始化"""
if self.client is None:
from semanticscholar import AsyncSemanticScholar
self.client = AsyncSemanticScholar(api_key=self.api_key)
async def search(
self,
query: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
await self._ensure_client()
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} year>={start_year}"
# 直接使用 search_paper 的结果
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
return []
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
"""获取指定DOI的论文详情"""
try:
await self._ensure_client()
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
return self._parse_paper_data(paper)
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
async def get_citations(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
params,
self.client.auth_header
)
citations = response.get('data', [])
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
except Exception as e:
print(f"获取引用列表时发生错误: {str(e)}")
return []
async def get_references(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
params,
self.client.auth_header
)
references = response.get('data', [])
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
except Exception as e:
print(f"获取参考文献列表时发生错误: {str(e)}")
return []
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取论文推荐
根据一篇论文获取相关的推荐论文
Args:
doi: 论文的DOI
limit: 返回结果数量限制最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
papers = await self.client.get_recommended_papers(
f"DOI:{doi}",
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐时发生错误: {str(e)}")
return []
async def get_recommended_papers_from_lists(
self,
positive_dois: List[str],
negative_dois: List[str] = None,
limit: int = 100
) -> List[PaperMetadata]:
"""基于正负例论文列表获取推荐
Args:
positive_dois: 正例论文DOI列表想要获取类似的论文
negative_dois: 负例论文DOI列表不想要类似的论文
limit: 返回结果数量限制最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
papers = await self.client.get_recommended_papers_from_lists(
positive_paper_ids=positive_ids,
negative_paper_ids=negative_ids,
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐列表时发生错误: {str(e)}")
return []
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
"""搜索作者"""
try:
await self._ensure_client()
# 直接使用 API 请求而不是 search_author 方法
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
self.client.auth_header
)
authors = response.get('data', [])
return [
{
'author_id': author.get('authorId'),
'name': author.get('name'),
'paper_count': author.get('paperCount'),
'citation_count': author.get('citationCount'),
}
for author in authors
]
except Exception as e:
print(f"搜索作者时发生错误: {str(e)}")
return []
async def get_author_details(self, author_id: str) -> Optional[dict]:
"""获取作者详细信息"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
"fields=name,paperCount,citationCount,hIndex",
self.client.auth_header
)
return {
'author_id': response.get('authorId'),
'name': response.get('name'),
'paper_count': response.get('paperCount'),
'citation_count': response.get('citationCount'),
'h_index': response.get('hIndex'),
}
except Exception as e:
print(f"获取作者详情时发生错误: {str(e)}")
return None
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
"""获取作者的论文列表"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取作者论文列表时发生错误: {str(e)}")
return []
async def get_paper_autocomplete(self, query: str) -> List[dict]:
"""论文标题自动补全"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
f"query={query}",
self.client.auth_header
)
suggestions = response.get('matches', [])
return [
{
'title': suggestion.get('title'),
'paper_id': suggestion.get('paperId'),
'year': suggestion.get('year'),
'venue': suggestion.get('venue'),
}
for suggestion in suggestions
]
except Exception as e:
print(f"获取标题自动补全时发生错误: {str(e)}")
return []
def _parse_paper_data(self, paper) -> PaperMetadata:
"""解析论文数据"""
# 获取DOI
doi = None
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
if external_ids:
if isinstance(external_ids, dict):
doi = external_ids.get('DOI')
if not doi and 'ArXiv' in external_ids:
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
else:
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
if not doi and hasattr(external_ids, 'ArXiv'):
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
# 获取PDF URL
pdf_url = None
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
if pdf_info:
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
# 获取发表场所详细信息
venue_type = None
venue_name = None
venue_info = {}
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
if venue:
if isinstance(venue, dict):
venue_name = venue.get('name')
venue_type = venue.get('type')
# 提取更多venue信息
venue_info = {
'issn': venue.get('issn'),
'publisher': venue.get('publisher'),
'url': venue.get('url'),
'alternate_names': venue.get('alternate_names', [])
}
else:
venue_name = venue.name if hasattr(venue, 'name') else None
venue_type = venue.type if hasattr(venue, 'type') else None
venue_info = {
'issn': getattr(venue, 'issn', None),
'publisher': getattr(venue, 'publisher', None),
'url': getattr(venue, 'url', None),
'alternate_names': getattr(venue, 'alternate_names', [])
}
# 获取标题
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
# 获取作者
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
author_names = []
for author in authors:
if isinstance(author, dict):
author_names.append(author.get('name', ''))
else:
author_names.append(author.name if hasattr(author, 'name') else str(author))
# 获取摘要
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
# 获取年份
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
# 获取引用次数
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
return PaperMetadata(
title=title,
authors=author_names,
abstract=abstract,
year=year,
doi=doi,
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
citations=citations,
venue=venue_name,
institutions=[],
venue_type=venue_type,
venue_name=venue_name,
venue_info=venue_info,
source='semantic' # 添加来源标记
)
async def example_usage():
"""SemanticScholarSource使用示例"""
semantic = SemanticScholarSource()
try:
# 示例1使用DOI直接获取论文
print("\n=== 示例1通过DOI获取论文 ===")
doi = "10.18653/v1/N19-1423" # BERT论文
print(f"获取DOI为 {doi} 的论文信息...")
paper = await semantic.get_paper_details(doi)
if paper:
print("\n--- 论文信息 ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\n引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
# 示例2搜索论文
print("\n=== 示例2搜索论文 ===")
query = "BERT pre-training"
print(f"搜索关键词 '{query}' 相关的论文...")
papers = await semantic.search(query=query, limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 搜索结果 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\nDOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例3获取论文推荐
print("\n=== 示例3获取论文推荐 ===")
print(f"获取与论文 {doi} 相关的推荐论文...")
recommendations = await semantic.get_recommended_papers(doi, limit=3)
for i, paper in enumerate(recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 示例4基于多篇论文的推荐
print("\n=== 示例4基于多篇论文的推荐 ===")
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
multi_recommendations = await semantic.get_recommended_papers_from_lists(
positive_dois=positive_dois,
limit=3
)
for i, paper in enumerate(multi_recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
# 示例5搜索作者
print("\n=== 示例5搜索作者 ===")
author_query = "Yann LeCun"
print(f"搜索作者: '{author_query}'")
authors = await semantic.search_author(author_query, limit=3)
for i, author in enumerate(authors, 1):
print(f"\n--- 作者 {i} ---")
print(f"姓名: {author['name']}")
print(f"论文数量: {author['paper_count']}")
print(f"总引用次数: {author['citation_count']}")
# 示例6获取作者详情
print("\n=== 示例6获取作者详情 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者ID {author_id} 的详细信息...")
author_details = await semantic.get_author_details(author_id)
if author_details:
print(f"姓名: {author_details['name']}")
print(f"H指数: {author_details['h_index']}")
print(f"总引用次数: {author_details['citation_count']}")
print(f"发表论文数: {author_details['paper_count']}")
# 示例7获取作者论文
print("\n=== 示例7获取作者论文 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者 {authors[0]['name']} 的论文列表...")
author_papers = await semantic.get_author_papers(author_id, limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
print(f"引用次数: {paper.citations}")
# 示例8论文标题自动补全
print("\n=== 示例8论文标题自动补全 ===")
title_query = "Attention is all"
print(f"搜索标题: '{title_query}'")
suggestions = await semantic.get_paper_autocomplete(title_query)
for i, suggestion in enumerate(suggestions[:3], 1):
print(f"\n--- 建议 {i} ---")
print(f"标题: {suggestion['title']}")
print(f"发表年份: {suggestion['year']}")
print(f"发表venue: {suggestion['venue']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

View File

@@ -0,0 +1,46 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from .base_source import DataSource, PaperMetadata
class UnpaywallSource(DataSource):
"""Unpaywall API实现"""
def _initialize(self) -> None:
self.base_url = "https://api.unpaywall.org/v2"
self.email = self.api_key # Unpaywall使用email作为API key
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/search",
params={
"query": query,
"email": self.email,
"limit": limit
}
) as response:
data = await response.json()
return [self._parse_response(item.response)
for item in data.get("results", [])]
def _parse_response(self, data: Dict) -> PaperMetadata:
"""解析Unpaywall返回的数据"""
return PaperMetadata(
title=data.get("title", ""),
authors=[
f"{author.get('given', '')} {author.get('family', '')}"
for author in data.get("z_authors", [])
],
institutions=[
aff.get("name", "")
for author in data.get("z_authors", [])
for aff in author.get("affiliation", [])
],
abstract="", # Unpaywall不提供摘要
year=data.get("year"),
doi=data.get("doi"),
url=data.get("doi_url"),
citations=None, # Unpaywall不提供引用计数
venue=data.get("journal_name")
)