Master 4.0 (#2210)
* stage academic conversation * stage document conversation * fix buggy gradio version * file dynamic load * merge more academic plugins * accelerate nltk * feat: 为predict函数添加文件和URL读取功能 - 添加URL检测和网页内容提取功能,支持自动提取网页文本 - 添加文件路径识别和文件内容读取功能,支持private_upload路径格式 - 集成WebTextExtractor处理网页内容提取 - 集成TextContentLoader处理本地文件读取 - 支持文件路径与问题组合的智能处理 * back * block unstable --------- Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
This commit is contained in:
0
crazy_functions/review_fns/data_sources/__init__.py
Normal file
0
crazy_functions/review_fns/data_sources/__init__.py
Normal file
279
crazy_functions/review_fns/data_sources/adsabs_source.py
Normal file
279
crazy_functions/review_fns/data_sources/adsabs_source.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class AdsabsSource(DataSource):
|
||||
"""ADS (Astrophysics Data System) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.adsabs.harvard.edu/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
method: HTTP方法
|
||||
data: POST请求数据
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
if method == "GET":
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif method == "POST":
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper(self, doc: dict) -> PaperMetadata:
|
||||
"""解析ADS文献数据
|
||||
|
||||
Args:
|
||||
doc: ADS文献数据
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
return PaperMetadata(
|
||||
title=doc.get('title', [''])[0] if doc.get('title') else '',
|
||||
authors=doc.get('author', []),
|
||||
abstract=doc.get('abstract', ''),
|
||||
year=doc.get('year'),
|
||||
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
|
||||
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
|
||||
citations=doc.get('citation_count'),
|
||||
venue=doc.get('pub', ''),
|
||||
institutions=doc.get('aff', []),
|
||||
venue_type="journal",
|
||||
venue_name=doc.get('pub', ''),
|
||||
venue_info={
|
||||
'volume': doc.get('volume'),
|
||||
'issue': doc.get('issue'),
|
||||
'pub_date': doc.get('pubdate', '')
|
||||
},
|
||||
source='adsabs'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
if start_year:
|
||||
query = f"{query} year:{start_year}-"
|
||||
|
||||
# 设置排序
|
||||
sort_mapping = {
|
||||
'relevance': 'score desc',
|
||||
'date': 'date desc',
|
||||
'citations': 'citation_count desc'
|
||||
}
|
||||
sort = sort_mapping.get(sort_by, 'score desc')
|
||||
|
||||
# 构建搜索请求
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": query,
|
||||
"rows": limit,
|
||||
"sort": sort,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _build_query_string(self, params: dict) -> str:
|
||||
"""构建查询字符串"""
|
||||
return "&".join([f"{k}={v}" for k, v in params.items()])
|
||||
|
||||
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定bibcode的论文详情"""
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"identifier:{bibcode}",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if response and 'response' in response and response['response']['docs']:
|
||||
return self._parse_paper(response['response']['docs'][0])
|
||||
return None
|
||||
|
||||
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
|
||||
"rows": limit,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"author:\"{author}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"pub:\"{journal}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"entdate:[NOW-{days}DAYS TO NOW]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"references(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def example_usage():
|
||||
"""AdsabsSource使用示例"""
|
||||
ads = AdsabsSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索黑洞相关论文 ===")
|
||||
papers = await ads.search("black hole", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 其他示例...
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# python -m crazy_functions.review_fns.data_sources.adsabs_source
|
||||
asyncio.run(example_usage())
|
||||
636
crazy_functions/review_fns/data_sources/arxiv_source.py
Normal file
636
crazy_functions/review_fns/data_sources/arxiv_source.py
Normal file
@@ -0,0 +1,636 @@
|
||||
import arxiv
|
||||
from typing import List, Optional, Union, Literal, Dict
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.request import urlretrieve
|
||||
import feedparser
|
||||
from tqdm import tqdm
|
||||
|
||||
class ArxivSource(DataSource):
|
||||
"""arXiv API实现"""
|
||||
|
||||
CATEGORIES = {
|
||||
# 物理学
|
||||
"Physics": {
|
||||
"astro-ph": "天体物理学",
|
||||
"cond-mat": "凝聚态物理",
|
||||
"gr-qc": "广义相对论与量子宇宙学",
|
||||
"hep-ex": "高能物理实验",
|
||||
"hep-lat": "格点场论",
|
||||
"hep-ph": "高能物理理论",
|
||||
"hep-th": "高能物理理论",
|
||||
"math-ph": "数学物理",
|
||||
"nlin": "非线性科学",
|
||||
"nucl-ex": "核实验",
|
||||
"nucl-th": "核理论",
|
||||
"physics": "物理学",
|
||||
"quant-ph": "量子物理",
|
||||
},
|
||||
|
||||
# 数学
|
||||
"Mathematics": {
|
||||
"math.AG": "代数几何",
|
||||
"math.AT": "代数拓扑",
|
||||
"math.AP": "分析与偏微分方程",
|
||||
"math.CT": "范畴论",
|
||||
"math.CA": "复分析",
|
||||
"math.CO": "组合数学",
|
||||
"math.AC": "交换代数",
|
||||
"math.CV": "复变函数",
|
||||
"math.DG": "微分几何",
|
||||
"math.DS": "动力系统",
|
||||
"math.FA": "泛函分析",
|
||||
"math.GM": "一般数学",
|
||||
"math.GN": "一般拓扑",
|
||||
"math.GT": "几何拓扑",
|
||||
"math.GR": "群论",
|
||||
"math.HO": "数学史与数学概述",
|
||||
"math.IT": "信息论",
|
||||
"math.KT": "K理论与同调",
|
||||
"math.LO": "逻辑",
|
||||
"math.MP": "数学物理",
|
||||
"math.MG": "度量几何",
|
||||
"math.NT": "数论",
|
||||
"math.NA": "数值分析",
|
||||
"math.OA": "算子代数",
|
||||
"math.OC": "最优化与控制",
|
||||
"math.PR": "概率论",
|
||||
"math.QA": "量子代数",
|
||||
"math.RT": "表示论",
|
||||
"math.RA": "环与代数",
|
||||
"math.SP": "谱理论",
|
||||
"math.ST": "统计理论",
|
||||
"math.SG": "辛几何",
|
||||
},
|
||||
|
||||
# 计算机科学
|
||||
"Computer Science": {
|
||||
"cs.AI": "人工智能",
|
||||
"cs.CL": "计算语言学",
|
||||
"cs.CC": "计算复杂性",
|
||||
"cs.CE": "计算工程",
|
||||
"cs.CG": "计算几何",
|
||||
"cs.GT": "计算机博弈论",
|
||||
"cs.CV": "计算机视觉",
|
||||
"cs.CY": "计算机与社会",
|
||||
"cs.CR": "密码学与安全",
|
||||
"cs.DS": "数据结构与算法",
|
||||
"cs.DB": "数据库",
|
||||
"cs.DL": "数字图书馆",
|
||||
"cs.DM": "离散数学",
|
||||
"cs.DC": "分布式计算",
|
||||
"cs.ET": "新兴技术",
|
||||
"cs.FL": "形式语言与自动机理论",
|
||||
"cs.GL": "一般文献",
|
||||
"cs.GR": "图形学",
|
||||
"cs.AR": "硬件架构",
|
||||
"cs.HC": "人机交互",
|
||||
"cs.IR": "信息检索",
|
||||
"cs.IT": "信息论",
|
||||
"cs.LG": "机器学习",
|
||||
"cs.LO": "逻辑与计算机",
|
||||
"cs.MS": "数学软件",
|
||||
"cs.MA": "多智能体系统",
|
||||
"cs.MM": "多媒体",
|
||||
"cs.NI": "网络与互联网架构",
|
||||
"cs.NE": "神经与进化计算",
|
||||
"cs.NA": "数值分析",
|
||||
"cs.OS": "操作系统",
|
||||
"cs.OH": "其他计算机科学",
|
||||
"cs.PF": "性能评估",
|
||||
"cs.PL": "编程语言",
|
||||
"cs.RO": "机器人学",
|
||||
"cs.SI": "社会与信息网络",
|
||||
"cs.SE": "软件工程",
|
||||
"cs.SD": "声音",
|
||||
"cs.SC": "符号计算",
|
||||
"cs.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 定量生物学
|
||||
"Quantitative Biology": {
|
||||
"q-bio.BM": "生物分子",
|
||||
"q-bio.CB": "细胞行为",
|
||||
"q-bio.GN": "基因组学",
|
||||
"q-bio.MN": "分子网络",
|
||||
"q-bio.NC": "神经计算",
|
||||
"q-bio.OT": "其他",
|
||||
"q-bio.PE": "群体与进化",
|
||||
"q-bio.QM": "定量方法",
|
||||
"q-bio.SC": "亚细胞过程",
|
||||
"q-bio.TO": "组织与器官",
|
||||
},
|
||||
|
||||
# 定量金融
|
||||
"Quantitative Finance": {
|
||||
"q-fin.CP": "计算金融",
|
||||
"q-fin.EC": "经济学",
|
||||
"q-fin.GN": "一般金融",
|
||||
"q-fin.MF": "数学金融",
|
||||
"q-fin.PM": "投资组合管理",
|
||||
"q-fin.PR": "定价理论",
|
||||
"q-fin.RM": "风险管理",
|
||||
"q-fin.ST": "统计金融",
|
||||
"q-fin.TR": "交易与市场微观结构",
|
||||
},
|
||||
|
||||
# 统计学
|
||||
"Statistics": {
|
||||
"stat.AP": "应用统计",
|
||||
"stat.CO": "计算统计",
|
||||
"stat.ML": "机器学习",
|
||||
"stat.ME": "方法论",
|
||||
"stat.OT": "其他统计",
|
||||
"stat.TH": "统计理论",
|
||||
},
|
||||
|
||||
# 电气工程与系统科学
|
||||
"Electrical Engineering and Systems Science": {
|
||||
"eess.AS": "音频与语音处理",
|
||||
"eess.IV": "图像与视频处理",
|
||||
"eess.SP": "信号处理",
|
||||
"eess.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 经济学
|
||||
"Economics": {
|
||||
"econ.EM": "计量经济学",
|
||||
"econ.GN": "一般经济学",
|
||||
"econ.TH": "理论经济学",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化"""
|
||||
self._initialize() # 调用初始化方法
|
||||
# 修改排序选项映射
|
||||
self.sort_options = {
|
||||
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
|
||||
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
|
||||
}
|
||||
|
||||
self.sort_order_options = {
|
||||
'ascending': arxiv.SortOrder.Ascending,
|
||||
'descending': arxiv.SortOrder.Descending
|
||||
}
|
||||
|
||||
self.default_sort = 'lastUpdatedDate'
|
||||
self.default_order = 'descending'
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.client = arxiv.Client()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[Dict]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
# 使用默认排序如果提供的排序选项无效
|
||||
if not sort_by or sort_by not in self.sort_options:
|
||||
sort_by = self.default_sort
|
||||
|
||||
# 使用默认排序顺序如果提供的顺序无效
|
||||
if not sort_order or sort_order not in self.sort_order_options:
|
||||
sort_order = self.default_order
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
search = arxiv.Search(
|
||||
query=query,
|
||||
max_results=limit,
|
||||
sort_by=self.sort_options[sort_by],
|
||||
sort_order=self.sort_order_options[sort_order]
|
||||
)
|
||||
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
|
||||
"""按ID搜索论文
|
||||
|
||||
Args:
|
||||
paper_id: 单个arXiv ID或ID列表,例如:'2005.14165' 或 ['2005.14165', '2103.14030']
|
||||
"""
|
||||
if isinstance(paper_id, str):
|
||||
paper_id = [paper_id]
|
||||
|
||||
search = arxiv.Search(
|
||||
id_list=paper_id,
|
||||
max_results=len(paper_id)
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
async def search_by_category(
|
||||
self,
|
||||
category: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
sort_order: str = 'descending',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按类别搜索论文"""
|
||||
query = f"cat:{category}"
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " AND ".join([f"au:\"{author}\"" for author in authors])
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
|
||||
sort_order: Literal['ascending', 'descending'] = 'descending'
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
|
||||
return await self.search(
|
||||
query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文PDF
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
# 清理标题中的非法字符
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.pdf"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
urlretrieve(paper.url, filepath)
|
||||
return filepath
|
||||
|
||||
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文源文件(通常是LaTeX源码)
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.tar.gz"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
source_url = paper.url.replace("/pdf/", "/src/")
|
||||
urlretrieve(source_url, filepath)
|
||||
return filepath
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID 或 DOI
|
||||
|
||||
Returns:
|
||||
论文详细信息,如果未找到返回 None
|
||||
"""
|
||||
try:
|
||||
# 如果是完整的 arXiv URL,提取 ID
|
||||
if "arxiv.org" in paper_id:
|
||||
paper_id = paper_id.split("/")[-1]
|
||||
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
|
||||
elif paper_id.startswith("10.48550/arXiv."):
|
||||
paper_id = paper_id.split(".")[-1]
|
||||
|
||||
papers = await self.search_by_id(paper_id)
|
||||
return papers[0] if papers else None
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
|
||||
"""解析arXiv API返回的数据"""
|
||||
# 解析主要类别和次要类别
|
||||
primary_category = result.primary_category
|
||||
categories = result.categories
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
'primary_category': primary_category,
|
||||
'categories': categories,
|
||||
'comments': getattr(result, 'comment', None),
|
||||
'journal_ref': getattr(result, 'journal_ref', None)
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=result.title,
|
||||
authors=[author.name for author in result.authors],
|
||||
abstract=result.summary,
|
||||
year=result.published.year,
|
||||
doi=result.entry_id,
|
||||
url=result.pdf_url,
|
||||
citations=None,
|
||||
venue=f"arXiv:{primary_category}",
|
||||
institutions=[],
|
||||
venue_type='preprint', # arXiv论文都是预印本
|
||||
venue_name='arXiv',
|
||||
venue_info=venue_info,
|
||||
source='arxiv' # 添加来源标记
|
||||
)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
category: str,
|
||||
debug: bool = False,
|
||||
batch_size: int = 50
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定类别的最新论文
|
||||
|
||||
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
|
||||
|
||||
Args:
|
||||
category: arXiv类别,例如:
|
||||
- 整个领域: 'cs'
|
||||
- 具体方向: 'cs.AI'
|
||||
- 多个类别: 'cs.AI+q-bio.NC'
|
||||
debug: 是否为调试模式,如果为True则只返回5篇最新论文
|
||||
batch_size: 批量获取论文的数量,默认50
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果类别无效
|
||||
"""
|
||||
try:
|
||||
# 处理类别格式
|
||||
# 1. 转换为小写
|
||||
# 2. 确保多个类别之间使用+连接
|
||||
category = category.lower().replace(' ', '+')
|
||||
|
||||
# 构建RSS feed URL
|
||||
feed_url = f"https://rss.arxiv.org/rss/{category}"
|
||||
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
|
||||
|
||||
feed = feedparser.parse(feed_url)
|
||||
|
||||
# 检查feed是否有效
|
||||
if hasattr(feed, 'status') and feed.status != 200:
|
||||
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
|
||||
|
||||
if not feed.entries:
|
||||
print(f"警告:未在feed中找到任何条目") # 添加调试信息
|
||||
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
|
||||
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
|
||||
|
||||
if debug:
|
||||
# 调试模式:只获取5篇最新论文
|
||||
search = arxiv.Search(
|
||||
query=f'cat:{category}',
|
||||
sort_by=arxiv.SortCriterion.SubmittedDate,
|
||||
sort_order=arxiv.SortOrder.Descending,
|
||||
max_results=5
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
# 正常模式:获取所有新论文
|
||||
# 从RSS条目中提取arXiv ID
|
||||
paper_ids = []
|
||||
for entry in feed.entries:
|
||||
try:
|
||||
# RSS链接格式可能是以下几种:
|
||||
# - http://arxiv.org/abs/2403.xxxxx
|
||||
# - http://arxiv.org/pdf/2403.xxxxx
|
||||
# - https://arxiv.org/abs/2403.xxxxx
|
||||
link = entry.link or entry.id
|
||||
arxiv_id = link.split('/')[-1].replace('.pdf', '')
|
||||
if arxiv_id:
|
||||
paper_ids.append(arxiv_id)
|
||||
except Exception as e:
|
||||
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
|
||||
continue
|
||||
|
||||
if not paper_ids:
|
||||
print("未能从feed中提取到任何论文ID") # 添加调试信息
|
||||
return []
|
||||
|
||||
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
|
||||
for i in range(0, len(paper_ids), batch_size):
|
||||
batch_ids = paper_ids[i:i + batch_size]
|
||||
search = arxiv.Search(
|
||||
id_list=batch_ids,
|
||||
max_results=len(batch_ids)
|
||||
)
|
||||
batch_results = list(self.client.results(search))
|
||||
papers.extend([self._parse_paper_data(result) for result in batch_results])
|
||||
pbar.update(len(batch_results))
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取最新论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc()) # 添加完整的错误追踪
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""ArxivSource使用示例"""
|
||||
arxiv_source = ArxivSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
# print("\n=== 示例1:搜索最新的机器学习论文(按提交时间排序)===")
|
||||
# papers = await arxiv_source.search(
|
||||
# "ti:\"machine learning\"",
|
||||
# limit=3,
|
||||
# sort_by='submitted',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# for i, paper in enumerate(papers, 1):
|
||||
# print(f"\n--- 论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
# print(f"arXiv ID: {paper.doi}")
|
||||
# print(f"PDF URL: {paper.url}")
|
||||
# if paper.abstract:
|
||||
# print(f"\n摘要:")
|
||||
# print(paper.abstract)
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例2:按ID搜索
|
||||
# print("\n=== 示例2:按ID搜索论文 ===")
|
||||
# paper_id = "2005.14165" # GPT-3论文
|
||||
# papers = await arxiv_source.search_by_id(paper_id)
|
||||
# if papers:
|
||||
# paper = papers[0]
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例3:按类别搜索
|
||||
# print("\n=== 示例3:搜索人工智能领域最新论文 ===")
|
||||
# ai_papers = await arxiv_source.search_by_category(
|
||||
# "cs.AI",
|
||||
# limit=2,
|
||||
# sort_by='updated',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# for i, paper in enumerate(ai_papers, 1):
|
||||
# print(f"\n--- AI论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例4:按作者搜索
|
||||
# print("\n=== 示例4:搜索特定作者的论文 ===")
|
||||
# author_papers = await arxiv_source.search_by_authors(
|
||||
# ["Bengio"],
|
||||
# limit=2,
|
||||
# sort_by='relevance'
|
||||
# )
|
||||
# for i, paper in enumerate(author_papers, 1):
|
||||
# print(f"\n--- Bengio的论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例5:按日期范围搜索
|
||||
# print("\n=== 示例5:搜索特定日期范围的论文 ===")
|
||||
# from datetime import datetime, timedelta
|
||||
# end_date = datetime.now()
|
||||
# start_date = end_date - timedelta(days=7) # 最近一周
|
||||
# recent_papers = await arxiv_source.search_by_date_range(
|
||||
# start_date,
|
||||
# end_date,
|
||||
# limit=2
|
||||
# )
|
||||
# for i, paper in enumerate(recent_papers, 1):
|
||||
# print(f"\n--- 最近论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例6:下载PDF
|
||||
# print("\n=== 示例6:下载论文PDF ===")
|
||||
# if papers: # 使用之前搜索到的GPT-3论文
|
||||
# pdf_path = await arxiv_source.download_pdf(paper_id)
|
||||
# print(f"PDF已下载到: {pdf_path}")
|
||||
|
||||
# # 示例7:下载源文件
|
||||
# print("\n=== 示例7:下载论文源文件 ===")
|
||||
# if papers:
|
||||
# source_path = await arxiv_source.download_source(paper_id)
|
||||
# print(f"源文件已下载到: {source_path}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例8:获取最新论文 ===")
|
||||
|
||||
# 获取CS.AI领域的最新论文
|
||||
print("\n--- 获取AI领域最新论文 ---")
|
||||
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
|
||||
for i, paper in enumerate(ai_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取整个计算机科学领域的最新论文
|
||||
print("\n--- 获取整个CS领域最新论文 ---")
|
||||
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
|
||||
for i, paper in enumerate(cs_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取多个类别的最新论文
|
||||
print("\n--- 获取AI和机器学习领域最新论文 ---")
|
||||
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
|
||||
for i, paper in enumerate(multi_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
102
crazy_functions/review_fns/data_sources/base_source.py
Normal file
102
crazy_functions/review_fns/data_sources/base_source.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
class PaperMetadata:
|
||||
"""论文元数据"""
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
authors: List[str],
|
||||
abstract: str,
|
||||
year: int,
|
||||
doi: str = None,
|
||||
url: str = None,
|
||||
citations: int = None,
|
||||
venue: str = None,
|
||||
institutions: List[str] = None,
|
||||
venue_type: str = None, # 来源类型(journal/conference/preprint等)
|
||||
venue_name: str = None, # 具体的期刊/会议名称
|
||||
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
|
||||
source: str = None # 新增: 论文来源标记
|
||||
):
|
||||
self.title = title
|
||||
self.authors = authors
|
||||
self.abstract = abstract
|
||||
self.year = year
|
||||
self.doi = doi
|
||||
self.url = url
|
||||
self.citations = citations
|
||||
self.venue = venue
|
||||
self.institutions = institutions or []
|
||||
self.venue_type = venue_type # 新增
|
||||
self.venue_name = venue_name # 新增
|
||||
self.venue_info = venue_info or {} # 新增
|
||||
self.source = source # 新增: 存储论文来源
|
||||
|
||||
# 新增:影响因子和分区信息,初始化为None
|
||||
self._if_factor = None
|
||||
self._cas_division = None
|
||||
self._jcr_division = None
|
||||
|
||||
@property
|
||||
def if_factor(self) -> Optional[float]:
|
||||
"""获取影响因子"""
|
||||
return self._if_factor
|
||||
|
||||
@if_factor.setter
|
||||
def if_factor(self, value: float):
|
||||
"""设置影响因子"""
|
||||
self._if_factor = value
|
||||
|
||||
@property
|
||||
def cas_division(self) -> Optional[str]:
|
||||
"""获取中科院分区"""
|
||||
return self._cas_division
|
||||
|
||||
@cas_division.setter
|
||||
def cas_division(self, value: str):
|
||||
"""设置中科院分区"""
|
||||
self._cas_division = value
|
||||
|
||||
@property
|
||||
def jcr_division(self) -> Optional[str]:
|
||||
"""获取JCR分区"""
|
||||
return self._jcr_division
|
||||
|
||||
@jcr_division.setter
|
||||
def jcr_division(self, value: str):
|
||||
"""设置JCR分区"""
|
||||
self._jcr_division = value
|
||||
|
||||
class DataSource(ABC):
|
||||
"""数据源基类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key
|
||||
self._initialize()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize(self) -> None:
|
||||
"""初始化数据源"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
|
||||
"""获取论文详细信息"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
pass
|
||||
1
crazy_functions/review_fns/data_sources/cas_if.json
Normal file
1
crazy_functions/review_fns/data_sources/cas_if.json
Normal file
File diff suppressed because one or more lines are too long
400
crazy_functions/review_fns/data_sources/crossref_source.py
Normal file
400
crazy_functions/review_fns/data_sources/crossref_source.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class CrossrefSource(DataSource):
|
||||
"""Crossref API实现"""
|
||||
|
||||
CONTACT_EMAILS = [
|
||||
"gpt_abc_academic@163.com",
|
||||
"gpt_abc_newapi@163.com",
|
||||
"gpt_abc_academic_pwd@163.com"
|
||||
]
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.crossref.org"
|
||||
# 随机选择一个邮箱
|
||||
contact_email = random.choice(self.CONTACT_EMAILS)
|
||||
self.headers = {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
|
||||
}
|
||||
if self.api_key:
|
||||
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序字段
|
||||
sort_order: 排序顺序
|
||||
start_year: 起始年份
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
# 请求更多的结果以补偿可能被过滤掉的文章
|
||||
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
|
||||
params = {
|
||||
"query": query,
|
||||
"rows": adjusted_limit,
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["filter"] = f"from-pub-date:{start_year}"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
params["sort"] = sort_by
|
||||
if sort_order:
|
||||
params["order"] = sort_order
|
||||
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
items = data.get("message", {}).get("items", [])
|
||||
if not items:
|
||||
print(f"未找到相关论文")
|
||||
return []
|
||||
|
||||
# 过滤掉没有摘要的文章
|
||||
papers = []
|
||||
filtered_count = 0
|
||||
for work in items:
|
||||
paper = self._parse_work(work)
|
||||
if paper.abstract and paper.abstract.strip():
|
||||
papers.append(paper)
|
||||
if len(papers) >= limit: # 达到原始请求的限制后停止
|
||||
break
|
||||
else:
|
||||
filtered_count += 1
|
||||
|
||||
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
|
||||
print(f"返回 {len(papers)} 篇包含摘要的论文")
|
||||
return papers
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取论文详情失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
return self._parse_work(data.get("message", {}))
|
||||
except Exception as e:
|
||||
print(f"解析论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={"select": "reference"}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取参考文献失败: HTTP {response.status}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 确保我们正确处理返回的数据结构
|
||||
if not isinstance(data, dict):
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
|
||||
references = data.get("message", {}).get("reference", [])
|
||||
if not references:
|
||||
print(f"未找到参考文献")
|
||||
return []
|
||||
|
||||
return [
|
||||
PaperMetadata(
|
||||
title=ref.get("article-title", ""),
|
||||
authors=[ref.get("author", "")],
|
||||
year=ref.get("year"),
|
||||
doi=ref.get("DOI"),
|
||||
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
|
||||
abstract="",
|
||||
citations=None,
|
||||
venue=ref.get("journal-title", ""),
|
||||
institutions=[]
|
||||
)
|
||||
for ref in references
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"解析参考文献数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params={
|
||||
"filter": f"reference.DOI:{doi}",
|
||||
"select": "DOI,title,author,published-print,abstract"
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取引用信息失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 检查返回的数据结构
|
||||
if isinstance(data, dict):
|
||||
items = data.get("message", {}).get("items", [])
|
||||
return [self._parse_work(work) for work in items]
|
||||
else:
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"解析引用数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析Crossref返回的数据"""
|
||||
# 获取摘要 - 处理可能的不同格式
|
||||
abstract = ""
|
||||
if isinstance(work.get("abstract"), str):
|
||||
abstract = work.get("abstract", "")
|
||||
elif isinstance(work.get("abstract"), dict):
|
||||
abstract = work.get("abstract", {}).get("value", "")
|
||||
|
||||
if not abstract:
|
||||
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
for author in work.get("author", []):
|
||||
if "affiliation" in author:
|
||||
for affiliation in author["affiliation"]:
|
||||
if "name" in affiliation and affiliation["name"] not in institutions:
|
||||
institutions.append(affiliation["name"])
|
||||
|
||||
# 获取venue信息
|
||||
venue_name = work.get("container-title", [None])[0]
|
||||
venue_type = work.get("type", "unknown") # 文献类型
|
||||
venue_info = {
|
||||
"publisher": work.get("publisher"),
|
||||
"issn": work.get("ISSN", []),
|
||||
"isbn": work.get("ISBN", []),
|
||||
"issue": work.get("issue"),
|
||||
"volume": work.get("volume"),
|
||||
"page": work.get("page")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", [None])[0] or "",
|
||||
authors=[
|
||||
author.get("given", "") + " " + author.get("family", "")
|
||||
for author in work.get("author", [])
|
||||
],
|
||||
institutions=institutions, # 添加机构信息
|
||||
abstract=abstract,
|
||||
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
|
||||
doi=work.get("DOI"),
|
||||
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
|
||||
citations=work.get("is-referenced-by-count"),
|
||||
venue=venue_name,
|
||||
venue_type=venue_type, # 添加venue类型
|
||||
venue_name=venue_name, # 添加venue名称
|
||||
venue_info=venue_info, # 添加venue详细信息
|
||||
source='crossref' # 添加来源标记
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " ".join([f"author:\"{author}\"" for author in authors])
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
start_year=start_year
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""CrossrefSource使用示例"""
|
||||
crossref = CrossrefSource(api_key=None)
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
print("\n=== 示例1:搜索最新的机器学习论文 ===")
|
||||
papers = await crossref.search(
|
||||
query="machine learning",
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc",
|
||||
start_year=2023
|
||||
)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
print(f"venue类型: {paper.venue_type}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:按DOI获取论文详情
|
||||
print("\n=== 示例2:获取特定论文详情 ===")
|
||||
# 使用BERT论文的DOI
|
||||
doi = "10.18653/v1/N19-1423"
|
||||
paper = await crossref.get_paper_details(doi)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:按作者搜索
|
||||
print("\n=== 示例3:搜索特定作者的论文 ===")
|
||||
author_papers = await crossref.search_by_authors(
|
||||
authors=["Yoshua Bengio"],
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
start_year=2020
|
||||
)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- {i}. {paper.title} ---")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例4:按日期范围搜索
|
||||
print("\n=== 示例4:搜索特定日期范围的论文 ===")
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=30) # 最近一个月
|
||||
recent_papers = await crossref.search_by_date_range(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc"
|
||||
)
|
||||
for i, paper in enumerate(recent_papers, 1):
|
||||
print(f"\n--- 最近发表的论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 示例5:获取论文引用信息
|
||||
print("\n=== 示例5:获取论文引用信息 ===")
|
||||
if paper: # 使用之前获取的BERT论文
|
||||
print("\n获取引用该论文的文献:")
|
||||
citations = await crossref.get_citations(paper.doi)
|
||||
for i, citing_paper in enumerate(citations[:3], 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {citing_paper.title}")
|
||||
print(f"作者: {', '.join(citing_paper.authors)}")
|
||||
print(f"发表年份: {citing_paper.year}")
|
||||
|
||||
print("\n获取该论文引用的参考文献:")
|
||||
references = await crossref.get_references(paper.doi)
|
||||
for i, ref_paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {ref_paper.title}")
|
||||
print(f"作者: {', '.join(ref_paper.authors)}")
|
||||
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
|
||||
|
||||
# 示例6:展示venue信息的使用
|
||||
print("\n=== 示例6:展示期刊/会议详细信息 ===")
|
||||
if papers:
|
||||
paper = papers[0]
|
||||
print(f"文献类型: {paper.venue_type}")
|
||||
print(f"发表venue: {paper.venue_name}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
449
crazy_functions/review_fns/data_sources/elsevier_source.py
Normal file
449
crazy_functions/review_fns/data_sources/elsevier_source.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class ElsevierSource(DataSource):
|
||||
"""Elsevier (Scopus) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
# 添加更多必要的头部信息
|
||||
"X-ELS-Insttoken": "", # 如果有机构令牌
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
JSON响应
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
# 添加更详细的错误信息
|
||||
error_text = await response.text()
|
||||
print(f"请求失败: {response.status}")
|
||||
print(f"错误详情: {error_text}")
|
||||
if response.status == 401:
|
||||
print(f"使用的API密钥: {self.api_key}")
|
||||
# 尝试切换到另一个API密钥
|
||||
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
|
||||
print(f"尝试切换到新的API密钥: {new_key}")
|
||||
self.api_key = new_key
|
||||
self.headers["X-ELS-APIKey"] = new_key
|
||||
# 重试请求
|
||||
return await self._make_request(url, params)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD",
|
||||
# 移除dc:description字段,因为它在STANDARD视图中不可用
|
||||
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by == "date":
|
||||
params["sort"] = "-coverDate"
|
||||
elif sort_by == "cited":
|
||||
params["sort"] = "-citedby-count"
|
||||
|
||||
# 发送搜索请求
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析搜索结果
|
||||
entries = response["search-results"].get("entry", [])
|
||||
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
|
||||
|
||||
# 尝试为每篇论文获取摘要
|
||||
for paper in papers:
|
||||
if paper.doi:
|
||||
paper.abstract = await self.fetch_abstract(paper.doi) or ""
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析Scopus API返回的条目"""
|
||||
try:
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
creator = entry.get("dc:creator")
|
||||
if creator:
|
||||
authors = [creator]
|
||||
|
||||
# 获取发表年份
|
||||
year = None
|
||||
if "prism:coverDate" in entry:
|
||||
try:
|
||||
year = int(entry["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 简化venue信息
|
||||
venue_info = {
|
||||
'source_id': entry.get("source-id"),
|
||||
'issn': entry.get("prism:issn")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=entry.get("dc:title", ""),
|
||||
authors=authors,
|
||||
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
|
||||
year=year,
|
||||
doi=entry.get("prism:doi"),
|
||||
url=entry.get("prism:url"),
|
||||
citations=int(entry.get("citedby-count", 0)),
|
||||
venue=entry.get("prism:publicationName"),
|
||||
institutions=[], # 移除机构信息
|
||||
venue_type="",
|
||||
venue_name=entry.get("prism:publicationName"),
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析条目时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
try:
|
||||
params = {
|
||||
"query": f"REF({doi})",
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD"
|
||||
}
|
||||
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
entries = response["search-results"].get("entry", [])
|
||||
return [self._parse_entry(entry) for entry in entries]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
try:
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}/references",
|
||||
params={"view": "STANDARD"}
|
||||
)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
references = response["references"].get("reference", [])
|
||||
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析参考文献数据"""
|
||||
try:
|
||||
authors = []
|
||||
if "author-list" in ref:
|
||||
author_list = ref["author-list"].get("author", [])
|
||||
if isinstance(author_list, list):
|
||||
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
|
||||
for author in author_list]
|
||||
else:
|
||||
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
|
||||
|
||||
year = None
|
||||
if "prism:coverDate" in ref:
|
||||
try:
|
||||
year = int(ref["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
return PaperMetadata(
|
||||
title=ref.get("ce:title", ""),
|
||||
authors=authors,
|
||||
abstract="", # 参考文献通常不包含摘要
|
||||
year=year,
|
||||
doi=ref.get("prism:doi"),
|
||||
url=None,
|
||||
citations=None,
|
||||
venue=ref.get("prism:publicationName"),
|
||||
institutions=[],
|
||||
venue_type="unknown",
|
||||
venue_name=ref.get("prism:publicationName"),
|
||||
venue_info={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析参考文献时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_affiliation(
|
||||
self,
|
||||
affiliation: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按机构搜索论文"""
|
||||
query = f"AF-ID({affiliation})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_venue(
|
||||
self,
|
||||
venue: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊/会议搜索论文"""
|
||||
query = f"SRCTITLE({venue})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def test_api_access(self):
|
||||
"""测试API访问权限"""
|
||||
print(f"\n测试API密钥: {self.api_key}")
|
||||
|
||||
# 测试1: 基础搜索
|
||||
basic_params = {
|
||||
"query": "test",
|
||||
"count": 1,
|
||||
"view": "STANDARD"
|
||||
}
|
||||
print("\n1. 测试基础搜索...")
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=basic_params
|
||||
)
|
||||
if response:
|
||||
print("基础搜索成功")
|
||||
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
|
||||
|
||||
# 测试2: 测试单篇文章访问
|
||||
print("\n2. 测试文章详情访问...")
|
||||
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{test_doi}",
|
||||
params={"view": "STANDARD"} # 改为STANDARD视图
|
||||
)
|
||||
if response:
|
||||
print("文章详情访问成功")
|
||||
else:
|
||||
print("文章详情访问失败")
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详细信息
|
||||
|
||||
注意:当前API权限不支持获取详细信息,返回None
|
||||
|
||||
Args:
|
||||
paper_id: 论文ID
|
||||
|
||||
Returns:
|
||||
None,因为当前API权限不支持此功能
|
||||
"""
|
||||
return None
|
||||
|
||||
async def fetch_abstract(self, doi: str) -> Optional[str]:
|
||||
"""获取论文摘要
|
||||
|
||||
使用Scopus Abstract API获取论文摘要
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
|
||||
Returns:
|
||||
摘要文本,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 使用Abstract API而不是Search API
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}",
|
||||
params={
|
||||
"view": "FULL" # 使用FULL视图
|
||||
}
|
||||
)
|
||||
|
||||
if response and "abstracts-retrieval-response" in response:
|
||||
# 从coredata中获取摘要
|
||||
coredata = response["abstracts-retrieval-response"].get("coredata", {})
|
||||
return coredata.get("dc:description", "")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取摘要时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def example_usage():
|
||||
"""ElsevierSource使用示例"""
|
||||
elsevier = ElsevierSource()
|
||||
|
||||
try:
|
||||
# 首先测试API访问权限
|
||||
print("\n=== 测试API访问权限 ===")
|
||||
await elsevier.test_api_access()
|
||||
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await elsevier.search("machine learning", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
print("期刊信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value: # 只打印非空值
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:获取引用信息
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例2:获取引用该论文的文献 ===")
|
||||
citations = await elsevier.get_citations(papers[0].doi, limit=3)
|
||||
for i, paper in enumerate(citations, 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例3:获取参考文献
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例3:获取论文的参考文献 ===")
|
||||
references = await elsevier.get_references(papers[0].doi)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例5:按机构搜索
|
||||
print("\n=== 示例5:按机构搜索 ===")
|
||||
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
|
||||
for i, paper in enumerate(affiliation_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例6:获取论文摘要
|
||||
print("\n=== 示例6:获取论文摘要 ===")
|
||||
test_doi = "10.1016/j.artint.2021.103535"
|
||||
abstract = await elsevier.fetch_abstract(test_doi)
|
||||
if abstract:
|
||||
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
|
||||
else:
|
||||
print("无法获取摘要")
|
||||
|
||||
# 在搜索结果中显示摘要
|
||||
print("\n=== 示例7:搜索结果中的摘要 ===")
|
||||
papers = await elsevier.search("machine learning", limit=1)
|
||||
for paper in papers:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
698
crazy_functions/review_fns/data_sources/github_source.py
Normal file
698
crazy_functions/review_fns/data_sources/github_source.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
|
||||
class GitHubSource:
|
||||
"""GitHub API实现"""
|
||||
|
||||
# 默认API密钥列表 - 可以放置多个GitHub令牌
|
||||
API_KEYS = [
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
|
||||
"""初始化GitHub API客户端
|
||||
|
||||
Args:
|
||||
api_key: GitHub个人访问令牌或令牌列表
|
||||
"""
|
||||
if api_key is None:
|
||||
self.api_keys = self.API_KEYS
|
||||
elif isinstance(api_key, str):
|
||||
self.api_keys = [api_key]
|
||||
else:
|
||||
self.api_keys = api_key
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.github.com"
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
"User-Agent": "GitHub-API-Python-Client"
|
||||
}
|
||||
|
||||
# 如果有可用的API密钥,随机选择一个
|
||||
if self.api_keys:
|
||||
selected_key = random.choice(self.api_keys)
|
||||
self.headers["Authorization"] = f"Bearer {selected_key}"
|
||||
print(f"已随机选择API密钥进行认证")
|
||||
else:
|
||||
print("警告: 未提供API密钥,将受到GitHub API请求限制")
|
||||
|
||||
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
|
||||
"""发送API请求
|
||||
|
||||
Args:
|
||||
method: HTTP方法 (GET, POST, PUT, DELETE等)
|
||||
endpoint: API端点
|
||||
params: URL参数
|
||||
data: 请求体数据
|
||||
|
||||
Returns:
|
||||
解析后的响应JSON
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# 为调试目的打印请求信息
|
||||
print(f"请求: {method} {url}")
|
||||
if params:
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 发送请求
|
||||
request_kwargs = {}
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if data:
|
||||
request_kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **request_kwargs) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
# 检查HTTP状态码
|
||||
if response.status >= 400:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {response_text}")
|
||||
return None
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
print(f"JSON解析错误: {response_text}")
|
||||
return None
|
||||
|
||||
# ===== 用户相关方法 =====
|
||||
|
||||
async def get_user(self, username: Optional[str] = None) -> Dict:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
|
||||
Returns:
|
||||
用户信息字典
|
||||
"""
|
||||
endpoint = "/user" if username is None else f"/users/{username}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
|
||||
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户的仓库列表
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
|
||||
params = {
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_user_starred(self, username: Optional[str] = None,
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户星标的仓库
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
星标仓库列表
|
||||
"""
|
||||
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 仓库相关方法 =====
|
||||
|
||||
async def get_repo(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的分支列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
分支列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/branches"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
|
||||
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的提交历史
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
sha: 特定提交SHA或分支名
|
||||
path: 文件路径筛选
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
提交列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
if sha:
|
||||
params["sha"] = sha
|
||||
if path:
|
||||
params["path"] = path
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
|
||||
"""获取特定提交的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
commit_sha: 提交SHA
|
||||
|
||||
Returns:
|
||||
提交详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 内容相关方法 =====
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
|
||||
"""获取文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 文件路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
文件内容信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
response = await self._request("GET", endpoint, params=params)
|
||||
if response and isinstance(response, dict) and "content" in response:
|
||||
try:
|
||||
# 解码Base64编码的文件内容
|
||||
content = base64.b64decode(response["content"].encode()).decode()
|
||||
response["decoded_content"] = content
|
||||
except Exception as e:
|
||||
print(f"解码文件内容时出错: {str(e)}")
|
||||
|
||||
return response
|
||||
|
||||
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
|
||||
"""获取目录内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 目录路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
# 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== Issues相关方法 =====
|
||||
|
||||
async def get_issues(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Issues列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: Issue状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Issues列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
|
||||
"""获取特定Issue的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
Issue详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
|
||||
"""获取Issue的评论
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
评论列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== Pull Requests相关方法 =====
|
||||
|
||||
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Pull Request列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: PR状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, popularity, long-running)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Pull Request列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
|
||||
"""获取特定Pull Request的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
Pull Request详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
|
||||
"""获取Pull Request中修改的文件
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
修改文件列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 搜索相关方法 =====
|
||||
|
||||
async def search_repositories(self, query: str, sort: str = "stars",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索仓库
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (stars, forks, updated)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/repositories"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_code(self, query: str, sort: str = "indexed",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索代码
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (indexed)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/code"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_issues(self, query: str, sort: str = "created",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索Issues和Pull Requests
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/issues"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_users(self, query: str, sort: str = "followers",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索用户
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (followers, repositories, joined)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/users"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 组织相关方法 =====
|
||||
|
||||
async def get_organization(self, org: str) -> Dict:
|
||||
"""获取组织信息
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
|
||||
Returns:
|
||||
组织信息
|
||||
"""
|
||||
endpoint = f"/orgs/{org}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_organization_repos(self, org: str, type: str = "all",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织的仓库列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
type: 仓库类型 (all, public, private, forks, sources, member, internal)
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/repos"
|
||||
params = {
|
||||
"type": type,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织成员列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
成员列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/members"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 更复杂的操作 =====
|
||||
|
||||
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库使用的编程语言及其比例
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
语言使用情况
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/languages"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的贡献者统计
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
贡献者统计信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的提交活动
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
提交活动统计
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def example_usage():
|
||||
"""GitHubSource使用示例"""
|
||||
# 创建客户端实例(可选传入API令牌)
|
||||
# github = GitHubSource(api_key="your_github_token")
|
||||
github = GitHubSource()
|
||||
|
||||
try:
|
||||
# 示例1:搜索热门Python仓库
|
||||
print("\n=== 示例1:搜索热门Python仓库 ===")
|
||||
repos = await github.search_repositories(
|
||||
query="language:python stars:>1000",
|
||||
sort="stars",
|
||||
order="desc",
|
||||
per_page=5
|
||||
)
|
||||
|
||||
if repos and "items" in repos:
|
||||
for i, repo in enumerate(repos["items"], 1):
|
||||
print(f"\n--- 仓库 {i} ---")
|
||||
print(f"名称: {repo['full_name']}")
|
||||
print(f"描述: {repo['description']}")
|
||||
print(f"星标数: {repo['stargazers_count']}")
|
||||
print(f"Fork数: {repo['forks_count']}")
|
||||
print(f"最近更新: {repo['updated_at']}")
|
||||
print(f"URL: {repo['html_url']}")
|
||||
|
||||
# 示例2:获取特定仓库的详情
|
||||
print("\n=== 示例2:获取特定仓库的详情 ===")
|
||||
repo_details = await github.get_repo("microsoft", "vscode")
|
||||
if repo_details:
|
||||
print(f"名称: {repo_details['full_name']}")
|
||||
print(f"描述: {repo_details['description']}")
|
||||
print(f"星标数: {repo_details['stargazers_count']}")
|
||||
print(f"Fork数: {repo_details['forks_count']}")
|
||||
print(f"默认分支: {repo_details['default_branch']}")
|
||||
print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}")
|
||||
print(f"语言: {repo_details['language']}")
|
||||
print(f"Open Issues数: {repo_details['open_issues_count']}")
|
||||
|
||||
# 示例3:获取仓库的提交历史
|
||||
print("\n=== 示例3:获取仓库的最近提交 ===")
|
||||
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
|
||||
if commits:
|
||||
for i, commit in enumerate(commits, 1):
|
||||
print(f"\n--- 提交 {i} ---")
|
||||
print(f"SHA: {commit['sha'][:7]}")
|
||||
print(f"作者: {commit['commit']['author']['name']}")
|
||||
print(f"日期: {commit['commit']['author']['date']}")
|
||||
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
|
||||
|
||||
# 示例4:搜索代码
|
||||
print("\n=== 示例4:搜索代码 ===")
|
||||
code_results = await github.search_code(
|
||||
query="filename:README.md language:markdown pytorch in:file",
|
||||
per_page=3
|
||||
)
|
||||
if code_results and "items" in code_results:
|
||||
print(f"共找到: {code_results['total_count']} 个结果")
|
||||
for i, item in enumerate(code_results["items"], 1):
|
||||
print(f"\n--- 代码 {i} ---")
|
||||
print(f"仓库: {item['repository']['full_name']}")
|
||||
print(f"文件: {item['path']}")
|
||||
print(f"URL: {item['html_url']}")
|
||||
|
||||
# 示例5:获取文件内容
|
||||
print("\n=== 示例5:获取文件内容 ===")
|
||||
file_content = await github.get_file_content("python", "cpython", "README.rst")
|
||||
if file_content and "decoded_content" in file_content:
|
||||
content = file_content["decoded_content"]
|
||||
print(f"文件名: {file_content['name']}")
|
||||
print(f"大小: {file_content['size']} 字节")
|
||||
print(f"内容预览: {content[:200]}...")
|
||||
|
||||
# 示例6:获取仓库使用的编程语言
|
||||
print("\n=== 示例6:获取仓库使用的编程语言 ===")
|
||||
languages = await github.get_repository_languages("facebook", "react")
|
||||
if languages:
|
||||
print(f"React仓库使用的编程语言:")
|
||||
for lang, bytes_of_code in languages.items():
|
||||
print(f"- {lang}: {bytes_of_code} 字节")
|
||||
|
||||
# 示例7:获取组织信息
|
||||
print("\n=== 示例7:获取组织信息 ===")
|
||||
org_info = await github.get_organization("google")
|
||||
if org_info:
|
||||
print(f"名称: {org_info['name']}")
|
||||
print(f"描述: {org_info.get('description', '无')}")
|
||||
print(f"位置: {org_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {org_info['public_repos']}")
|
||||
print(f"成员数: {org_info.get('public_members', 0)}")
|
||||
print(f"URL: {org_info['html_url']}")
|
||||
|
||||
# 示例8:获取用户信息
|
||||
print("\n=== 示例8:获取用户信息 ===")
|
||||
user_info = await github.get_user("torvalds")
|
||||
if user_info:
|
||||
print(f"名称: {user_info['name']}")
|
||||
print(f"公司: {user_info.get('company', '无')}")
|
||||
print(f"博客: {user_info.get('blog', '无')}")
|
||||
print(f"位置: {user_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {user_info['public_repos']}")
|
||||
print(f"关注者数: {user_info['followers']}")
|
||||
print(f"URL: {user_info['html_url']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
142
crazy_functions/review_fns/data_sources/journal_metrics.py
Normal file
142
crazy_functions/review_fns/data_sources/journal_metrics.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
class JournalMetrics:
|
||||
"""期刊指标管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.journal_data: Dict = {} # 期刊名称到指标的映射
|
||||
self.issn_map: Dict = {} # ISSN到指标的映射
|
||||
self.name_map: Dict = {} # 标准化名称到指标的映射
|
||||
self._load_journal_data()
|
||||
|
||||
def _normalize_journal_name(self, name: str) -> str:
|
||||
"""标准化期刊名称
|
||||
|
||||
Args:
|
||||
name: 原始期刊名称
|
||||
|
||||
Returns:
|
||||
标准化后的期刊名称
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
|
||||
# 转换为小写
|
||||
name = name.lower()
|
||||
|
||||
# 移除常见的前缀和后缀
|
||||
prefixes = ['the ', 'proceedings of ', 'journal of ']
|
||||
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
|
||||
|
||||
for prefix in prefixes:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix):]
|
||||
|
||||
for suffix in suffixes:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)]
|
||||
|
||||
# 移除特殊字符,保留字母、数字和空格
|
||||
name = ''.join(c for c in name if c.isalnum() or c.isspace())
|
||||
|
||||
# 移除多余的空格
|
||||
name = ' '.join(name.split())
|
||||
|
||||
return name
|
||||
|
||||
def _convert_if_value(self, if_str: str) -> Optional[float]:
|
||||
"""转换IF值为float,处理特殊情况"""
|
||||
try:
|
||||
if if_str.startswith('<'):
|
||||
# 对于<0.1这样的值,返回0.1
|
||||
return float(if_str.strip('<'))
|
||||
return float(if_str)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _load_journal_data(self):
|
||||
"""加载期刊数据"""
|
||||
try:
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 建立期刊名称到指标的映射
|
||||
for journal in data:
|
||||
# 准备指标数据
|
||||
metrics = {
|
||||
'if_factor': self._convert_if_value(journal.get('IF')),
|
||||
'jcr_division': journal.get('Q'),
|
||||
'cas_division': journal.get('B')
|
||||
}
|
||||
|
||||
# 存储期刊名称映射(使用标准化名称)
|
||||
if journal.get('journal'):
|
||||
normalized_name = self._normalize_journal_name(journal['journal'])
|
||||
self.journal_data[normalized_name] = metrics
|
||||
self.name_map[normalized_name] = metrics
|
||||
|
||||
# 存储期刊缩写映射
|
||||
if journal.get('jabb'):
|
||||
normalized_abbr = self._normalize_journal_name(journal['jabb'])
|
||||
self.journal_data[normalized_abbr] = metrics
|
||||
self.name_map[normalized_abbr] = metrics
|
||||
|
||||
# 存储ISSN映射
|
||||
if journal.get('issn'):
|
||||
self.issn_map[journal['issn']] = metrics
|
||||
if journal.get('eissn'):
|
||||
self.issn_map[journal['eissn']] = metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载期刊数据时出错: {str(e)}")
|
||||
self.journal_data = {}
|
||||
self.issn_map = {}
|
||||
self.name_map = {}
|
||||
|
||||
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
|
||||
"""获取期刊指标
|
||||
|
||||
Args:
|
||||
venue_name: 期刊名称
|
||||
venue_info: 期刊详细信息
|
||||
|
||||
Returns:
|
||||
包含期刊指标的字典
|
||||
"""
|
||||
try:
|
||||
metrics = {}
|
||||
|
||||
# 1. 首先尝试通过ISSN匹配
|
||||
if venue_info and 'issn' in venue_info:
|
||||
issn_value = venue_info['issn']
|
||||
# 处理ISSN可能是列表的情况
|
||||
if isinstance(issn_value, list):
|
||||
# 尝试每个ISSN
|
||||
for issn in issn_value:
|
||||
metrics = self.issn_map.get(issn, {})
|
||||
if metrics: # 如果找到匹配的指标,就停止搜索
|
||||
break
|
||||
else: # ISSN是字符串的情况
|
||||
metrics = self.issn_map.get(issn_value, {})
|
||||
|
||||
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
|
||||
if not metrics and venue_name:
|
||||
# 标准化期刊名称
|
||||
normalized_name = self._normalize_journal_name(venue_name)
|
||||
metrics = self.name_map.get(normalized_name, {})
|
||||
|
||||
# 如果完全匹配失败,尝试部分匹配
|
||||
# if not metrics:
|
||||
# for db_name, db_metrics in self.name_map.items():
|
||||
# if normalized_name in db_name:
|
||||
# metrics = db_metrics
|
||||
# break
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取期刊指标时出错: {str(e)}")
|
||||
return {}
|
||||
163
crazy_functions/review_fns/data_sources/openalex_source.py
Normal file
163
crazy_functions/review_fns/data_sources/openalex_source.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
class OpenAlexSource(DataSource):
|
||||
"""OpenAlex API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.openalex.org"
|
||||
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"title.search:{query}",
|
||||
"per-page": limit
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
results = data.get("results", [])
|
||||
return [self._parse_work(work) for work in results]
|
||||
except Exception as e:
|
||||
print(f"搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析OpenAlex返回的数据"""
|
||||
# 获取作者信息
|
||||
raw_author_names = [
|
||||
authorship.get("raw_author_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
if authorship
|
||||
]
|
||||
# 处理作者名字格式
|
||||
authors = [
|
||||
self._reformat_name(author)
|
||||
for author in raw_author_names
|
||||
]
|
||||
|
||||
# 获取机构信息
|
||||
institutions = [
|
||||
inst.get("display_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
for inst in authorship.get("institutions", [])
|
||||
if inst
|
||||
]
|
||||
|
||||
# 获取主要发表位置信息
|
||||
primary_location = work.get("primary_location") or {}
|
||||
source = primary_location.get("source") or {}
|
||||
venue = source.get("display_name")
|
||||
|
||||
# 获取发表日期
|
||||
year = work.get("publication_year")
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", ""),
|
||||
authors=authors,
|
||||
institutions=institutions,
|
||||
abstract=work.get("abstract", ""),
|
||||
year=year,
|
||||
doi=work.get("doi"),
|
||||
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
|
||||
citations=work.get("cited_by_count"),
|
||||
venue=venue
|
||||
)
|
||||
|
||||
def _reformat_name(self, name: str) -> str:
|
||||
"""重新格式化作者名字"""
|
||||
if "," not in name:
|
||||
return name
|
||||
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
|
||||
return f"{given_names} {family}"
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return self._parse_work(data)
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"cites:doi:{doi}",
|
||||
"per-page": 100
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def example_usage():
|
||||
"""OpenAlexSource使用示例"""
|
||||
# 初始化OpenAlexSource
|
||||
openalex = OpenAlexSource()
|
||||
|
||||
try:
|
||||
print("正在搜索论文...")
|
||||
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
|
||||
papers = await openalex.search(query="artificial intelligence", limit=5)
|
||||
|
||||
if not papers:
|
||||
print("未获取到任何论文信息")
|
||||
return
|
||||
|
||||
print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# 打印搜索结果
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"发表年份: {paper.year if paper.year else '未知'}")
|
||||
print(f"DOI: {paper.doi if paper.doi else '未知'}")
|
||||
print(f"URL: {paper.url if paper.url else '未知'}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
|
||||
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 如果直接运行此文件,执行示例代码
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
458
crazy_functions/review_fns/data_sources/pubmed_source.py
Normal file
458
crazy_functions/review_fns/data_sources/pubmed_source.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import quote
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class PubMedSource(DataSource):
|
||||
"""PubMed API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: PubMed API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str) -> Optional[str]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.text()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
query = f"{query} AND {start_year}:3000[dp]"
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = (
|
||||
f"{self.base_url}/esearch.fcgi?"
|
||||
f"db=pubmed&term={quote(query)}&retmax={limit}"
|
||||
f"&usehistory=y&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
if sort_by == "date":
|
||||
search_url += "&sort=date"
|
||||
|
||||
# 获取搜索结果
|
||||
response = await self._make_request(search_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
id_list = root.findall(".//Id")
|
||||
pmids = [id_elem.text for id_elem in id_list]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
batch_size = 50
|
||||
for i in range(0, len(pmids), batch_size):
|
||||
batch = pmids[i:i + batch_size]
|
||||
batch_papers = await self._fetch_papers_batch(batch)
|
||||
papers.extend(batch_papers)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
|
||||
"""批量获取论文详情
|
||||
|
||||
Args:
|
||||
pmids: PubMed ID列表
|
||||
|
||||
Returns:
|
||||
论文详情列表
|
||||
"""
|
||||
try:
|
||||
# 构建批量获取URL
|
||||
fetch_url = (
|
||||
f"{self.base_url}/efetch.fcgi?"
|
||||
f"db=pubmed&id={','.join(pmids)}"
|
||||
f"&retmode=xml&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(fetch_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
articles = root.findall(".//PubmedArticle")
|
||||
|
||||
return [self._parse_article(article) for article in articles]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文批次时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_article(self, article: ET.Element) -> PaperMetadata:
|
||||
"""解析PubMed文章XML
|
||||
|
||||
Args:
|
||||
article: XML元素
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
pmid = article.find(".//PMID").text
|
||||
article_meta = article.find(".//Article")
|
||||
|
||||
# 获取标题
|
||||
title = article_meta.find(".//ArticleTitle")
|
||||
title = title.text if title is not None else ""
|
||||
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
author_list = article_meta.findall(".//Author")
|
||||
for author in author_list:
|
||||
last_name = author.find("LastName")
|
||||
fore_name = author.find("ForeName")
|
||||
if last_name is not None and fore_name is not None:
|
||||
authors.append(f"{fore_name.text} {last_name.text}")
|
||||
elif last_name is not None:
|
||||
authors.append(last_name.text)
|
||||
|
||||
# 获取摘要
|
||||
abstract = article_meta.find(".//Abstract/AbstractText")
|
||||
abstract = abstract.text if abstract is not None else ""
|
||||
|
||||
# 获取发表年份
|
||||
pub_date = article_meta.find(".//PubDate/Year")
|
||||
year = int(pub_date.text) if pub_date is not None else None
|
||||
|
||||
# 获取DOI
|
||||
doi = article.find(".//ELocationID[@EIdType='doi']")
|
||||
doi = doi.text if doi is not None else None
|
||||
|
||||
# 获取期刊信息
|
||||
journal = article_meta.find(".//Journal")
|
||||
if journal is not None:
|
||||
journal_title = journal.find(".//Title")
|
||||
venue = journal_title.text if journal_title is not None else None
|
||||
|
||||
# 获取期刊详细信息
|
||||
venue_info = {
|
||||
'issn': journal.findtext(".//ISSN"),
|
||||
'volume': journal.findtext(".//Volume"),
|
||||
'issue': journal.findtext(".//Issue"),
|
||||
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
|
||||
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
|
||||
}
|
||||
else:
|
||||
venue = None
|
||||
venue_info = {}
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
affiliations = article_meta.findall(".//Affiliation")
|
||||
for affiliation in affiliations:
|
||||
if affiliation is not None and affiliation.text:
|
||||
institutions.append(affiliation.text)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
|
||||
citations=None, # PubMed API不直接提供引用数据
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info,
|
||||
source='pubmed' # 添加来源标记
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定PMID的论文详情"""
|
||||
papers = await self._fetch_papers_batch([pmid])
|
||||
return papers[0] if papers else None
|
||||
|
||||
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文
|
||||
|
||||
使用PubMed的相关文章功能
|
||||
|
||||
Args:
|
||||
pmid: PubMed ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
相关论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建相关文章URL
|
||||
link_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(link_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
related_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in related_ids][:limit]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取相关论文详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取相关论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"{author}[Author]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"{journal}[Journal]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文
|
||||
|
||||
Args:
|
||||
days: 最近几天的论文
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
最新论文列表
|
||||
"""
|
||||
query = f"last {days} days[dp]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
注意:PubMed API本身不提供引用数据,此方法将返回空列表
|
||||
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
空列表,因为PubMed不提供引用数据
|
||||
"""
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
从PubMed文章的参考文献列表获取引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
引用的文献列表
|
||||
"""
|
||||
try:
|
||||
# 构建获取参考文献的URL
|
||||
refs_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
|
||||
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
|
||||
f"&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(refs_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
ref_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in ref_ids]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取参考文献详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""PubMedSource使用示例"""
|
||||
pubmed = PubMedSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索COVID-19相关论文 ===")
|
||||
papers = await pubmed.search("COVID-19", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
|
||||
# 示例2:获取论文详情
|
||||
if papers:
|
||||
print("\n=== 示例2:获取论文详情 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
paper = await pubmed.get_paper_details(paper_id)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"期刊: {paper.venue}")
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
|
||||
# 示例3:获取相关论文
|
||||
if papers:
|
||||
print("\n=== 示例3:获取相关论文 ===")
|
||||
related = await pubmed.get_related_papers(paper_id, limit=3)
|
||||
for i, paper in enumerate(related, 1):
|
||||
print(f"\n--- 相关论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例5:按期刊搜索
|
||||
print("\n=== 示例5:按期刊搜索 ===")
|
||||
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
|
||||
for i, paper in enumerate(journal_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例6:获取最新论文 ===")
|
||||
latest = await pubmed.get_latest_papers(days=7, limit=3)
|
||||
for i, paper in enumerate(latest, 1):
|
||||
print(f"\n--- 最新论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表日期: {paper.venue_info.get('pub_date')}")
|
||||
|
||||
# 示例7:获取论文的参考文献
|
||||
if papers:
|
||||
print("\n=== 示例7:获取论文的参考文献 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
references = await pubmed.get_references(paper_id)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例8:尝试获取引用信息(将返回空列表)
|
||||
if papers:
|
||||
print("\n=== 示例8:获取论文的引用信息 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
citations = await pubmed.get_citations(paper_id)
|
||||
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
326
crazy_functions/review_fns/data_sources/scihub_source.py
Normal file
326
crazy_functions/review_fns/data_sources/scihub_source.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
from loguru import logger
|
||||
import PyPDF2
|
||||
import io
|
||||
|
||||
|
||||
class SciHub:
|
||||
# 更新的镜像列表,包含更多可用的镜像
|
||||
MIRRORS = [
|
||||
'https://sci-hub.se/',
|
||||
'https://sci-hub.st/',
|
||||
'https://sci-hub.ru/',
|
||||
'https://sci-hub.wf/',
|
||||
'https://sci-hub.ee/',
|
||||
'https://sci-hub.ren/',
|
||||
'https://sci-hub.tf/',
|
||||
'https://sci-hub.si/',
|
||||
'https://sci-hub.do/',
|
||||
'https://sci-hub.hkvisa.net/',
|
||||
'https://sci-hub.mksa.top/',
|
||||
'https://sci-hub.shop/',
|
||||
'https://sci-hub.yncjkj.com/',
|
||||
'https://sci-hub.41610.org/',
|
||||
'https://sci-hub.automic.us/',
|
||||
'https://sci-hub.et-fine.com/',
|
||||
'https://sci-hub.pooh.mu/',
|
||||
'https://sci-hub.bban.top/',
|
||||
'https://sci-hub.usualwant.com/',
|
||||
'https://sci-hub.unblockit.kim/'
|
||||
]
|
||||
|
||||
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
|
||||
self.timeout = timeout
|
||||
self.path = path
|
||||
self.doi = str(doi)
|
||||
self.use_proxy = use_proxy
|
||||
self.headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
}
|
||||
self.payload = {
|
||||
'sci-hub-plugin-check': '',
|
||||
'request': self.doi
|
||||
}
|
||||
self.url = url if url else self.MIRRORS[0]
|
||||
self.proxies = {
|
||||
"http": "socks5h://localhost:10880",
|
||||
"https": "socks5h://localhost:10880",
|
||||
} if use_proxy else None
|
||||
|
||||
def _test_proxy_connection(self):
|
||||
"""测试代理连接是否可用"""
|
||||
if not self.use_proxy:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 测试代理连接
|
||||
test_response = requests.get(
|
||||
'https://httpbin.org/ip',
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
logger.info("代理连接测试成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"代理连接测试失败: {str(e)}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def _check_pdf_validity(self, content):
|
||||
"""检查PDF文件是否有效"""
|
||||
try:
|
||||
# 使用PyPDF2检查PDF是否可以正常打开和读取
|
||||
pdf = PyPDF2.PdfReader(io.BytesIO(content))
|
||||
if len(pdf.pages) > 0:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"PDF文件无效: {str(e)}")
|
||||
return False
|
||||
|
||||
def _send_request(self):
|
||||
"""发送请求到Sci-Hub镜像站点"""
|
||||
# 首先测试代理连接
|
||||
if self.use_proxy and not self._test_proxy_connection():
|
||||
logger.warning("代理连接不可用,切换到直连模式")
|
||||
self.use_proxy = False
|
||||
self.proxies = None
|
||||
|
||||
last_exception = None
|
||||
working_mirrors = []
|
||||
|
||||
# 先测试哪些镜像可用
|
||||
logger.info("正在测试镜像站点可用性...")
|
||||
for mirror in self.MIRRORS:
|
||||
try:
|
||||
test_response = requests.get(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
working_mirrors.append(mirror)
|
||||
logger.info(f"镜像 {mirror} 可用")
|
||||
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
|
||||
continue
|
||||
|
||||
if not working_mirrors:
|
||||
raise Exception("没有找到可用的镜像站点")
|
||||
|
||||
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
|
||||
|
||||
# 使用可用的镜像进行下载
|
||||
for mirror in working_mirrors:
|
||||
try:
|
||||
res = requests.post(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
data=self.payload,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout
|
||||
)
|
||||
if res.ok:
|
||||
logger.info(f"成功使用镜像站点: {mirror}")
|
||||
self.url = mirror # 更新当前使用的镜像
|
||||
time.sleep(1) # 降低等待时间以提高效率
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise Exception("所有可用镜像站点均无法完成下载")
|
||||
|
||||
def _extract_url(self, response):
|
||||
"""从响应中提取PDF下载链接"""
|
||||
soup = BeautifulSoup(response.content, 'html.parser')
|
||||
try:
|
||||
# 尝试多种方式提取PDF链接
|
||||
pdf_element = soup.find(id='pdf')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 尝试其他可能的选择器
|
||||
pdf_element = soup.find('iframe')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 查找直接的PDF链接
|
||||
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
|
||||
if pdf_links:
|
||||
content_url = pdf_links[0].get('href')
|
||||
else:
|
||||
raise AttributeError("未找到PDF链接")
|
||||
|
||||
if content_url:
|
||||
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
|
||||
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
|
||||
raise AttributeError("找到的链接不是PDF文件")
|
||||
except AttributeError:
|
||||
logger.error(f"未找到论文 {self.doi}")
|
||||
return None
|
||||
|
||||
current_mirror = self.url.rstrip('/')
|
||||
if content_url.startswith('/'):
|
||||
return current_mirror + content_url
|
||||
elif content_url.startswith('http'):
|
||||
return content_url
|
||||
else:
|
||||
return 'https:/' + content_url
|
||||
|
||||
def _download_pdf(self, pdf_url):
|
||||
"""下载PDF文件并验证其完整性"""
|
||||
try:
|
||||
# 尝试不同的下载方式
|
||||
download_methods = [
|
||||
# 方法1:直接下载
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
|
||||
# 方法2:添加 Referer 头
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': self.url}),
|
||||
# 方法3:使用原始域名作为 Referer
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
|
||||
]
|
||||
|
||||
for i, download_method in enumerate(download_methods):
|
||||
try:
|
||||
logger.info(f"尝试下载方式 {i+1}/3...")
|
||||
response = download_method()
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
|
||||
logger.info(f"PDF下载成功,文件大小: {len(content)} bytes")
|
||||
return content
|
||||
else:
|
||||
logger.warning("下载的文件可能不是有效的PDF")
|
||||
elif response.status_code == 403:
|
||||
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"下载失败,状态码: {response.status_code}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 如果所有方法都失败,尝试构造替代URL
|
||||
try:
|
||||
logger.info("尝试使用替代镜像下载...")
|
||||
# 从原始URL提取关键信息
|
||||
if '/downloads/' in pdf_url:
|
||||
file_part = pdf_url.split('/downloads/')[-1]
|
||||
alternative_mirrors = [
|
||||
f"https://sci-hub.se/downloads/{file_part}",
|
||||
f"https://sci-hub.st/downloads/{file_part}",
|
||||
f"https://sci-hub.ru/downloads/{file_part}",
|
||||
f"https://sci-hub.wf/downloads/{file_part}",
|
||||
f"https://sci-hub.ee/downloads/{file_part}",
|
||||
f"https://sci-hub.ren/downloads/{file_part}",
|
||||
f"https://sci-hub.tf/downloads/{file_part}"
|
||||
]
|
||||
|
||||
for alt_url in alternative_mirrors:
|
||||
try:
|
||||
response = requests.get(
|
||||
alt_url,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content):
|
||||
logger.info(f"使用替代镜像成功下载: {alt_url}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"所有下载方式都失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载PDF文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def fetch(self):
|
||||
"""获取论文PDF,包含重试和验证机制"""
|
||||
for attempt in range(2): # 最多重试3次
|
||||
try:
|
||||
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
|
||||
|
||||
# 获取PDF下载链接
|
||||
response = self._send_request()
|
||||
pdf_url = self._extract_url(response)
|
||||
if pdf_url is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:未找到PDF下载链接")
|
||||
continue
|
||||
|
||||
logger.info(f"找到PDF下载链接: {pdf_url}")
|
||||
|
||||
# 下载并验证PDF
|
||||
pdf_content = self._download_pdf(pdf_url)
|
||||
if pdf_content is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:PDF下载失败")
|
||||
continue
|
||||
|
||||
# 保存PDF文件
|
||||
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
|
||||
pdf_path = self.path.joinpath(pdf_name)
|
||||
pdf_path.write_bytes(pdf_content)
|
||||
|
||||
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第 {attempt + 1} 次尝试失败: {str(e)}")
|
||||
if attempt < 2: # 不是最后一次尝试
|
||||
wait_time = (attempt + 1) * 3 # 递增等待时间
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
|
||||
|
||||
# Usage Example
|
||||
if __name__ == '__main__':
|
||||
# 创建一个用于保存PDF的目录
|
||||
save_path = Path('./downloaded_papers')
|
||||
save_path.mkdir(exist_ok=True)
|
||||
|
||||
# DOI示例
|
||||
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
|
||||
|
||||
try:
|
||||
# 初始化SciHub下载器,先尝试使用代理
|
||||
logger.info("尝试使用代理模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
|
||||
|
||||
# 开始下载
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用代理模式失败: {str(e)}")
|
||||
try:
|
||||
# 如果代理模式失败,尝试直连模式
|
||||
logger.info("尝试直连模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
except Exception as e2:
|
||||
print(f"直连模式也失败: {str(e2)}")
|
||||
print("建议检查网络连接或尝试其他DOI")
|
||||
400
crazy_functions/review_fns/data_sources/scopus_source.py
Normal file
400
crazy_functions/review_fns/data_sources/scopus_source.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import random
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
from tqdm import tqdm
|
||||
|
||||
class ScopusSource(DataSource):
|
||||
"""Scopus API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Scopus API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
响应JSON数据
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Scopus API返回的数据
|
||||
|
||||
Args:
|
||||
data: Scopus API返回的论文数据
|
||||
|
||||
Returns:
|
||||
解析后的论文元数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
title = data.get("dc:title", "")
|
||||
|
||||
# 提取作者信息
|
||||
authors = []
|
||||
if "author" in data:
|
||||
if isinstance(data["author"], list):
|
||||
for author in data["author"]:
|
||||
if "given-name" in author and "surname" in author:
|
||||
authors.append(f"{author['given-name']} {author['surname']}")
|
||||
elif "indexed-name" in author:
|
||||
authors.append(author["indexed-name"])
|
||||
elif isinstance(data["author"], dict):
|
||||
if "given-name" in data["author"] and "surname" in data["author"]:
|
||||
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
|
||||
elif "indexed-name" in data["author"]:
|
||||
authors.append(data["author"]["indexed-name"])
|
||||
|
||||
# 提取摘要
|
||||
abstract = data.get("dc:description", "")
|
||||
|
||||
# 提取年份
|
||||
year = None
|
||||
if "prism:coverDate" in data:
|
||||
try:
|
||||
year = int(data["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 提取DOI
|
||||
doi = data.get("prism:doi")
|
||||
|
||||
# 提取引用次数
|
||||
citations = data.get("citedby-count")
|
||||
if citations:
|
||||
try:
|
||||
citations = int(citations)
|
||||
except:
|
||||
citations = None
|
||||
|
||||
# 提取期刊信息
|
||||
venue = data.get("prism:publicationName")
|
||||
|
||||
# 提取机构信息
|
||||
institutions = []
|
||||
if "affiliation" in data:
|
||||
if isinstance(data["affiliation"], list):
|
||||
for aff in data["affiliation"]:
|
||||
if "affilname" in aff:
|
||||
institutions.append(aff["affilname"])
|
||||
elif isinstance(data["affiliation"], dict):
|
||||
if "affilname" in data["affiliation"]:
|
||||
institutions.append(data["affiliation"]["affilname"])
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
"issn": data.get("prism:issn"),
|
||||
"eissn": data.get("prism:eIssn"),
|
||||
"volume": data.get("prism:volume"),
|
||||
"issue": data.get("prism:issueIdentifier"),
|
||||
"page_range": data.get("prism:pageRange"),
|
||||
"article_number": data.get("article-number"),
|
||||
"publication_date": data.get("prism:coverDate")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=data.get("link", [{}])[0].get("@href"),
|
||||
citations=citations,
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析论文数据时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询参数
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100), # Scopus API单次请求限制
|
||||
"start": 0
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
sort_map = {
|
||||
"relevance": "-score",
|
||||
"date": "-coverDate",
|
||||
"citations": "-citedby-count"
|
||||
}
|
||||
if sort_by in sort_map:
|
||||
params["sort"] = sort_map[sort_by]
|
||||
|
||||
# 发送请求
|
||||
url = f"{self.base_url}/search/scopus"
|
||||
response = await self._make_request(url, params)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
results = response["search-results"].get("entry", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID或DOI
|
||||
|
||||
Returns:
|
||||
论文详情
|
||||
"""
|
||||
try:
|
||||
# 判断是否为DOI
|
||||
if "/" in paper_id:
|
||||
url = f"{self.base_url}/article/doi/{paper_id}"
|
||||
else:
|
||||
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
|
||||
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "abstracts-retrieval-response" not in response:
|
||||
return None
|
||||
|
||||
data = response["abstracts-retrieval-response"]
|
||||
return self._parse_paper_data(data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
引用论文列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/citations/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "citing-papers" not in response:
|
||||
return []
|
||||
|
||||
results = response["citing-papers"].get("papers", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用信息时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
参考文献列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/references/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
results = response["references"].get("reference", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"SRCTITLE({journal})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"LOAD-DATE > NOW() - {days}d"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def example_usage():
|
||||
"""ScopusSource使用示例"""
|
||||
scopus = ScopusSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await scopus.search("machine learning", limit=3)
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例2:按作者搜索
|
||||
print("\n=== 示例2:搜索特定作者的论文 ===")
|
||||
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
|
||||
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例3:根据关键词搜索相关论文
|
||||
print("\n=== 示例3:搜索人工智能相关论文 ===")
|
||||
keywords = "artificial intelligence AND deep learning"
|
||||
papers = await scopus.search(
|
||||
query=keywords,
|
||||
limit=5,
|
||||
sort_by="citations", # 按引用次数排序
|
||||
start_year=2020 # 只搜索2020年之后的论文
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
480
crazy_functions/review_fns/data_sources/semantic_source.py
Normal file
480
crazy_functions/review_fns/data_sources/semantic_source.py
Normal file
@@ -0,0 +1,480 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class SemanticScholarSource(DataSource):
|
||||
"""Semantic Scholar API实现,使用官方Python包"""
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Semantic Scholar API密钥(可选)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self._initialize() # 调用初始化方法
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化API客户端"""
|
||||
if not self.api_key:
|
||||
# 默认API密钥列表
|
||||
default_api_keys = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
self.api_key = random.choice(default_api_keys)
|
||||
|
||||
self.client = None # 延迟初始化
|
||||
self.fields = [
|
||||
"title",
|
||||
"authors",
|
||||
"abstract",
|
||||
"year",
|
||||
"externalIds",
|
||||
"citationCount",
|
||||
"venue",
|
||||
"openAccessPdf",
|
||||
"publicationVenue"
|
||||
]
|
||||
|
||||
async def _ensure_client(self):
|
||||
"""确保客户端已初始化"""
|
||||
if self.client is None:
|
||||
from semanticscholar import AsyncSemanticScholar
|
||||
self.client = AsyncSemanticScholar(api_key=self.api_key)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} year>={start_year}"
|
||||
|
||||
# 直接使用 search_paper 的结果
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
|
||||
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定DOI的论文详情"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
|
||||
return self._parse_paper_data(paper)
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
citations = response.get('data', [])
|
||||
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
|
||||
except Exception as e:
|
||||
print(f"获取引用列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
references = response.get('data', [])
|
||||
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
|
||||
except Exception as e:
|
||||
print(f"获取参考文献列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取论文推荐
|
||||
|
||||
根据一篇论文获取相关的推荐论文
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
papers = await self.client.get_recommended_papers(
|
||||
f"DOI:{doi}",
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers_from_lists(
|
||||
self,
|
||||
positive_dois: List[str],
|
||||
negative_dois: List[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""基于正负例论文列表获取推荐
|
||||
|
||||
Args:
|
||||
positive_dois: 正例论文DOI列表(想要获取类似的论文)
|
||||
negative_dois: 负例论文DOI列表(不想要类似的论文)
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
|
||||
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
|
||||
|
||||
papers = await self.client.get_recommended_papers_from_lists(
|
||||
positive_paper_ids=positive_ids,
|
||||
negative_paper_ids=negative_ids,
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
|
||||
"""搜索作者"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求而不是 search_author 方法
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
|
||||
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
authors = response.get('data', [])
|
||||
return [
|
||||
{
|
||||
'author_id': author.get('authorId'),
|
||||
'name': author.get('name'),
|
||||
'paper_count': author.get('paperCount'),
|
||||
'citation_count': author.get('citationCount'),
|
||||
}
|
||||
for author in authors
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"搜索作者时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_author_details(self, author_id: str) -> Optional[dict]:
|
||||
"""获取作者详细信息"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
|
||||
"fields=name,paperCount,citationCount,hIndex",
|
||||
self.client.auth_header
|
||||
)
|
||||
return {
|
||||
'author_id': response.get('authorId'),
|
||||
'name': response.get('name'),
|
||||
'paper_count': response.get('paperCount'),
|
||||
'citation_count': response.get('citationCount'),
|
||||
'h_index': response.get('hIndex'),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取作者详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取作者的论文列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
|
||||
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取作者论文列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_autocomplete(self, query: str) -> List[dict]:
|
||||
"""论文标题自动补全"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
|
||||
f"query={query}",
|
||||
self.client.auth_header
|
||||
)
|
||||
suggestions = response.get('matches', [])
|
||||
return [
|
||||
{
|
||||
'title': suggestion.get('title'),
|
||||
'paper_id': suggestion.get('paperId'),
|
||||
'year': suggestion.get('year'),
|
||||
'venue': suggestion.get('venue'),
|
||||
}
|
||||
for suggestion in suggestions
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"获取标题自动补全时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_paper_data(self, paper) -> PaperMetadata:
|
||||
"""解析论文数据"""
|
||||
# 获取DOI
|
||||
doi = None
|
||||
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
|
||||
if external_ids:
|
||||
if isinstance(external_ids, dict):
|
||||
doi = external_ids.get('DOI')
|
||||
if not doi and 'ArXiv' in external_ids:
|
||||
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
|
||||
else:
|
||||
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
|
||||
if not doi and hasattr(external_ids, 'ArXiv'):
|
||||
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
|
||||
|
||||
# 获取PDF URL
|
||||
pdf_url = None
|
||||
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
|
||||
if pdf_info:
|
||||
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
|
||||
|
||||
# 获取发表场所详细信息
|
||||
venue_type = None
|
||||
venue_name = None
|
||||
venue_info = {}
|
||||
|
||||
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
|
||||
if venue:
|
||||
if isinstance(venue, dict):
|
||||
venue_name = venue.get('name')
|
||||
venue_type = venue.get('type')
|
||||
# 提取更多venue信息
|
||||
venue_info = {
|
||||
'issn': venue.get('issn'),
|
||||
'publisher': venue.get('publisher'),
|
||||
'url': venue.get('url'),
|
||||
'alternate_names': venue.get('alternate_names', [])
|
||||
}
|
||||
else:
|
||||
venue_name = venue.name if hasattr(venue, 'name') else None
|
||||
venue_type = venue.type if hasattr(venue, 'type') else None
|
||||
venue_info = {
|
||||
'issn': getattr(venue, 'issn', None),
|
||||
'publisher': getattr(venue, 'publisher', None),
|
||||
'url': getattr(venue, 'url', None),
|
||||
'alternate_names': getattr(venue, 'alternate_names', [])
|
||||
}
|
||||
|
||||
# 获取标题
|
||||
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
|
||||
|
||||
# 获取作者
|
||||
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
|
||||
author_names = []
|
||||
for author in authors:
|
||||
if isinstance(author, dict):
|
||||
author_names.append(author.get('name', ''))
|
||||
else:
|
||||
author_names.append(author.name if hasattr(author, 'name') else str(author))
|
||||
|
||||
# 获取摘要
|
||||
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
|
||||
|
||||
# 获取年份
|
||||
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
|
||||
|
||||
# 获取引用次数
|
||||
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=author_names,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
|
||||
citations=citations,
|
||||
venue=venue_name,
|
||||
institutions=[],
|
||||
venue_type=venue_type,
|
||||
venue_name=venue_name,
|
||||
venue_info=venue_info,
|
||||
source='semantic' # 添加来源标记
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""SemanticScholarSource使用示例"""
|
||||
semantic = SemanticScholarSource()
|
||||
|
||||
try:
|
||||
# 示例1:使用DOI直接获取论文
|
||||
print("\n=== 示例1:通过DOI获取论文 ===")
|
||||
doi = "10.18653/v1/N19-1423" # BERT论文
|
||||
print(f"获取DOI为 {doi} 的论文信息...")
|
||||
|
||||
paper = await semantic.get_paper_details(doi)
|
||||
if paper:
|
||||
print("\n--- 论文信息 ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\n引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
|
||||
# 示例2:搜索论文
|
||||
print("\n=== 示例2:搜索论文 ===")
|
||||
query = "BERT pre-training"
|
||||
print(f"搜索关键词 '{query}' 相关的论文...")
|
||||
papers = await semantic.search(query=query, limit=3)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 搜索结果 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\nDOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:获取论文推荐
|
||||
print("\n=== 示例3:获取论文推荐 ===")
|
||||
print(f"获取与论文 {doi} 相关的推荐论文...")
|
||||
recommendations = await semantic.get_recommended_papers(doi, limit=3)
|
||||
for i, paper in enumerate(recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例4:基于多篇论文的推荐
|
||||
print("\n=== 示例4:基于多篇论文的推荐 ===")
|
||||
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
|
||||
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
|
||||
multi_recommendations = await semantic.get_recommended_papers_from_lists(
|
||||
positive_dois=positive_dois,
|
||||
limit=3
|
||||
)
|
||||
for i, paper in enumerate(multi_recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例5:搜索作者
|
||||
print("\n=== 示例5:搜索作者 ===")
|
||||
author_query = "Yann LeCun"
|
||||
print(f"搜索作者: '{author_query}'")
|
||||
authors = await semantic.search_author(author_query, limit=3)
|
||||
for i, author in enumerate(authors, 1):
|
||||
print(f"\n--- 作者 {i} ---")
|
||||
print(f"姓名: {author['name']}")
|
||||
print(f"论文数量: {author['paper_count']}")
|
||||
print(f"总引用次数: {author['citation_count']}")
|
||||
|
||||
# 示例6:获取作者详情
|
||||
print("\n=== 示例6:获取作者详情 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者ID {author_id} 的详细信息...")
|
||||
author_details = await semantic.get_author_details(author_id)
|
||||
if author_details:
|
||||
print(f"姓名: {author_details['name']}")
|
||||
print(f"H指数: {author_details['h_index']}")
|
||||
print(f"总引用次数: {author_details['citation_count']}")
|
||||
print(f"发表论文数: {author_details['paper_count']}")
|
||||
|
||||
# 示例7:获取作者论文
|
||||
print("\n=== 示例7:获取作者论文 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者 {authors[0]['name']} 的论文列表...")
|
||||
author_papers = await semantic.get_author_papers(author_id, limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例8:论文标题自动补全
|
||||
print("\n=== 示例8:论文标题自动补全 ===")
|
||||
title_query = "Attention is all"
|
||||
print(f"搜索标题: '{title_query}'")
|
||||
suggestions = await semantic.get_paper_autocomplete(title_query)
|
||||
for i, suggestion in enumerate(suggestions[:3], 1):
|
||||
print(f"\n--- 建议 {i} ---")
|
||||
print(f"标题: {suggestion['title']}")
|
||||
print(f"发表年份: {suggestion['year']}")
|
||||
print(f"发表venue: {suggestion['venue']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
46
crazy_functions/review_fns/data_sources/unpaywall_source.py
Normal file
46
crazy_functions/review_fns/data_sources/unpaywall_source.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
|
||||
class UnpaywallSource(DataSource):
|
||||
"""Unpaywall API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.unpaywall.org/v2"
|
||||
self.email = self.api_key # Unpaywall使用email作为API key
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/search",
|
||||
params={
|
||||
"query": query,
|
||||
"email": self.email,
|
||||
"limit": limit
|
||||
}
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_response(item.response)
|
||||
for item in data.get("results", [])]
|
||||
|
||||
def _parse_response(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Unpaywall返回的数据"""
|
||||
return PaperMetadata(
|
||||
title=data.get("title", ""),
|
||||
authors=[
|
||||
f"{author.get('given', '')} {author.get('family', '')}"
|
||||
for author in data.get("z_authors", [])
|
||||
],
|
||||
institutions=[
|
||||
aff.get("name", "")
|
||||
for author in data.get("z_authors", [])
|
||||
for aff in author.get("affiliation", [])
|
||||
],
|
||||
abstract="", # Unpaywall不提供摘要
|
||||
year=data.get("year"),
|
||||
doi=data.get("doi"),
|
||||
url=data.get("doi_url"),
|
||||
citations=None, # Unpaywall不提供引用计数
|
||||
venue=data.get("journal_name")
|
||||
)
|
||||
Reference in New Issue
Block a user