Files
gpt_academic/crazy_functions/review_fns/data_sources/crossref_source.py
binary-husky 8042750d41 Master 4.0 (#2210)
* stage academic conversation

* stage document conversation

* fix buggy gradio version

* file dynamic load

* merge more academic plugins

* accelerate nltk

* feat: 为predict函数添加文件和URL读取功能
- 添加URL检测和网页内容提取功能,支持自动提取网页文本
- 添加文件路径识别和文件内容读取功能,支持private_upload路径格式
- 集成WebTextExtractor处理网页内容提取
- 集成TextContentLoader处理本地文件读取
- 支持文件路径与问题组合的智能处理

* back

* block unstable

---------

Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
2025-08-23 15:59:22 +08:00

400 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class CrossrefSource(DataSource):
"""Crossref API实现"""
CONTACT_EMAILS = [
"gpt_abc_academic@163.com",
"gpt_abc_newapi@163.com",
"gpt_abc_academic_pwd@163.com"
]
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.crossref.org"
# 随机选择一个邮箱
contact_email = random.choice(self.CONTACT_EMAILS)
self.headers = {
"Accept": "application/json",
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
}
if self.api_key:
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序字段
sort_order: 排序顺序
start_year: 起始年份
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
# 请求更多的结果以补偿可能被过滤掉的文章
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
params = {
"query": query,
"rows": adjusted_limit,
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
# 添加年份过滤
if start_year:
params["filter"] = f"from-pub-date:{start_year}"
# 添加排序
if sort_by:
params["sort"] = sort_by
if sort_order:
params["order"] = sort_order
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
data = await response.json()
items = data.get("message", {}).get("items", [])
if not items:
print(f"未找到相关论文")
return []
# 过滤掉没有摘要的文章
papers = []
filtered_count = 0
for work in items:
paper = self._parse_work(work)
if paper.abstract and paper.abstract.strip():
papers.append(paper)
if len(papers) >= limit: # 达到原始请求的限制后停止
break
else:
filtered_count += 1
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
print(f"返回 {len(papers)} 篇包含摘要的论文")
return papers
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
) as response:
if response.status != 200:
print(f"获取论文详情失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return None
try:
data = await response.json()
return self._parse_work(data.get("message", {}))
except Exception as e:
print(f"解析论文详情时发生错误: {str(e)}")
return None
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={"select": "reference"}
) as response:
if response.status != 200:
print(f"获取参考文献失败: HTTP {response.status}")
return []
try:
data = await response.json()
# 确保我们正确处理返回的数据结构
if not isinstance(data, dict):
print(f"API返回了意外的数据格式: {type(data)}")
return []
references = data.get("message", {}).get("reference", [])
if not references:
print(f"未找到参考文献")
return []
return [
PaperMetadata(
title=ref.get("article-title", ""),
authors=[ref.get("author", "")],
year=ref.get("year"),
doi=ref.get("DOI"),
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
abstract="",
citations=None,
venue=ref.get("journal-title", ""),
institutions=[]
)
for ref in references
]
except Exception as e:
print(f"解析参考文献数据时发生错误: {str(e)}")
return []
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works",
params={
"filter": f"reference.DOI:{doi}",
"select": "DOI,title,author,published-print,abstract"
}
) as response:
if response.status != 200:
print(f"获取引用信息失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
try:
data = await response.json()
# 检查返回的数据结构
if isinstance(data, dict):
items = data.get("message", {}).get("items", [])
return [self._parse_work(work) for work in items]
else:
print(f"API返回了意外的数据格式: {type(data)}")
return []
except Exception as e:
print(f"解析引用数据时发生错误: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析Crossref返回的数据"""
# 获取摘要 - 处理可能的不同格式
abstract = ""
if isinstance(work.get("abstract"), str):
abstract = work.get("abstract", "")
elif isinstance(work.get("abstract"), dict):
abstract = work.get("abstract", {}).get("value", "")
if not abstract:
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
# 获取机构信息
institutions = []
for author in work.get("author", []):
if "affiliation" in author:
for affiliation in author["affiliation"]:
if "name" in affiliation and affiliation["name"] not in institutions:
institutions.append(affiliation["name"])
# 获取venue信息
venue_name = work.get("container-title", [None])[0]
venue_type = work.get("type", "unknown") # 文献类型
venue_info = {
"publisher": work.get("publisher"),
"issn": work.get("ISSN", []),
"isbn": work.get("ISBN", []),
"issue": work.get("issue"),
"volume": work.get("volume"),
"page": work.get("page")
}
return PaperMetadata(
title=work.get("title", [None])[0] or "",
authors=[
author.get("given", "") + " " + author.get("family", "")
for author in work.get("author", [])
],
institutions=institutions, # 添加机构信息
abstract=abstract,
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
doi=work.get("DOI"),
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
citations=work.get("is-referenced-by-count"),
venue=venue_name,
venue_type=venue_type, # 添加venue类型
venue_name=venue_name, # 添加venue名称
venue_info=venue_info, # 添加venue详细信息
source='crossref' # 添加来源标记
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " ".join([f"author:\"{author}\"" for author in authors])
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
start_year=start_year
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: str = None,
sort_order: str = None
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def example_usage():
"""CrossrefSource使用示例"""
crossref = CrossrefSource(api_key=None)
try:
# 示例1基本搜索使用不同的排序方式
print("\n=== 示例1搜索最新的机器学习论文 ===")
papers = await crossref.search(
query="machine learning",
limit=3,
sort_by="published",
sort_order="desc",
start_year=2023
)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
print(f"venue类型: {paper.venue_type}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
# 示例2按DOI获取论文详情
print("\n=== 示例2获取特定论文详情 ===")
# 使用BERT论文的DOI
doi = "10.18653/v1/N19-1423"
paper = await crossref.get_paper_details(doi)
if paper:
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations}")
# 示例3按作者搜索
print("\n=== 示例3搜索特定作者的论文 ===")
author_papers = await crossref.search_by_authors(
authors=["Yoshua Bengio"],
limit=3,
sort_by="published",
start_year=2020
)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- {i}. {paper.title} ---")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例4按日期范围搜索
print("\n=== 示例4搜索特定日期范围的论文 ===")
from datetime import datetime, timedelta
end_date = datetime.now()
start_date = end_date - timedelta(days=30) # 最近一个月
recent_papers = await crossref.search_by_date_range(
start_date=start_date,
end_date=end_date,
limit=3,
sort_by="published",
sort_order="desc"
)
for i, paper in enumerate(recent_papers, 1):
print(f"\n--- 最近发表的论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 示例5获取论文引用信息
print("\n=== 示例5获取论文引用信息 ===")
if paper: # 使用之前获取的BERT论文
print("\n获取引用该论文的文献:")
citations = await crossref.get_citations(paper.doi)
for i, citing_paper in enumerate(citations[:3], 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {citing_paper.title}")
print(f"作者: {', '.join(citing_paper.authors)}")
print(f"发表年份: {citing_paper.year}")
print("\n获取该论文引用的参考文献:")
references = await crossref.get_references(paper.doi)
for i, ref_paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {ref_paper.title}")
print(f"作者: {', '.join(ref_paper.authors)}")
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
# 示例6展示venue信息的使用
print("\n=== 示例6展示期刊/会议详细信息 ===")
if papers:
paper = papers[0]
print(f"文献类型: {paper.venue_type}")
print(f"发表venue: {paper.venue_name}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())