Compare commits
14 Commits
2706263a4b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87542a0ed7 | ||
| 8bb8b9a2f4 | |||
|
|
516f5af00a | ||
| 68c0cafffe | |||
|
|
99d80ba61a | ||
| 54d0df9627 | |||
|
|
7d70566402 | ||
|
|
3546f8c1c4 | ||
|
|
d8ad675f13 | ||
|
|
0ab0417954 | ||
|
|
248b0aefae | ||
|
|
661fe63941 | ||
|
|
269804fb82 | ||
|
|
8042750d41 |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
.venv
|
||||
.github
|
||||
.vscode
|
||||
gpt_log
|
||||
tests
|
||||
README.md
|
||||
44
.github/workflows/build-with-chatglm.yml
vendored
44
.github/workflows/build-with-chatglm.yml
vendored
@@ -1,44 +0,0 @@
|
||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
||||
name: build-with-chatglm
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'master'
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}_chatglm_moss
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: docs/GithubAction+ChatGLM+Moss
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -163,3 +163,5 @@ objdump*
|
||||
TODO
|
||||
experimental_mods
|
||||
search_results
|
||||
gg.docx
|
||||
unstructured_reader.py
|
||||
|
||||
@@ -11,7 +11,7 @@ RUN echo '[global]' > /etc/pip.conf && \
|
||||
echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
|
||||
echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
|
||||
|
||||
# 语音输出功能(以下1,2行更换阿里源,第3,4行安装ffmpeg,都可以删除)
|
||||
# 语音输出功能(以下1,2行更换阿里源,第3,4行安装ffmpeg,都可以删除)
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \
|
||||
sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \
|
||||
apt-get update
|
||||
@@ -29,10 +29,12 @@ RUN python -c 'import loguru'
|
||||
|
||||
# 装载项目文件,安装剩余依赖(必要)
|
||||
COPY . .
|
||||
RUN uv venv --python=3.12 && uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
|
||||
RUN uv pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# # 非必要步骤,用于预热模块(可以删除)
|
||||
RUN python -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||
|
||||
ENV CGO_ENABLED=0
|
||||
|
||||
# 启动(必要)
|
||||
CMD ["bash", "-c", "python main.py"]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
> [!IMPORTANT]
|
||||
> `master主分支`最新动态(2025.7.31): 新GUI前端,Coming Soon
|
||||
> `master主分支`最新动态(2025.3.2): 修复大量代码typo / 联网组件支持Jina的api / 增加deepseek-r1支持
|
||||
> `frontier开发分支`最新动态(2024.12.9): 更新对话时间线功能,优化xelatex论文翻译
|
||||
> `wiki文档`最新动态(2024.12.5): 更新ollama接入指南
|
||||
> `master主分支`最新动态(2025.8.23): Dockerfile构建效率大幅优化
|
||||
> `master主分支`最新动态(2025.7.31): 新GUI前端,Coming Soon
|
||||
>
|
||||
> 2025.2.2: 三分钟快速接入最强qwen2.5-max[视频](https://www.bilibili.com/video/BV1LeFuerEG4)
|
||||
> 2025.2.1: 支持自定义字体
|
||||
@@ -65,7 +63,6 @@ Read this in [English](docs/README.English.md) | [日本語](docs/README.Japanes
|
||||
⭐支持mermaid图像渲染 | 支持让GPT生成[流程图](https://www.bilibili.com/video/BV18c41147H9/)、状态转移图、甘特图、饼状图、GitGraph等等(3.7版本)
|
||||
⭐Arxiv论文精细翻译 ([Docker](https://github.com/binary-husky/gpt_academic/pkgs/container/gpt_academic_with_latex)) | [插件] 一键[以超高质量翻译arxiv论文](https://www.bilibili.com/video/BV1dz4y1v77A/),目前最好的论文翻译工具
|
||||
⭐[实时语音对话输入](https://github.com/binary-husky/gpt_academic/blob/master/docs/use_audio.md) | [插件] 异步[监听音频](https://www.bilibili.com/video/BV1AV4y187Uy/),自动断句,自动寻找回答时机
|
||||
⭐AutoGen多智能体插件 | [插件] 借助微软AutoGen,探索多Agent的智能涌现可能!
|
||||
⭐虚空终端插件 | [插件] 能够使用自然语言直接调度本项目其他插件
|
||||
润色、翻译、代码解释 | 一键润色、翻译、查找论文语法错误、解释代码
|
||||
[自定义快捷键](https://www.bilibili.com/video/BV14s4y1E7jN) | 支持自定义快捷键
|
||||
|
||||
@@ -230,6 +230,48 @@ def warm_up_modules():
|
||||
enc.encode("模块预热", disallowed_special=())
|
||||
enc = model_info["gpt-4"]['tokenizer']
|
||||
enc.encode("模块预热", disallowed_special=())
|
||||
try_warm_up_vectordb()
|
||||
|
||||
|
||||
# def try_warm_up_vectordb():
|
||||
# try:
|
||||
# import os
|
||||
# import nltk
|
||||
# target = os.path.expanduser('~/nltk_data')
|
||||
# logger.info(f'模块预热: nltk punkt (从Github下载部分文件到 {target})')
|
||||
# nltk.data.path.append(target)
|
||||
# nltk.download('punkt', download_dir=target)
|
||||
# logger.info('模块预热完成: nltk punkt')
|
||||
# except:
|
||||
# logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
|
||||
# logger.error('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
|
||||
|
||||
|
||||
def try_warm_up_vectordb():
|
||||
import os
|
||||
import nltk
|
||||
target = os.path.expanduser('~/nltk_data')
|
||||
nltk.data.path.append(target)
|
||||
try:
|
||||
# 尝试加载 punkt
|
||||
logger.info(f'nltk模块预热')
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
nltk.data.find('tokenizers/punkt_tab')
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
|
||||
logger.info('nltk模块预热完成(读取本地缓存)')
|
||||
except:
|
||||
# 如果找不到,则尝试下载
|
||||
try:
|
||||
logger.info(f'模块预热: nltk punkt (从 Github 下载部分文件到 {target})')
|
||||
from shared_utils.nltk_downloader import Downloader
|
||||
_downloader = Downloader()
|
||||
_downloader.download('punkt', download_dir=target)
|
||||
_downloader.download('punkt_tab', download_dir=target)
|
||||
_downloader.download('averaged_perceptron_tagger_eng', download_dir=target)
|
||||
logger.info('nltk模块预热完成')
|
||||
except Exception:
|
||||
logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt')
|
||||
|
||||
|
||||
def warm_up_vectordb():
|
||||
"""
|
||||
|
||||
56
config.py
56
config.py
@@ -8,42 +8,33 @@
|
||||
"""
|
||||
|
||||
# [step 1-1]>> ( 接入OpenAI模型家族 ) API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织(格式如org-123456789abcdefghijklmno的),请向下翻,找 API_ORG 设置项
|
||||
API_KEY = "在此处填写APIKEY" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
|
||||
API_KEY = "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" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
|
||||
|
||||
# [step 1-2]>> ( 强烈推荐!接入通义家族 & 大模型服务平台百炼 ) 接入通义千问在线大模型,api-key获取地址 https://dashscope.console.aliyun.com/
|
||||
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY(用于接入qwen-max,dashscope-qwen3-14b,dashscope-deepseek-r1等)
|
||||
|
||||
# [step 1-3]>> ( 接入 deepseek-reasoner, 即 deepseek-r1 ) 深度求索(DeepSeek) API KEY,默认请求地址为"https://api.deepseek.com/v1/chat/completions"
|
||||
DEEPSEEK_API_KEY = ""
|
||||
DEEPSEEK_API_KEY = "sk-d99b8cc6b7414cc88a5d950a3ff7585e"
|
||||
|
||||
# [step 2]>> 改为True应用代理。如果使用本地或无地域限制的大模型时,此处不修改;如果直接在海外服务器部署,此处不修改
|
||||
USE_PROXY = False
|
||||
if USE_PROXY:
|
||||
"""
|
||||
代理网络的地址,打开你的代理软件查看代理协议(socks5h / http)、地址(localhost)和端口(11284)
|
||||
填写格式是 [协议]:// [地址] :[端口],填写之前不要忘记把USE_PROXY改成True,如果直接在海外服务器部署,此处不修改
|
||||
<配置教程&视频教程> https://github.com/binary-husky/gpt_academic/issues/1>
|
||||
[协议] 常见协议无非socks5h/http; 例如 v2**y 和 ss* 的默认本地协议是socks5h; 而cl**h 的默认本地协议是http
|
||||
[地址] 填localhost或者127.0.0.1(localhost意思是代理软件安装在本机上)
|
||||
[端口] 在代理软件的设置里找。虽然不同的代理软件界面不一样,但端口号都应该在最显眼的位置上
|
||||
"""
|
||||
proxies = {
|
||||
# [协议]:// [地址] :[端口]
|
||||
"http": "socks5h://localhost:11284", # 再例如 "http": "http://127.0.0.1:7890",
|
||||
"https": "socks5h://localhost:11284", # 再例如 "https": "http://127.0.0.1:7890",
|
||||
"http":"socks5h://192.168.8.9:1070", # 再例如 "http": "http://127.0.0.1:7890",
|
||||
"https":"socks5h://192.168.8.9:1070", # 再例如 "https": "http://127.0.0.1:7890",
|
||||
}
|
||||
else:
|
||||
proxies = None
|
||||
|
||||
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
|
||||
LLM_MODEL = "gpt-4" # 可选 ↓↓↓
|
||||
AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-12-17", "o1-preview", "o1-preview-2024-09-12",
|
||||
"gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
|
||||
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
||||
"gemini-1.5-pro", "chatglm3", "chatglm4",
|
||||
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
||||
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
||||
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
|
||||
"dashscope-deepseek-r1", "dashscope-deepseek-v3",
|
||||
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
|
||||
@@ -94,19 +85,19 @@ AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gsta
|
||||
|
||||
FONT = "Theme-Default-Font"
|
||||
AVAIL_FONTS = [
|
||||
"默认值(Theme-Default-Font)",
|
||||
"宋体(SimSun)",
|
||||
"黑体(SimHei)",
|
||||
"楷体(KaiTi)",
|
||||
"仿宋(FangSong)",
|
||||
"默认值(Theme-Default-Font)",
|
||||
"宋体(SimSun)",
|
||||
"黑体(SimHei)",
|
||||
"楷体(KaiTi)",
|
||||
"仿宋(FangSong)",
|
||||
"华文细黑(STHeiti Light)",
|
||||
"华文楷体(STKaiti)",
|
||||
"华文仿宋(STFangsong)",
|
||||
"华文宋体(STSong)",
|
||||
"华文中宋(STZhongsong)",
|
||||
"华文新魏(STXinwei)",
|
||||
"华文隶书(STLiti)",
|
||||
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
|
||||
"华文楷体(STKaiti)",
|
||||
"华文仿宋(STFangsong)",
|
||||
"华文宋体(STSong)",
|
||||
"华文中宋(STZhongsong)",
|
||||
"华文新魏(STXinwei)",
|
||||
"华文隶书(STLiti)",
|
||||
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
|
||||
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
|
||||
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
|
||||
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
|
||||
@@ -143,15 +134,14 @@ TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
# 网页的端口, -1代表随机端口
|
||||
WEB_PORT = -1
|
||||
|
||||
WEB_PORT = 19998
|
||||
|
||||
# 是否自动打开浏览器页面
|
||||
AUTO_OPEN_BROWSER = True
|
||||
|
||||
|
||||
# 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
|
||||
MAX_RETRY = 2
|
||||
MAX_RETRY = 3
|
||||
|
||||
|
||||
# 插件分类默认选项
|
||||
@@ -195,7 +185,7 @@ AUTO_CLEAR_TXT = False
|
||||
|
||||
|
||||
# 加一个live2d装饰
|
||||
ADD_WAIFU = False
|
||||
ADD_WAIFU = True
|
||||
|
||||
|
||||
# 设置用户名和密码(不需要修改)(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
|
||||
@@ -355,6 +345,10 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
|
||||
JINA_API_KEY = ""
|
||||
|
||||
|
||||
# SEMANTIC SCHOLAR API KEY
|
||||
SEMANTIC_SCHOLAR_KEY = ""
|
||||
|
||||
|
||||
# 是否自动裁剪上下文长度(是否启动,默认不启动)
|
||||
AUTO_CONTEXT_CLIP_ENABLE = False
|
||||
# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪)
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
Configuration reading priority: environment variable > config_private.py > config.py
|
||||
"""
|
||||
|
||||
# [step 1-1]>> ( 接入GPT等模型 ) API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织(格式如org-123456789abcdefghijklmno的),请向下翻,找 API_ORG 设置项
|
||||
API_KEY = "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" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey1,fkxxxx-api2dkey2"
|
||||
# [step 1-1]>> ( 接入OpenAI模型家族 ) API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织(格式如org-123456789abcdefghijklmno的),请向下翻,找 API_ORG 设置项
|
||||
API_KEY = "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" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
|
||||
|
||||
# [step 1-2]>> ( 接入通义 qwen-max ) 接入通义千问在线大模型,api-key获取地址 https://dashscope.console.aliyun.com/
|
||||
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY
|
||||
# [step 1-2]>> ( 强烈推荐!接入通义家族 & 大模型服务平台百炼 ) 接入通义千问在线大模型,api-key获取地址 https://dashscope.console.aliyun.com/
|
||||
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY(用于接入qwen-max,dashscope-qwen3-14b,dashscope-deepseek-r1等)
|
||||
|
||||
# [step 1-3]>> ( 接入 deepseek-reasoner, 即 deepseek-r1 ) 深度求索(DeepSeek) API KEY,默认请求地址为"https://api.deepseek.com/v1/chat/completions"
|
||||
DEEPSEEK_API_KEY = "sk-d99b8cc6b7414cc88a5d950a3ff7585e"
|
||||
@@ -25,16 +25,19 @@ if USE_PROXY:
|
||||
}
|
||||
else:
|
||||
proxies = None
|
||||
DEFAULT_WORKER_NUM = 256
|
||||
|
||||
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||
LLM_MODEL = "gpt-4-32k" # 可选 ↓↓↓
|
||||
AVAIL_LLM_MODELS = ["deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
||||
LLM_MODEL = "gpt-4" # 可选 ↓↓↓
|
||||
AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-12-17", "o1-preview", "o1-preview-2024-09-12",
|
||||
"gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
|
||||
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
||||
"gemini-1.5-pro", "chatglm3", "chatglm4",
|
||||
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
||||
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
|
||||
"dashscope-deepseek-r1", "dashscope-deepseek-v3",
|
||||
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
|
||||
]
|
||||
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
@@ -72,7 +75,7 @@ API_URL_REDIRECT = {}
|
||||
|
||||
# 多线程函数插件中,默认允许多少路线程同时访问OpenAI。Free trial users的限制是每分钟3次,Pay-as-you-go users的限制是每分钟3500次
|
||||
# 一言以蔽之:免费(5刀)用户填3,OpenAI绑了信用卡的用户可以填 16 或者更高。提高限制请查询:https://platform.openai.com/docs/guides/rate-limits/overview
|
||||
DEFAULT_WORKER_NUM = 64
|
||||
DEFAULT_WORKER_NUM = 8
|
||||
|
||||
|
||||
# 色彩主题, 可选 ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast"]
|
||||
@@ -94,6 +97,7 @@ AVAIL_FONTS = [
|
||||
"华文中宋(STZhongsong)",
|
||||
"华文新魏(STXinwei)",
|
||||
"华文隶书(STLiti)",
|
||||
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
|
||||
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
|
||||
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
|
||||
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
|
||||
@@ -106,7 +110,7 @@ AVAIL_FONTS = [
|
||||
|
||||
|
||||
# 默认的系统提示词(system prompt)
|
||||
INIT_SYS_PROMPT = " "
|
||||
INIT_SYS_PROMPT = "Serve me as a writing and programming assistant."
|
||||
|
||||
|
||||
# 对话窗的高度 (仅在LAYOUT="TOP-DOWN"时生效)
|
||||
@@ -126,7 +130,7 @@ DARK_MODE = True
|
||||
|
||||
|
||||
# 发送请求到OpenAI后,等待多久判定为超时
|
||||
TIMEOUT_SECONDS = 60
|
||||
TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
# 网页的端口, -1代表随机端口
|
||||
@@ -137,7 +141,7 @@ AUTO_OPEN_BROWSER = True
|
||||
|
||||
|
||||
# 如果OpenAI不响应(网络卡顿、代理失败、KEY失效),重试的次数限制
|
||||
MAX_RETRY = 5
|
||||
MAX_RETRY = 3
|
||||
|
||||
|
||||
# 插件分类默认选项
|
||||
@@ -181,12 +185,12 @@ AUTO_CLEAR_TXT = False
|
||||
|
||||
|
||||
# 加一个live2d装饰
|
||||
ADD_WAIFU = False
|
||||
ADD_WAIFU = True
|
||||
|
||||
|
||||
# 设置用户名和密码(不需要修改)(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
|
||||
# [("username", "password"), ("username2", "password2"), ...]
|
||||
AUTHENTICATION = [("van", "L807878712"),("林", "L807878712"),("源", "L807878712"),("欣", "L807878712"),("z", "czh123456789")]
|
||||
AUTHENTICATION = []
|
||||
|
||||
|
||||
# 如果需要在二级路径下运行(常规情况下,不要修改!!)
|
||||
@@ -228,7 +232,7 @@ ALIYUN_SECRET="" # (无需填写)
|
||||
|
||||
|
||||
# GPT-SOVITS 文本转语音服务的运行地址(将语言模型的生成文本朗读出来)
|
||||
TTS_TYPE = "DISABLE" # EDGE_TTS / LOCAL_SOVITS_API / DISABLE
|
||||
TTS_TYPE = "EDGE_TTS" # EDGE_TTS / LOCAL_SOVITS_API / DISABLE
|
||||
GPT_SOVITS_URL = ""
|
||||
EDGE_TTS_VOICE = "zh-CN-XiaoxiaoNeural"
|
||||
|
||||
@@ -256,6 +260,10 @@ MOONSHOT_API_KEY = ""
|
||||
YIMODEL_API_KEY = ""
|
||||
|
||||
|
||||
# 接入火山引擎的在线大模型),api-key获取地址 https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
|
||||
ARK_API_KEY = "00000000-0000-0000-0000-000000000000" # 火山引擎 API KEY
|
||||
|
||||
|
||||
# 紫东太初大模型 https://ai-maas.wair.ac.cn
|
||||
TAICHU_API_KEY = ""
|
||||
|
||||
@@ -333,6 +341,23 @@ NUM_CUSTOM_BASIC_BTN = 4
|
||||
DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in range(1,5) ]
|
||||
|
||||
|
||||
# 在互联网搜索组件中,负责将搜索结果整理成干净的Markdown
|
||||
JINA_API_KEY = ""
|
||||
|
||||
|
||||
# SEMANTIC SCHOLAR API KEY
|
||||
SEMANTIC_SCHOLAR_KEY = ""
|
||||
|
||||
|
||||
# 是否自动裁剪上下文长度(是否启动,默认不启动)
|
||||
AUTO_CONTEXT_CLIP_ENABLE = False
|
||||
# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪)
|
||||
AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN = 30*1000
|
||||
# 无条件丢弃x以上的轮数
|
||||
AUTO_CONTEXT_MAX_ROUND = 64
|
||||
# 在裁剪上下文时,倒数第x次对话能“最多”保留的上下文token的比例占 AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN 的多少
|
||||
AUTO_CONTEXT_MAX_CLIP_RATIO = [0.80, 0.60, 0.45, 0.25, 0.20, 0.18, 0.16, 0.14, 0.12, 0.10, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01]
|
||||
|
||||
|
||||
"""
|
||||
--------------- 配置关联关系说明 ---------------
|
||||
@@ -439,6 +464,3 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from toolbox import trimmed_format_exc
|
||||
from loguru import logger
|
||||
|
||||
def get_crazy_functions():
|
||||
from crazy_functions.读文章写摘要 import 读文章写摘要
|
||||
from crazy_functions.生成函数注释 import 批量生成函数注释
|
||||
from crazy_functions.Paper_Abstract_Writer import Paper_Abstract_Writer
|
||||
from crazy_functions.Program_Comment_Gen import 批量Program_Comment_Gen
|
||||
from crazy_functions.SourceCode_Analyse import 解析项目本身
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Python项目
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Matlab项目
|
||||
@@ -17,26 +17,26 @@ def get_crazy_functions():
|
||||
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
||||
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
||||
from crazy_functions.Latex_Project_Polish import Latex英文润色
|
||||
from crazy_functions.询问多个大语言模型 import 同时问询
|
||||
from crazy_functions.Multi_LLM_Query import 同时问询
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
||||
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
||||
from crazy_functions.总结word文档 import 总结word文档
|
||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||
from crazy_functions.Word_Summary import Word_Summary
|
||||
from crazy_functions.SourceCode_Analyse_JupyterNotebook import 解析ipynb文件
|
||||
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
||||
from crazy_functions.Conversation_To_File import 对话历史存档
|
||||
from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap
|
||||
from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录
|
||||
from crazy_functions.辅助功能 import 清除缓存
|
||||
from crazy_functions.Helpers import 清除缓存
|
||||
from crazy_functions.Markdown_Translate import Markdown英译中
|
||||
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
||||
from crazy_functions.PDF_Summary import PDF_Summary
|
||||
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
||||
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
|
||||
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
|
||||
from crazy_functions.Google_Scholar_Assistant_Legacy import Google_Scholar_Assistant_Legacy
|
||||
from crazy_functions.PDF_QA import PDF_QA标准文件输入
|
||||
from crazy_functions.Latex_Project_Polish import Latex中文润色
|
||||
from crazy_functions.Latex_Project_Polish import Latex英文纠错
|
||||
from crazy_functions.Markdown_Translate import Markdown中译英
|
||||
from crazy_functions.虚空终端 import 虚空终端
|
||||
from crazy_functions.生成多种Mermaid图表 import Mermaid_Gen
|
||||
from crazy_functions.Void_Terminal import Void_Terminal
|
||||
from crazy_functions.Mermaid_Figure_Gen import Mermaid_Gen
|
||||
from crazy_functions.PDF_Translate_Wrap import PDF_Tran
|
||||
from crazy_functions.Latex_Function import Latex英文纠错加PDF对比
|
||||
from crazy_functions.Latex_Function import Latex翻译中文并重新编译PDF
|
||||
@@ -50,6 +50,9 @@ def get_crazy_functions():
|
||||
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
|
||||
from crazy_functions.VideoResource_GPT import 多媒体任务
|
||||
from crazy_functions.Document_Conversation import 批量文件询问
|
||||
from crazy_functions.Document_Conversation_Wrap import Document_Conversation_Wrap
|
||||
|
||||
|
||||
function_plugins = {
|
||||
"多媒体智能体": {
|
||||
@@ -64,7 +67,7 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Info": "使用自然语言实现您的想法",
|
||||
"Function": HotReload(虚空终端),
|
||||
"Function": HotReload(Void_Terminal),
|
||||
},
|
||||
"解析整个Python项目": {
|
||||
"Group": "编程",
|
||||
@@ -122,7 +125,7 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "批量总结word文档 | 输入参数为路径",
|
||||
"Function": HotReload(总结word文档),
|
||||
"Function": HotReload(Word_Summary),
|
||||
},
|
||||
"解析整个Matlab项目": {
|
||||
"Group": "编程",
|
||||
@@ -201,7 +204,7 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "读取Tex论文并写摘要 | 输入参数为路径",
|
||||
"Function": HotReload(读文章写摘要),
|
||||
"Function": HotReload(Paper_Abstract_Writer),
|
||||
},
|
||||
"翻译README或MD": {
|
||||
"Group": "编程",
|
||||
@@ -222,14 +225,14 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "批量生成函数的注释 | 输入参数为路径",
|
||||
"Function": HotReload(批量生成函数注释),
|
||||
"Function": HotReload(批量Program_Comment_Gen),
|
||||
},
|
||||
"保存当前的对话": {
|
||||
"Group": "对话",
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Info": "保存当前的对话 | 不需要输入参数",
|
||||
"Function": HotReload(对话历史存档), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Function": HotReload(对话历史存档), # 当注册Class后,Function旧接口仅会在“Void_Terminal”中起作用
|
||||
"Class": Conversation_To_File_Wrap # 新一代插件需要注册Class
|
||||
},
|
||||
"[多线程Demo]解析此项目本身(源码自译解)": {
|
||||
@@ -255,12 +258,12 @@ def get_crazy_functions():
|
||||
"Function": None,
|
||||
"Class": Demo_Wrap, # 新一代插件需要注册Class
|
||||
},
|
||||
"精准翻译PDF论文": {
|
||||
"PDF论文翻译": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Info": "精准翻译PDF论文为中文 | 输入参数为路径",
|
||||
"Function": HotReload(批量翻译PDF文档), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Function": HotReload(批量翻译PDF文档), # 当注册Class后,Function旧接口仅会在“Void_Terminal”中起作用
|
||||
"Class": PDF_Tran, # 新一代插件需要注册Class
|
||||
},
|
||||
"询问多个GPT模型": {
|
||||
@@ -274,21 +277,21 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "批量总结PDF文档的内容 | 输入参数为路径",
|
||||
"Function": HotReload(批量总结PDF文档),
|
||||
"Function": HotReload(PDF_Summary),
|
||||
},
|
||||
"谷歌学术检索助手(输入谷歌学术搜索页url)": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "使用谷歌学术检索助手搜索指定URL的结果 | 输入参数为谷歌学术搜索页的URL",
|
||||
"Function": HotReload(谷歌检索小助手),
|
||||
"Function": HotReload(Google_Scholar_Assistant_Legacy),
|
||||
},
|
||||
"理解PDF文档内容 (模仿ChatPDF)": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False, # 加入下拉菜单中
|
||||
"Info": "理解PDF文档的内容并进行回答 | 输入参数为路径",
|
||||
"Function": HotReload(理解PDF文档内容标准文件输入),
|
||||
"Function": HotReload(PDF_QA标准文件输入),
|
||||
},
|
||||
"英文Latex项目全文润色(输入路径或上传压缩包)": {
|
||||
"Group": "学术",
|
||||
@@ -353,7 +356,7 @@ def get_crazy_functions():
|
||||
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
|
||||
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
|
||||
"Info": "ArXiv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
|
||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“Void_Terminal”中起作用
|
||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||
},
|
||||
"📚本地Latex论文精细翻译(上传Latex项目)[需Latex]": {
|
||||
@@ -376,9 +379,18 @@ def get_crazy_functions():
|
||||
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
|
||||
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
|
||||
"Info": "PDF翻译中文,并重新编译PDF | 输入参数为路径",
|
||||
"Function": HotReload(PDF翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Function": HotReload(PDF翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“Void_Terminal”中起作用
|
||||
"Class": PDF_Localize # 新一代插件需要注册Class
|
||||
}
|
||||
},
|
||||
"批量文件询问 (支持自定义总结各种文件)": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"AdvancedArgs": False,
|
||||
"Info": "先上传文件,点击此按钮,进行提问",
|
||||
"Function": HotReload(批量文件询问),
|
||||
"Class": Document_Conversation_Wrap,
|
||||
},
|
||||
}
|
||||
|
||||
function_plugins.update(
|
||||
@@ -388,7 +400,7 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "使用 DALLE2/DALLE3 生成图片 | 输入参数字符串,提供图像的内容",
|
||||
"Function": HotReload(图片生成_DALLE2), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||
"Function": HotReload(图片生成_DALLE2), # 当注册Class后,Function旧接口仅会在“Void_Terminal”中起作用
|
||||
"Class": ImageGen_Wrap # 新一代插件需要注册Class
|
||||
},
|
||||
}
|
||||
@@ -414,10 +426,8 @@ def get_crazy_functions():
|
||||
|
||||
|
||||
|
||||
|
||||
# -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=-
|
||||
try:
|
||||
from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
|
||||
from crazy_functions.Arxiv_Downloader import 下载arxiv论文并翻译摘要
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -455,7 +465,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.询问多个大语言模型 import 同时问询_指定模型
|
||||
from crazy_functions.Multi_LLM_Query import 同时问询_指定模型
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -476,7 +486,7 @@ def get_crazy_functions():
|
||||
|
||||
|
||||
try:
|
||||
from crazy_functions.总结音视频 import 总结音视频
|
||||
from crazy_functions.Audio_Summary import Audio_Summary
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -487,7 +497,7 @@ def get_crazy_functions():
|
||||
"AdvancedArgs": True,
|
||||
"ArgsReminder": "调用openai api 使用whisper-1模型, 目前支持的格式:mp4, m4a, wav, mpga, mpeg, mp3。此处可以输入解析提示,例如:解析为简体中文(默认)。",
|
||||
"Info": "批量总结音频或视频 | 输入参数为路径",
|
||||
"Function": HotReload(总结音视频),
|
||||
"Function": HotReload(Audio_Summary),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -496,7 +506,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.数学动画生成manim import 动画生成
|
||||
from crazy_functions.Math_Animation_Gen import 动画生成
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -533,7 +543,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.知识库问答 import 知识库文件注入
|
||||
from crazy_functions.Vectorstore_QA import 知识库文件注入
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -552,7 +562,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.知识库问答 import 读取知识库作答
|
||||
from crazy_functions.Vectorstore_QA import 读取知识库作答
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -571,7 +581,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.交互功能函数模板 import 交互功能模板函数
|
||||
from crazy_functions.Interactive_Func_Template import 交互功能模板函数
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -593,7 +603,7 @@ def get_crazy_functions():
|
||||
|
||||
ENABLE_AUDIO = get_conf("ENABLE_AUDIO")
|
||||
if ENABLE_AUDIO:
|
||||
from crazy_functions.语音助手 import 语音助手
|
||||
from crazy_functions.Audio_Assistant import Audio_Assistant
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -602,7 +612,7 @@ def get_crazy_functions():
|
||||
"Color": "stop",
|
||||
"AsButton": True,
|
||||
"Info": "这是一个时刻聆听着的语音对话助手 | 没有输入参数",
|
||||
"Function": HotReload(语音助手),
|
||||
"Function": HotReload(Audio_Assistant),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -611,7 +621,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.批量翻译PDF文档_NOUGAT import 批量翻译PDF文档
|
||||
from crazy_functions.PDF_Translate_Nougat import 批量翻译PDF文档
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -628,7 +638,7 @@ def get_crazy_functions():
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.函数动态生成 import 函数动态生成
|
||||
from crazy_functions.Dynamic_Function_Generate import Dynamic_Function_Generate
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
@@ -636,7 +646,7 @@ def get_crazy_functions():
|
||||
"Group": "智能体",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Function": HotReload(函数动态生成),
|
||||
"Function": HotReload(Dynamic_Function_Generate),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -644,39 +654,21 @@ def get_crazy_functions():
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.多智能体 import 多智能体终端
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
"AutoGen多智能体终端(仅供测试)": {
|
||||
"Group": "智能体",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Function": HotReload(多智能体终端),
|
||||
}
|
||||
}
|
||||
)
|
||||
except:
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.互动小游戏 import 随机小游戏
|
||||
|
||||
function_plugins.update(
|
||||
{
|
||||
"随机互动小游戏(仅供测试)": {
|
||||
"Group": "智能体",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Function": HotReload(随机小游戏),
|
||||
}
|
||||
}
|
||||
)
|
||||
except:
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
# try:
|
||||
# from crazy_functions.Multi_Agent_Legacy import Multi_Agent_Legacy终端
|
||||
# function_plugins.update(
|
||||
# {
|
||||
# "AutoGenMulti_Agent_Legacy终端(仅供测试)": {
|
||||
# "Group": "智能体",
|
||||
# "Color": "stop",
|
||||
# "AsButton": False,
|
||||
# "Function": HotReload(Multi_Agent_Legacy终端),
|
||||
# }
|
||||
# }
|
||||
# )
|
||||
# except:
|
||||
# logger.error(trimmed_format_exc())
|
||||
# logger.error("Load function plugin failed")
|
||||
|
||||
try:
|
||||
from crazy_functions.Rag_Interface import Rag问答
|
||||
@@ -696,6 +688,44 @@ def get_crazy_functions():
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
# try:
|
||||
# from crazy_functions.Document_Optimize import 自定义智能文档处理
|
||||
# function_plugins.update(
|
||||
# {
|
||||
# "一键处理文档(支持自定义全文润色、降重等)": {
|
||||
# "Group": "学术",
|
||||
# "Color": "stop",
|
||||
# "AsButton": False,
|
||||
# "AdvancedArgs": True,
|
||||
# "ArgsReminder": "请输入处理指令和要求(可以详细描述),如:请帮我润色文本,要求幽默点。默认调用润色指令。",
|
||||
# "Info": "保留文档结构,智能处理文档内容 | 输入参数为文件路径",
|
||||
# "Function": HotReload(自定义智能文档处理)
|
||||
# },
|
||||
# }
|
||||
# )
|
||||
# except:
|
||||
# logger.error(trimmed_format_exc())
|
||||
# logger.error("Load function plugin failed")
|
||||
|
||||
|
||||
|
||||
try:
|
||||
from crazy_functions.Paper_Reading import 快速论文解读
|
||||
function_plugins.update(
|
||||
{
|
||||
"速读论文": {
|
||||
"Group": "学术",
|
||||
"Color": "stop",
|
||||
"AsButton": False,
|
||||
"Info": "上传一篇论文进行快速分析和解读 | 输入参数为论文路径或DOI/arXiv ID",
|
||||
"Function": HotReload(快速论文解读),
|
||||
},
|
||||
}
|
||||
)
|
||||
except:
|
||||
logger.error(trimmed_format_exc())
|
||||
logger.error("Load function plugin failed")
|
||||
|
||||
|
||||
# try:
|
||||
# from crazy_functions.高级功能函数模板 import 测试图表渲染
|
||||
@@ -744,12 +774,12 @@ def get_multiplex_button_functions():
|
||||
"查互联网后回答":
|
||||
"查互联网后回答",
|
||||
|
||||
"多模型对话":
|
||||
"多模型对话":
|
||||
"询问多个GPT模型", # 映射到上面的 `询问多个GPT模型` 插件
|
||||
|
||||
"智能召回 RAG":
|
||||
"智能召回 RAG":
|
||||
"Rag智能召回", # 映射到上面的 `Rag智能召回` 插件
|
||||
|
||||
"多媒体查询":
|
||||
"多媒体查询":
|
||||
"多媒体智能体", # 映射到上面的 `多媒体智能体` 插件
|
||||
}
|
||||
|
||||
290
crazy_functions/Academic_Conversation.py
Normal file
290
crazy_functions/Academic_Conversation.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from textwrap import dedent
|
||||
from toolbox import CatchException, get_conf, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
|
||||
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
|
||||
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||||
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||
from crazy_functions.review_fns.query_analyzer import QueryAnalyzer
|
||||
from crazy_functions.review_fns.handlers.review_handler import 文献综述功能
|
||||
from crazy_functions.review_fns.handlers.recommend_handler import 论文推荐功能
|
||||
from crazy_functions.review_fns.handlers.qa_handler import 学术问答功能
|
||||
from crazy_functions.review_fns.handlers.paper_handler import 单篇论文分析功能
|
||||
from crazy_functions.Conversation_To_File import write_chat_to_file
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from crazy_functions.review_fns.handlers.latest_handler import Arxiv最新论文推荐功能
|
||||
from datetime import datetime
|
||||
|
||||
@CatchException
|
||||
def 学术对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数"""
|
||||
|
||||
# 初始化数据源
|
||||
arxiv_source = ArxivSource()
|
||||
semantic_source = SemanticScholarSource(
|
||||
api_key=get_conf("SEMANTIC_SCHOLAR_KEY")
|
||||
)
|
||||
|
||||
# 初始化处理器
|
||||
handlers = {
|
||||
"review": 文献综述功能(arxiv_source, semantic_source, llm_kwargs),
|
||||
"recommend": 论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
|
||||
"qa": 学术问答功能(arxiv_source, semantic_source, llm_kwargs),
|
||||
"paper": 单篇论文分析功能(arxiv_source, semantic_source, llm_kwargs),
|
||||
"latest": Arxiv最新论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
|
||||
}
|
||||
|
||||
# 分析查询意图
|
||||
chatbot.append([None, "正在分析研究主题和查询要求..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
query_analyzer = QueryAnalyzer()
|
||||
search_criteria = yield from query_analyzer.analyze_query(txt, chatbot, llm_kwargs)
|
||||
handler = handlers.get(search_criteria.query_type)
|
||||
if not handler:
|
||||
handler = handlers["qa"] # 默认使用QA处理器
|
||||
|
||||
# 处理查询
|
||||
chatbot.append([None, f"使用{handler.__class__.__name__}处理...,可能需要您耐心等待3~5分钟..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
final_prompt = asyncio.run(handler.handle(
|
||||
criteria=search_criteria,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
system_prompt=system_prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
plugin_kwargs=plugin_kwargs
|
||||
))
|
||||
|
||||
if final_prompt:
|
||||
# 检查是否是道歉提示
|
||||
if "很抱歉,我们未能找到" in final_prompt:
|
||||
chatbot.append([txt, final_prompt])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
# 在 final_prompt 末尾添加用户原始查询要求
|
||||
final_prompt += dedent(f"""
|
||||
Original user query: "{txt}"
|
||||
|
||||
IMPORTANT NOTE :
|
||||
- Your response must directly address the user's original user query above
|
||||
- While following the previous guidelines, prioritize answering what the user specifically asked
|
||||
- Make sure your response format and content align with the user's expectations
|
||||
- Do not translate paper titles, keep them in their original language
|
||||
- Do not generate a reference list in your response - references will be handled separately
|
||||
""")
|
||||
|
||||
# 使用最终的prompt生成回答
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=final_prompt,
|
||||
inputs_show_user=txt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=[],
|
||||
sys_prompt=f"You are a helpful academic assistant. Response in Chinese by default unless specified language is required in the user's query."
|
||||
)
|
||||
|
||||
# 1. 获取文献列表
|
||||
papers_list = handler.ranked_papers # 直接使用原始论文数据
|
||||
|
||||
# 在新的对话中添加格式化的参考文献列表
|
||||
if papers_list:
|
||||
references = ""
|
||||
for idx, paper in enumerate(papers_list, 1):
|
||||
# 构建作者列表
|
||||
authors = paper.authors[:3]
|
||||
if len(paper.authors) > 3:
|
||||
authors.append("et al.")
|
||||
authors_str = ", ".join(authors)
|
||||
|
||||
# 构建期刊指标信息
|
||||
metrics = []
|
||||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||
metrics.append(f"IF: {paper.if_factor}")
|
||||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||
metrics.append(f"JCR: {paper.jcr_division}")
|
||||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||
metrics.append(f"中科院分区: {paper.cas_division}")
|
||||
metrics_str = f" [{', '.join(metrics)}]" if metrics else ""
|
||||
|
||||
# 构建DOI链接
|
||||
doi_link = ""
|
||||
if paper.doi:
|
||||
if "arxiv.org" in str(paper.doi):
|
||||
doi_url = paper.doi
|
||||
else:
|
||||
doi_url = f"https://doi.org/{paper.doi}"
|
||||
doi_link = f" <a href='{doi_url}' target='_blank'>DOI: {paper.doi}</a>"
|
||||
|
||||
# 构建完整的引用
|
||||
reference = f"[{idx}] {authors_str}. *{paper.title}*"
|
||||
if paper.venue_name:
|
||||
reference += f". {paper.venue_name}"
|
||||
if paper.year:
|
||||
reference += f", {paper.year}"
|
||||
reference += metrics_str
|
||||
if doi_link:
|
||||
reference += f".{doi_link}"
|
||||
reference += " \n"
|
||||
|
||||
references += reference
|
||||
|
||||
# 添加新的对话显示参考文献
|
||||
chatbot.append(["参考文献如下:", references])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
# 2. 保存为不同格式
|
||||
from .review_fns.conversation_doc.word_doc import WordFormatter
|
||||
from .review_fns.conversation_doc.word2pdf import WordToPdfConverter
|
||||
from .review_fns.conversation_doc.markdown_doc import MarkdownFormatter
|
||||
from .review_fns.conversation_doc.html_doc import HtmlFormatter
|
||||
|
||||
# 创建保存目录
|
||||
save_dir = get_log_folder(get_user(chatbot), plugin_name='chatscholar')
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
# 生成文件名
|
||||
def get_safe_filename(txt, max_length=10):
|
||||
# 获取文本前max_length个字符作为文件名
|
||||
filename = txt[:max_length].strip()
|
||||
# 移除不安全的文件名字符
|
||||
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
|
||||
# 如果文件名为空,使用时间戳
|
||||
if not filename:
|
||||
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return filename
|
||||
|
||||
base_filename = get_safe_filename(txt)
|
||||
|
||||
result_files = [] # 收集所有生成的文件
|
||||
pdf_path = None # 用于跟踪PDF是否成功生成
|
||||
|
||||
# 保存为Markdown
|
||||
try:
|
||||
md_formatter = MarkdownFormatter()
|
||||
md_content = md_formatter.create_document(txt, response, papers_list)
|
||||
result_file_md = write_history_to_file(
|
||||
history=[md_content],
|
||||
file_basename=f"markdown_{base_filename}.md"
|
||||
)
|
||||
result_files.append(result_file_md)
|
||||
except Exception as e:
|
||||
print(f"Markdown保存失败: {str(e)}")
|
||||
|
||||
# 保存为HTML
|
||||
try:
|
||||
html_formatter = HtmlFormatter()
|
||||
html_content = html_formatter.create_document(txt, response, papers_list)
|
||||
result_file_html = write_history_to_file(
|
||||
history=[html_content],
|
||||
file_basename=f"html_{base_filename}.html"
|
||||
)
|
||||
result_files.append(result_file_html)
|
||||
except Exception as e:
|
||||
print(f"HTML保存失败: {str(e)}")
|
||||
|
||||
# 保存为Word
|
||||
try:
|
||||
word_formatter = WordFormatter()
|
||||
try:
|
||||
doc = word_formatter.create_document(txt, response, papers_list)
|
||||
except Exception as e:
|
||||
print(f"Word文档内容生成失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
try:
|
||||
result_file_docx = os.path.join(
|
||||
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||
f"docx_{base_filename}.docx"
|
||||
)
|
||||
doc.save(result_file_docx)
|
||||
result_files.append(result_file_docx)
|
||||
print(f"Word文档已保存到: {result_file_docx}")
|
||||
|
||||
# 转换为PDF
|
||||
try:
|
||||
pdf_path = WordToPdfConverter.convert_to_pdf(result_file_docx)
|
||||
if pdf_path:
|
||||
result_files.append(pdf_path)
|
||||
print(f"PDF文档已生成: {pdf_path}")
|
||||
except Exception as e:
|
||||
print(f"PDF转换失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Word文档保存失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
print(f"Word格式化失败: {str(e)}")
|
||||
import traceback
|
||||
print(f"详细错误信息: {traceback.format_exc()}")
|
||||
|
||||
# 保存为BibTeX格式
|
||||
try:
|
||||
from .review_fns.conversation_doc.reference_formatter import ReferenceFormatter
|
||||
ref_formatter = ReferenceFormatter()
|
||||
bibtex_content = ref_formatter.create_document(papers_list)
|
||||
|
||||
# 在与其他文件相同目录下创建BibTeX文件
|
||||
result_file_bib = os.path.join(
|
||||
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||
f"references_{base_filename}.bib"
|
||||
)
|
||||
|
||||
# 直接写入文件
|
||||
with open(result_file_bib, 'w', encoding='utf-8') as f:
|
||||
f.write(bibtex_content)
|
||||
|
||||
result_files.append(result_file_bib)
|
||||
print(f"BibTeX文件已保存到: {result_file_bib}")
|
||||
except Exception as e:
|
||||
print(f"BibTeX格式保存失败: {str(e)}")
|
||||
|
||||
# 保存为EndNote格式
|
||||
try:
|
||||
from .review_fns.conversation_doc.endnote_doc import EndNoteFormatter
|
||||
endnote_formatter = EndNoteFormatter()
|
||||
endnote_content = endnote_formatter.create_document(papers_list)
|
||||
|
||||
# 在与其他文件相同目录下创建EndNote文件
|
||||
result_file_enw = os.path.join(
|
||||
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||
f"references_{base_filename}.enw"
|
||||
)
|
||||
|
||||
# 直接写入文件
|
||||
with open(result_file_enw, 'w', encoding='utf-8') as f:
|
||||
f.write(endnote_content)
|
||||
|
||||
result_files.append(result_file_enw)
|
||||
print(f"EndNote文件已保存到: {result_file_enw}")
|
||||
except Exception as e:
|
||||
print(f"EndNote格式保存失败: {str(e)}")
|
||||
|
||||
# 添加所有文件到下载区
|
||||
success_files = []
|
||||
for file in result_files:
|
||||
try:
|
||||
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||
success_files.append(os.path.basename(file))
|
||||
except Exception as e:
|
||||
print(f"文件添加到下载区失败: {str(e)}")
|
||||
|
||||
# 更新成功提示消息
|
||||
if success_files:
|
||||
chatbot.append(["保存对话记录成功,bib和enw文件支持导入到EndNote、Zotero、JabRef、Mendeley等文献管理软件,HTML文件支持在浏览器中打开,里面包含详细论文源信息", "对话已保存并添加到下载区,可以在下载区找到相关文件"])
|
||||
else:
|
||||
chatbot.append(["保存对话记录", "所有格式的保存都失败了,请检查错误日志。"])
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
else:
|
||||
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -168,7 +168,7 @@ class InterviewAssistant(AliyunASR):
|
||||
|
||||
|
||||
@CatchException
|
||||
def 语音助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Audio_Assistant(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
# pip install -U openai-whisper
|
||||
chatbot.append(["对话助手函数插件:使用时,双手离开鼠标键盘吧", "音频助手, 正在听您讲话(点击“停止”键可终止程序)..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
@@ -132,13 +132,13 @@ def AnalyAudio(parse_prompt, file_manifest, llm_kwargs, chatbot, history):
|
||||
|
||||
|
||||
@CatchException
|
||||
def 总结音视频(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, WEB_PORT):
|
||||
def Audio_Summary(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, WEB_PORT):
|
||||
import glob, os
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append([
|
||||
"函数插件功能?",
|
||||
"总结音视频内容,函数插件贡献者: dalvqw & BinaryHusky"])
|
||||
"Audio_Summary内容,函数插件贡献者: dalvqw & BinaryHusky"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
try:
|
||||
@@ -4,7 +4,7 @@ from crazy_functions.crazy_utils import input_clipping
|
||||
import copy, json
|
||||
|
||||
@CatchException
|
||||
def 命令行助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Commandline_Assistant(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
txt 输入栏用户输入的文本, 例如需要翻译的一段话, 再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行
|
||||
537
crazy_functions/Document_Conversation.py
Normal file
537
crazy_functions/Document_Conversation.py
Normal file
@@ -0,0 +1,537 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Dict, Generator
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||
from request_llms.bridge_all import model_info
|
||||
from toolbox import update_ui, CatchException, report_exception
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileFragment:
|
||||
"""文件片段数据类,用于组织处理单元"""
|
||||
file_path: str
|
||||
content: str
|
||||
rel_path: str
|
||||
fragment_index: int
|
||||
total_fragments: int
|
||||
|
||||
|
||||
class BatchDocumentSummarizer:
|
||||
"""优化的文档总结器 - 批处理版本"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, query: str, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化总结器"""
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.query = query
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.failed_files = []
|
||||
self.file_summaries_map = {}
|
||||
|
||||
def _get_token_limit(self) -> int:
|
||||
"""获取模型token限制"""
|
||||
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||
return max_token * 3 // 4
|
||||
|
||||
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
|
||||
"""创建批处理输入"""
|
||||
inputs_array = []
|
||||
inputs_show_user_array = []
|
||||
history_array = []
|
||||
|
||||
for frag in fragments:
|
||||
if self.query:
|
||||
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)},'
|
||||
f'用户要求为:{self.query}:'
|
||||
f'文件内容是 ```{frag.content}```')
|
||||
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
|
||||
else:
|
||||
i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)},'
|
||||
f'内容是 ```{frag.content}```')
|
||||
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
|
||||
|
||||
inputs_array.append(i_say)
|
||||
inputs_show_user_array.append(i_say_show_user)
|
||||
history_array.append([])
|
||||
|
||||
return inputs_array, inputs_show_user_array, history_array
|
||||
|
||||
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
|
||||
"""包装了超时控制的文件处理函数"""
|
||||
|
||||
def timeout_handler():
|
||||
thread = threading.current_thread()
|
||||
if hasattr(thread, '_timeout_occurred'):
|
||||
thread._timeout_occurred = True
|
||||
|
||||
# 设置超时标记
|
||||
thread = threading.current_thread()
|
||||
thread._timeout_occurred = False
|
||||
|
||||
# 设置超时时间为30秒,给予更多处理时间
|
||||
TIMEOUT_SECONDS = 30
|
||||
timer = threading.Timer(TIMEOUT_SECONDS, timeout_handler)
|
||||
timer.start()
|
||||
|
||||
try:
|
||||
fp, project_folder = file_info
|
||||
fragments = []
|
||||
|
||||
# 定期检查是否超时
|
||||
def check_timeout():
|
||||
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
|
||||
raise TimeoutError(f"处理文件 {os.path.basename(fp)} 超时({TIMEOUT_SECONDS}秒)")
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "检查文件大小"
|
||||
mutable_status[1] = time.time()
|
||||
check_timeout()
|
||||
|
||||
# 文件大小检查
|
||||
if os.path.getsize(fp) > self.max_file_size:
|
||||
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
|
||||
mutable_status[2] = "文件过大"
|
||||
return fragments
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "提取文件内容"
|
||||
mutable_status[1] = time.time()
|
||||
|
||||
# 提取内容 - 使用单独的超时控制
|
||||
content = None
|
||||
extract_start_time = time.time()
|
||||
try:
|
||||
while True:
|
||||
check_timeout() # 检查全局超时
|
||||
|
||||
# 检查提取过程是否超时(10秒)
|
||||
if time.time() - extract_start_time > 10:
|
||||
raise TimeoutError("文件内容提取超时(10秒)")
|
||||
|
||||
try:
|
||||
content = extract_text(fp)
|
||||
break
|
||||
except Exception as e:
|
||||
if "timeout" in str(e).lower():
|
||||
continue # 如果是临时超时,重试
|
||||
raise # 其他错误直接抛出
|
||||
|
||||
except Exception as e:
|
||||
self.failed_files.append((fp, f"文件读取失败:{str(e)}"))
|
||||
mutable_status[2] = "读取失败"
|
||||
return fragments
|
||||
|
||||
if content is None:
|
||||
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
|
||||
mutable_status[2] = "格式不支持"
|
||||
return fragments
|
||||
elif not content.strip():
|
||||
self.failed_files.append((fp, "文件内容为空"))
|
||||
mutable_status[2] = "内容为空"
|
||||
return fragments
|
||||
|
||||
check_timeout()
|
||||
|
||||
# 更新状态
|
||||
mutable_status[0] = "分割文本"
|
||||
mutable_status[1] = time.time()
|
||||
|
||||
# 分割文本 - 添加超时检查
|
||||
split_start_time = time.time()
|
||||
try:
|
||||
while True:
|
||||
check_timeout() # 检查全局超时
|
||||
|
||||
# 检查分割过程是否超时(5秒)
|
||||
if time.time() - split_start_time > 5:
|
||||
raise TimeoutError("文本分割超时(5秒)")
|
||||
|
||||
paper_fragments = breakdown_text_to_satisfy_token_limit(
|
||||
txt=content,
|
||||
limit=self._get_token_limit(),
|
||||
llm_model=self.llm_kwargs['llm_model']
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
|
||||
mutable_status[2] = "分割失败"
|
||||
return fragments
|
||||
|
||||
# 处理片段
|
||||
rel_path = os.path.relpath(fp, project_folder)
|
||||
for i, frag in enumerate(paper_fragments):
|
||||
check_timeout() # 每处理一个片段检查一次超时
|
||||
if frag.strip():
|
||||
fragments.append(FileFragment(
|
||||
file_path=fp,
|
||||
content=frag,
|
||||
rel_path=rel_path,
|
||||
fragment_index=i,
|
||||
total_fragments=len(paper_fragments)
|
||||
))
|
||||
|
||||
mutable_status[2] = "处理完成"
|
||||
return fragments
|
||||
|
||||
except TimeoutError as e:
|
||||
self.failed_files.append((fp, str(e)))
|
||||
mutable_status[2] = "处理超时"
|
||||
return []
|
||||
except Exception as e:
|
||||
self.failed_files.append((fp, f"处理失败:{str(e)}"))
|
||||
mutable_status[2] = "处理异常"
|
||||
return []
|
||||
finally:
|
||||
timer.cancel()
|
||||
|
||||
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||
import concurrent.futures
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Generator, List
|
||||
"""并行准备所有文件的处理片段"""
|
||||
all_fragments = []
|
||||
total_files = len(file_paths)
|
||||
|
||||
# 配置参数
|
||||
self.refresh_interval = 0.2 # UI刷新间隔
|
||||
self.watch_dog_patience = 5 # 看门狗超时时间
|
||||
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
|
||||
self.max_workers = min(32, len(file_paths)) # 最多32个线程
|
||||
|
||||
# 创建有超时控制的线程池
|
||||
executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||
|
||||
# 用于跨线程状态传递的可变列表 - 增加文件名信息
|
||||
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
|
||||
|
||||
# 创建文件处理任务
|
||||
file_infos = [(fp, project_folder) for fp in file_paths]
|
||||
|
||||
# 提交所有任务,使用带超时控制的处理函数
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._process_single_file_with_timeout,
|
||||
file_info,
|
||||
mutable_status_array[i]
|
||||
) for i, file_info in enumerate(file_infos)
|
||||
]
|
||||
|
||||
# 更新UI的计数器
|
||||
cnt = 0
|
||||
|
||||
try:
|
||||
# 监控任务执行
|
||||
while True:
|
||||
time.sleep(self.refresh_interval)
|
||||
cnt += 1
|
||||
|
||||
# 检查任务完成状态
|
||||
worker_done = [f.done() for f in futures]
|
||||
|
||||
# 更新状态显示
|
||||
status_str = ""
|
||||
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
|
||||
# 获取文件名(去掉路径)
|
||||
file_name = os.path.basename(file_path)
|
||||
if worker_done[i]:
|
||||
status_str += f"文件 {file_name}: {desc}\n\n"
|
||||
else:
|
||||
status_str += f"文件 {file_name}: {status} {desc}\n\n"
|
||||
|
||||
# 更新UI
|
||||
self.chatbot[-1] = [
|
||||
"处理进度",
|
||||
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
|
||||
]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 检查是否所有任务完成
|
||||
if all(worker_done):
|
||||
break
|
||||
|
||||
finally:
|
||||
# 确保线程池正确关闭
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
# 收集结果
|
||||
processed_files = 0
|
||||
for future in futures:
|
||||
try:
|
||||
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
|
||||
all_fragments.extend(fragments)
|
||||
processed_files += 1
|
||||
except concurrent.futures.TimeoutError:
|
||||
# 处理获取结果超时
|
||||
file_index = futures.index(future)
|
||||
self.failed_files.append((file_paths[file_index], "结果获取超时"))
|
||||
continue
|
||||
except Exception as e:
|
||||
# 处理其他异常
|
||||
file_index = futures.index(future)
|
||||
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
|
||||
continue
|
||||
|
||||
# 最终进度更新
|
||||
self.chatbot.append([
|
||||
"文件处理完成",
|
||||
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
|
||||
])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return all_fragments
|
||||
|
||||
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
|
||||
"""批量处理文件片段"""
|
||||
from collections import defaultdict
|
||||
batch_size = 64 # 每批处理的片段数
|
||||
max_retries = 3 # 最大重试次数
|
||||
retry_delay = 5 # 重试延迟(秒)
|
||||
results = defaultdict(list)
|
||||
|
||||
# 按批次处理
|
||||
for i in range(0, len(fragments), batch_size):
|
||||
batch = fragments[i:i + batch_size]
|
||||
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
|
||||
sys_prompt_array = ["请总结以下内容:"] * len(batch)
|
||||
|
||||
# 添加重试机制
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
for j, frag in enumerate(batch):
|
||||
summary = response_collection[j * 2 + 1]
|
||||
if summary and summary.strip():
|
||||
results[frag.rel_path].append({
|
||||
'index': frag.fragment_index,
|
||||
'summary': summary,
|
||||
'total': frag.total_fragments
|
||||
})
|
||||
break # 成功处理,跳出重试循环
|
||||
|
||||
except Exception as e:
|
||||
if retry == max_retries - 1: # 最后一次重试失败
|
||||
for frag in batch:
|
||||
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
|
||||
else:
|
||||
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return results
|
||||
|
||||
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
|
||||
"""准备最终总结请求"""
|
||||
if not self.file_summaries_map:
|
||||
return (["无可用的文件总结"], ["生成最终总结"], [[]])
|
||||
|
||||
summaries = list(self.file_summaries_map.values())
|
||||
if all(not summary for summary in summaries):
|
||||
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
|
||||
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
|
||||
else:
|
||||
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
|
||||
|
||||
return ([i_say], [i_say], [summaries])
|
||||
|
||||
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||
"""处理所有文件"""
|
||||
total_files = len(file_paths)
|
||||
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 1. 准备所有文件片段
|
||||
# 在 process_files 函数中:
|
||||
fragments = yield from self.prepare_fragments(project_folder, file_paths)
|
||||
if not fragments:
|
||||
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
|
||||
return "没有可处理的文件内容"
|
||||
|
||||
# 2. 批量处理所有文件片段
|
||||
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
file_summaries = yield from self._process_fragments_batch(fragments)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
|
||||
return "处理过程发生错误"
|
||||
|
||||
# 3. 为每个文件生成整体总结
|
||||
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 处理每个文件的总结
|
||||
for rel_path, summaries in file_summaries.items():
|
||||
if len(summaries) > 1: # 多片段文件需要生成整体总结
|
||||
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
|
||||
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||
else:
|
||||
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
|
||||
|
||||
try:
|
||||
summary_texts = [s['summary'] for s in sorted_summaries]
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=[i_say],
|
||||
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=[summary_texts],
|
||||
sys_prompt_array=["你是一个优秀的助手,"],
|
||||
)
|
||||
self.file_summaries_map[rel_path] = response_collection[1]
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
|
||||
self.file_summaries_map[rel_path] = "总结生成失败"
|
||||
else: # 单片段文件直接使用其唯一的总结
|
||||
self.file_summaries_map[rel_path] = summaries[0]['summary']
|
||||
|
||||
# 4. 生成最终总结
|
||||
if total_files == 1:
|
||||
return "文件数为1,此时不调用总结模块"
|
||||
else:
|
||||
try:
|
||||
# 收集所有文件的总结用于生成最终总结
|
||||
file_summaries_for_final = []
|
||||
for rel_path, summary in self.file_summaries_map.items():
|
||||
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
|
||||
|
||||
if self.plugin_kwargs.get("advanced_arg"):
|
||||
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
|
||||
self.plugin_kwargs['advanced_arg'])
|
||||
else:
|
||||
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
|
||||
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=[final_summary_prompt],
|
||||
inputs_show_user_array=["生成最终总结报告"],
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=[file_summaries_for_final],
|
||||
sys_prompt_array=["总结所有文件内容。"],
|
||||
max_workers=1
|
||||
)
|
||||
|
||||
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
|
||||
return "生成总结失败"
|
||||
|
||||
def save_results(self, final_summary: str):
|
||||
"""保存结果到文件"""
|
||||
from toolbox import promote_file_to_downloadzone, write_history_to_file
|
||||
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
|
||||
import os
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 创建各种格式化器
|
||||
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||
|
||||
result_files = []
|
||||
|
||||
# 保存 Markdown
|
||||
try:
|
||||
md_content = md_formatter.create_document()
|
||||
result_file_md = write_history_to_file(
|
||||
history=[md_content], # 直接传入内容列表
|
||||
file_basename=f"文档总结_{timestamp}.md"
|
||||
)
|
||||
result_files.append(result_file_md)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 保存 HTML
|
||||
try:
|
||||
html_content = html_formatter.create_document()
|
||||
result_file_html = write_history_to_file(
|
||||
history=[html_content],
|
||||
file_basename=f"文档总结_{timestamp}.html"
|
||||
)
|
||||
result_files.append(result_file_html)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 保存 Word
|
||||
try:
|
||||
doc = word_formatter.create_document()
|
||||
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
|
||||
result_file_docx = os.path.join(
|
||||
os.path.dirname(result_file_md),
|
||||
f"文档总结_{timestamp}.docx"
|
||||
)
|
||||
doc.save(result_file_docx)
|
||||
result_files.append(result_file_docx)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 添加到下载区
|
||||
for file in result_files:
|
||||
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||
|
||||
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
|
||||
|
||||
|
||||
@CatchException
|
||||
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 优化版本"""
|
||||
# 初始化
|
||||
import glob
|
||||
import re
|
||||
from crazy_functions.rag_fns.rag_file_support import supports_format
|
||||
from toolbox import report_exception
|
||||
query = plugin_kwargs.get("advanced_arg")
|
||||
summarizer = BatchDocumentSummarizer(llm_kwargs, query, chatbot, history, system_prompt)
|
||||
chatbot.append(["函数插件功能", f"作者:lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 验证输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 获取文件列表
|
||||
project_folder = txt
|
||||
user_name = chatbot.get_user()
|
||||
validate_path_safety(project_folder, user_name)
|
||||
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||
|
||||
if not file_manifest:
|
||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 处理所有文件并生成总结
|
||||
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 保存结果
|
||||
summarizer.save_results(final_summary)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
36
crazy_functions/Document_Conversation_Wrap.py
Normal file
36
crazy_functions/Document_Conversation_Wrap.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import random
|
||||
from toolbox import get_conf
|
||||
from crazy_functions.Document_Conversation import 批量文件询问
|
||||
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
|
||||
|
||||
|
||||
class Document_Conversation_Wrap(GptAcademicPluginTemplate):
|
||||
def __init__(self):
|
||||
"""
|
||||
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
|
||||
"""
|
||||
pass
|
||||
|
||||
def define_arg_selection_menu(self):
|
||||
"""
|
||||
定义插件的二级选项菜单
|
||||
|
||||
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
|
||||
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
|
||||
第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options`,`default_value`为下拉菜单默认值;
|
||||
|
||||
"""
|
||||
gui_definition = {
|
||||
"main_input":
|
||||
ArgProperty(title="已上传的文件", description="上传文件后自动填充", default_value="", type="string").model_dump_json(),
|
||||
"searxng_url":
|
||||
ArgProperty(title="对材料提问", description="提问", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||
}
|
||||
return gui_definition
|
||||
|
||||
def execute(txt, llm_kwargs, plugin_kwargs:dict, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
执行插件
|
||||
"""
|
||||
yield from 批量文件询问(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
|
||||
673
crazy_functions/Document_Optimize.py
Normal file
673
crazy_functions/Document_Optimize.py
Normal file
@@ -0,0 +1,673 @@
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
import threading
|
||||
from typing import Dict, List, Generator, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format, convert_to_markdown
|
||||
from request_llms.bridge_all import model_info
|
||||
from toolbox import update_ui, CatchException, report_exception, promote_file_to_downloadzone, write_history_to_file
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
|
||||
# 新增:导入结构化论文提取器
|
||||
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor, ExtractorConfig, StructuredPaper
|
||||
|
||||
# 导入格式化器
|
||||
from crazy_functions.paper_fns.file2file_doc import (
|
||||
TxtFormatter,
|
||||
MarkdownFormatter,
|
||||
HtmlFormatter,
|
||||
WordFormatter
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class TextFragment:
|
||||
"""文本片段数据类,用于组织处理单元"""
|
||||
content: str
|
||||
fragment_index: int
|
||||
total_fragments: int
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器 - 处理单个文档并输出结果"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化处理器"""
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.processed_results = []
|
||||
self.failed_fragments = []
|
||||
# 新增:初始化论文结构提取器
|
||||
self.paper_extractor = PaperStructureExtractor()
|
||||
|
||||
def _get_token_limit(self) -> int:
|
||||
"""获取模型token限制,返回更小的值以确保更细粒度的分割"""
|
||||
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||
# 降低token限制,使每个片段更小
|
||||
return max_token // 4 # 从3/4降低到1/4
|
||||
|
||||
def _create_batch_inputs(self, fragments: List[TextFragment]) -> Tuple[List, List, List]:
|
||||
"""创建批处理输入"""
|
||||
inputs_array = []
|
||||
inputs_show_user_array = []
|
||||
history_array = []
|
||||
|
||||
user_instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
|
||||
|
||||
for frag in fragments:
|
||||
i_say = (f'请按照以下要求处理文本内容:{user_instruction}\n\n'
|
||||
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
|
||||
f'文本内容:\n```\n{frag.content}\n```')
|
||||
|
||||
i_say_show_user = f'正在处理文本片段 {frag.fragment_index + 1}/{frag.total_fragments}'
|
||||
|
||||
inputs_array.append(i_say)
|
||||
inputs_show_user_array.append(i_say_show_user)
|
||||
history_array.append([])
|
||||
|
||||
return inputs_array, inputs_show_user_array, history_array
|
||||
|
||||
def _extract_decision(self, text: str) -> str:
|
||||
"""从LLM响应中提取<decision>标签内的内容"""
|
||||
import re
|
||||
pattern = r'<decision>(.*?)</decision>'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到标签,返回原始文本
|
||||
return text.strip()
|
||||
|
||||
def process_file(self, file_path: str) -> Generator:
|
||||
"""处理单个文件"""
|
||||
self.chatbot.append(["开始处理文件", f"文件路径: {file_path}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
# 首先尝试转换为Markdown
|
||||
from crazy_functions.rag_fns.rag_file_support import convert_to_markdown
|
||||
file_path = convert_to_markdown(file_path)
|
||||
|
||||
# 1. 检查文件是否为支持的论文格式
|
||||
is_paper_format = any(file_path.lower().endswith(ext) for ext in self.paper_extractor.SUPPORTED_EXTENSIONS)
|
||||
|
||||
if is_paper_format:
|
||||
# 使用结构化提取器处理论文
|
||||
return (yield from self._process_structured_paper(file_path))
|
||||
else:
|
||||
# 使用原有方式处理普通文档
|
||||
return (yield from self._process_regular_file(file_path))
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["处理错误", f"文件处理失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
def _process_structured_paper(self, file_path: str) -> Generator:
|
||||
"""处理结构化论文文件"""
|
||||
# 1. 提取论文结构
|
||||
self.chatbot[-1] = ["正在分析论文结构", f"文件路径: {file_path}"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
paper = self.paper_extractor.extract_paper_structure(file_path)
|
||||
|
||||
if not paper or not paper.sections:
|
||||
self.chatbot.append(["无法提取论文结构", "将使用全文内容进行处理"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用全文内容进行段落切分
|
||||
if paper and paper.full_text:
|
||||
# 使用增强的分割函数进行更细致的分割
|
||||
fragments = self._breakdown_section_content(paper.full_text)
|
||||
|
||||
# 创建文本片段对象
|
||||
text_fragments = []
|
||||
for i, frag in enumerate(fragments):
|
||||
if frag.strip():
|
||||
text_fragments.append(TextFragment(
|
||||
content=frag,
|
||||
fragment_index=i,
|
||||
total_fragments=len(fragments)
|
||||
))
|
||||
|
||||
# 批量处理片段
|
||||
if text_fragments:
|
||||
self.chatbot[-1] = ["开始处理文本", f"共 {len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 一次性准备所有输入
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(text_fragments)
|
||||
|
||||
# 使用系统提示
|
||||
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
|
||||
sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(text_fragments)
|
||||
|
||||
# 调用LLM一次性处理所有片段
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
for j, frag in enumerate(text_fragments):
|
||||
try:
|
||||
llm_response = response_collection[j * 2 + 1]
|
||||
processed_text = self._extract_decision(llm_response)
|
||||
|
||||
if processed_text and processed_text.strip():
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': processed_text
|
||||
})
|
||||
else:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content
|
||||
})
|
||||
except Exception as e:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content
|
||||
})
|
||||
|
||||
# 按原始顺序合并结果
|
||||
self.processed_results.sort(key=lambda x: x['index'])
|
||||
final_content = "\n".join([item['content'] for item in self.processed_results])
|
||||
|
||||
# 更新UI
|
||||
success_count = len(text_fragments) - len(self.failed_fragments)
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return final_content
|
||||
else:
|
||||
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
else:
|
||||
self.chatbot.append(["处理失败", "未能提取到论文内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 2. 准备处理章节内容(不处理标题)
|
||||
self.chatbot[-1] = ["已提取论文结构", f"共 {len(paper.sections)} 个主要章节"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 3. 收集所有需要处理的章节内容并分割为合适大小
|
||||
sections_to_process = []
|
||||
section_map = {} # 用于映射处理前后的内容
|
||||
|
||||
def collect_section_contents(sections, parent_path=""):
|
||||
"""递归收集章节内容,跳过参考文献部分"""
|
||||
for i, section in enumerate(sections):
|
||||
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
|
||||
|
||||
# 检查是否为参考文献部分,如果是则跳过
|
||||
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
|
||||
continue # 跳过参考文献部分
|
||||
|
||||
# 只处理内容非空的章节
|
||||
if section.content and section.content.strip():
|
||||
# 使用增强的分割函数进行更细致的分割
|
||||
fragments = self._breakdown_section_content(section.content)
|
||||
|
||||
for fragment_idx, fragment_content in enumerate(fragments):
|
||||
if fragment_content.strip():
|
||||
fragment_index = len(sections_to_process)
|
||||
sections_to_process.append(TextFragment(
|
||||
content=fragment_content,
|
||||
fragment_index=fragment_index,
|
||||
total_fragments=0 # 临时值,稍后更新
|
||||
))
|
||||
|
||||
# 保存映射关系,用于稍后更新章节内容
|
||||
# 为每个片段存储原始章节和片段索引信息
|
||||
section_map[fragment_index] = (current_path, section, fragment_idx, len(fragments))
|
||||
|
||||
# 递归处理子章节
|
||||
if section.subsections:
|
||||
collect_section_contents(section.subsections, current_path)
|
||||
|
||||
# 收集所有章节内容
|
||||
collect_section_contents(paper.sections)
|
||||
|
||||
# 更新总片段数
|
||||
total_fragments = len(sections_to_process)
|
||||
for frag in sections_to_process:
|
||||
frag.total_fragments = total_fragments
|
||||
|
||||
# 4. 如果没有内容需要处理,直接返回
|
||||
if not sections_to_process:
|
||||
self.chatbot.append(["处理完成", "未找到需要处理的内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 5. 批量处理章节内容
|
||||
self.chatbot[-1] = ["开始处理论文内容", f"共 {len(sections_to_process)} 个内容片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 一次性准备所有输入
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(sections_to_process)
|
||||
|
||||
# 使用系统提示
|
||||
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点")
|
||||
sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(sections_to_process)
|
||||
|
||||
# 调用LLM一次性处理所有片段
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应,重组章节内容
|
||||
section_contents = {} # 用于重组各章节的处理后内容
|
||||
|
||||
for j, frag in enumerate(sections_to_process):
|
||||
try:
|
||||
llm_response = response_collection[j * 2 + 1]
|
||||
processed_text = self._extract_decision(llm_response)
|
||||
|
||||
if processed_text and processed_text.strip():
|
||||
# 保存处理结果
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': processed_text
|
||||
})
|
||||
|
||||
# 存储处理后的文本片段,用于后续重组
|
||||
fragment_index = frag.fragment_index
|
||||
if fragment_index in section_map:
|
||||
path, section, fragment_idx, total_fragments = section_map[fragment_index]
|
||||
|
||||
# 初始化此章节的内容容器(如果尚未创建)
|
||||
if path not in section_contents:
|
||||
section_contents[path] = [""] * total_fragments
|
||||
|
||||
# 将处理后的片段放入正确位置
|
||||
section_contents[path][fragment_idx] = processed_text
|
||||
else:
|
||||
self.failed_fragments.append(frag)
|
||||
except Exception as e:
|
||||
self.failed_fragments.append(frag)
|
||||
|
||||
# 重组每个章节的内容
|
||||
for path, fragments in section_contents.items():
|
||||
section = None
|
||||
for idx in section_map:
|
||||
if section_map[idx][0] == path:
|
||||
section = section_map[idx][1]
|
||||
break
|
||||
|
||||
if section:
|
||||
# 合并该章节的所有处理后片段
|
||||
section.content = "\n".join(fragments)
|
||||
|
||||
# 6. 更新UI
|
||||
success_count = total_fragments - len(self.failed_fragments)
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 收集参考文献部分(不进行处理)
|
||||
references_sections = []
|
||||
def collect_references(sections, parent_path=""):
|
||||
"""递归收集参考文献部分"""
|
||||
for i, section in enumerate(sections):
|
||||
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
|
||||
|
||||
# 检查是否为参考文献部分
|
||||
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
|
||||
references_sections.append((current_path, section))
|
||||
|
||||
# 递归检查子章节
|
||||
if section.subsections:
|
||||
collect_references(section.subsections, current_path)
|
||||
|
||||
# 收集参考文献
|
||||
collect_references(paper.sections)
|
||||
|
||||
# 7. 将处理后的结构化论文转换为Markdown
|
||||
markdown_content = self.paper_extractor.generate_markdown(paper)
|
||||
|
||||
# 8. 返回处理后的内容
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段,参考文献部分未处理"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["结构化处理失败", f"错误: {str(e)},将尝试作为普通文件处理"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return (yield from self._process_regular_file(file_path))
|
||||
|
||||
def _process_regular_file(self, file_path: str) -> Generator:
|
||||
"""使用原有方式处理普通文件"""
|
||||
# 原有的文件处理逻辑
|
||||
self.chatbot[-1] = ["正在读取文件", f"文件路径: {file_path}"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
content = extract_text(file_path)
|
||||
if not content or not content.strip():
|
||||
self.chatbot.append(["处理失败", "文件内容为空或无法提取内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 2. 分割文本
|
||||
self.chatbot[-1] = ["正在分析文件", "将文件内容分割为适当大小的片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用增强的分割函数
|
||||
fragments = self._breakdown_section_content(content)
|
||||
|
||||
# 3. 创建文本片段对象
|
||||
text_fragments = []
|
||||
for i, frag in enumerate(fragments):
|
||||
if frag.strip():
|
||||
text_fragments.append(TextFragment(
|
||||
content=frag,
|
||||
fragment_index=i,
|
||||
total_fragments=len(fragments)
|
||||
))
|
||||
|
||||
# 4. 处理所有片段
|
||||
self.chatbot[-1] = ["开始处理文本", f"共 {len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 批量处理片段
|
||||
batch_size = 8 # 每批处理的片段数
|
||||
for i in range(0, len(text_fragments), batch_size):
|
||||
batch = text_fragments[i:i + batch_size]
|
||||
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
|
||||
|
||||
# 使用系统提示
|
||||
instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下文本")
|
||||
sys_prompt_array = [f"你是一个专业的文本处理助手。请按照用户的要求:'{instruction}'处理文本。"] * len(batch)
|
||||
|
||||
# 调用LLM处理
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
for j, frag in enumerate(batch):
|
||||
try:
|
||||
llm_response = response_collection[j * 2 + 1]
|
||||
processed_text = self._extract_decision(llm_response)
|
||||
|
||||
if processed_text and processed_text.strip():
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': processed_text
|
||||
})
|
||||
else:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content # 如果处理失败,使用原始内容
|
||||
})
|
||||
except Exception as e:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content # 如果处理失败,使用原始内容
|
||||
})
|
||||
|
||||
# 5. 按原始顺序合并结果
|
||||
self.processed_results.sort(key=lambda x: x['index'])
|
||||
final_content = "\n".join([item['content'] for item in self.processed_results])
|
||||
|
||||
# 6. 更新UI
|
||||
success_count = len(text_fragments) - len(self.failed_fragments)
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return final_content
|
||||
|
||||
def save_results(self, content: str, original_file_path: str) -> List[str]:
|
||||
"""保存处理结果为多种格式"""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
original_filename = os.path.basename(original_file_path)
|
||||
filename_without_ext = os.path.splitext(original_filename)[0]
|
||||
base_filename = f"{filename_without_ext}_processed_{timestamp}"
|
||||
|
||||
result_files = []
|
||||
|
||||
# 获取用户指定的处理类型
|
||||
processing_type = self.plugin_kwargs.get("advanced_arg", "文本处理")
|
||||
|
||||
# 1. 保存为TXT
|
||||
try:
|
||||
txt_formatter = TxtFormatter()
|
||||
txt_content = txt_formatter.create_document(content)
|
||||
txt_file = write_history_to_file(
|
||||
history=[txt_content],
|
||||
file_basename=f"{base_filename}.txt"
|
||||
)
|
||||
result_files.append(txt_file)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"TXT格式保存失败: {str(e)}"])
|
||||
|
||||
# 2. 保存为Markdown
|
||||
try:
|
||||
md_formatter = MarkdownFormatter()
|
||||
md_content = md_formatter.create_document(content, processing_type)
|
||||
md_file = write_history_to_file(
|
||||
history=[md_content],
|
||||
file_basename=f"{base_filename}.md"
|
||||
)
|
||||
result_files.append(md_file)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"Markdown格式保存失败: {str(e)}"])
|
||||
|
||||
# 3. 保存为HTML
|
||||
try:
|
||||
html_formatter = HtmlFormatter(processing_type=processing_type)
|
||||
html_content = html_formatter.create_document(content)
|
||||
html_file = write_history_to_file(
|
||||
history=[html_content],
|
||||
file_basename=f"{base_filename}.html"
|
||||
)
|
||||
result_files.append(html_file)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"HTML格式保存失败: {str(e)}"])
|
||||
|
||||
# 4. 保存为Word
|
||||
try:
|
||||
word_formatter = WordFormatter()
|
||||
doc = word_formatter.create_document(content, processing_type)
|
||||
|
||||
# 获取保存路径
|
||||
from toolbox import get_log_folder
|
||||
word_path = os.path.join(get_log_folder(), f"{base_filename}.docx")
|
||||
doc.save(word_path)
|
||||
|
||||
# 5. 保存为PDF(通过Word转换)
|
||||
try:
|
||||
from crazy_functions.paper_fns.file2file_doc.word2pdf import WordToPdfConverter
|
||||
pdf_path = WordToPdfConverter.convert_to_pdf(word_path)
|
||||
result_files.append(pdf_path)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"PDF格式保存失败: {str(e)}"])
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"Word格式保存失败: {str(e)}"])
|
||||
|
||||
# 添加到下载区
|
||||
for file in result_files:
|
||||
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||
|
||||
return result_files
|
||||
|
||||
def _breakdown_section_content(self, content: str) -> List[str]:
|
||||
"""对文本内容进行分割与合并
|
||||
|
||||
主要按段落进行组织,只合并较小的段落以减少片段数量
|
||||
保留原始段落结构,不对长段落进行强制分割
|
||||
针对中英文设置不同的阈值,因为字符密度不同
|
||||
"""
|
||||
# 先按段落分割文本
|
||||
paragraphs = content.split('\n\n')
|
||||
|
||||
# 检测语言类型
|
||||
chinese_char_count = sum(1 for char in content if '\u4e00' <= char <= '\u9fff')
|
||||
is_chinese_text = chinese_char_count / max(1, len(content)) > 0.3
|
||||
|
||||
# 根据语言类型设置不同的阈值(只用于合并小段落)
|
||||
if is_chinese_text:
|
||||
# 中文文本:一个汉字就是一个字符,信息密度高
|
||||
min_chunk_size = 300 # 段落合并的最小阈值
|
||||
target_size = 800 # 理想的段落大小
|
||||
else:
|
||||
# 英文文本:一个单词由多个字符组成,信息密度低
|
||||
min_chunk_size = 600 # 段落合并的最小阈值
|
||||
target_size = 1600 # 理想的段落大小
|
||||
|
||||
# 1. 只合并小段落,不对长段落进行分割
|
||||
result_fragments = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for para in paragraphs:
|
||||
# 如果段落太小且不会超过目标大小,则合并
|
||||
if len(para) < min_chunk_size and current_length + len(para) <= target_size:
|
||||
current_chunk.append(para)
|
||||
current_length += len(para)
|
||||
# 否则,创建新段落
|
||||
else:
|
||||
# 如果当前块非空且与当前段落无关,先保存它
|
||||
if current_chunk and current_length > 0:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
|
||||
# 当前段落作为新块
|
||||
current_chunk = [para]
|
||||
current_length = len(para)
|
||||
|
||||
# 如果当前块大小已接近目标大小,保存并开始新块
|
||||
if current_length >= target_size:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
# 保存最后一个块
|
||||
if current_chunk:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
|
||||
# 2. 处理可能过大的片段(确保不超过token限制)
|
||||
final_fragments = []
|
||||
max_token = self._get_token_limit()
|
||||
|
||||
for fragment in result_fragments:
|
||||
# 检查fragment是否可能超出token限制
|
||||
# 根据语言类型调整token估算
|
||||
if is_chinese_text:
|
||||
estimated_tokens = len(fragment) / 1.5 # 中文每个token约1-2个字符
|
||||
else:
|
||||
estimated_tokens = len(fragment) / 4 # 英文每个token约4个字符
|
||||
|
||||
if estimated_tokens > max_token:
|
||||
# 即使可能超出限制,也尽量保持段落的完整性
|
||||
# 使用breakdown_text但设置更大的限制来减少分割
|
||||
larger_limit = max_token * 0.95 # 使用95%的限制
|
||||
sub_fragments = breakdown_text_to_satisfy_token_limit(
|
||||
txt=fragment,
|
||||
limit=larger_limit,
|
||||
llm_model=self.llm_kwargs['llm_model']
|
||||
)
|
||||
final_fragments.extend(sub_fragments)
|
||||
else:
|
||||
final_fragments.append(fragment)
|
||||
|
||||
return final_fragments
|
||||
|
||||
|
||||
@CatchException
|
||||
def 自定义智能文档处理(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 文件到文件处理"""
|
||||
# 初始化
|
||||
processor = DocumentProcessor(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
chatbot.append(["函数插件功能", "文件内容处理:将文档内容按照指定要求处理后输出为多种格式"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 验证输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析路径: {txt}", b=f"找不到路径或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 验证路径安全性
|
||||
user_name = chatbot.get_user()
|
||||
validate_path_safety(txt, user_name)
|
||||
|
||||
# 获取文件列表
|
||||
if os.path.isfile(txt):
|
||||
# 单个文件处理
|
||||
file_paths = [txt]
|
||||
else:
|
||||
# 目录处理 - 类似批量文件询问插件
|
||||
project_folder = txt
|
||||
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||
|
||||
# 排除压缩文件
|
||||
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||
file_paths = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||
|
||||
# 过滤支持的文件格式
|
||||
file_paths = [f for f in file_paths if any(f.lower().endswith(ext) for ext in
|
||||
list(processor.paper_extractor.SUPPORTED_EXTENSIONS) + ['.json', '.csv', '.xlsx', '.xls'])]
|
||||
|
||||
if not file_paths:
|
||||
report_exception(chatbot, history, a=f"解析路径: {txt}", b="未找到支持的文件类型")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 处理文件
|
||||
if len(file_paths) > 1:
|
||||
chatbot.append(["发现多个文件", f"共找到 {len(file_paths)} 个文件,将处理第一个文件"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 只处理第一个文件
|
||||
file_to_process = file_paths[0]
|
||||
processed_content = yield from processor.process_file(file_to_process)
|
||||
|
||||
if processed_content:
|
||||
# 保存结果
|
||||
result_files = processor.save_results(processed_content, file_to_process)
|
||||
|
||||
if result_files:
|
||||
chatbot.append(["处理完成", f"已生成 {len(result_files)} 个结果文件"])
|
||||
else:
|
||||
chatbot.append(["处理完成", "但未能保存任何结果文件"])
|
||||
else:
|
||||
chatbot.append(["处理失败", "未能生成有效的处理结果"])
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -139,7 +139,7 @@ def get_recent_file_prompt_support(chatbot):
|
||||
return path
|
||||
|
||||
@CatchException
|
||||
def 函数动态生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Dynamic_Function_Generate(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
@@ -159,7 +159,7 @@ def 函数动态生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
|
||||
# ⭐ 文件上传区是否有东西
|
||||
# 1. 如果有文件: 作为函数参数
|
||||
# 2. 如果没有文件:需要用GPT提取参数 (太懒了,以后再写,虚空终端已经实现了类似的代码)
|
||||
# 2. 如果没有文件:需要用GPT提取参数 (太懒了,以后再写,Void_Terminal已经实现了类似的代码)
|
||||
file_list = []
|
||||
if get_plugin_arg(plugin_kwargs, key="file_path_arg", default=False):
|
||||
file_path = get_plugin_arg(plugin_kwargs, key="file_path_arg", default=None)
|
||||
@@ -132,7 +132,7 @@ def get_meta_information(url, chatbot, history):
|
||||
return profile
|
||||
|
||||
@CatchException
|
||||
def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Google_Scholar_Assistant_Legacy(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
disable_auto_promotion(chatbot=chatbot)
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append([
|
||||
@@ -13,13 +13,13 @@ def 交互功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
user_request 当前用户的请求信息(IP地址等)
|
||||
"""
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
chatbot.append(("这是什么功能?", "交互功能函数模板。在执行完成之后, 可以将自身的状态存储到cookie中, 等待用户的再次调用。"))
|
||||
chatbot.append(("这是什么功能?", "Interactive_Func_Template。在执行完成之后, 可以将自身的状态存储到cookie中, 等待用户的再次调用。"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
state = chatbot._cookies.get('plugin_state_0001', None) # 初始化插件状态
|
||||
|
||||
if state is None:
|
||||
chatbot._cookies['lock_plugin'] = 'crazy_functions.交互功能函数模板->交互功能模板函数' # 赋予插件锁定 锁定插件回调路径,当下一次用户提交时,会直接转到该函数
|
||||
chatbot._cookies['lock_plugin'] = 'crazy_functions.Interactive_Func_Template->交互功能模板函数' # 赋予插件锁定 锁定插件回调路径,当下一次用户提交时,会直接转到该函数
|
||||
chatbot._cookies['plugin_state_0001'] = 'wait_user_keyword' # 赋予插件状态
|
||||
|
||||
chatbot.append(("第一次调用:", "请输入关键词, 我将为您查找相关壁纸, 建议使用英文单词, 插件锁定中,请直接提交即可。"))
|
||||
@@ -16,7 +16,7 @@ def 随机小游戏(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
llm_kwargs,
|
||||
cls,
|
||||
plugin_name='MiniGame_ResumeStory',
|
||||
callback_fn='crazy_functions.互动小游戏->随机小游戏',
|
||||
callback_fn='crazy_functions.Interactive_Mini_Game->随机小游戏',
|
||||
lock_plugin=True
|
||||
)
|
||||
yield from state.continue_game(prompt, chatbot, history)
|
||||
@@ -34,7 +34,7 @@ def 随机小游戏1(prompt, llm_kwargs, plugin_kwargs, chatbot, history, system
|
||||
llm_kwargs,
|
||||
cls,
|
||||
plugin_name='MiniGame_ASCII_Art',
|
||||
callback_fn='crazy_functions.互动小游戏->随机小游戏1',
|
||||
callback_fn='crazy_functions.Interactive_Mini_Game->随机小游戏1',
|
||||
lock_plugin=True
|
||||
)
|
||||
yield from state.continue_game(prompt, chatbot, history)
|
||||
@@ -297,7 +297,7 @@ def 解析历史输入(history, llm_kwargs, file_manifest, chatbot, plugin_kwarg
|
||||
|
||||
|
||||
@CatchException
|
||||
def 生成多种Mermaid图表(
|
||||
def Mermaid_Figure_Gen(
|
||||
txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port
|
||||
):
|
||||
"""
|
||||
@@ -426,7 +426,7 @@ class Mermaid_Gen(GptAcademicPluginTemplate):
|
||||
"思维导图",
|
||||
]
|
||||
plugin_kwargs = options.index(plugin_kwargs['Type_of_Mermaid'])
|
||||
yield from 生成多种Mermaid图表(
|
||||
yield from Mermaid_Figure_Gen(
|
||||
txt,
|
||||
llm_kwargs,
|
||||
plugin_kwargs,
|
||||
@@ -22,7 +22,7 @@ def remove_model_prefix(llm):
|
||||
|
||||
|
||||
@CatchException
|
||||
def 多智能体终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Multi_Agent_Legacy终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
|
||||
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
|
||||
@@ -78,17 +78,17 @@ def 多智能体终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
chatbot.get_cookies()['lock_plugin'] = None
|
||||
persistent_class_multi_user_manager = GradioMultiuserManagerForPersistentClasses()
|
||||
user_uuid = chatbot.get_cookies().get('uuid')
|
||||
persistent_key = f"{user_uuid}->多智能体终端"
|
||||
persistent_key = f"{user_uuid}->Multi_Agent_Legacy终端"
|
||||
if persistent_class_multi_user_manager.already_alive(persistent_key):
|
||||
# 当已经存在一个正在运行的多智能体终端时,直接将用户输入传递给它,而不是再次启动一个新的多智能体终端
|
||||
# 当已经存在一个正在运行的Multi_Agent_Legacy终端时,直接将用户输入传递给它,而不是再次启动一个新的Multi_Agent_Legacy终端
|
||||
logger.info('[debug] feed new user input')
|
||||
executor = persistent_class_multi_user_manager.get(persistent_key)
|
||||
exit_reason = yield from executor.main_process_ui_control(txt, create_or_resume="resume")
|
||||
else:
|
||||
# 运行多智能体终端 (首次)
|
||||
# 运行Multi_Agent_Legacy终端 (首次)
|
||||
logger.info('[debug] create new executor instance')
|
||||
history = []
|
||||
chatbot.append(["正在启动: 多智能体终端", "插件动态生成, 执行开始, 作者 Microsoft & Binary-Husky."])
|
||||
chatbot.append(["正在启动: Multi_Agent_Legacy终端", "插件动态生成, 执行开始, 作者 Microsoft & Binary-Husky."])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
executor = AutoGenMath(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
persistent_class_multi_user_manager.set(persistent_key, executor)
|
||||
@@ -96,7 +96,7 @@ def 多智能体终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
||||
|
||||
if exit_reason == "wait_feedback":
|
||||
# 当用户点击了“等待反馈”按钮时,将executor存储到cookie中,等待用户的再次调用
|
||||
executor.chatbot.get_cookies()['lock_plugin'] = 'crazy_functions.多智能体->多智能体终端'
|
||||
executor.chatbot.get_cookies()['lock_plugin'] = 'crazy_functions.Multi_Agent_Legacy->Multi_Agent_Legacy终端'
|
||||
else:
|
||||
executor.chatbot.get_cookies()['lock_plugin'] = None
|
||||
yield from update_ui(chatbot=executor.chatbot, history=executor.history) # 更新状态
|
||||
@@ -62,7 +62,7 @@ def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
||||
|
||||
|
||||
@CatchException
|
||||
def 理解PDF文档内容标准文件输入(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def PDF_QA标准文件输入(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
import glob, os
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
@@ -103,13 +103,13 @@ do not have too much repetitive information, numerical values using the original
|
||||
|
||||
|
||||
@CatchException
|
||||
def 批量总结PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def PDF_Summary(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
import glob, os
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append([
|
||||
"函数插件功能?",
|
||||
"批量总结PDF文档。函数插件贡献者: ValeriaWong,Eralien"])
|
||||
"PDF_Summary。函数插件贡献者: ValeriaWong,Eralien"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
@@ -43,7 +43,7 @@ def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbo
|
||||
|
||||
|
||||
@CatchException
|
||||
def 读文章写摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Paper_Abstract_Writer(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
360
crazy_functions/Paper_Reading.py
Normal file
360
crazy_functions/Paper_Reading.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Generator, Tuple
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from toolbox import update_ui, promote_file_to_downloadzone, write_history_to_file, CatchException, report_exception
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from crazy_functions.paper_fns.paper_download import extract_paper_id, extract_paper_ids, get_arxiv_paper, format_arxiv_id
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaperQuestion:
|
||||
"""论文分析问题类"""
|
||||
id: str # 问题ID
|
||||
question: str # 问题内容
|
||||
importance: int # 重要性 (1-5,5最高)
|
||||
description: str # 问题描述
|
||||
|
||||
|
||||
class PaperAnalyzer:
|
||||
"""论文快速分析器"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化分析器"""
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.paper_content = ""
|
||||
self.results = {}
|
||||
|
||||
# 定义论文分析问题库(已合并为4个核心问题)
|
||||
self.questions = [
|
||||
PaperQuestion(
|
||||
id="research_and_methods",
|
||||
question="这篇论文的主要研究问题、目标和方法是什么?请分析:1)论文的核心研究问题和研究动机;2)论文提出的关键方法、模型或理论框架;3)这些方法如何解决研究问题。",
|
||||
importance=5,
|
||||
description="研究问题与方法"
|
||||
),
|
||||
PaperQuestion(
|
||||
id="findings_and_innovation",
|
||||
question="论文的主要发现、结论及创新点是什么?请分析:1)论文的核心结果与主要发现;2)作者得出的关键结论;3)研究的创新点与对领域的贡献;4)与已有工作的区别。",
|
||||
importance=4,
|
||||
description="研究发现与创新"
|
||||
),
|
||||
PaperQuestion(
|
||||
id="methodology_and_data",
|
||||
question="论文使用了什么研究方法和数据?请详细分析:1)研究设计与实验设置;2)数据收集方法与数据集特点;3)分析技术与评估方法;4)方法学上的合理性。",
|
||||
importance=3,
|
||||
description="研究方法与数据"
|
||||
),
|
||||
PaperQuestion(
|
||||
id="limitations_and_impact",
|
||||
question="论文的局限性、未来方向及潜在影响是什么?请分析:1)研究的不足与限制因素;2)作者提出的未来研究方向;3)该研究对学术界和行业可能产生的影响;4)研究结果的适用范围与推广价值。",
|
||||
importance=2,
|
||||
description="局限性与影响"
|
||||
),
|
||||
]
|
||||
|
||||
# 按重要性排序
|
||||
self.questions.sort(key=lambda q: q.importance, reverse=True)
|
||||
|
||||
def _load_paper(self, paper_path: str) -> Generator:
|
||||
from crazy_functions.doc_fns.text_content_loader import TextContentLoader
|
||||
"""加载论文内容"""
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用TextContentLoader读取文件
|
||||
loader = TextContentLoader(self.chatbot, self.history)
|
||||
|
||||
yield from loader.execute_single_file(paper_path)
|
||||
|
||||
# 获取加载的内容
|
||||
if len(self.history) >= 2 and self.history[-2]:
|
||||
self.paper_content = self.history[-2]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return True
|
||||
else:
|
||||
self.chatbot.append(["错误", "无法读取论文内容,请检查文件是否有效"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return False
|
||||
|
||||
def _analyze_question(self, question: PaperQuestion) -> Generator:
|
||||
"""分析单个问题 - 直接显示问题和答案"""
|
||||
try:
|
||||
# 创建分析提示
|
||||
prompt = f"请基于以下论文内容回答问题:\n\n{self.paper_content}\n\n问题:{question.question}"
|
||||
|
||||
# 使用单线程版本的请求函数
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=question.question, # 显示问题本身
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[], # 空历史,确保每个问题独立分析
|
||||
sys_prompt="你是一个专业的科研论文分析助手,需要仔细阅读论文内容并回答问题。请保持客观、准确,并基于论文内容提供深入分析。"
|
||||
)
|
||||
|
||||
if response:
|
||||
self.results[question.id] = response
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"分析问题时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return False
|
||||
|
||||
def _generate_summary(self) -> Generator:
|
||||
"""生成最终总结报告"""
|
||||
self.chatbot.append(["生成报告", "正在整合分析结果,生成最终报告..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
summary_prompt = "请基于以下对论文的各个方面的分析,生成一份全面的论文解读报告。报告应该简明扼要地呈现论文的关键内容,并保持逻辑连贯性。"
|
||||
|
||||
for q in self.questions:
|
||||
if q.id in self.results:
|
||||
summary_prompt += f"\n\n关于{q.description}的分析:\n{self.results[q.id]}"
|
||||
|
||||
try:
|
||||
# 使用单线程版本的请求函数,可以在前端实时显示生成结果
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=summary_prompt,
|
||||
inputs_show_user="生成论文解读报告",
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[],
|
||||
sys_prompt="你是一个科研论文解读专家,请将多个方面的分析整合为一份完整、连贯、有条理的报告。报告应当重点突出,层次分明,并且保持学术性和客观性。"
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
return "报告生成失败"
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"生成报告时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return "报告生成失败: " + str(e)
|
||||
|
||||
def save_report(self, report: str) -> Generator:
|
||||
"""保存分析报告"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 保存为Markdown文件
|
||||
try:
|
||||
md_content = f"# 论文快速解读报告\n\n{report}"
|
||||
for q in self.questions:
|
||||
if q.id in self.results:
|
||||
md_content += f"\n\n## {q.description}\n\n{self.results[q.id]}"
|
||||
|
||||
result_file = write_history_to_file(
|
||||
history=[md_content],
|
||||
file_basename=f"论文解读_{timestamp}.md"
|
||||
)
|
||||
|
||||
if result_file and os.path.exists(result_file):
|
||||
promote_file_to_downloadzone(result_file, chatbot=self.chatbot)
|
||||
self.chatbot.append(["保存成功", f"解读报告已保存至: {os.path.basename(result_file)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
else:
|
||||
self.chatbot.append(["警告", "保存报告成功但找不到文件"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"保存报告失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
def analyze_paper(self, paper_path: str) -> Generator:
|
||||
"""分析论文主流程"""
|
||||
# 加载论文
|
||||
success = yield from self._load_paper(paper_path)
|
||||
if not success:
|
||||
return
|
||||
|
||||
# 分析关键问题 - 直接询问每个问题,不显示进度信息
|
||||
for question in self.questions:
|
||||
yield from self._analyze_question(question)
|
||||
|
||||
# 生成总结报告
|
||||
final_report = yield from self._generate_summary()
|
||||
|
||||
# 显示最终报告
|
||||
# self.chatbot.append(["论文解读报告", final_report])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 保存报告
|
||||
yield from self.save_report(final_report)
|
||||
|
||||
|
||||
def _find_paper_file(path: str) -> str:
|
||||
"""查找路径中的论文文件(简化版)"""
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
|
||||
# 支持的文件扩展名(按优先级排序)
|
||||
extensions = ["pdf", "docx", "doc", "txt", "md", "tex"]
|
||||
|
||||
# 简单地遍历目录
|
||||
if os.path.isdir(path):
|
||||
try:
|
||||
for ext in extensions:
|
||||
# 手动检查每个可能的文件,而不使用glob
|
||||
potential_file = os.path.join(path, f"paper.{ext}")
|
||||
if os.path.exists(potential_file) and os.path.isfile(potential_file):
|
||||
return potential_file
|
||||
|
||||
# 如果没找到特定命名的文件,检查目录中的所有文件
|
||||
for file in os.listdir(path):
|
||||
file_path = os.path.join(path, file)
|
||||
if os.path.isfile(file_path):
|
||||
file_ext = file.split('.')[-1].lower() if '.' in file else ""
|
||||
if file_ext in extensions:
|
||||
return file_path
|
||||
except Exception:
|
||||
pass # 忽略任何错误
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def download_paper_by_id(paper_info, chatbot, history) -> str:
|
||||
"""下载论文并返回保存路径
|
||||
|
||||
Args:
|
||||
paper_info: 元组,包含论文ID类型(arxiv或doi)和ID值
|
||||
chatbot: 聊天机器人对象
|
||||
history: 历史记录
|
||||
|
||||
Returns:
|
||||
str: 下载的论文路径或None
|
||||
"""
|
||||
from crazy_functions.review_fns.data_sources.scihub_source import SciHub
|
||||
id_type, paper_id = paper_info
|
||||
|
||||
# 创建保存目录 - 使用时间戳创建唯一文件夹
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
user_name = chatbot.get_user() if hasattr(chatbot, 'get_user') else "default"
|
||||
from toolbox import get_log_folder, get_user
|
||||
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
|
||||
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_path = Path(save_dir)
|
||||
|
||||
chatbot.append([f"下载论文", f"正在下载{'arXiv' if id_type == 'arxiv' else 'DOI'} {paper_id} 的论文..."])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
pdf_path = None
|
||||
|
||||
try:
|
||||
if id_type == 'arxiv':
|
||||
# 使用改进的arxiv查询方法
|
||||
formatted_id = format_arxiv_id(paper_id)
|
||||
paper_result = get_arxiv_paper(formatted_id)
|
||||
|
||||
if not paper_result:
|
||||
chatbot.append([f"下载失败", f"未找到arXiv论文: {paper_id}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
# 下载PDF
|
||||
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
|
||||
pdf_path = str(save_path / filename)
|
||||
paper_result.download_pdf(filename=pdf_path)
|
||||
|
||||
else: # doi
|
||||
# 下载DOI
|
||||
sci_hub = SciHub(
|
||||
doi=paper_id,
|
||||
path=save_path
|
||||
)
|
||||
pdf_path = sci_hub.fetch()
|
||||
|
||||
# 检查下载结果
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
|
||||
chatbot.append([f"下载成功", f"已成功下载论文: {os.path.basename(pdf_path)}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return pdf_path
|
||||
else:
|
||||
chatbot.append([f"下载失败", f"论文下载失败: {paper_id}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
chatbot.append([f"下载错误", f"下载论文时出错: {str(e)}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
|
||||
@CatchException
|
||||
def 快速论文解读(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 论文快速解读"""
|
||||
# 初始化分析器
|
||||
chatbot.append(["函数插件功能及使用方式", "论文快速解读:通过分析论文的关键要素,帮助您迅速理解论文内容,适用于各学科领域的科研论文。 <br><br>📋 使用方式:<br>1、直接上传PDF文件或者输入DOI号(仅针对SCI hub存在的论文)或arXiv ID(如2501.03916)<br>2、点击插件开始分析"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
paper_file = None
|
||||
|
||||
# 检查输入是否为论文ID(arxiv或DOI)
|
||||
paper_info = extract_paper_id(txt)
|
||||
|
||||
if paper_info:
|
||||
# 如果是论文ID,下载论文
|
||||
chatbot.append(["检测到论文ID", f"检测到{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'} ID: {paper_info[1]},准备下载论文..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 下载论文 - 完全重新实现
|
||||
paper_file = download_paper_by_id(paper_info, chatbot, history)
|
||||
|
||||
if not paper_file:
|
||||
report_exception(chatbot, history, a=f"下载论文失败", b=f"无法下载{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'}论文: {paper_info[1]}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
else:
|
||||
# 检查输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析论文: {txt}", b=f"找不到文件或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 验证路径安全性
|
||||
user_name = chatbot.get_user()
|
||||
validate_path_safety(txt, user_name)
|
||||
|
||||
# 查找论文文件
|
||||
paper_file = _find_paper_file(txt)
|
||||
|
||||
if not paper_file:
|
||||
report_exception(chatbot, history, a=f"解析论文", b=f"在路径 {txt} 中未找到支持的论文文件")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 增加调试信息,检查paper_file的类型和值
|
||||
chatbot.append(["文件类型检查", f"paper_file类型: {type(paper_file)}, 值: {paper_file}"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
chatbot.pop() # 移除调试信息
|
||||
|
||||
# 确保paper_file是字符串
|
||||
if paper_file is not None and not isinstance(paper_file, str):
|
||||
# 尝试转换为字符串
|
||||
try:
|
||||
paper_file = str(paper_file)
|
||||
except:
|
||||
report_exception(chatbot, history, a=f"类型错误", b=f"论文路径不是有效的字符串: {type(paper_file)}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 分析论文
|
||||
chatbot.append(["开始分析", f"正在分析论文: {os.path.basename(paper_file)}"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
analyzer = PaperAnalyzer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
yield from analyzer.analyze_paper(paper_file)
|
||||
@@ -4,7 +4,7 @@ from toolbox import CatchException, report_exception
|
||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
|
||||
def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
def Program_Comment_Gen(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||
import time, os
|
||||
logger.info('begin analysis on:', file_manifest)
|
||||
for index, fp in enumerate(file_manifest):
|
||||
@@ -34,7 +34,7 @@ def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
|
||||
|
||||
|
||||
@CatchException
|
||||
def 批量生成函数注释(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def 批量Program_Comment_Gen(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
history = [] # 清空历史,以免输入溢出
|
||||
import glob, os
|
||||
if os.path.exists(txt):
|
||||
@@ -51,4 +51,4 @@ def 批量生成函数注释(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
||||
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.tex文件: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
return
|
||||
yield from 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
yield from Program_Comment_Gen(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
@@ -79,8 +79,8 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
||||
# yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
# chatbot._cookies['langchain_plugin_embedding'] = kai.get_current_archive_id()
|
||||
# chatbot._cookies['lock_plugin'] = 'crazy_functions.知识库文件注入->读取知识库作答'
|
||||
# chatbot.append(['完成', "“根据知识库作答”函数插件已经接管问答系统, 提问吧! 但注意, 您接下来不能再使用其他插件了,刷新页面即可以退出知识库问答模式。"])
|
||||
chatbot.append(['构建完成', f"当前知识库内的有效文件:\n\n---\n\n{kai_files}\n\n---\n\n请切换至“知识库问答”插件进行知识库访问, 或者使用此插件继续上传更多文件。"])
|
||||
# chatbot.append(['完成', "“根据知识库作答”函数插件已经接管问答系统, 提问吧! 但注意, 您接下来不能再使用其他插件了,刷新页面即可以退出Vectorstore_QA模式。"])
|
||||
chatbot.append(['构建完成', f"当前知识库内的有效文件:\n\n---\n\n{kai_files}\n\n---\n\n请切换至“Vectorstore_QA”插件进行知识库访问, 或者使用此插件继续上传更多文件。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||
|
||||
@CatchException
|
||||
@@ -21,7 +21,7 @@ Please describe in natural language what you want to do.
|
||||
5. If you don't need to upload a file, you can simply repeat your command again.
|
||||
"""
|
||||
explain_msg = """
|
||||
## 虚空终端插件说明:
|
||||
## Void_Terminal插件说明:
|
||||
|
||||
1. 请用**自然语言**描述您需要做什么。例如:
|
||||
- 「请调用插件,为我翻译PDF论文,论文我刚刚放到上传区了」
|
||||
@@ -104,9 +104,9 @@ def analyze_intention_with_simple_rules(txt):
|
||||
|
||||
|
||||
@CatchException
|
||||
def 虚空终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Void_Terminal(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
disable_auto_promotion(chatbot=chatbot)
|
||||
# 获取当前虚空终端状态
|
||||
# 获取当前Void_Terminal状态
|
||||
state = VoidTerminalState.get_state(chatbot)
|
||||
appendix_msg = ""
|
||||
|
||||
@@ -121,21 +121,21 @@ def 虚空终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt
|
||||
state.set_state(chatbot=chatbot, key='has_provided_explanation', value=True)
|
||||
state.unlock_plugin(chatbot=chatbot)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
yield from 虚空终端主路由(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
yield from Void_Terminal主路由(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||
return
|
||||
else:
|
||||
# 如果意图模糊,提示
|
||||
state.set_state(chatbot=chatbot, key='has_provided_explanation', value=True)
|
||||
state.lock_plugin(chatbot=chatbot)
|
||||
chatbot.append(("虚空终端状态:", explain_msg+appendix_msg))
|
||||
chatbot.append(("Void_Terminal状态:", explain_msg+appendix_msg))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
|
||||
|
||||
def 虚空终端主路由(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Void_Terminal主路由(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
history = []
|
||||
chatbot.append(("虚空终端状态: ", f"正在执行任务: {txt}"))
|
||||
chatbot.append(("Void_Terminal状态: ", f"正在执行任务: {txt}"))
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# ⭐ ⭐ ⭐ 分析用户意图
|
||||
@@ -79,13 +79,13 @@ def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot
|
||||
|
||||
|
||||
@CatchException
|
||||
def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
def Word_Summary(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
import glob, os
|
||||
|
||||
# 基本信息:功能、贡献者
|
||||
chatbot.append([
|
||||
"函数插件功能?",
|
||||
"批量总结Word文档。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
|
||||
"批量Word_Summary。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
|
||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||
@@ -1,6 +1,4 @@
|
||||
import nltk
|
||||
nltk.data.path.append('~/nltk_data')
|
||||
nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data',
|
||||
)
|
||||
nltk.download('punkt', download_dir='~/nltk_data',
|
||||
)
|
||||
nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
|
||||
nltk.download('punkt', download_dir='~/nltk_data')
|
||||
451
crazy_functions/doc_fns/text_content_loader.py
Normal file
451
crazy_functions/doc_fns/text_content_loader.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import time
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Generator, Tuple, Set, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from toolbox import update_ui
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||
from crazy_functions.doc_fns.content_folder import ContentFoldingManager, FileMetadata, FoldingOptions, FoldingStyle, FoldingError
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
from datetime import datetime
|
||||
import mimetypes
|
||||
|
||||
@dataclass
|
||||
class FileInfo:
|
||||
"""文件信息数据类"""
|
||||
path: str # 完整路径
|
||||
rel_path: str # 相对路径
|
||||
size: float # 文件大小(MB)
|
||||
extension: str # 文件扩展名
|
||||
last_modified: str # 最后修改时间
|
||||
|
||||
|
||||
class TextContentLoader:
|
||||
"""优化版本的文本内容加载器 - 保持原有接口"""
|
||||
|
||||
# 压缩文件扩展名
|
||||
COMPRESSED_EXTENSIONS: Set[str] = {'.zip', '.rar', '.7z', '.tar', '.gz', '.bz2', '.xz'}
|
||||
|
||||
# 系统配置
|
||||
MAX_FILE_SIZE: int = 100 * 1024 * 1024 # 最大文件大小(100MB)
|
||||
MAX_TOTAL_SIZE: int = 100 * 1024 * 1024 # 最大总大小(100MB)
|
||||
MAX_FILES: int = 100 # 最大文件数量
|
||||
CHUNK_SIZE: int = 1024 * 1024 # 文件读取块大小(1MB)
|
||||
MAX_WORKERS: int = min(32, (os.cpu_count() or 1) * 4) # 最大工作线程数
|
||||
BATCH_SIZE: int = 5 # 批处理大小
|
||||
|
||||
def __init__(self, chatbot: List, history: List):
|
||||
"""初始化加载器"""
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.failed_files: List[Tuple[str, str]] = []
|
||||
self.processed_size: int = 0
|
||||
self.start_time: float = 0
|
||||
self.file_cache: Dict[str, str] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.executor = ThreadPoolExecutor(max_workers=self.MAX_WORKERS)
|
||||
self.results_queue = queue.Queue()
|
||||
self.folding_manager = ContentFoldingManager()
|
||||
|
||||
def _create_file_info(self, entry: os.DirEntry, root_path: str) -> FileInfo:
|
||||
"""优化的文件信息创建
|
||||
|
||||
Args:
|
||||
entry: 目录入口对象
|
||||
root_path: 根路径
|
||||
|
||||
Returns:
|
||||
FileInfo: 文件信息对象
|
||||
"""
|
||||
try:
|
||||
stats = entry.stat() # 使用缓存的文件状态
|
||||
return FileInfo(
|
||||
path=entry.path,
|
||||
rel_path=os.path.relpath(entry.path, root_path),
|
||||
size=stats.st_size / (1024 * 1024),
|
||||
extension=os.path.splitext(entry.path)[1].lower(),
|
||||
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime(stats.st_mtime))
|
||||
)
|
||||
except (OSError, ValueError) as e:
|
||||
return None
|
||||
|
||||
def _process_file_batch(self, file_batch: List[FileInfo]) -> List[Tuple[FileInfo, Optional[str]]]:
|
||||
"""批量处理文件
|
||||
|
||||
Args:
|
||||
file_batch: 要处理的文件信息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[FileInfo, Optional[str]]]: 处理结果列表
|
||||
"""
|
||||
results = []
|
||||
futures = {}
|
||||
|
||||
for file_info in file_batch:
|
||||
if file_info.path in self.file_cache:
|
||||
results.append((file_info, self.file_cache[file_info.path]))
|
||||
continue
|
||||
|
||||
if file_info.size * 1024 * 1024 > self.MAX_FILE_SIZE:
|
||||
with self._lock:
|
||||
self.failed_files.append(
|
||||
(file_info.rel_path,
|
||||
f"文件过大({file_info.size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)")
|
||||
)
|
||||
continue
|
||||
|
||||
future = self.executor.submit(self._read_file_content, file_info)
|
||||
futures[future] = file_info
|
||||
|
||||
for future in as_completed(futures):
|
||||
file_info = futures[future]
|
||||
try:
|
||||
content = future.result()
|
||||
if content:
|
||||
with self._lock:
|
||||
self.file_cache[file_info.path] = content
|
||||
self.processed_size += file_info.size * 1024 * 1024
|
||||
results.append((file_info, content))
|
||||
except Exception as e:
|
||||
with self._lock:
|
||||
self.failed_files.append((file_info.rel_path, f"读取失败: {str(e)}"))
|
||||
|
||||
return results
|
||||
|
||||
def _read_file_content(self, file_info: FileInfo) -> Optional[str]:
|
||||
"""读取单个文件内容
|
||||
|
||||
Args:
|
||||
file_info: 文件信息对象
|
||||
|
||||
Returns:
|
||||
Optional[str]: 文件内容
|
||||
"""
|
||||
try:
|
||||
content = extract_text(file_info.path)
|
||||
if not content or not content.strip():
|
||||
return None
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.exception(f"读取文件失败: {str(e)}")
|
||||
raise Exception(f"读取文件失败: {str(e)}")
|
||||
|
||||
def _is_valid_file(self, file_path: str) -> bool:
|
||||
"""检查文件是否有效
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否为有效文件
|
||||
"""
|
||||
if not os.path.isfile(file_path):
|
||||
return False
|
||||
|
||||
extension = os.path.splitext(file_path)[1].lower()
|
||||
if (extension in self.COMPRESSED_EXTENSIONS or
|
||||
os.path.basename(file_path).startswith('.') or
|
||||
not os.access(file_path, os.R_OK)):
|
||||
return False
|
||||
|
||||
# 只要文件可以访问且不在排除列表中就认为是有效的
|
||||
return True
|
||||
|
||||
def _collect_files(self, path: str) -> List[FileInfo]:
|
||||
"""收集文件信息
|
||||
|
||||
Args:
|
||||
path: 目标路径
|
||||
|
||||
Returns:
|
||||
List[FileInfo]: 有效文件信息列表
|
||||
"""
|
||||
files = []
|
||||
total_size = 0
|
||||
|
||||
# 处理单个文件的情况
|
||||
if os.path.isfile(path):
|
||||
if self._is_valid_file(path):
|
||||
file_info = self._create_file_info(os.DirEntry(os.path.dirname(path)), os.path.dirname(path))
|
||||
if file_info:
|
||||
return [file_info]
|
||||
return []
|
||||
|
||||
# 处理目录的情况
|
||||
try:
|
||||
# 使用os.walk来递归遍历目录
|
||||
for root, _, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
if len(files) >= self.MAX_FILES:
|
||||
self.failed_files.append((filename, f"超出最大文件数限制({self.MAX_FILES})"))
|
||||
continue
|
||||
|
||||
file_path = os.path.join(root, filename)
|
||||
|
||||
if not self._is_valid_file(file_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
stats = os.stat(file_path)
|
||||
file_size = stats.st_size / (1024 * 1024) # 转换为MB
|
||||
|
||||
if file_size * 1024 * 1024 > self.MAX_FILE_SIZE:
|
||||
self.failed_files.append((file_path,
|
||||
f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)"))
|
||||
continue
|
||||
|
||||
if total_size + file_size * 1024 * 1024 > self.MAX_TOTAL_SIZE:
|
||||
self.failed_files.append((file_path, "超出总大小限制"))
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
path=file_path,
|
||||
rel_path=os.path.relpath(file_path, path),
|
||||
size=file_size,
|
||||
extension=os.path.splitext(file_path)[1].lower(),
|
||||
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime(stats.st_mtime))
|
||||
)
|
||||
|
||||
total_size += file_size * 1024 * 1024
|
||||
files.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
self.failed_files.append((file_path, f"处理文件失败: {str(e)}"))
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
self.failed_files.append(("目录扫描", f"扫描失败: {str(e)}"))
|
||||
return []
|
||||
|
||||
return sorted(files, key=lambda x: x.rel_path)
|
||||
|
||||
def _format_content_with_fold(self, file_info, content: str) -> str:
|
||||
"""使用折叠管理器格式化文件内容"""
|
||||
try:
|
||||
metadata = FileMetadata(
|
||||
rel_path=file_info.rel_path,
|
||||
size=file_info.size,
|
||||
last_modified=datetime.fromtimestamp(
|
||||
os.path.getmtime(file_info.path)
|
||||
),
|
||||
mime_type=mimetypes.guess_type(file_info.path)[0]
|
||||
)
|
||||
|
||||
options = FoldingOptions(
|
||||
style=FoldingStyle.DETAILED,
|
||||
code_language=self.folding_manager._guess_language(
|
||||
os.path.splitext(file_info.path)[1]
|
||||
),
|
||||
show_timestamp=True
|
||||
)
|
||||
|
||||
return self.folding_manager.format_content(
|
||||
content=content,
|
||||
formatter_type='file',
|
||||
metadata=metadata,
|
||||
options=options
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error formatting content: {str(e)}"
|
||||
|
||||
def _format_content_for_llm(self, file_infos: List[FileInfo], contents: List[str]) -> str:
|
||||
"""格式化用于LLM的内容
|
||||
|
||||
Args:
|
||||
file_infos: 文件信息列表
|
||||
contents: 内容列表
|
||||
|
||||
Returns:
|
||||
str: 格式化后的内容
|
||||
"""
|
||||
if len(file_infos) != len(contents):
|
||||
raise ValueError("文件信息和内容数量不匹配")
|
||||
|
||||
result = [
|
||||
"以下是多个文件的内容集合。每个文件的内容都以 '===== 文件 {序号}: {文件名} =====' 开始,",
|
||||
"以 '===== 文件 {序号} 结束 =====' 结束。你可以根据这些分隔符来识别不同文件的内容。\n\n"
|
||||
]
|
||||
|
||||
for idx, (file_info, content) in enumerate(zip(file_infos, contents), 1):
|
||||
result.extend([
|
||||
f"===== 文件 {idx}: {file_info.rel_path} =====",
|
||||
"文件内容:",
|
||||
content.strip(),
|
||||
f"===== 文件 {idx} 结束 =====\n"
|
||||
])
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def execute(self, txt: str) -> Generator:
|
||||
"""执行文本加载和显示 - 保持原有接口
|
||||
|
||||
Args:
|
||||
txt: 目标路径
|
||||
|
||||
Yields:
|
||||
Generator: UI更新生成器
|
||||
"""
|
||||
try:
|
||||
# 首先显示正在处理的提示信息
|
||||
self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
user_name = self.chatbot.get_user()
|
||||
validate_path_safety(txt, user_name)
|
||||
self.start_time = time.time()
|
||||
self.processed_size = 0
|
||||
self.failed_files.clear()
|
||||
successful_files = []
|
||||
successful_contents = []
|
||||
|
||||
# 收集文件
|
||||
files = self._collect_files(txt)
|
||||
if not files:
|
||||
# 移除之前的提示信息
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["提示", "未找到任何有效文件"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
# 批量处理文件
|
||||
content_blocks = []
|
||||
for i in range(0, len(files), self.BATCH_SIZE):
|
||||
batch = files[i:i + self.BATCH_SIZE]
|
||||
results = self._process_file_batch(batch)
|
||||
|
||||
for file_info, content in results:
|
||||
if content:
|
||||
content_blocks.append(self._format_content_with_fold(file_info, content))
|
||||
successful_files.append(file_info)
|
||||
successful_contents.append(content)
|
||||
|
||||
# 显示文件内容,替换之前的提示信息
|
||||
if content_blocks:
|
||||
# 移除之前的提示信息
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["文件内容", "\n".join(content_blocks)])
|
||||
self.history.extend([
|
||||
self._format_content_for_llm(successful_files, successful_contents),
|
||||
"我已经接收到你上传的文件的内容,请提问"
|
||||
])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
except Exception as e:
|
||||
# 发生错误时,移除之前的提示信息
|
||||
if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示":
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
finally:
|
||||
self.executor.shutdown(wait=False)
|
||||
self.file_cache.clear()
|
||||
|
||||
def execute_single_file(self, file_path: str) -> Generator:
|
||||
"""执行单个文件的加载和显示
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Yields:
|
||||
Generator: UI更新生成器
|
||||
"""
|
||||
try:
|
||||
# 首先显示正在处理的提示信息
|
||||
self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
user_name = self.chatbot.get_user()
|
||||
validate_path_safety(file_path, user_name)
|
||||
self.start_time = time.time()
|
||||
self.processed_size = 0
|
||||
self.failed_files.clear()
|
||||
|
||||
# 验证文件是否存在且可读
|
||||
if not os.path.isfile(file_path):
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"指定路径不是文件: {file_path}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
if not self._is_valid_file(file_path):
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"无效的文件类型或无法读取: {file_path}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
# 创建文件信息
|
||||
try:
|
||||
stats = os.stat(file_path)
|
||||
file_size = stats.st_size / (1024 * 1024) # 转换为MB
|
||||
|
||||
if file_size * 1024 * 1024 > self.MAX_FILE_SIZE:
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
file_info = FileInfo(
|
||||
path=file_path,
|
||||
rel_path=os.path.basename(file_path),
|
||||
size=file_size,
|
||||
extension=os.path.splitext(file_path)[1].lower(),
|
||||
last_modified=time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime(stats.st_mtime))
|
||||
)
|
||||
except Exception as e:
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"处理文件失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
# 读取文件内容
|
||||
try:
|
||||
content = self._read_file_content(file_info)
|
||||
if not content:
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["提示", f"文件内容为空或无法提取: {file_path}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
except Exception as e:
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"读取文件失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return
|
||||
|
||||
# 格式化内容并更新UI
|
||||
formatted_content = self._format_content_with_fold(file_info, content)
|
||||
|
||||
# 移除之前的提示信息
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["文件内容", formatted_content])
|
||||
|
||||
# 更新历史记录,便于LLM处理
|
||||
llm_content = self._format_content_for_llm([file_info], [content])
|
||||
self.history.extend([llm_content, "我已经接收到你上传的文件的内容,请提问"])
|
||||
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
except Exception as e:
|
||||
# 发生错误时,移除之前的提示信息
|
||||
if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示":
|
||||
self.chatbot.pop()
|
||||
self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数 - 确保资源被正确释放"""
|
||||
if hasattr(self, 'executor'):
|
||||
self.executor.shutdown(wait=False)
|
||||
if hasattr(self, 'file_cache'):
|
||||
self.file_cache.clear()
|
||||
0
crazy_functions/paper_fns/__init__.py
Normal file
0
crazy_functions/paper_fns/__init__.py
Normal file
386
crazy_functions/paper_fns/auto_git/handlers/base_handler.py
Normal file
386
crazy_functions/paper_fns/auto_git/handlers/base_handler.py
Normal file
@@ -0,0 +1,386 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from ..query_analyzer import SearchCriteria
|
||||
from ..sources.github_source import GitHubSource
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""处理器基类"""
|
||||
|
||||
def __init__(self, github: GitHubSource, llm_kwargs: Dict = None):
|
||||
self.github = github
|
||||
self.llm_kwargs = llm_kwargs or {}
|
||||
self.ranked_repos = [] # 存储排序后的仓库列表
|
||||
|
||||
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
|
||||
"""获取搜索参数"""
|
||||
return {
|
||||
'max_repos': plugin_kwargs.get('max_repos', 150), # 最大仓库数量,从30改为150
|
||||
'max_details': plugin_kwargs.get('max_details', 80), # 最多展示详情的仓库数量,新增参数
|
||||
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
|
||||
'min_stars': plugin_kwargs.get('min_stars', 0), # 最少星标数
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理查询"""
|
||||
pass
|
||||
|
||||
async def _search_repositories(self, query: str, language: str = None, min_stars: int = 0,
|
||||
sort: str = "stars", per_page: int = 30) -> List[Dict]:
|
||||
"""搜索仓库"""
|
||||
try:
|
||||
# 构建查询字符串
|
||||
if min_stars > 0 and "stars:>" not in query:
|
||||
query += f" stars:>{min_stars}"
|
||||
|
||||
if language and "language:" not in query:
|
||||
query += f" language:{language}"
|
||||
|
||||
# 执行搜索
|
||||
result = await self.github.search_repositories(
|
||||
query=query,
|
||||
sort=sort,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
if result and "items" in result:
|
||||
return result["items"]
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"仓库搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_bilingual_repositories(self, english_query: str, chinese_query: str, language: str = None, min_stars: int = 0,
|
||||
sort: str = "stars", per_page: int = 30) -> List[Dict]:
|
||||
"""同时搜索中英文仓库并合并结果"""
|
||||
try:
|
||||
# 搜索英文仓库
|
||||
english_results = await self._search_repositories(
|
||||
query=english_query,
|
||||
language=language,
|
||||
min_stars=min_stars,
|
||||
sort=sort,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 搜索中文仓库
|
||||
chinese_results = await self._search_repositories(
|
||||
query=chinese_query,
|
||||
language=language,
|
||||
min_stars=min_stars,
|
||||
sort=sort,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 合并结果,去除重复项
|
||||
merged_results = []
|
||||
seen_repos = set()
|
||||
|
||||
# 优先添加英文结果
|
||||
for repo in english_results:
|
||||
repo_id = repo.get('id')
|
||||
if repo_id and repo_id not in seen_repos:
|
||||
seen_repos.add(repo_id)
|
||||
merged_results.append(repo)
|
||||
|
||||
# 添加中文结果(排除重复)
|
||||
for repo in chinese_results:
|
||||
repo_id = repo.get('id')
|
||||
if repo_id and repo_id not in seen_repos:
|
||||
seen_repos.add(repo_id)
|
||||
merged_results.append(repo)
|
||||
|
||||
# 按星标数重新排序
|
||||
merged_results.sort(key=lambda x: x.get('stargazers_count', 0), reverse=True)
|
||||
|
||||
return merged_results[:per_page] # 返回合并后的前per_page个结果
|
||||
except Exception as e:
|
||||
print(f"双语仓库搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_code(self, query: str, language: str = None, per_page: int = 30) -> List[Dict]:
|
||||
"""搜索代码"""
|
||||
try:
|
||||
# 构建查询字符串
|
||||
if language and "language:" not in query:
|
||||
query += f" language:{language}"
|
||||
|
||||
# 执行搜索
|
||||
result = await self.github.search_code(
|
||||
query=query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
if result and "items" in result:
|
||||
return result["items"]
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"代码搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_bilingual_code(self, english_query: str, chinese_query: str, language: str = None, per_page: int = 30) -> List[Dict]:
|
||||
"""同时搜索中英文代码并合并结果"""
|
||||
try:
|
||||
# 搜索英文代码
|
||||
english_results = await self._search_code(
|
||||
query=english_query,
|
||||
language=language,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 搜索中文代码
|
||||
chinese_results = await self._search_code(
|
||||
query=chinese_query,
|
||||
language=language,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 合并结果,去除重复项
|
||||
merged_results = []
|
||||
seen_files = set()
|
||||
|
||||
# 优先添加英文结果
|
||||
for item in english_results:
|
||||
# 使用文件URL作为唯一标识
|
||||
file_url = item.get('html_url', '')
|
||||
if file_url and file_url not in seen_files:
|
||||
seen_files.add(file_url)
|
||||
merged_results.append(item)
|
||||
|
||||
# 添加中文结果(排除重复)
|
||||
for item in chinese_results:
|
||||
file_url = item.get('html_url', '')
|
||||
if file_url and file_url not in seen_files:
|
||||
seen_files.add(file_url)
|
||||
merged_results.append(item)
|
||||
|
||||
# 对结果进行排序,优先显示匹配度高的结果
|
||||
# 由于无法直接获取匹配度,这里使用仓库的星标数作为替代指标
|
||||
merged_results.sort(key=lambda x: x.get('repository', {}).get('stargazers_count', 0), reverse=True)
|
||||
|
||||
return merged_results[:per_page] # 返回合并后的前per_page个结果
|
||||
except Exception as e:
|
||||
print(f"双语代码搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_users(self, query: str, per_page: int = 30) -> List[Dict]:
|
||||
"""搜索用户"""
|
||||
try:
|
||||
result = await self.github.search_users(
|
||||
query=query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
if result and "items" in result:
|
||||
return result["items"]
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"用户搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_bilingual_users(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]:
|
||||
"""同时搜索中英文用户并合并结果"""
|
||||
try:
|
||||
# 搜索英文用户
|
||||
english_results = await self._search_users(
|
||||
query=english_query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 搜索中文用户
|
||||
chinese_results = await self._search_users(
|
||||
query=chinese_query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 合并结果,去除重复项
|
||||
merged_results = []
|
||||
seen_users = set()
|
||||
|
||||
# 优先添加英文结果
|
||||
for user in english_results:
|
||||
user_id = user.get('id')
|
||||
if user_id and user_id not in seen_users:
|
||||
seen_users.add(user_id)
|
||||
merged_results.append(user)
|
||||
|
||||
# 添加中文结果(排除重复)
|
||||
for user in chinese_results:
|
||||
user_id = user.get('id')
|
||||
if user_id and user_id not in seen_users:
|
||||
seen_users.add(user_id)
|
||||
merged_results.append(user)
|
||||
|
||||
# 按关注者数量进行排序
|
||||
merged_results.sort(key=lambda x: x.get('followers', 0), reverse=True)
|
||||
|
||||
return merged_results[:per_page] # 返回合并后的前per_page个结果
|
||||
except Exception as e:
|
||||
print(f"双语用户搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_topics(self, query: str, per_page: int = 30) -> List[Dict]:
|
||||
"""搜索主题"""
|
||||
try:
|
||||
result = await self.github.search_topics(
|
||||
query=query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
if result and "items" in result:
|
||||
return result["items"]
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"主题搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_bilingual_topics(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]:
|
||||
"""同时搜索中英文主题并合并结果"""
|
||||
try:
|
||||
# 搜索英文主题
|
||||
english_results = await self._search_topics(
|
||||
query=english_query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 搜索中文主题
|
||||
chinese_results = await self._search_topics(
|
||||
query=chinese_query,
|
||||
per_page=per_page
|
||||
)
|
||||
|
||||
# 合并结果,去除重复项
|
||||
merged_results = []
|
||||
seen_topics = set()
|
||||
|
||||
# 优先添加英文结果
|
||||
for topic in english_results:
|
||||
topic_name = topic.get('name')
|
||||
if topic_name and topic_name not in seen_topics:
|
||||
seen_topics.add(topic_name)
|
||||
merged_results.append(topic)
|
||||
|
||||
# 添加中文结果(排除重复)
|
||||
for topic in chinese_results:
|
||||
topic_name = topic.get('name')
|
||||
if topic_name and topic_name not in seen_topics:
|
||||
seen_topics.add(topic_name)
|
||||
merged_results.append(topic)
|
||||
|
||||
# 可以按流行度进行排序(如果有)
|
||||
if merged_results and 'featured' in merged_results[0]:
|
||||
merged_results.sort(key=lambda x: x.get('featured', False), reverse=True)
|
||||
|
||||
return merged_results[:per_page] # 返回合并后的前per_page个结果
|
||||
except Exception as e:
|
||||
print(f"双语主题搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _get_repo_details(self, repos: List[Dict]) -> List[Dict]:
|
||||
"""获取仓库详细信息"""
|
||||
enhanced_repos = []
|
||||
|
||||
for repo in repos:
|
||||
try:
|
||||
# 获取README信息
|
||||
owner = repo.get('owner', {}).get('login') if repo.get('owner') is not None else None
|
||||
repo_name = repo.get('name')
|
||||
|
||||
if owner and repo_name:
|
||||
readme = await self.github.get_repo_readme(owner, repo_name)
|
||||
if readme and "decoded_content" in readme:
|
||||
# 提取README的前1000个字符作为摘要
|
||||
repo['readme_excerpt'] = readme["decoded_content"][:1000] + "..."
|
||||
|
||||
# 获取语言使用情况
|
||||
languages = await self.github.get_repository_languages(owner, repo_name)
|
||||
if languages:
|
||||
repo['languages_detail'] = languages
|
||||
|
||||
# 获取最新发布版本
|
||||
releases = await self.github.get_repo_releases(owner, repo_name, per_page=1)
|
||||
if releases and len(releases) > 0:
|
||||
repo['latest_release'] = releases[0]
|
||||
|
||||
# 获取主题标签
|
||||
topics = await self.github.get_repo_topics(owner, repo_name)
|
||||
if topics and "names" in topics:
|
||||
repo['topics'] = topics["names"]
|
||||
|
||||
enhanced_repos.append(repo)
|
||||
except Exception as e:
|
||||
print(f"获取仓库 {repo.get('full_name')} 详情时出错: {str(e)}")
|
||||
enhanced_repos.append(repo) # 添加原始仓库信息
|
||||
|
||||
return enhanced_repos
|
||||
|
||||
def _format_repos(self, repos: List[Dict]) -> str:
|
||||
"""格式化仓库列表"""
|
||||
formatted = []
|
||||
|
||||
for i, repo in enumerate(repos, 1):
|
||||
# 构建仓库URL
|
||||
repo_url = repo.get('html_url', '')
|
||||
|
||||
# 构建完整的引用
|
||||
reference = (
|
||||
f"{i}. **{repo.get('full_name', '')}**\n"
|
||||
f" - 描述: {repo.get('description', 'N/A')}\n"
|
||||
f" - 语言: {repo.get('language', 'N/A')}\n"
|
||||
f" - 星标: {repo.get('stargazers_count', 0)}\n"
|
||||
f" - Fork数: {repo.get('forks_count', 0)}\n"
|
||||
f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n"
|
||||
f" - 创建时间: {repo.get('created_at', 'N/A')[:10]}\n"
|
||||
f" - URL: <a href='{repo_url}' target='_blank'>{repo_url}</a>\n"
|
||||
)
|
||||
|
||||
# 添加主题标签(如果有)
|
||||
if repo.get('topics'):
|
||||
topics_str = ", ".join(repo.get('topics'))
|
||||
reference += f" - 主题标签: {topics_str}\n"
|
||||
|
||||
# 添加最新发布版本(如果有)
|
||||
if repo.get('latest_release'):
|
||||
release = repo.get('latest_release')
|
||||
reference += f" - 最新版本: {release.get('tag_name', 'N/A')} ({release.get('published_at', 'N/A')[:10]})\n"
|
||||
|
||||
# 添加README摘要(如果有)
|
||||
if repo.get('readme_excerpt'):
|
||||
# 截断README,只取前300个字符
|
||||
readme_short = repo.get('readme_excerpt')[:300].replace('\n', ' ')
|
||||
reference += f" - README摘要: {readme_short}...\n"
|
||||
|
||||
formatted.append(reference)
|
||||
|
||||
return "\n".join(formatted)
|
||||
|
||||
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成道歉提示"""
|
||||
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的GitHub项目。
|
||||
|
||||
可能的原因:
|
||||
1. 搜索词过于具体或冷门
|
||||
2. 星标数要求过高
|
||||
3. 编程语言限制过于严格
|
||||
|
||||
建议解决方案:
|
||||
1. 尝试使用更通用的关键词
|
||||
2. 降低最低星标数要求
|
||||
3. 移除或更改编程语言限制
|
||||
请根据以上建议调整后重试。"""
|
||||
|
||||
def _get_current_time(self) -> str:
|
||||
"""获取当前时间信息"""
|
||||
now = datetime.now()
|
||||
return now.strftime("%Y年%m月%d日")
|
||||
156
crazy_functions/paper_fns/auto_git/handlers/code_handler.py
Normal file
156
crazy_functions/paper_fns/auto_git/handlers/code_handler.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from ..query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class CodeSearchHandler(BaseHandler):
|
||||
"""代码搜索处理器"""
|
||||
|
||||
def __init__(self, github, llm_kwargs=None):
|
||||
super().__init__(github, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理代码搜索请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 搜索代码
|
||||
code_results = await self._search_bilingual_code(
|
||||
english_query=criteria.github_params["query"],
|
||||
chinese_query=criteria.github_params["chinese_query"],
|
||||
language=criteria.language,
|
||||
per_page=search_params['max_repos']
|
||||
)
|
||||
|
||||
if not code_results:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 获取代码文件内容
|
||||
enhanced_code_results = await self._get_code_details(code_results[:search_params['max_details']])
|
||||
self.ranked_repos = [item["repository"] for item in enhanced_code_results if "repository" in item]
|
||||
|
||||
if not enhanced_code_results:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""当前时间: {current_time}
|
||||
|
||||
基于用户对{criteria.main_topic}的查询,我找到了以下代码示例。
|
||||
|
||||
代码搜索结果:
|
||||
{self._format_code_results(enhanced_code_results)}
|
||||
|
||||
请提供:
|
||||
|
||||
1. 对于搜索的"{criteria.main_topic}"主题的综合解释:
|
||||
- 概念和原理介绍
|
||||
- 常见实现方法和技术
|
||||
- 最佳实践和注意事项
|
||||
|
||||
2. 对每个代码示例:
|
||||
- 解释代码的主要功能和实现方式
|
||||
- 分析代码质量、可读性和效率
|
||||
- 指出代码中的亮点和潜在改进空间
|
||||
- 说明代码的适用场景
|
||||
|
||||
3. 代码实现比较:
|
||||
- 不同实现方法的优缺点
|
||||
- 性能和可维护性分析
|
||||
- 适用不同场景的实现建议
|
||||
|
||||
4. 学习建议:
|
||||
- 理解和使用这些代码需要的背景知识
|
||||
- 如何扩展或改进所展示的代码
|
||||
- 进一步学习相关技术的资源
|
||||
|
||||
重要提示:
|
||||
- 深入解释代码的核心逻辑和实现思路
|
||||
- 提供专业、技术性的分析
|
||||
- 优先关注代码的实现质量和技术价值
|
||||
- 当代码实现有问题时,指出并提供改进建议
|
||||
- 对于复杂代码,分解解释其组成部分
|
||||
- 根据用户查询的具体问题提供针对性答案
|
||||
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
|
||||
|
||||
使用markdown格式提供清晰的分节回复。
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _get_code_details(self, code_results: List[Dict]) -> List[Dict]:
|
||||
"""获取代码详情"""
|
||||
enhanced_results = []
|
||||
|
||||
for item in code_results:
|
||||
try:
|
||||
repo = item.get('repository', {})
|
||||
file_path = item.get('path', '')
|
||||
repo_name = repo.get('full_name', '')
|
||||
|
||||
if repo_name and file_path:
|
||||
owner, repo_name = repo_name.split('/')
|
||||
|
||||
# 获取文件内容
|
||||
file_content = await self.github.get_file_content(owner, repo_name, file_path)
|
||||
if file_content and "decoded_content" in file_content:
|
||||
item['code_content'] = file_content["decoded_content"]
|
||||
|
||||
# 获取仓库基本信息
|
||||
repo_details = await self.github.get_repo(owner, repo_name)
|
||||
if repo_details:
|
||||
item['repository'] = repo_details
|
||||
|
||||
enhanced_results.append(item)
|
||||
except Exception as e:
|
||||
print(f"获取代码详情时出错: {str(e)}")
|
||||
enhanced_results.append(item) # 添加原始信息
|
||||
|
||||
return enhanced_results
|
||||
|
||||
def _format_code_results(self, code_results: List[Dict]) -> str:
|
||||
"""格式化代码搜索结果"""
|
||||
formatted = []
|
||||
|
||||
for i, item in enumerate(code_results, 1):
|
||||
# 构建仓库信息
|
||||
repo = item.get('repository', {})
|
||||
repo_name = repo.get('full_name', 'N/A')
|
||||
repo_url = repo.get('html_url', '')
|
||||
stars = repo.get('stargazers_count', 0)
|
||||
language = repo.get('language', 'N/A')
|
||||
|
||||
# 构建文件信息
|
||||
file_path = item.get('path', 'N/A')
|
||||
file_url = item.get('html_url', '')
|
||||
|
||||
# 构建代码内容
|
||||
code_content = item.get('code_content', '')
|
||||
if code_content:
|
||||
# 只显示前30行代码
|
||||
code_lines = code_content.split("\n")
|
||||
if len(code_lines) > 30:
|
||||
displayed_code = "\n".join(code_lines[:30]) + "\n... (代码太长已截断) ..."
|
||||
else:
|
||||
displayed_code = code_content
|
||||
else:
|
||||
displayed_code = "(代码内容获取失败)"
|
||||
|
||||
reference = (
|
||||
f"### {i}. {file_path} (在 {repo_name} 中)\n\n"
|
||||
f"- **仓库**: <a href='{repo_url}' target='_blank'>{repo_name}</a> (⭐ {stars}, 语言: {language})\n"
|
||||
f"- **文件路径**: <a href='{file_url}' target='_blank'>{file_path}</a>\n\n"
|
||||
f"```{language.lower()}\n{displayed_code}\n```\n\n"
|
||||
)
|
||||
|
||||
formatted.append(reference)
|
||||
|
||||
return "\n".join(formatted)
|
||||
192
crazy_functions/paper_fns/auto_git/handlers/repo_handler.py
Normal file
192
crazy_functions/paper_fns/auto_git/handlers/repo_handler.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from ..query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class RepositoryHandler(BaseHandler):
|
||||
"""仓库搜索处理器"""
|
||||
|
||||
def __init__(self, github, llm_kwargs=None):
|
||||
super().__init__(github, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理仓库搜索请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 如果是特定仓库查询
|
||||
if criteria.repo_id:
|
||||
try:
|
||||
owner, repo = criteria.repo_id.split('/')
|
||||
repo_details = await self.github.get_repo(owner, repo)
|
||||
if repo_details:
|
||||
# 获取推荐的相似仓库
|
||||
similar_repos = await self.github.get_repo_recommendations(criteria.repo_id, limit=5)
|
||||
|
||||
# 添加详细信息
|
||||
all_repos = [repo_details] + similar_repos
|
||||
enhanced_repos = await self._get_repo_details(all_repos)
|
||||
|
||||
self.ranked_repos = enhanced_repos
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = self._build_repo_detail_prompt(enhanced_repos[0], enhanced_repos[1:], current_time)
|
||||
return final_prompt
|
||||
else:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
except Exception as e:
|
||||
print(f"处理特定仓库时出错: {str(e)}")
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 一般仓库搜索
|
||||
repos = await self._search_bilingual_repositories(
|
||||
english_query=criteria.github_params["query"],
|
||||
chinese_query=criteria.github_params["chinese_query"],
|
||||
language=criteria.language,
|
||||
min_stars=criteria.min_stars,
|
||||
per_page=search_params['max_repos']
|
||||
)
|
||||
|
||||
if not repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 获取仓库详情
|
||||
enhanced_repos = await self._get_repo_details(repos[:search_params['max_details']]) # 使用max_details参数
|
||||
self.ranked_repos = enhanced_repos
|
||||
|
||||
if not enhanced_repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""当前时间: {current_time}
|
||||
|
||||
基于用户对{criteria.main_topic}的兴趣,以下是相关的GitHub仓库。
|
||||
|
||||
可供推荐的GitHub仓库:
|
||||
{self._format_repos(enhanced_repos)}
|
||||
|
||||
请提供:
|
||||
1. 按功能、用途或成熟度对仓库进行分组
|
||||
|
||||
2. 对每个仓库:
|
||||
- 简要描述其主要功能和用途
|
||||
- 分析其技术特点和优势
|
||||
- 说明其适用场景和使用难度
|
||||
- 指出其与同类产品相比的独特优势
|
||||
- 解释其星标数量和活跃度代表的意义
|
||||
|
||||
3. 使用建议:
|
||||
- 新手最适合入门的仓库
|
||||
- 生产环境中最稳定可靠的选择
|
||||
- 最新技术栈或创新方案的代表
|
||||
- 学习特定技术的最佳资源
|
||||
|
||||
4. 相关资源:
|
||||
- 学习这些项目需要的前置知识
|
||||
- 项目间的关联和技术栈兼容性
|
||||
- 可能的使用组合方案
|
||||
|
||||
重要提示:
|
||||
- 重点解释为什么每个仓库值得关注
|
||||
- 突出项目间的关联性和差异性
|
||||
- 考虑用户不同水平的需求(初学者vs专业人士)
|
||||
- 在介绍项目时,使用<a href='链接' target='_blank'>文本</a>格式,确保链接在新窗口打开
|
||||
- 根据仓库的活跃度、更新频率、维护状态提供使用建议
|
||||
- 仅基于提供的信息,不要做无根据的猜测
|
||||
- 在信息缺失或不明确时,坦诚说明
|
||||
|
||||
使用markdown格式提供清晰的分节回复。
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
def _build_repo_detail_prompt(self, main_repo: Dict, similar_repos: List[Dict], current_time: str) -> str:
|
||||
"""构建仓库详情prompt"""
|
||||
|
||||
# 提取README摘要
|
||||
readme_content = "未提供"
|
||||
if main_repo.get('readme_excerpt'):
|
||||
readme_content = main_repo.get('readme_excerpt')
|
||||
|
||||
# 构建语言分布
|
||||
languages = main_repo.get('languages_detail', {})
|
||||
lang_distribution = []
|
||||
if languages:
|
||||
total = sum(languages.values())
|
||||
for lang, bytes_val in languages.items():
|
||||
percentage = (bytes_val / total) * 100
|
||||
lang_distribution.append(f"{lang}: {percentage:.1f}%")
|
||||
|
||||
lang_str = "未知"
|
||||
if lang_distribution:
|
||||
lang_str = ", ".join(lang_distribution)
|
||||
|
||||
# 构建最终prompt
|
||||
prompt = f"""当前时间: {current_time}
|
||||
|
||||
## 主要仓库信息
|
||||
|
||||
### {main_repo.get('full_name')}
|
||||
|
||||
- **描述**: {main_repo.get('description', '未提供')}
|
||||
- **星标数**: {main_repo.get('stargazers_count', 0)}
|
||||
- **Fork数**: {main_repo.get('forks_count', 0)}
|
||||
- **Watch数**: {main_repo.get('watchers_count', 0)}
|
||||
- **Issues数**: {main_repo.get('open_issues_count', 0)}
|
||||
- **语言分布**: {lang_str}
|
||||
- **许可证**: {main_repo.get('license', {}).get('name', '未指定') if main_repo.get('license') is not None else '未指定'}
|
||||
- **创建时间**: {main_repo.get('created_at', '')[:10]}
|
||||
- **最近更新**: {main_repo.get('updated_at', '')[:10]}
|
||||
- **主题标签**: {', '.join(main_repo.get('topics', ['无']))}
|
||||
- **GitHub链接**: <a href='{main_repo.get('html_url')}' target='_blank'>链接</a>
|
||||
|
||||
### README摘要:
|
||||
{readme_content}
|
||||
|
||||
## 类似仓库:
|
||||
{self._format_repos(similar_repos)}
|
||||
|
||||
请提供以下内容:
|
||||
|
||||
1. **项目概述**
|
||||
- 详细解释{main_repo.get('name', '')}项目的主要功能和用途
|
||||
- 分析其技术特点、架构和实现原理
|
||||
- 讨论其在所属领域的地位和影响力
|
||||
- 评估项目成熟度和稳定性
|
||||
|
||||
2. **优势与特点**
|
||||
- 与同类项目相比的独特优势
|
||||
- 显著的技术创新或设计模式
|
||||
- 值得学习或借鉴的代码实践
|
||||
|
||||
3. **使用场景**
|
||||
- 最适合的应用场景
|
||||
- 潜在的使用限制和注意事项
|
||||
- 入门门槛和学习曲线评估
|
||||
- 产品级应用的可行性分析
|
||||
|
||||
4. **资源与生态**
|
||||
- 相关学习资源推荐
|
||||
- 配套工具和库的建议
|
||||
- 社区支持和活跃度评估
|
||||
|
||||
5. **类似项目对比**
|
||||
- 与列出的类似项目的详细对比
|
||||
- 不同场景下的最佳选择建议
|
||||
- 潜在的互补使用方案
|
||||
|
||||
提示:所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开。
|
||||
|
||||
请以专业、客观的技术分析角度回答,使用markdown格式提供结构化信息。
|
||||
"""
|
||||
return prompt
|
||||
217
crazy_functions/paper_fns/auto_git/handlers/topic_handler.py
Normal file
217
crazy_functions/paper_fns/auto_git/handlers/topic_handler.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from ..query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class TopicHandler(BaseHandler):
|
||||
"""主题搜索处理器"""
|
||||
|
||||
def __init__(self, github, llm_kwargs=None):
|
||||
super().__init__(github, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理主题搜索请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 搜索主题
|
||||
topics = await self._search_bilingual_topics(
|
||||
english_query=criteria.github_params["query"],
|
||||
chinese_query=criteria.github_params["chinese_query"],
|
||||
per_page=search_params['max_repos']
|
||||
)
|
||||
|
||||
if not topics:
|
||||
# 尝试用主题搜索仓库
|
||||
search_query = criteria.github_params["query"]
|
||||
chinese_search_query = criteria.github_params["chinese_query"]
|
||||
if "topic:" not in search_query:
|
||||
search_query += " topic:" + criteria.main_topic.replace(" ", "-")
|
||||
if "topic:" not in chinese_search_query:
|
||||
chinese_search_query += " topic:" + criteria.main_topic.replace(" ", "-")
|
||||
|
||||
repos = await self._search_bilingual_repositories(
|
||||
english_query=search_query,
|
||||
chinese_query=chinese_search_query,
|
||||
language=criteria.language,
|
||||
min_stars=criteria.min_stars,
|
||||
per_page=search_params['max_repos']
|
||||
)
|
||||
|
||||
if not repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 获取仓库详情
|
||||
enhanced_repos = await self._get_repo_details(repos[:10])
|
||||
self.ranked_repos = enhanced_repos
|
||||
|
||||
if not enhanced_repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建基于主题的仓库列表prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""当前时间: {current_time}
|
||||
|
||||
基于用户对主题"{criteria.main_topic}"的查询,我找到了以下相关GitHub仓库。
|
||||
|
||||
主题相关仓库:
|
||||
{self._format_repos(enhanced_repos)}
|
||||
|
||||
请提供:
|
||||
|
||||
1. 主题综述:
|
||||
- "{criteria.main_topic}"主题的概述和重要性
|
||||
- 该主题在技术领域中的应用和发展趋势
|
||||
- 主题相关的主要技术栈和知识体系
|
||||
|
||||
2. 仓库分析:
|
||||
- 按功能、技术栈或应用场景对仓库进行分类
|
||||
- 每个仓库在该主题领域的定位和贡献
|
||||
- 不同仓库间的技术路线对比
|
||||
|
||||
3. 学习路径建议:
|
||||
- 初学者入门该主题的推荐仓库和学习顺序
|
||||
- 进阶学习的关键仓库和技术要点
|
||||
- 实际应用中的最佳实践选择
|
||||
|
||||
4. 技术生态分析:
|
||||
- 该主题下的主流工具和库
|
||||
- 社区活跃度和维护状况
|
||||
- 与其他相关技术的集成方案
|
||||
|
||||
重要提示:
|
||||
- 主题"{criteria.main_topic}"是用户查询的核心,请围绕此主题展开分析
|
||||
- 注重仓库质量评估和使用建议
|
||||
- 提供基于事实的客观技术分析
|
||||
- 在介绍仓库时使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
|
||||
- 考虑不同技术水平用户的需求
|
||||
|
||||
使用markdown格式提供清晰的分节回复。
|
||||
"""
|
||||
return final_prompt
|
||||
|
||||
# 如果找到了主题,则获取主题下的热门仓库
|
||||
topic_repos = []
|
||||
for topic in topics[:5]: # 增加到5个主题
|
||||
topic_name = topic.get('name', '')
|
||||
if topic_name:
|
||||
# 搜索该主题下的仓库
|
||||
repos = await self._search_repositories(
|
||||
query=f"topic:{topic_name}",
|
||||
language=criteria.language,
|
||||
min_stars=criteria.min_stars,
|
||||
per_page=20 # 每个主题最多20个仓库
|
||||
)
|
||||
|
||||
if repos:
|
||||
for repo in repos:
|
||||
repo['topic_source'] = topic_name
|
||||
topic_repos.append(repo)
|
||||
|
||||
if not topic_repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 获取前N个仓库的详情
|
||||
enhanced_repos = await self._get_repo_details(topic_repos[:search_params['max_details']])
|
||||
self.ranked_repos = enhanced_repos
|
||||
|
||||
if not enhanced_repos:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""当前时间: {current_time}
|
||||
|
||||
基于用户对"{criteria.main_topic}"主题的查询,我找到了以下相关GitHub主题和仓库。
|
||||
|
||||
主题相关仓库:
|
||||
{self._format_topic_repos(enhanced_repos)}
|
||||
|
||||
请提供:
|
||||
|
||||
1. 主题概述:
|
||||
- 对"{criteria.main_topic}"相关主题的介绍和技术背景
|
||||
- 这些主题在软件开发中的重要性和应用范围
|
||||
- 主题间的关联性和技术演进路径
|
||||
|
||||
2. 精选仓库分析:
|
||||
- 每个主题下最具代表性的仓库详解
|
||||
- 仓库的技术亮点和创新点
|
||||
- 使用场景和技术成熟度评估
|
||||
|
||||
3. 技术趋势分析:
|
||||
- 基于主题和仓库活跃度的技术发展趋势
|
||||
- 新兴解决方案和传统方案的对比
|
||||
- 未来可能的技术方向预测
|
||||
|
||||
4. 实践建议:
|
||||
- 不同应用场景下的最佳仓库选择
|
||||
- 学习路径和资源推荐
|
||||
- 实际项目中的应用策略
|
||||
|
||||
重要提示:
|
||||
- 将分析重点放在主题的技术内涵和价值上
|
||||
- 突出主题间的关联性和技术演进脉络
|
||||
- 提供基于数据(星标数、更新频率等)的客观分析
|
||||
- 考虑不同技术背景用户的需求
|
||||
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
|
||||
|
||||
使用markdown格式提供清晰的分节回复。
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
def _format_topic_repos(self, repos: List[Dict]) -> str:
|
||||
"""按主题格式化仓库列表"""
|
||||
# 按主题分组
|
||||
topics_dict = {}
|
||||
for repo in repos:
|
||||
topic = repo.get('topic_source', '其他')
|
||||
if topic not in topics_dict:
|
||||
topics_dict[topic] = []
|
||||
topics_dict[topic].append(repo)
|
||||
|
||||
# 格式化输出
|
||||
formatted = []
|
||||
for topic, topic_repos in topics_dict.items():
|
||||
formatted.append(f"## 主题: {topic}\n")
|
||||
|
||||
for i, repo in enumerate(topic_repos, 1):
|
||||
# 构建仓库URL
|
||||
repo_url = repo.get('html_url', '')
|
||||
|
||||
# 构建引用
|
||||
reference = (
|
||||
f"{i}. **{repo.get('full_name', '')}**\n"
|
||||
f" - 描述: {repo.get('description', 'N/A')}\n"
|
||||
f" - 语言: {repo.get('language', 'N/A')}\n"
|
||||
f" - 星标: {repo.get('stargazers_count', 0)}\n"
|
||||
f" - Fork数: {repo.get('forks_count', 0)}\n"
|
||||
f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n"
|
||||
f" - URL: <a href='{repo_url}' target='_blank'>{repo_url}</a>\n"
|
||||
)
|
||||
|
||||
# 添加主题标签(如果有)
|
||||
if repo.get('topics'):
|
||||
topics_str = ", ".join(repo.get('topics'))
|
||||
reference += f" - 主题标签: {topics_str}\n"
|
||||
|
||||
# 添加README摘要(如果有)
|
||||
if repo.get('readme_excerpt'):
|
||||
# 截断README,只取前200个字符
|
||||
readme_short = repo.get('readme_excerpt')[:200].replace('\n', ' ')
|
||||
reference += f" - README摘要: {readme_short}...\n"
|
||||
|
||||
formatted.append(reference)
|
||||
|
||||
formatted.append("\n") # 主题之间添加空行
|
||||
|
||||
return "\n".join(formatted)
|
||||
164
crazy_functions/paper_fns/auto_git/handlers/user_handler.py
Normal file
164
crazy_functions/paper_fns/auto_git/handlers/user_handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from ..query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class UserSearchHandler(BaseHandler):
|
||||
"""用户搜索处理器"""
|
||||
|
||||
def __init__(self, github, llm_kwargs=None):
|
||||
super().__init__(github, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理用户搜索请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 搜索用户
|
||||
users = await self._search_bilingual_users(
|
||||
english_query=criteria.github_params["query"],
|
||||
chinese_query=criteria.github_params["chinese_query"],
|
||||
per_page=search_params['max_repos']
|
||||
)
|
||||
|
||||
if not users:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 获取用户详情和仓库
|
||||
enhanced_users = await self._get_user_details(users[:search_params['max_details']])
|
||||
self.ranked_repos = [] # 添加用户top仓库进行展示
|
||||
|
||||
for user in enhanced_users:
|
||||
if user.get('top_repos'):
|
||||
self.ranked_repos.extend(user.get('top_repos'))
|
||||
|
||||
if not enhanced_users:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""当前时间: {current_time}
|
||||
|
||||
基于用户对{criteria.main_topic}的查询,我找到了以下GitHub用户。
|
||||
|
||||
GitHub用户搜索结果:
|
||||
{self._format_users(enhanced_users)}
|
||||
|
||||
请提供:
|
||||
|
||||
1. 用户综合分析:
|
||||
- 各开发者的专业领域和技术专长
|
||||
- 他们在GitHub开源社区的影响力
|
||||
- 技术实力和项目质量评估
|
||||
|
||||
2. 对每位开发者:
|
||||
- 其主要贡献领域和技术栈
|
||||
- 代表性项目及其价值
|
||||
- 编程风格和技术特点
|
||||
- 在相关领域的影响力
|
||||
|
||||
3. 项目推荐:
|
||||
- 针对用户查询的最有价值项目
|
||||
- 值得学习和借鉴的代码实践
|
||||
- 不同用户项目的相互补充关系
|
||||
|
||||
4. 如何学习和使用:
|
||||
- 如何从这些开发者项目中学习
|
||||
- 最适合入门学习的项目
|
||||
- 进阶学习的路径建议
|
||||
|
||||
重要提示:
|
||||
- 关注开发者的技术专长和核心贡献
|
||||
- 分析其开源项目的技术价值
|
||||
- 根据用户的原始查询提供相关建议
|
||||
- 避免过度赞美或主观评价
|
||||
- 基于事实数据(项目数、星标数等)进行客观分析
|
||||
- 所有链接请使用<a href='链接地址' target='_blank'>链接文本</a>格式,确保链接在新窗口打开
|
||||
|
||||
使用markdown格式提供清晰的分节回复。
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _get_user_details(self, users: List[Dict]) -> List[Dict]:
|
||||
"""获取用户详情和仓库"""
|
||||
enhanced_users = []
|
||||
|
||||
for user in users:
|
||||
try:
|
||||
username = user.get('login')
|
||||
|
||||
if username:
|
||||
# 获取用户详情
|
||||
user_details = await self.github.get_user(username)
|
||||
if user_details:
|
||||
user.update(user_details)
|
||||
|
||||
# 获取用户仓库
|
||||
repos = await self.github.get_user_repos(
|
||||
username,
|
||||
sort="stars",
|
||||
per_page=10 # 增加到10个仓库
|
||||
)
|
||||
if repos:
|
||||
user['top_repos'] = repos
|
||||
|
||||
enhanced_users.append(user)
|
||||
except Exception as e:
|
||||
print(f"获取用户 {user.get('login')} 详情时出错: {str(e)}")
|
||||
enhanced_users.append(user) # 添加原始信息
|
||||
|
||||
return enhanced_users
|
||||
|
||||
def _format_users(self, users: List[Dict]) -> str:
|
||||
"""格式化用户列表"""
|
||||
formatted = []
|
||||
|
||||
for i, user in enumerate(users, 1):
|
||||
# 构建用户信息
|
||||
username = user.get('login', 'N/A')
|
||||
name = user.get('name', username)
|
||||
profile_url = user.get('html_url', '')
|
||||
bio = user.get('bio', '无简介')
|
||||
followers = user.get('followers', 0)
|
||||
public_repos = user.get('public_repos', 0)
|
||||
company = user.get('company', '未指定')
|
||||
location = user.get('location', '未指定')
|
||||
blog = user.get('blog', '')
|
||||
|
||||
user_info = (
|
||||
f"### {i}. {name} (@{username})\n\n"
|
||||
f"- **简介**: {bio}\n"
|
||||
f"- **关注者**: {followers} | **公开仓库**: {public_repos}\n"
|
||||
f"- **公司**: {company} | **地点**: {location}\n"
|
||||
f"- **个人网站**: {blog}\n"
|
||||
f"- **GitHub**: <a href='{profile_url}' target='_blank'>{username}</a>\n\n"
|
||||
)
|
||||
|
||||
# 添加用户的热门仓库
|
||||
top_repos = user.get('top_repos', [])
|
||||
if top_repos:
|
||||
user_info += "**热门仓库**:\n\n"
|
||||
for repo in top_repos:
|
||||
repo_name = repo.get('name', '')
|
||||
repo_url = repo.get('html_url', '')
|
||||
repo_desc = repo.get('description', '无描述')
|
||||
repo_stars = repo.get('stargazers_count', 0)
|
||||
repo_language = repo.get('language', '未指定')
|
||||
|
||||
user_info += (
|
||||
f"- <a href='{repo_url}' target='_blank'>{repo_name}</a> - ⭐ {repo_stars}, {repo_language}\n"
|
||||
f" {repo_desc}\n\n"
|
||||
)
|
||||
|
||||
formatted.append(user_info)
|
||||
|
||||
return "\n".join(formatted)
|
||||
356
crazy_functions/paper_fns/auto_git/query_analyzer.py
Normal file
356
crazy_functions/paper_fns/auto_git/query_analyzer.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
@dataclass
|
||||
class SearchCriteria:
|
||||
"""搜索条件"""
|
||||
query_type: str # 查询类型: repo/code/user/topic
|
||||
main_topic: str # 主题
|
||||
sub_topics: List[str] # 子主题列表
|
||||
language: str # 编程语言
|
||||
min_stars: int # 最少星标数
|
||||
github_params: Dict # GitHub搜索参数
|
||||
original_query: str = "" # 原始查询字符串
|
||||
repo_id: str = "" # 特定仓库ID或名称
|
||||
|
||||
class QueryAnalyzer:
|
||||
"""查询分析器"""
|
||||
|
||||
# 响应索引常量
|
||||
BASIC_QUERY_INDEX = 0
|
||||
GITHUB_QUERY_INDEX = 1
|
||||
|
||||
def __init__(self):
|
||||
self.valid_types = {
|
||||
"repo": ["repository", "project", "library", "framework", "tool"],
|
||||
"code": ["code", "snippet", "implementation", "function", "class", "algorithm"],
|
||||
"user": ["user", "developer", "organization", "contributor", "maintainer"],
|
||||
"topic": ["topic", "category", "tag", "field", "area", "domain"]
|
||||
}
|
||||
|
||||
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
|
||||
"""分析查询意图"""
|
||||
from crazy_functions.crazy_utils import \
|
||||
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
# 1. 基本查询分析
|
||||
type_prompt = f"""请分析这个与GitHub相关的查询,并严格按照以下XML格式回答:
|
||||
|
||||
查询: {query}
|
||||
|
||||
说明:
|
||||
1. 你的回答必须使用下面显示的XML标签,不要有任何标签外的文本
|
||||
2. 从以下选项中选择查询类型: repo/code/user/topic
|
||||
- repo: 用于查找仓库、项目、框架或库
|
||||
- code: 用于查找代码片段、函数实现或算法
|
||||
- user: 用于查找用户、开发者或组织
|
||||
- topic: 用于查找主题、类别或领域相关项目
|
||||
3. 识别主题和子主题
|
||||
4. 识别首选编程语言(如果有)
|
||||
5. 确定最低星标数(如果适用)
|
||||
|
||||
必需格式:
|
||||
<query_type>此处回答</query_type>
|
||||
<main_topic>此处回答</main_topic>
|
||||
<sub_topics>子主题1, 子主题2, ...</sub_topics>
|
||||
<language>此处回答</language>
|
||||
<min_stars>此处回答</min_stars>
|
||||
|
||||
示例回答:
|
||||
|
||||
1. 仓库查询:
|
||||
查询: "查找有至少1000颗星的Python web框架"
|
||||
<query_type>repo</query_type>
|
||||
<main_topic>web框架</main_topic>
|
||||
<sub_topics>后端开发, HTTP服务器, ORM</sub_topics>
|
||||
<language>Python</language>
|
||||
<min_stars>1000</min_stars>
|
||||
|
||||
2. 代码查询:
|
||||
查询: "如何用JavaScript实现防抖函数"
|
||||
<query_type>code</query_type>
|
||||
<main_topic>防抖函数</main_topic>
|
||||
<sub_topics>事件处理, 性能优化, 函数节流</sub_topics>
|
||||
<language>JavaScript</language>
|
||||
<min_stars>0</min_stars>"""
|
||||
|
||||
# 2. 生成英文搜索条件
|
||||
github_prompt = f"""Optimize the following GitHub search query:
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Convert the natural language query into an optimized GitHub search query.
|
||||
Please use English, regardless of the language of the input query.
|
||||
|
||||
Available search fields and filters:
|
||||
1. Basic fields:
|
||||
- in:name - Search in repository names
|
||||
- in:description - Search in repository descriptions
|
||||
- in:readme - Search in README files
|
||||
- in:topic - Search in topics
|
||||
- language:X - Filter by programming language
|
||||
- user:X - Repositories from a specific user
|
||||
- org:X - Repositories from a specific organization
|
||||
|
||||
2. Code search fields:
|
||||
- extension:X - Filter by file extension
|
||||
- path:X - Filter by path
|
||||
- filename:X - Filter by filename
|
||||
|
||||
3. Metric filters:
|
||||
- stars:>X - Has more than X stars
|
||||
- forks:>X - Has more than X forks
|
||||
- size:>X - Size greater than X KB
|
||||
- created:>YYYY-MM-DD - Created after a specific date
|
||||
- pushed:>YYYY-MM-DD - Updated after a specific date
|
||||
|
||||
4. Other filters:
|
||||
- is:public/private - Public or private repositories
|
||||
- archived:true/false - Archived or not archived
|
||||
- license:X - Specific license
|
||||
- topic:X - Contains specific topic tag
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Find Python machine learning libraries with at least 1000 stars"
|
||||
<query>machine learning in:description language:python stars:>1000</query>
|
||||
|
||||
2. Query: "Recently updated React UI component libraries"
|
||||
<query>UI components library in:readme in:description language:javascript topic:react pushed:>2023-01-01</query>
|
||||
|
||||
3. Query: "Open source projects developed by Facebook"
|
||||
<query>org:facebook is:public</query>
|
||||
|
||||
4. Query: "Depth-first search implementation in JavaScript"
|
||||
<query>depth first search in:file language:javascript</query>
|
||||
|
||||
Please analyze the query and answer using only the XML tag:
|
||||
<query>Provide the optimized GitHub search query, using appropriate fields and operators</query>"""
|
||||
|
||||
# 3. 生成中文搜索条件
|
||||
chinese_github_prompt = f"""优化以下GitHub搜索查询:
|
||||
|
||||
查询: {query}
|
||||
|
||||
任务: 将自然语言查询转换为优化的GitHub搜索查询语句。
|
||||
为了搜索中文内容,请提取原始查询的关键词并使用中文形式,同时保留GitHub特定的搜索语法为英文。
|
||||
|
||||
可用的搜索字段和过滤器:
|
||||
1. 基本字段:
|
||||
- in:name - 在仓库名称中搜索
|
||||
- in:description - 在仓库描述中搜索
|
||||
- in:readme - 在README文件中搜索
|
||||
- in:topic - 在主题中搜索
|
||||
- language:X - 按编程语言筛选
|
||||
- user:X - 特定用户的仓库
|
||||
- org:X - 特定组织的仓库
|
||||
|
||||
2. 代码搜索字段:
|
||||
- extension:X - 按文件扩展名筛选
|
||||
- path:X - 按路径筛选
|
||||
- filename:X - 按文件名筛选
|
||||
|
||||
3. 指标过滤器:
|
||||
- stars:>X - 有超过X颗星
|
||||
- forks:>X - 有超过X个分支
|
||||
- size:>X - 大小超过X KB
|
||||
- created:>YYYY-MM-DD - 在特定日期后创建
|
||||
- pushed:>YYYY-MM-DD - 在特定日期后更新
|
||||
|
||||
4. 其他过滤器:
|
||||
- is:public/private - 公开或私有仓库
|
||||
- archived:true/false - 已归档或未归档
|
||||
- license:X - 特定许可证
|
||||
- topic:X - 含特定主题标签
|
||||
|
||||
示例:
|
||||
|
||||
1. 查询: "找有关机器学习的Python库,至少1000颗星"
|
||||
<query>机器学习 in:description language:python stars:>1000</query>
|
||||
|
||||
2. 查询: "最近更新的React UI组件库"
|
||||
<query>UI 组件库 in:readme in:description language:javascript topic:react pushed:>2023-01-01</query>
|
||||
|
||||
3. 查询: "微信小程序开发框架"
|
||||
<query>微信小程序 开发框架 in:name in:description in:readme</query>
|
||||
|
||||
请分析查询并仅使用XML标签回答:
|
||||
<query>提供优化的GitHub搜索查询,使用适当的字段和运算符,保留中文关键词</query>"""
|
||||
|
||||
try:
|
||||
# 构建提示数组
|
||||
prompts = [
|
||||
type_prompt,
|
||||
github_prompt,
|
||||
chinese_github_prompt,
|
||||
]
|
||||
|
||||
show_messages = [
|
||||
"分析查询类型...",
|
||||
"优化英文GitHub搜索参数...",
|
||||
"优化中文GitHub搜索参数...",
|
||||
]
|
||||
|
||||
sys_prompts = [
|
||||
"你是一个精通GitHub生态系统的专家,擅长分析与GitHub相关的查询。",
|
||||
"You are a GitHub search expert, specialized in converting natural language queries into optimized GitHub search queries in English.",
|
||||
"你是一个GitHub搜索专家,擅长处理查询并保留中文关键词进行搜索。",
|
||||
]
|
||||
|
||||
# 使用同步方式调用LLM
|
||||
responses = yield from request_gpt(
|
||||
inputs_array=prompts,
|
||||
inputs_show_user_array=show_messages,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[[] for _ in prompts],
|
||||
sys_prompt_array=sys_prompts,
|
||||
max_workers=3
|
||||
)
|
||||
|
||||
# 从收集的响应中提取我们需要的内容
|
||||
extracted_responses = []
|
||||
for i in range(len(prompts)):
|
||||
if (i * 2 + 1) < len(responses):
|
||||
response = responses[i * 2 + 1]
|
||||
if response is None:
|
||||
raise Exception(f"Response {i} is None")
|
||||
if not isinstance(response, str):
|
||||
try:
|
||||
response = str(response)
|
||||
except:
|
||||
raise Exception(f"Cannot convert response {i} to string")
|
||||
extracted_responses.append(response)
|
||||
else:
|
||||
raise Exception(f"未收到第 {i + 1} 个响应")
|
||||
|
||||
# 解析基本信息
|
||||
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
|
||||
if not query_type:
|
||||
print(
|
||||
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
|
||||
raise Exception("无法提取query_type标签内容")
|
||||
query_type = query_type.lower()
|
||||
|
||||
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
|
||||
if not main_topic:
|
||||
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
|
||||
main_topic = query
|
||||
|
||||
query_type = self._normalize_query_type(query_type, query)
|
||||
|
||||
# 提取子主题
|
||||
sub_topics = []
|
||||
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
|
||||
if sub_topics_text:
|
||||
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
|
||||
|
||||
# 提取语言
|
||||
language = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "language")
|
||||
|
||||
# 提取最低星标数
|
||||
min_stars = 0
|
||||
min_stars_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "min_stars")
|
||||
if min_stars_text and min_stars_text.isdigit():
|
||||
min_stars = int(min_stars_text)
|
||||
|
||||
# 解析GitHub搜索参数 - 英文
|
||||
english_github_query = self._extract_tag(extracted_responses[self.GITHUB_QUERY_INDEX], "query")
|
||||
|
||||
# 解析GitHub搜索参数 - 中文
|
||||
chinese_github_query = self._extract_tag(extracted_responses[2], "query")
|
||||
|
||||
# 构建GitHub参数
|
||||
github_params = {
|
||||
"query": english_github_query,
|
||||
"chinese_query": chinese_github_query,
|
||||
"sort": "stars", # 默认按星标排序
|
||||
"order": "desc", # 默认降序
|
||||
"per_page": 30, # 默认每页30条
|
||||
"page": 1 # 默认第1页
|
||||
}
|
||||
|
||||
# 检查是否为特定仓库查询
|
||||
repo_id = ""
|
||||
if "repo:" in english_github_query or "repository:" in english_github_query:
|
||||
repo_match = re.search(r'(repo|repository):([a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+)', english_github_query)
|
||||
if repo_match:
|
||||
repo_id = repo_match.group(2)
|
||||
|
||||
print(f"Debug - 提取的信息:")
|
||||
print(f"查询类型: {query_type}")
|
||||
print(f"主题: {main_topic}")
|
||||
print(f"子主题: {sub_topics}")
|
||||
print(f"语言: {language}")
|
||||
print(f"最低星标数: {min_stars}")
|
||||
print(f"英文GitHub参数: {english_github_query}")
|
||||
print(f"中文GitHub参数: {chinese_github_query}")
|
||||
print(f"特定仓库: {repo_id}")
|
||||
|
||||
# 更新返回的 SearchCriteria,包含中英文查询
|
||||
return SearchCriteria(
|
||||
query_type=query_type,
|
||||
main_topic=main_topic,
|
||||
sub_topics=sub_topics,
|
||||
language=language,
|
||||
min_stars=min_stars,
|
||||
github_params=github_params,
|
||||
original_query=query,
|
||||
repo_id=repo_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"分析查询失败: {str(e)}")
|
||||
|
||||
def _normalize_query_type(self, query_type: str, query: str) -> str:
|
||||
"""规范化查询类型"""
|
||||
if query_type in ["repo", "code", "user", "topic"]:
|
||||
return query_type
|
||||
|
||||
query_lower = query.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_lower:
|
||||
return type_name
|
||||
|
||||
query_type_lower = query_type.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_type_lower:
|
||||
return type_name
|
||||
|
||||
return "repo" # 默认返回repo类型
|
||||
|
||||
def _extract_tag(self, text: str, tag: str) -> str:
|
||||
"""提取标记内容"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 标准XML格式(处理多行和特殊字符)
|
||||
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 备用模式
|
||||
patterns = [
|
||||
rf"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
|
||||
rf"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
|
||||
rf"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
|
||||
rf"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
|
||||
rf"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
|
||||
]
|
||||
|
||||
# 尝试所有模式
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content: # 确保提取的内容不为空
|
||||
return content
|
||||
|
||||
# 如果所有模式都失败,返回空字符串
|
||||
return ""
|
||||
701
crazy_functions/paper_fns/auto_git/sources/github_source.py
Normal file
701
crazy_functions/paper_fns/auto_git/sources/github_source.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
|
||||
class GitHubSource:
|
||||
"""GitHub API实现"""
|
||||
|
||||
# 默认API密钥列表 - 可以放置多个GitHub令牌
|
||||
API_KEYS = [
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
# "your_github_token_1",
|
||||
# "your_github_token_2",
|
||||
# "your_github_token_3"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
|
||||
"""初始化GitHub API客户端
|
||||
|
||||
Args:
|
||||
api_key: GitHub个人访问令牌或令牌列表
|
||||
"""
|
||||
if api_key is None:
|
||||
self.api_keys = self.API_KEYS
|
||||
elif isinstance(api_key, str):
|
||||
self.api_keys = [api_key]
|
||||
else:
|
||||
self.api_keys = api_key
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.github.com"
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
"User-Agent": "GitHub-API-Python-Client"
|
||||
}
|
||||
|
||||
# 如果有可用的API密钥,随机选择一个
|
||||
if self.api_keys:
|
||||
selected_key = random.choice(self.api_keys)
|
||||
self.headers["Authorization"] = f"Bearer {selected_key}"
|
||||
print(f"已随机选择API密钥进行认证")
|
||||
else:
|
||||
print("警告: 未提供API密钥,将受到GitHub API请求限制")
|
||||
|
||||
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
|
||||
"""发送API请求
|
||||
|
||||
Args:
|
||||
method: HTTP方法 (GET, POST, PUT, DELETE等)
|
||||
endpoint: API端点
|
||||
params: URL参数
|
||||
data: 请求体数据
|
||||
|
||||
Returns:
|
||||
解析后的响应JSON
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# 为调试目的打印请求信息
|
||||
print(f"请求: {method} {url}")
|
||||
if params:
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 发送请求
|
||||
request_kwargs = {}
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if data:
|
||||
request_kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **request_kwargs) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
# 检查HTTP状态码
|
||||
if response.status >= 400:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {response_text}")
|
||||
return None
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
print(f"JSON解析错误: {response_text}")
|
||||
return None
|
||||
|
||||
# ===== 用户相关方法 =====
|
||||
|
||||
async def get_user(self, username: Optional[str] = None) -> Dict:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
|
||||
Returns:
|
||||
用户信息字典
|
||||
"""
|
||||
endpoint = "/user" if username is None else f"/users/{username}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
|
||||
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户的仓库列表
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
|
||||
params = {
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_user_starred(self, username: Optional[str] = None,
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户星标的仓库
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
星标仓库列表
|
||||
"""
|
||||
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 仓库相关方法 =====
|
||||
|
||||
async def get_repo(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的分支列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
分支列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/branches"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
|
||||
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的提交历史
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
sha: 特定提交SHA或分支名
|
||||
path: 文件路径筛选
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
提交列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
if sha:
|
||||
params["sha"] = sha
|
||||
if path:
|
||||
params["path"] = path
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
|
||||
"""获取特定提交的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
commit_sha: 提交SHA
|
||||
|
||||
Returns:
|
||||
提交详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 内容相关方法 =====
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
|
||||
"""获取文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 文件路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
文件内容信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
response = await self._request("GET", endpoint, params=params)
|
||||
if response and isinstance(response, dict) and "content" in response:
|
||||
try:
|
||||
# 解码Base64编码的文件内容
|
||||
content = base64.b64decode(response["content"].encode()).decode()
|
||||
response["decoded_content"] = content
|
||||
except Exception as e:
|
||||
print(f"解码文件内容时出错: {str(e)}")
|
||||
|
||||
return response
|
||||
|
||||
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
|
||||
"""获取目录内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 目录路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
# 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== Issues相关方法 =====
|
||||
|
||||
async def get_issues(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Issues列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: Issue状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Issues列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
|
||||
"""获取特定Issue的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
Issue详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
|
||||
"""获取Issue的评论
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
评论列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== Pull Requests相关方法 =====
|
||||
|
||||
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Pull Request列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: PR状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, popularity, long-running)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Pull Request列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
|
||||
"""获取特定Pull Request的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
Pull Request详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
|
||||
"""获取Pull Request中修改的文件
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
修改文件列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 搜索相关方法 =====
|
||||
|
||||
async def search_repositories(self, query: str, sort: str = "stars",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索仓库
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (stars, forks, updated)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/repositories"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_code(self, query: str, sort: str = "indexed",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索代码
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (indexed)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/code"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_issues(self, query: str, sort: str = "created",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索Issues和Pull Requests
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/issues"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_users(self, query: str, sort: str = "followers",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索用户
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (followers, repositories, joined)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/users"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 组织相关方法 =====
|
||||
|
||||
async def get_organization(self, org: str) -> Dict:
|
||||
"""获取组织信息
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
|
||||
Returns:
|
||||
组织信息
|
||||
"""
|
||||
endpoint = f"/orgs/{org}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_organization_repos(self, org: str, type: str = "all",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织的仓库列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
type: 仓库类型 (all, public, private, forks, sources, member, internal)
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/repos"
|
||||
params = {
|
||||
"type": type,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织成员列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
成员列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/members"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 更复杂的操作 =====
|
||||
|
||||
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库使用的编程语言及其比例
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
语言使用情况
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/languages"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的贡献者统计
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
贡献者统计信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的提交活动
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
提交活动统计
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def example_usage():
|
||||
"""GitHubSource使用示例"""
|
||||
# 创建客户端实例(可选传入API令牌)
|
||||
# github = GitHubSource(api_key="your_github_token")
|
||||
github = GitHubSource()
|
||||
|
||||
try:
|
||||
# 示例1:搜索热门Python仓库
|
||||
print("\n=== 示例1:搜索热门Python仓库 ===")
|
||||
repos = await github.search_repositories(
|
||||
query="language:python stars:>1000",
|
||||
sort="stars",
|
||||
order="desc",
|
||||
per_page=5
|
||||
)
|
||||
|
||||
if repos and "items" in repos:
|
||||
for i, repo in enumerate(repos["items"], 1):
|
||||
print(f"\n--- 仓库 {i} ---")
|
||||
print(f"名称: {repo['full_name']}")
|
||||
print(f"描述: {repo['description']}")
|
||||
print(f"星标数: {repo['stargazers_count']}")
|
||||
print(f"Fork数: {repo['forks_count']}")
|
||||
print(f"最近更新: {repo['updated_at']}")
|
||||
print(f"URL: {repo['html_url']}")
|
||||
|
||||
# 示例2:获取特定仓库的详情
|
||||
print("\n=== 示例2:获取特定仓库的详情 ===")
|
||||
repo_details = await github.get_repo("microsoft", "vscode")
|
||||
if repo_details:
|
||||
print(f"名称: {repo_details['full_name']}")
|
||||
print(f"描述: {repo_details['description']}")
|
||||
print(f"星标数: {repo_details['stargazers_count']}")
|
||||
print(f"Fork数: {repo_details['forks_count']}")
|
||||
print(f"默认分支: {repo_details['default_branch']}")
|
||||
print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}")
|
||||
print(f"语言: {repo_details['language']}")
|
||||
print(f"Open Issues数: {repo_details['open_issues_count']}")
|
||||
|
||||
# 示例3:获取仓库的提交历史
|
||||
print("\n=== 示例3:获取仓库的最近提交 ===")
|
||||
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
|
||||
if commits:
|
||||
for i, commit in enumerate(commits, 1):
|
||||
print(f"\n--- 提交 {i} ---")
|
||||
print(f"SHA: {commit['sha'][:7]}")
|
||||
print(f"作者: {commit['commit']['author']['name']}")
|
||||
print(f"日期: {commit['commit']['author']['date']}")
|
||||
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
|
||||
|
||||
# 示例4:搜索代码
|
||||
print("\n=== 示例4:搜索代码 ===")
|
||||
code_results = await github.search_code(
|
||||
query="filename:README.md language:markdown pytorch in:file",
|
||||
per_page=3
|
||||
)
|
||||
if code_results and "items" in code_results:
|
||||
print(f"共找到: {code_results['total_count']} 个结果")
|
||||
for i, item in enumerate(code_results["items"], 1):
|
||||
print(f"\n--- 代码 {i} ---")
|
||||
print(f"仓库: {item['repository']['full_name']}")
|
||||
print(f"文件: {item['path']}")
|
||||
print(f"URL: {item['html_url']}")
|
||||
|
||||
# 示例5:获取文件内容
|
||||
print("\n=== 示例5:获取文件内容 ===")
|
||||
file_content = await github.get_file_content("python", "cpython", "README.rst")
|
||||
if file_content and "decoded_content" in file_content:
|
||||
content = file_content["decoded_content"]
|
||||
print(f"文件名: {file_content['name']}")
|
||||
print(f"大小: {file_content['size']} 字节")
|
||||
print(f"内容预览: {content[:200]}...")
|
||||
|
||||
# 示例6:获取仓库使用的编程语言
|
||||
print("\n=== 示例6:获取仓库使用的编程语言 ===")
|
||||
languages = await github.get_repository_languages("facebook", "react")
|
||||
if languages:
|
||||
print(f"React仓库使用的编程语言:")
|
||||
for lang, bytes_of_code in languages.items():
|
||||
print(f"- {lang}: {bytes_of_code} 字节")
|
||||
|
||||
# 示例7:获取组织信息
|
||||
print("\n=== 示例7:获取组织信息 ===")
|
||||
org_info = await github.get_organization("google")
|
||||
if org_info:
|
||||
print(f"名称: {org_info['name']}")
|
||||
print(f"描述: {org_info.get('description', '无')}")
|
||||
print(f"位置: {org_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {org_info['public_repos']}")
|
||||
print(f"成员数: {org_info.get('public_members', 0)}")
|
||||
print(f"URL: {org_info['html_url']}")
|
||||
|
||||
# 示例8:获取用户信息
|
||||
print("\n=== 示例8:获取用户信息 ===")
|
||||
user_info = await github.get_user("torvalds")
|
||||
if user_info:
|
||||
print(f"名称: {user_info['name']}")
|
||||
print(f"公司: {user_info.get('company', '无')}")
|
||||
print(f"博客: {user_info.get('blog', '无')}")
|
||||
print(f"位置: {user_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {user_info['public_repos']}")
|
||||
print(f"关注者数: {user_info['followers']}")
|
||||
print(f"URL: {user_info['html_url']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
593
crazy_functions/paper_fns/document_structure_extractor.py
Normal file
593
crazy_functions/paper_fns/document_structure_extractor.py
Normal file
@@ -0,0 +1,593 @@
|
||||
from typing import List, Dict, Optional, Tuple, Union, Any
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
|
||||
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import (
|
||||
PaperStructureExtractor, PaperSection, StructuredPaper
|
||||
)
|
||||
from unstructured.partition.auto import partition
|
||||
from unstructured.documents.elements import (
|
||||
Text, Title, NarrativeText, ListItem, Table,
|
||||
Footer, Header, PageBreak, Image, Address
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class DocumentSection:
|
||||
"""通用文档章节数据类"""
|
||||
title: str # 章节标题,如果没有标题则为空字符串
|
||||
content: str # 章节内容
|
||||
level: int = 0 # 标题级别,0为主标题,1为一级标题,以此类推
|
||||
section_type: str = "content" # 章节类型
|
||||
is_heading_only: bool = False # 是否仅包含标题
|
||||
subsections: List['DocumentSection'] = field(default_factory=list) # 子章节列表
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredDocument:
|
||||
"""结构化文档数据类"""
|
||||
title: str = "" # 文档标题
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
|
||||
sections: List[DocumentSection] = field(default_factory=list) # 章节列表
|
||||
full_text: str = "" # 完整文本
|
||||
is_paper: bool = False # 是否为学术论文
|
||||
|
||||
|
||||
class GenericDocumentStructureExtractor:
|
||||
"""通用文档结构提取器
|
||||
|
||||
可以从各种文档格式中提取结构信息,包括标题和内容。
|
||||
支持论文、报告、文章和一般文本文档。
|
||||
"""
|
||||
|
||||
# 支持的文件扩展名
|
||||
SUPPORTED_EXTENSIONS = [
|
||||
'.pdf', '.docx', '.doc', '.pptx', '.ppt',
|
||||
'.txt', '.md', '.html', '.htm', '.xml',
|
||||
'.rtf', '.odt', '.epub', '.msg', '.eml'
|
||||
]
|
||||
|
||||
# 常见的标题前缀模式
|
||||
HEADING_PATTERNS = [
|
||||
# 数字标题 (1., 1.1., etc.)
|
||||
r'^\s*(\d+\.)+\s+',
|
||||
# 中文数字标题 (一、, 二、, etc.)
|
||||
r'^\s*[一二三四五六七八九十]+[、::]\s+',
|
||||
# 带括号的数字标题 ((1), (2), etc.)
|
||||
r'^\s*\(\s*\d+\s*\)\s+',
|
||||
# 特定标记的标题 (Chapter 1, Section 1, etc.)
|
||||
r'^\s*(chapter|section|part|附录|章|节)\s+\d+[\.::]\s+',
|
||||
]
|
||||
|
||||
# 常见的文档分段标记词
|
||||
SECTION_MARKERS = {
|
||||
'introduction': ['简介', '导言', '引言', 'introduction', '概述', 'overview'],
|
||||
'background': ['背景', '现状', 'background', '理论基础', '相关工作'],
|
||||
'main_content': ['主要内容', '正文', 'main content', '分析', '讨论'],
|
||||
'conclusion': ['结论', '总结', 'conclusion', '结语', '小结', 'summary'],
|
||||
'reference': ['参考', '参考文献', 'references', '文献', 'bibliography'],
|
||||
'appendix': ['附录', 'appendix', '补充资料', 'supplementary']
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化提取器"""
|
||||
self.paper_extractor = PaperStructureExtractor() # 论文专用提取器
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""配置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_document_structure(self, file_path: str, strategy: str = "fast") -> StructuredDocument:
|
||||
"""提取文档结构
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
strategy: 提取策略 ("fast" 或 "accurate")
|
||||
|
||||
Returns:
|
||||
StructuredDocument: 结构化文档对象
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"正在处理文档结构: {file_path}")
|
||||
|
||||
# 1. 首先尝试使用论文提取器
|
||||
try:
|
||||
paper_result = self.paper_extractor.extract_paper_structure(file_path)
|
||||
if paper_result and len(paper_result.sections) > 2: # 如果成功识别为论文结构
|
||||
self.logger.info(f"成功识别为学术论文: {file_path}")
|
||||
# 将论文结构转换为通用文档结构
|
||||
return self._convert_paper_to_document(paper_result)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"论文结构提取失败,将尝试通用提取: {str(e)}")
|
||||
|
||||
# 2. 使用通用方法提取文档结构
|
||||
elements = partition(
|
||||
str(file_path),
|
||||
strategy=strategy,
|
||||
include_metadata=True,
|
||||
nlp=False
|
||||
)
|
||||
|
||||
# 3. 使用通用提取器处理
|
||||
doc = self._extract_generic_structure(elements)
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"文档结构提取失败: {str(e)}")
|
||||
# 返回一个空的结构化文档
|
||||
return StructuredDocument(
|
||||
title="未能提取文档标题",
|
||||
sections=[DocumentSection(
|
||||
title="",
|
||||
content="",
|
||||
level=0,
|
||||
section_type="content"
|
||||
)]
|
||||
)
|
||||
|
||||
def _convert_paper_to_document(self, paper: StructuredPaper) -> StructuredDocument:
|
||||
"""将论文结构转换为通用文档结构
|
||||
|
||||
Args:
|
||||
paper: 结构化论文对象
|
||||
|
||||
Returns:
|
||||
StructuredDocument: 转换后的通用文档结构
|
||||
"""
|
||||
doc = StructuredDocument(
|
||||
title=paper.metadata.title,
|
||||
is_paper=True,
|
||||
full_text=paper.full_text
|
||||
)
|
||||
|
||||
# 转换元数据
|
||||
doc.metadata = {
|
||||
'title': paper.metadata.title,
|
||||
'authors': paper.metadata.authors,
|
||||
'keywords': paper.keywords,
|
||||
'abstract': paper.metadata.abstract if hasattr(paper.metadata, 'abstract') else "",
|
||||
'is_paper': True
|
||||
}
|
||||
|
||||
# 转换章节结构
|
||||
doc.sections = self._convert_paper_sections(paper.sections)
|
||||
|
||||
return doc
|
||||
|
||||
def _convert_paper_sections(self, paper_sections: List[PaperSection], level: int = 0) -> List[DocumentSection]:
|
||||
"""递归转换论文章节为通用文档章节
|
||||
|
||||
Args:
|
||||
paper_sections: 论文章节列表
|
||||
level: 当前章节级别
|
||||
|
||||
Returns:
|
||||
List[DocumentSection]: 通用文档章节列表
|
||||
"""
|
||||
doc_sections = []
|
||||
|
||||
for section in paper_sections:
|
||||
doc_section = DocumentSection(
|
||||
title=section.title,
|
||||
content=section.content,
|
||||
level=section.level,
|
||||
section_type=section.section_type,
|
||||
is_heading_only=False if section.content else True
|
||||
)
|
||||
|
||||
# 递归处理子章节
|
||||
if section.subsections:
|
||||
doc_section.subsections = self._convert_paper_sections(
|
||||
section.subsections, level + 1
|
||||
)
|
||||
|
||||
doc_sections.append(doc_section)
|
||||
|
||||
return doc_sections
|
||||
|
||||
def _extract_generic_structure(self, elements) -> StructuredDocument:
|
||||
"""从元素列表中提取通用文档结构
|
||||
|
||||
Args:
|
||||
elements: 文档元素列表
|
||||
|
||||
Returns:
|
||||
StructuredDocument: 结构化文档对象
|
||||
"""
|
||||
# 创建结构化文档对象
|
||||
doc = StructuredDocument(full_text="")
|
||||
|
||||
# 1. 提取文档标题
|
||||
title_candidates = []
|
||||
for i, element in enumerate(elements[:5]): # 只检查前5个元素
|
||||
if isinstance(element, Title):
|
||||
title_text = str(element).strip()
|
||||
title_candidates.append((i, title_text))
|
||||
|
||||
if title_candidates:
|
||||
# 使用第一个标题作为文档标题
|
||||
doc.title = title_candidates[0][1]
|
||||
|
||||
# 2. 识别所有标题元素和内容
|
||||
title_elements = []
|
||||
|
||||
# 2.1 首先识别所有标题
|
||||
for i, element in enumerate(elements):
|
||||
is_heading = False
|
||||
title_text = ""
|
||||
level = 0
|
||||
|
||||
# 检查元素类型
|
||||
if isinstance(element, Title):
|
||||
is_heading = True
|
||||
title_text = str(element).strip()
|
||||
|
||||
# 进一步检查是否为真正的标题
|
||||
if self._is_likely_heading(title_text, element, i, elements):
|
||||
level = self._estimate_heading_level(title_text, element)
|
||||
else:
|
||||
is_heading = False
|
||||
|
||||
# 也检查格式像标题的普通文本
|
||||
elif isinstance(element, (Text, NarrativeText)) and i > 0:
|
||||
text = str(element).strip()
|
||||
# 检查是否匹配标题模式
|
||||
if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS):
|
||||
# 检查长度和后续内容以确认是否为标题
|
||||
if len(text) < 100 and self._has_sufficient_following_content(i, elements):
|
||||
is_heading = True
|
||||
title_text = text
|
||||
level = self._estimate_heading_level(title_text, element)
|
||||
|
||||
if is_heading:
|
||||
section_type = self._identify_section_type(title_text)
|
||||
title_elements.append((i, title_text, level, section_type))
|
||||
|
||||
# 2.2 为每个标题提取内容
|
||||
sections = []
|
||||
|
||||
for i, (index, title_text, level, section_type) in enumerate(title_elements):
|
||||
# 确定内容范围
|
||||
content_start = index + 1
|
||||
content_end = elements[-1] # 默认到文档结束
|
||||
|
||||
# 如果有下一个标题,内容到下一个标题开始
|
||||
if i < len(title_elements) - 1:
|
||||
content_end = title_elements[i+1][0]
|
||||
else:
|
||||
content_end = len(elements)
|
||||
|
||||
# 提取内容
|
||||
content = self._extract_content_between(elements, content_start, content_end)
|
||||
|
||||
# 创建章节
|
||||
section = DocumentSection(
|
||||
title=title_text,
|
||||
content=content,
|
||||
level=level,
|
||||
section_type=section_type,
|
||||
is_heading_only=False if content.strip() else True
|
||||
)
|
||||
|
||||
sections.append(section)
|
||||
|
||||
# 3. 如果没有识别到任何章节,创建一个默认章节
|
||||
if not sections:
|
||||
all_content = self._extract_content_between(elements, 0, len(elements))
|
||||
|
||||
# 尝试从内容中提取标题
|
||||
first_line = all_content.split('\n')[0] if all_content else ""
|
||||
if first_line and len(first_line) < 100:
|
||||
doc.title = first_line
|
||||
all_content = '\n'.join(all_content.split('\n')[1:])
|
||||
|
||||
default_section = DocumentSection(
|
||||
title="",
|
||||
content=all_content,
|
||||
level=0,
|
||||
section_type="content"
|
||||
)
|
||||
sections.append(default_section)
|
||||
|
||||
# 4. 构建层次结构
|
||||
doc.sections = self._build_section_hierarchy(sections)
|
||||
|
||||
# 5. 提取完整文本
|
||||
doc.full_text = "\n\n".join([str(element) for element in elements if isinstance(element, (Text, NarrativeText, Title, ListItem))])
|
||||
|
||||
return doc
|
||||
|
||||
def _build_section_hierarchy(self, sections: List[DocumentSection]) -> List[DocumentSection]:
|
||||
"""构建章节层次结构
|
||||
|
||||
Args:
|
||||
sections: 章节列表
|
||||
|
||||
Returns:
|
||||
List[DocumentSection]: 具有层次结构的章节列表
|
||||
"""
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
# 按层级排序
|
||||
top_level_sections = []
|
||||
current_parents = {0: None} # 每个层级的当前父节点
|
||||
|
||||
for section in sections:
|
||||
# 找到当前节点的父节点
|
||||
parent_level = None
|
||||
for level in sorted([k for k in current_parents.keys() if k < section.level], reverse=True):
|
||||
parent_level = level
|
||||
break
|
||||
|
||||
if parent_level is None:
|
||||
# 顶级章节
|
||||
top_level_sections.append(section)
|
||||
else:
|
||||
# 子章节
|
||||
parent = current_parents[parent_level]
|
||||
if parent:
|
||||
parent.subsections.append(section)
|
||||
else:
|
||||
top_level_sections.append(section)
|
||||
|
||||
# 更新当前层级的父节点
|
||||
current_parents[section.level] = section
|
||||
|
||||
# 清除所有更深层级的父节点缓存
|
||||
deeper_levels = [k for k in current_parents.keys() if k > section.level]
|
||||
for level in deeper_levels:
|
||||
current_parents.pop(level, None)
|
||||
|
||||
return top_level_sections
|
||||
|
||||
def _is_likely_heading(self, text: str, element, index: int, elements) -> bool:
|
||||
"""判断文本是否可能是标题
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
element: 元素对象
|
||||
index: 元素索引
|
||||
elements: 所有元素列表
|
||||
|
||||
Returns:
|
||||
bool: 是否可能是标题
|
||||
"""
|
||||
# 1. 检查文本长度 - 标题通常不会太长
|
||||
if len(text) > 150: # 标题通常不超过150个字符
|
||||
return False
|
||||
|
||||
# 2. 检查是否匹配标题的数字编号模式
|
||||
if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS):
|
||||
return True
|
||||
|
||||
# 3. 检查是否包含常见章节标记词
|
||||
lower_text = text.lower()
|
||||
for markers in self.SECTION_MARKERS.values():
|
||||
if any(marker.lower() in lower_text for marker in markers):
|
||||
return True
|
||||
|
||||
# 4. 检查后续内容数量 - 标题后通常有足够多的内容
|
||||
if not self._has_sufficient_following_content(index, elements, min_chars=100):
|
||||
# 但如果文本很短且以特定格式开头,仍可能是标题
|
||||
if len(text) < 50 and (text.endswith(':') or text.endswith(':')):
|
||||
return True
|
||||
return False
|
||||
|
||||
# 5. 检查格式特征
|
||||
# 标题通常是元素的开头,不在段落中间
|
||||
if len(text.split('\n')) > 1:
|
||||
# 多行文本不太可能是标题
|
||||
return False
|
||||
|
||||
# 如果有元数据,检查字体特征(字体大小等)
|
||||
if hasattr(element, 'metadata') and element.metadata:
|
||||
try:
|
||||
font_size = getattr(element.metadata, 'font_size', None)
|
||||
is_bold = getattr(element.metadata, 'is_bold', False)
|
||||
|
||||
# 字体较大或加粗的文本更可能是标题
|
||||
if font_size and font_size > 12:
|
||||
return True
|
||||
if is_bold:
|
||||
return True
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# 默认返回True,因为元素已被识别为Title类型
|
||||
return True
|
||||
|
||||
def _estimate_heading_level(self, text: str, element) -> int:
|
||||
"""估计标题的层级
|
||||
|
||||
Args:
|
||||
text: 标题文本
|
||||
element: 元素对象
|
||||
|
||||
Returns:
|
||||
int: 标题层级 (0为主标题,1为一级标题, 等等)
|
||||
"""
|
||||
# 1. 通过编号模式判断层级
|
||||
for pattern, level in [
|
||||
(r'^\s*\d+\.\s+', 1), # 1. 开头 (一级标题)
|
||||
(r'^\s*\d+\.\d+\.\s+', 2), # 1.1. 开头 (二级标题)
|
||||
(r'^\s*\d+\.\d+\.\d+\.\s+', 3), # 1.1.1. 开头 (三级标题)
|
||||
(r'^\s*\d+\.\d+\.\d+\.\d+\.\s+', 4), # 1.1.1.1. 开头 (四级标题)
|
||||
]:
|
||||
if re.match(pattern, text):
|
||||
return level
|
||||
|
||||
# 2. 检查是否是常见的主要章节标题
|
||||
lower_text = text.lower()
|
||||
main_sections = [
|
||||
'abstract', 'introduction', 'background', 'methodology',
|
||||
'results', 'discussion', 'conclusion', 'references'
|
||||
]
|
||||
for section in main_sections:
|
||||
if section in lower_text:
|
||||
return 1 # 主要章节为一级标题
|
||||
|
||||
# 3. 根据文本特征判断
|
||||
if text.isupper(): # 全大写文本可能是章标题
|
||||
return 1
|
||||
|
||||
# 4. 通过元数据判断层级
|
||||
if hasattr(element, 'metadata') and element.metadata:
|
||||
try:
|
||||
# 根据字体大小判断层级
|
||||
font_size = getattr(element.metadata, 'font_size', None)
|
||||
if font_size is not None:
|
||||
if font_size > 18: # 假设主标题字体最大
|
||||
return 0
|
||||
elif font_size > 16:
|
||||
return 1
|
||||
elif font_size > 14:
|
||||
return 2
|
||||
else:
|
||||
return 3
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# 默认为二级标题
|
||||
return 2
|
||||
|
||||
def _identify_section_type(self, title_text: str) -> str:
|
||||
"""识别章节类型,包括参考文献部分"""
|
||||
lower_text = title_text.lower()
|
||||
|
||||
# 特别检查是否为参考文献部分
|
||||
references_patterns = [
|
||||
r'references', r'参考文献', r'bibliography', r'引用文献',
|
||||
r'literature cited', r'^cited\s+literature', r'^文献$', r'^引用$'
|
||||
]
|
||||
|
||||
for pattern in references_patterns:
|
||||
if re.search(pattern, lower_text, re.IGNORECASE):
|
||||
return "references"
|
||||
|
||||
# 检查是否匹配其他常见章节类型
|
||||
for section_type, markers in self.SECTION_MARKERS.items():
|
||||
if any(marker.lower() in lower_text for marker in markers):
|
||||
return section_type
|
||||
|
||||
# 检查带编号的章节
|
||||
if re.match(r'^\d+\.', lower_text):
|
||||
return "content"
|
||||
|
||||
# 默认为内容章节
|
||||
return "content"
|
||||
|
||||
def _has_sufficient_following_content(self, index: int, elements, min_chars: int = 150) -> bool:
|
||||
"""检查元素后是否有足够的内容
|
||||
|
||||
Args:
|
||||
index: 当前元素索引
|
||||
elements: 所有元素列表
|
||||
min_chars: 最小字符数要求
|
||||
|
||||
Returns:
|
||||
bool: 是否有足够的内容
|
||||
"""
|
||||
total_chars = 0
|
||||
for i in range(index + 1, min(index + 5, len(elements))):
|
||||
if isinstance(elements[i], Title):
|
||||
# 如果紧接着是标题,就停止检查
|
||||
break
|
||||
if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)):
|
||||
total_chars += len(str(elements[i]))
|
||||
if total_chars >= min_chars:
|
||||
return True
|
||||
|
||||
return total_chars >= min_chars
|
||||
|
||||
def _extract_content_between(self, elements, start_index: int, end_index: int) -> str:
|
||||
"""提取指定范围内的内容文本
|
||||
|
||||
Args:
|
||||
elements: 元素列表
|
||||
start_index: 开始索引
|
||||
end_index: 结束索引
|
||||
|
||||
Returns:
|
||||
str: 提取的内容文本
|
||||
"""
|
||||
content_parts = []
|
||||
|
||||
for i in range(start_index, end_index):
|
||||
if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)):
|
||||
content_parts.append(str(elements[i]).strip())
|
||||
|
||||
return "\n\n".join([part for part in content_parts if part])
|
||||
|
||||
def generate_markdown(self, doc: StructuredDocument) -> str:
|
||||
"""将结构化文档转换为Markdown格式
|
||||
|
||||
Args:
|
||||
doc: 结构化文档对象
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
md_parts = []
|
||||
|
||||
# 添加标题
|
||||
if doc.title:
|
||||
md_parts.append(f"# {doc.title}\n")
|
||||
|
||||
# 添加元数据
|
||||
if doc.is_paper:
|
||||
# 作者信息
|
||||
if 'authors' in doc.metadata and doc.metadata['authors']:
|
||||
authors_str = ", ".join(doc.metadata['authors'])
|
||||
md_parts.append(f"**作者:** {authors_str}\n")
|
||||
|
||||
# 关键词
|
||||
if 'keywords' in doc.metadata and doc.metadata['keywords']:
|
||||
keywords_str = ", ".join(doc.metadata['keywords'])
|
||||
md_parts.append(f"**关键词:** {keywords_str}\n")
|
||||
|
||||
# 摘要
|
||||
if 'abstract' in doc.metadata and doc.metadata['abstract']:
|
||||
md_parts.append(f"## 摘要\n\n{doc.metadata['abstract']}\n")
|
||||
|
||||
# 添加章节内容
|
||||
md_parts.append(self._format_sections_markdown(doc.sections))
|
||||
|
||||
return "\n".join(md_parts)
|
||||
|
||||
def _format_sections_markdown(self, sections: List[DocumentSection], base_level: int = 0) -> str:
|
||||
"""递归格式化章节为Markdown
|
||||
|
||||
Args:
|
||||
sections: 章节列表
|
||||
base_level: 基础层级
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
md_parts = []
|
||||
|
||||
for section in sections:
|
||||
# 计算标题级别 (确保不超过6级)
|
||||
header_level = min(section.level + base_level + 1, 6)
|
||||
|
||||
# 添加标题和内容
|
||||
if section.title:
|
||||
md_parts.append(f"{'#' * header_level} {section.title}\n")
|
||||
|
||||
if section.content:
|
||||
md_parts.append(f"{section.content}\n")
|
||||
|
||||
# 递归处理子章节
|
||||
if section.subsections:
|
||||
md_parts.append(self._format_sections_markdown(
|
||||
section.subsections, base_level
|
||||
))
|
||||
|
||||
return "\n".join(md_parts)
|
||||
4
crazy_functions/paper_fns/file2file_doc/__init__.py
Normal file
4
crazy_functions/paper_fns/file2file_doc/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .txt_doc import TxtFormatter
|
||||
from .markdown_doc import MarkdownFormatter
|
||||
from .html_doc import HtmlFormatter
|
||||
from .word_doc import WordFormatter
|
||||
300
crazy_functions/paper_fns/file2file_doc/html_doc.py
Normal file
300
crazy_functions/paper_fns/file2file_doc/html_doc.py
Normal file
@@ -0,0 +1,300 @@
|
||||
class HtmlFormatter:
|
||||
"""HTML格式文档生成器 - 保留原始文档结构"""
|
||||
|
||||
def __init__(self, processing_type="文本处理"):
|
||||
self.processing_type = processing_type
|
||||
self.css_styles = """
|
||||
:root {
|
||||
--primary-color: #2563eb;
|
||||
--primary-light: #eff6ff;
|
||||
--secondary-color: #1e293b;
|
||||
--background-color: #f8fafc;
|
||||
--text-color: #334155;
|
||||
--border-color: #e2e8f0;
|
||||
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
line-height: 1.8;
|
||||
margin: 0;
|
||||
padding: 2rem;
|
||||
color: var(--text-color);
|
||||
background-color: var(--background-color);
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 16px;
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
::selection {
|
||||
background: var(--primary-light);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
.container {
|
||||
animation: fadeIn 0.6s ease-out;
|
||||
}
|
||||
|
||||
.document-title {
|
||||
color: var(--primary-color);
|
||||
font-size: 2em;
|
||||
text-align: center;
|
||||
margin: 1rem 0 2rem;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: 2px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.document-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
margin: 2rem 0;
|
||||
}
|
||||
|
||||
.document-header {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.processing-type {
|
||||
color: var(--secondary-color);
|
||||
font-size: 1.2em;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.processing-date {
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.document-content {
|
||||
background: white;
|
||||
padding: 1.5rem;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid var(--primary-color);
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
/* 保留文档结构的样式 */
|
||||
h1, h2, h3, h4, h5, h6 {
|
||||
color: var(--secondary-color);
|
||||
margin-top: 1.5em;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
h1 { font-size: 1.8em; }
|
||||
h2 { font-size: 1.5em; }
|
||||
h3 { font-size: 1.3em; }
|
||||
h4 { font-size: 1.1em; }
|
||||
|
||||
p {
|
||||
margin: 0.8em 0;
|
||||
}
|
||||
|
||||
ul, ol {
|
||||
margin: 1em 0;
|
||||
padding-left: 2em;
|
||||
}
|
||||
|
||||
li {
|
||||
margin: 0.5em 0;
|
||||
}
|
||||
|
||||
blockquote {
|
||||
margin: 1em 0;
|
||||
padding: 0.5em 1em;
|
||||
border-left: 4px solid var(--primary-light);
|
||||
background: rgba(0,0,0,0.02);
|
||||
}
|
||||
|
||||
code {
|
||||
font-family: monospace;
|
||||
background: rgba(0,0,0,0.05);
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
pre {
|
||||
background: rgba(0,0,0,0.05);
|
||||
padding: 1em;
|
||||
border-radius: 5px;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
pre code {
|
||||
background: transparent;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background-color: #0f172a;
|
||||
--text-color: #e2e8f0;
|
||||
--border-color: #1e293b;
|
||||
}
|
||||
|
||||
.container, .document-content {
|
||||
background: #1e293b;
|
||||
}
|
||||
|
||||
blockquote {
|
||||
background: rgba(255,255,255,0.05);
|
||||
}
|
||||
|
||||
code, pre {
|
||||
background: rgba(255,255,255,0.05);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _escape_html(self, text):
|
||||
"""转义HTML特殊字符"""
|
||||
import html
|
||||
return html.escape(text)
|
||||
|
||||
def _markdown_to_html(self, text):
|
||||
"""将Markdown格式转换为HTML格式,保留文档结构"""
|
||||
try:
|
||||
import markdown
|
||||
# 使用Python-Markdown库将markdown转换为HTML,启用更多扩展以支持嵌套列表
|
||||
return markdown.markdown(text, extensions=['tables', 'fenced_code', 'codehilite', 'nl2br', 'sane_lists', 'smarty', 'extra'])
|
||||
except ImportError:
|
||||
# 如果没有markdown库,使用更复杂的替换来处理嵌套列表
|
||||
import re
|
||||
|
||||
# 替换标题
|
||||
text = re.sub(r'^# (.+)$', r'<h1>\1</h1>', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^## (.+)$', r'<h2>\1</h2>', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^### (.+)$', r'<h3>\1</h3>', text, flags=re.MULTILINE)
|
||||
|
||||
# 预处理列表 - 在列表项之间添加空行以正确分隔
|
||||
# 处理编号列表
|
||||
text = re.sub(r'(\n\d+\.\s.+)(\n\d+\.\s)', r'\1\n\2', text)
|
||||
# 处理项目符号列表
|
||||
text = re.sub(r'(\n•\s.+)(\n•\s)', r'\1\n\2', text)
|
||||
text = re.sub(r'(\n\*\s.+)(\n\*\s)', r'\1\n\2', text)
|
||||
text = re.sub(r'(\n-\s.+)(\n-\s)', r'\1\n\2', text)
|
||||
|
||||
# 处理嵌套列表 - 确保正确的缩进和结构
|
||||
lines = text.split('\n')
|
||||
in_list = False
|
||||
list_type = None # 'ol' 或 'ul'
|
||||
list_html = []
|
||||
normal_lines = []
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
|
||||
# 匹配编号列表项
|
||||
numbered_match = re.match(r'^(\d+)\.\s+(.+)$', line)
|
||||
# 匹配项目符号列表项
|
||||
bullet_match = re.match(r'^[•\*-]\s+(.+)$', line)
|
||||
|
||||
if numbered_match:
|
||||
if not in_list or list_type != 'ol':
|
||||
# 开始新的编号列表
|
||||
if in_list:
|
||||
# 关闭前一个列表
|
||||
list_html.append(f'</{list_type}>')
|
||||
list_html.append('<ol>')
|
||||
in_list = True
|
||||
list_type = 'ol'
|
||||
|
||||
num, content = numbered_match.groups()
|
||||
list_html.append(f'<li>{content}</li>')
|
||||
|
||||
elif bullet_match:
|
||||
if not in_list or list_type != 'ul':
|
||||
# 开始新的项目符号列表
|
||||
if in_list:
|
||||
# 关闭前一个列表
|
||||
list_html.append(f'</{list_type}>')
|
||||
list_html.append('<ul>')
|
||||
in_list = True
|
||||
list_type = 'ul'
|
||||
|
||||
content = bullet_match.group(1)
|
||||
list_html.append(f'<li>{content}</li>')
|
||||
|
||||
else:
|
||||
if in_list:
|
||||
# 结束当前列表
|
||||
list_html.append(f'</{list_type}>')
|
||||
in_list = False
|
||||
# 将完成的列表添加到正常行中
|
||||
normal_lines.append(''.join(list_html))
|
||||
list_html = []
|
||||
|
||||
normal_lines.append(line)
|
||||
|
||||
i += 1
|
||||
|
||||
# 如果最后还在列表中,确保关闭列表
|
||||
if in_list:
|
||||
list_html.append(f'</{list_type}>')
|
||||
normal_lines.append(''.join(list_html))
|
||||
|
||||
# 重建文本
|
||||
text = '\n'.join(normal_lines)
|
||||
|
||||
# 替换段落,但避免处理已经是HTML标签的部分
|
||||
paragraphs = text.split('\n\n')
|
||||
for i, p in enumerate(paragraphs):
|
||||
# 如果不是以HTML标签开始且不为空
|
||||
if not (p.strip().startswith('<') and p.strip().endswith('>')) and p.strip() != '':
|
||||
paragraphs[i] = f'<p>{p}</p>'
|
||||
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
def create_document(self, content: str) -> str:
|
||||
"""生成完整的HTML文档,保留原始文档结构
|
||||
|
||||
Args:
|
||||
content: 处理后的文档内容
|
||||
|
||||
Returns:
|
||||
str: 完整的HTML文档字符串
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# 将markdown内容转换为HTML
|
||||
html_content = self._markdown_to_html(content)
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>文档处理结果</title>
|
||||
<style>{self.css_styles}</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 class="document-title">文档处理结果</h1>
|
||||
|
||||
<div class="document-header">
|
||||
<div class="processing-type">处理方式: {self._escape_html(self.processing_type)}</div>
|
||||
<div class="processing-date">处理时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>
|
||||
</div>
|
||||
|
||||
<div class="document-content">
|
||||
{html_content}
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
40
crazy_functions/paper_fns/file2file_doc/markdown_doc.py
Normal file
40
crazy_functions/paper_fns/file2file_doc/markdown_doc.py
Normal file
@@ -0,0 +1,40 @@
|
||||
class MarkdownFormatter:
|
||||
"""Markdown格式文档生成器 - 保留原始文档结构"""
|
||||
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
|
||||
def _add_content(self, text: str):
|
||||
"""添加正文内容"""
|
||||
if text:
|
||||
self.content.append(f"\n{text}\n")
|
||||
|
||||
def create_document(self, content: str, processing_type: str = "文本处理") -> str:
|
||||
"""
|
||||
创建完整的Markdown文档,保留原始文档结构
|
||||
Args:
|
||||
content: 处理后的文档内容
|
||||
processing_type: 处理类型(润色、翻译等)
|
||||
Returns:
|
||||
str: 生成的Markdown文本
|
||||
"""
|
||||
self.content = []
|
||||
|
||||
# 添加标题和说明
|
||||
self.content.append(f"# 文档处理结果\n")
|
||||
self.content.append(f"## 处理方式: {processing_type}\n")
|
||||
|
||||
# 添加处理时间
|
||||
from datetime import datetime
|
||||
self.content.append(f"*处理时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n")
|
||||
|
||||
# 添加分隔线
|
||||
self.content.append("---\n")
|
||||
|
||||
# 添加原始内容,保留结构
|
||||
self.content.append(content)
|
||||
|
||||
# 添加结尾分隔线
|
||||
self.content.append("\n---\n")
|
||||
|
||||
return "\n".join(self.content)
|
||||
69
crazy_functions/paper_fns/file2file_doc/txt_doc.py
Normal file
69
crazy_functions/paper_fns/file2file_doc/txt_doc.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import re
|
||||
|
||||
def convert_markdown_to_txt(markdown_text):
|
||||
"""Convert markdown text to plain text while preserving formatting"""
|
||||
# Standardize line endings
|
||||
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
# 1. Handle headers but keep their formatting instead of removing them
|
||||
markdown_text = re.sub(r'^#\s+(.+)$', r'# \1', markdown_text, flags=re.MULTILINE)
|
||||
markdown_text = re.sub(r'^##\s+(.+)$', r'## \1', markdown_text, flags=re.MULTILINE)
|
||||
markdown_text = re.sub(r'^###\s+(.+)$', r'### \1', markdown_text, flags=re.MULTILINE)
|
||||
|
||||
# 2. Handle bold and italic - simply remove markers
|
||||
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text)
|
||||
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text)
|
||||
|
||||
# 3. Handle lists but preserve formatting
|
||||
markdown_text = re.sub(r'^\s*[-*+]\s+(.+?)(?=\n|$)', r'• \1', markdown_text, flags=re.MULTILINE)
|
||||
|
||||
# 4. Handle links - keep only the text
|
||||
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1 (\2)', markdown_text)
|
||||
|
||||
# 5. Handle HTML links - convert to user-friendly format
|
||||
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)', markdown_text)
|
||||
|
||||
# 6. Preserve paragraph breaks
|
||||
markdown_text = re.sub(r'\n{3,}', '\n\n', markdown_text) # normalize multiple newlines to double newlines
|
||||
|
||||
# 7. Clean up extra spaces but maintain indentation
|
||||
markdown_text = re.sub(r' +', ' ', markdown_text)
|
||||
|
||||
return markdown_text.strip()
|
||||
|
||||
|
||||
class TxtFormatter:
|
||||
"""文本格式化器 - 保留原始文档结构"""
|
||||
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
self._setup_document()
|
||||
|
||||
def _setup_document(self):
|
||||
"""初始化文档标题"""
|
||||
self.content.append("=" * 50)
|
||||
self.content.append("处理后文档".center(48))
|
||||
self.content.append("=" * 50)
|
||||
|
||||
def _format_header(self):
|
||||
"""创建文档头部信息"""
|
||||
from datetime import datetime
|
||||
date_str = datetime.now().strftime('%Y年%m月%d日')
|
||||
return [
|
||||
date_str.center(48),
|
||||
"\n" # 添加空行
|
||||
]
|
||||
|
||||
def create_document(self, content):
|
||||
"""生成保留原始结构的文档"""
|
||||
# 添加头部信息
|
||||
self.content.extend(self._format_header())
|
||||
|
||||
# 处理内容,保留原始结构
|
||||
processed_content = convert_markdown_to_txt(content)
|
||||
|
||||
# 添加处理后的内容
|
||||
self.content.append(processed_content)
|
||||
|
||||
# 合并所有内容
|
||||
return "\n".join(self.content)
|
||||
125
crazy_functions/paper_fns/file2file_doc/word2pdf.py
Normal file
125
crazy_functions/paper_fns/file2file_doc/word2pdf.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from docx2pdf import convert
|
||||
import os
|
||||
import platform
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
class WordToPdfConverter:
|
||||
"""Word文档转PDF转换器"""
|
||||
|
||||
@staticmethod
|
||||
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将Word文档转换为PDF
|
||||
|
||||
参数:
|
||||
word_path: Word文档的路径
|
||||
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
|
||||
异常:
|
||||
如果转换失败,将抛出相应异常
|
||||
"""
|
||||
try:
|
||||
# 确保输入路径是Path对象
|
||||
word_path = Path(word_path)
|
||||
|
||||
# 如果未指定pdf_path,则使用与word文档相同的名称
|
||||
if pdf_path is None:
|
||||
pdf_path = word_path.with_suffix('.pdf')
|
||||
else:
|
||||
pdf_path = Path(pdf_path)
|
||||
|
||||
# 检查操作系统
|
||||
if platform.system() == 'Linux':
|
||||
# Linux系统需要安装libreoffice
|
||||
if not os.system('which libreoffice') == 0:
|
||||
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
|
||||
|
||||
# 使用libreoffice进行转换
|
||||
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
|
||||
|
||||
# 如果输出路径与默认生成的不同,则重命名
|
||||
default_pdf = word_path.with_suffix('.pdf')
|
||||
if default_pdf != pdf_path:
|
||||
os.rename(default_pdf, pdf_path)
|
||||
else:
|
||||
# Windows和MacOS使用docx2pdf
|
||||
convert(word_path, pdf_path)
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
|
||||
"""
|
||||
批量转换目录下的所有Word文档
|
||||
|
||||
参数:
|
||||
word_dir: 包含Word文档的目录路径
|
||||
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径列表
|
||||
"""
|
||||
word_dir = Path(word_dir)
|
||||
if pdf_dir:
|
||||
pdf_dir = Path(pdf_dir)
|
||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
converted_files = []
|
||||
|
||||
for word_file in word_dir.glob("*.docx"):
|
||||
try:
|
||||
if pdf_dir:
|
||||
pdf_path = pdf_dir / word_file.with_suffix('.pdf').name
|
||||
else:
|
||||
pdf_path = word_file.with_suffix('.pdf')
|
||||
|
||||
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
|
||||
converted_files.append(pdf_file)
|
||||
|
||||
except Exception as e:
|
||||
print(f"转换 {word_file} 失败: {str(e)}")
|
||||
|
||||
return converted_files
|
||||
|
||||
@staticmethod
|
||||
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将docx对象直接转换为PDF
|
||||
|
||||
参数:
|
||||
doc: python-docx的Document对象
|
||||
output_dir: 可选,输出目录。如果未指定,将使用当前目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
"""
|
||||
try:
|
||||
# 设置临时文件路径和输出路径
|
||||
output_dir = Path(output_dir) if output_dir else Path.cwd()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成临时word文件
|
||||
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
|
||||
doc.save(temp_docx)
|
||||
|
||||
# 转换为PDF
|
||||
pdf_path = temp_docx.with_suffix('.pdf')
|
||||
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
|
||||
|
||||
# 删除临时word文件
|
||||
temp_docx.unlink()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
if temp_docx.exists():
|
||||
temp_docx.unlink()
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
236
crazy_functions/paper_fns/file2file_doc/word_doc.py
Normal file
236
crazy_functions/paper_fns/file2file_doc/word_doc.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import re
|
||||
from docx import Document
|
||||
from docx.shared import Cm, Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||
from docx.enum.style import WD_STYLE_TYPE
|
||||
from docx.oxml.ns import qn
|
||||
from datetime import datetime
|
||||
|
||||
def convert_markdown_to_word(markdown_text):
|
||||
# 0. 首先标准化所有换行符为\n
|
||||
markdown_text = markdown_text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
# 1. 处理标题 - 支持更多级别的标题,使用更精确的正则
|
||||
# 保留标题标记,以便后续处理时还能识别出标题级别
|
||||
markdown_text = re.sub(r'^(#{1,6})\s+(.+?)(?:\s+#+)?$', r'\1 \2', markdown_text, flags=re.MULTILINE)
|
||||
|
||||
# 2. 处理粗体、斜体和加粗斜体
|
||||
markdown_text = re.sub(r'\*\*\*(.+?)\*\*\*', r'\1', markdown_text) # 加粗斜体
|
||||
markdown_text = re.sub(r'\*\*(.+?)\*\*', r'\1', markdown_text) # 加粗
|
||||
markdown_text = re.sub(r'\*(.+?)\*', r'\1', markdown_text) # 斜体
|
||||
markdown_text = re.sub(r'_(.+?)_', r'\1', markdown_text) # 下划线斜体
|
||||
markdown_text = re.sub(r'__(.+?)__', r'\1', markdown_text) # 下划线加粗
|
||||
|
||||
# 3. 处理代码块 - 不移除,而是简化格式
|
||||
# 多行代码块
|
||||
markdown_text = re.sub(r'```(?:\w+)?\n([\s\S]*?)```', r'[代码块]\n\1[/代码块]', markdown_text)
|
||||
# 单行代码
|
||||
markdown_text = re.sub(r'`([^`]+)`', r'[代码]\1[/代码]', markdown_text)
|
||||
|
||||
# 4. 处理列表 - 保留列表结构
|
||||
# 匹配无序列表
|
||||
markdown_text = re.sub(r'^(\s*)[-*+]\s+(.+?)$', r'\1• \2', markdown_text, flags=re.MULTILINE)
|
||||
|
||||
# 5. 处理Markdown链接
|
||||
markdown_text = re.sub(r'\[([^\]]+)\]\(([^)]+?)\s*(?:"[^"]*")?\)', r'\1 (\2)', markdown_text)
|
||||
|
||||
# 6. 处理HTML链接
|
||||
markdown_text = re.sub(r'<a href=[\'"]([^\'"]+)[\'"](?:\s+target=[\'"][^\'"]+[\'"])?>([^<]+)</a>', r'\2 (\1)', markdown_text)
|
||||
|
||||
# 7. 处理图片
|
||||
markdown_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', r'[图片:\1]', markdown_text)
|
||||
|
||||
return markdown_text
|
||||
|
||||
|
||||
class WordFormatter:
|
||||
"""文档Word格式化器 - 保留原始文档结构"""
|
||||
|
||||
def __init__(self):
|
||||
self.doc = Document()
|
||||
self._setup_document()
|
||||
self._create_styles()
|
||||
|
||||
def _setup_document(self):
|
||||
"""设置文档基本格式,包括页面设置和页眉"""
|
||||
sections = self.doc.sections
|
||||
for section in sections:
|
||||
# 设置页面大小为A4
|
||||
section.page_width = Cm(21)
|
||||
section.page_height = Cm(29.7)
|
||||
# 设置页边距
|
||||
section.top_margin = Cm(3.7) # 上边距37mm
|
||||
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||
section.left_margin = Cm(2.8) # 左边距28mm
|
||||
section.right_margin = Cm(2.6) # 右边距26mm
|
||||
# 设置页眉页脚距离
|
||||
section.header_distance = Cm(2.0)
|
||||
section.footer_distance = Cm(2.0)
|
||||
|
||||
# 添加页眉
|
||||
header = section.header
|
||||
header_para = header.paragraphs[0]
|
||||
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
|
||||
header_run = header_para.add_run("文档处理结果")
|
||||
header_run.font.name = '仿宋'
|
||||
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
header_run.font.size = Pt(9)
|
||||
|
||||
def _create_styles(self):
|
||||
"""创建文档样式"""
|
||||
# 创建正文样式
|
||||
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
style.font.name = '仿宋'
|
||||
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
style.font.size = Pt(12) # 调整为12磅
|
||||
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
style.paragraph_format.space_after = Pt(0)
|
||||
|
||||
# 创建标题样式
|
||||
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
title_style.font.name = '黑体'
|
||||
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
title_style.font.size = Pt(22) # 调整为22磅
|
||||
title_style.font.bold = True
|
||||
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
title_style.paragraph_format.space_before = Pt(0)
|
||||
title_style.paragraph_format.space_after = Pt(24)
|
||||
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
|
||||
# 创建标题1样式
|
||||
h1_style = self.doc.styles.add_style('Heading1_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
h1_style.font.name = '黑体'
|
||||
h1_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
h1_style.font.size = Pt(18)
|
||||
h1_style.font.bold = True
|
||||
h1_style.paragraph_format.space_before = Pt(12)
|
||||
h1_style.paragraph_format.space_after = Pt(6)
|
||||
|
||||
# 创建标题2样式
|
||||
h2_style = self.doc.styles.add_style('Heading2_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
h2_style.font.name = '黑体'
|
||||
h2_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
h2_style.font.size = Pt(16)
|
||||
h2_style.font.bold = True
|
||||
h2_style.paragraph_format.space_before = Pt(10)
|
||||
h2_style.paragraph_format.space_after = Pt(6)
|
||||
|
||||
# 创建标题3样式
|
||||
h3_style = self.doc.styles.add_style('Heading3_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
h3_style.font.name = '黑体'
|
||||
h3_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
h3_style.font.size = Pt(14)
|
||||
h3_style.font.bold = True
|
||||
h3_style.paragraph_format.space_before = Pt(8)
|
||||
h3_style.paragraph_format.space_after = Pt(4)
|
||||
|
||||
# 创建代码块样式
|
||||
code_style = self.doc.styles.add_style('Code_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
code_style.font.name = 'Courier New'
|
||||
code_style.font.size = Pt(11)
|
||||
code_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
|
||||
code_style.paragraph_format.space_before = Pt(6)
|
||||
code_style.paragraph_format.space_after = Pt(6)
|
||||
code_style.paragraph_format.left_indent = Pt(36)
|
||||
code_style.paragraph_format.right_indent = Pt(36)
|
||||
|
||||
# 创建列表样式
|
||||
list_style = self.doc.styles.add_style('List_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
list_style.font.name = '仿宋'
|
||||
list_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
list_style.font.size = Pt(12)
|
||||
list_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
list_style.paragraph_format.left_indent = Pt(21)
|
||||
list_style.paragraph_format.first_line_indent = Pt(-21)
|
||||
|
||||
def create_document(self, content: str, processing_type: str = "文本处理"):
|
||||
"""创建文档,保留原始结构"""
|
||||
# 添加标题
|
||||
title_para = self.doc.add_paragraph(style='Title_Custom')
|
||||
title_run = title_para.add_run('文档处理结果')
|
||||
|
||||
# 添加处理类型
|
||||
processing_para = self.doc.add_paragraph()
|
||||
processing_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
processing_run = processing_para.add_run(f"处理方式: {processing_type}")
|
||||
processing_run.font.name = '仿宋'
|
||||
processing_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
processing_run.font.size = Pt(14)
|
||||
|
||||
# 添加日期
|
||||
date_para = self.doc.add_paragraph()
|
||||
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
date_run = date_para.add_run(f"处理时间: {datetime.now().strftime('%Y年%m月%d日')}")
|
||||
date_run.font.name = '仿宋'
|
||||
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
date_run.font.size = Pt(14)
|
||||
|
||||
self.doc.add_paragraph() # 添加空行
|
||||
|
||||
# 预处理内容,将Markdown格式转换为适合Word的格式
|
||||
processed_content = convert_markdown_to_word(content)
|
||||
|
||||
# 按行处理文本,保留结构
|
||||
lines = processed_content.split('\n')
|
||||
in_code_block = False
|
||||
current_paragraph = None
|
||||
|
||||
for line in lines:
|
||||
# 检查是否为标题
|
||||
header_match = re.match(r'^(#{1,6})\s+(.+)$', line)
|
||||
|
||||
if header_match:
|
||||
# 根据#的数量确定标题级别
|
||||
level = len(header_match.group(1))
|
||||
title_text = header_match.group(2)
|
||||
|
||||
if level == 1:
|
||||
style = 'Heading1_Custom'
|
||||
elif level == 2:
|
||||
style = 'Heading2_Custom'
|
||||
else:
|
||||
style = 'Heading3_Custom'
|
||||
|
||||
self.doc.add_paragraph(title_text, style=style)
|
||||
current_paragraph = None
|
||||
|
||||
# 检查代码块标记
|
||||
elif '[代码块]' in line:
|
||||
in_code_block = True
|
||||
current_paragraph = self.doc.add_paragraph(style='Code_Custom')
|
||||
code_line = line.replace('[代码块]', '').strip()
|
||||
if code_line:
|
||||
current_paragraph.add_run(code_line)
|
||||
|
||||
elif '[/代码块]' in line:
|
||||
in_code_block = False
|
||||
code_line = line.replace('[/代码块]', '').strip()
|
||||
if code_line and current_paragraph:
|
||||
current_paragraph.add_run(code_line)
|
||||
current_paragraph = None
|
||||
|
||||
# 检查列表项
|
||||
elif line.strip().startswith('•'):
|
||||
p = self.doc.add_paragraph(style='List_Custom')
|
||||
p.add_run(line.strip())
|
||||
current_paragraph = None
|
||||
|
||||
# 处理普通文本行
|
||||
elif line.strip():
|
||||
if in_code_block:
|
||||
if current_paragraph:
|
||||
current_paragraph.add_run('\n' + line)
|
||||
else:
|
||||
current_paragraph = self.doc.add_paragraph(line, style='Code_Custom')
|
||||
else:
|
||||
if current_paragraph is None or not current_paragraph.text:
|
||||
current_paragraph = self.doc.add_paragraph(line, style='Normal_Custom')
|
||||
else:
|
||||
current_paragraph.add_run('\n' + line)
|
||||
|
||||
# 处理空行,创建新段落
|
||||
elif not in_code_block:
|
||||
current_paragraph = None
|
||||
|
||||
return self.doc
|
||||
|
||||
278
crazy_functions/paper_fns/github_search.py
Normal file
278
crazy_functions/paper_fns/github_search.py
Normal file
@@ -0,0 +1,278 @@
|
||||
from typing import List, Dict, Tuple
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
|
||||
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
|
||||
from crazy_functions.paper_fns.auto_git.query_analyzer import QueryAnalyzer, SearchCriteria
|
||||
from crazy_functions.paper_fns.auto_git.handlers.repo_handler import RepositoryHandler
|
||||
from crazy_functions.paper_fns.auto_git.handlers.code_handler import CodeSearchHandler
|
||||
from crazy_functions.paper_fns.auto_git.handlers.user_handler import UserSearchHandler
|
||||
from crazy_functions.paper_fns.auto_git.handlers.topic_handler import TopicHandler
|
||||
from crazy_functions.paper_fns.auto_git.sources.github_source import GitHubSource
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
import re
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
# 导入格式化器
|
||||
from crazy_functions.paper_fns.file2file_doc import (
|
||||
TxtFormatter,
|
||||
MarkdownFormatter,
|
||||
HtmlFormatter,
|
||||
WordFormatter
|
||||
)
|
||||
from crazy_functions.paper_fns.file2file_doc.word2pdf import WordToPdfConverter
|
||||
|
||||
@CatchException
|
||||
def GitHub项目智能检索(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""GitHub项目智能检索主函数"""
|
||||
|
||||
# 初始化GitHub API调用源
|
||||
github_source = GitHubSource(api_key=plugin_kwargs.get("github_api_key"))
|
||||
|
||||
# 初始化处理器
|
||||
handlers = {
|
||||
"repo": RepositoryHandler(github_source, llm_kwargs),
|
||||
"code": CodeSearchHandler(github_source, llm_kwargs),
|
||||
"user": UserSearchHandler(github_source, llm_kwargs),
|
||||
"topic": TopicHandler(github_source, llm_kwargs),
|
||||
}
|
||||
|
||||
# 分析查询意图
|
||||
chatbot.append(["分析查询意图", "正在分析您的查询需求..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
query_analyzer = QueryAnalyzer()
|
||||
search_criteria = yield from query_analyzer.analyze_query(
|
||||
txt, chatbot, llm_kwargs
|
||||
)
|
||||
|
||||
# 根据查询类型选择处理器
|
||||
handler = handlers.get(search_criteria.query_type)
|
||||
if not handler:
|
||||
handler = handlers["repo"] # 默认使用仓库处理器
|
||||
|
||||
# 处理查询
|
||||
chatbot.append(["开始搜索", f"使用{handler.__class__.__name__}处理您的请求,正在搜索GitHub..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
final_prompt = asyncio.run(handler.handle(
|
||||
criteria=search_criteria,
|
||||
chatbot=chatbot,
|
||||
history=history,
|
||||
system_prompt=system_prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
plugin_kwargs=plugin_kwargs
|
||||
))
|
||||
|
||||
if final_prompt:
|
||||
# 检查是否是道歉提示
|
||||
if "很抱歉,我们未能找到" in final_prompt:
|
||||
chatbot.append([txt, final_prompt])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 在 final_prompt 末尾添加用户原始查询要求
|
||||
final_prompt += f"""
|
||||
|
||||
原始用户查询: "{txt}"
|
||||
|
||||
重要提示:
|
||||
- 你的回答必须直接满足用户的原始查询要求
|
||||
- 在遵循之前指南的同时,优先回答用户明确提出的问题
|
||||
- 确保回答格式和内容与用户期望一致
|
||||
- 对于GitHub仓库需要提供链接地址, 回复中请采用以下格式的HTML链接:
|
||||
* 对于GitHub仓库: <a href='Github_URL' target='_blank'>仓库名</a>
|
||||
- 不要生成参考列表,引用信息将另行处理
|
||||
"""
|
||||
|
||||
# 使用最终的prompt生成回答
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=final_prompt,
|
||||
inputs_show_user=txt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history=[],
|
||||
sys_prompt=f"你是一个熟悉GitHub生态系统的专业助手,能帮助用户找到合适的项目、代码和开发者。除非用户指定,否则请使用中文回复。"
|
||||
)
|
||||
|
||||
# 1. 获取项目列表
|
||||
repos_list = handler.ranked_repos # 直接使用原始仓库数据
|
||||
|
||||
# 在新的对话中添加格式化的仓库参考列表
|
||||
if repos_list:
|
||||
references = ""
|
||||
for idx, repo in enumerate(repos_list, 1):
|
||||
# 构建仓库引用
|
||||
stars_str = f"⭐ {repo.get('stargazers_count', 'N/A')}" if repo.get('stargazers_count') else ""
|
||||
forks_str = f"🍴 {repo.get('forks_count', 'N/A')}" if repo.get('forks_count') else ""
|
||||
stats = f"{stars_str} {forks_str}".strip()
|
||||
stats = f" ({stats})" if stats else ""
|
||||
|
||||
language = f" [{repo.get('language', '')}]" if repo.get('language') else ""
|
||||
|
||||
reference = f"[{idx}] **{repo.get('name', '')}**{language}{stats} \n"
|
||||
reference += f"👤 {repo.get('owner', {}).get('login', 'N/A') if repo.get('owner') is not None else 'N/A'} | "
|
||||
reference += f"📅 {repo.get('updated_at', 'N/A')[:10]} | "
|
||||
reference += f"<a href='{repo.get('html_url', '')}' target='_blank'>GitHub</a> \n"
|
||||
|
||||
if repo.get('description'):
|
||||
reference += f"{repo.get('description')} \n"
|
||||
reference += " \n"
|
||||
|
||||
references += reference
|
||||
|
||||
# 添加新的对话显示参考仓库
|
||||
chatbot.append(["推荐项目如下:", references])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 2. 保存结果到文件
|
||||
# 创建保存目录
|
||||
save_dir = get_log_folder(get_user(chatbot), plugin_name='github_search')
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
# 生成文件名
|
||||
def get_safe_filename(txt, max_length=10):
|
||||
# 获取文本前max_length个字符作为文件名
|
||||
filename = txt[:max_length].strip()
|
||||
# 移除不安全的文件名字符
|
||||
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
|
||||
# 如果文件名为空,使用时间戳
|
||||
if not filename:
|
||||
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return filename
|
||||
|
||||
base_filename = get_safe_filename(txt)
|
||||
|
||||
# 准备保存的内容 - 优化文档结构
|
||||
md_content = f"# GitHub搜索结果: {txt}\n\n"
|
||||
md_content += f"搜索时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
|
||||
# 添加模型回复
|
||||
md_content += "## 搜索分析与总结\n\n"
|
||||
md_content += response + "\n\n"
|
||||
|
||||
# 添加所有搜索到的仓库详细信息
|
||||
md_content += "## 推荐项目详情\n\n"
|
||||
|
||||
if not repos_list:
|
||||
md_content += "未找到匹配的项目\n\n"
|
||||
else:
|
||||
md_content += f"共找到 {len(repos_list)} 个相关项目\n\n"
|
||||
|
||||
# 添加项目简表
|
||||
md_content += "### 项目一览表\n\n"
|
||||
md_content += "| 序号 | 项目名称 | 作者 | 语言 | 星标数 | 更新时间 |\n"
|
||||
md_content += "| ---- | -------- | ---- | ---- | ------ | -------- |\n"
|
||||
|
||||
for idx, repo in enumerate(repos_list, 1):
|
||||
md_content += f"| {idx} | [{repo.get('name', '')}]({repo.get('html_url', '')}) | {repo.get('owner', {}).get('login', 'N/A') if repo.get('owner') is not None else 'N/A'} | {repo.get('language', 'N/A')} | {repo.get('stargazers_count', 'N/A')} | {repo.get('updated_at', 'N/A')[:10]} |\n"
|
||||
|
||||
md_content += "\n"
|
||||
|
||||
# 添加详细项目信息
|
||||
md_content += "### 项目详细信息\n\n"
|
||||
for idx, repo in enumerate(repos_list, 1):
|
||||
md_content += f"#### {idx}. {repo.get('name', '')}\n\n"
|
||||
md_content += f"- **仓库**: [{repo.get('full_name', '')}]({repo.get('html_url', '')})\n"
|
||||
md_content += f"- **作者**: [{repo.get('owner', {}).get('login', '') if repo.get('owner') is not None else 'N/A'}]({repo.get('owner', {}).get('html_url', '') if repo.get('owner') is not None else '#'})\n"
|
||||
md_content += f"- **描述**: {repo.get('description', 'N/A')}\n"
|
||||
md_content += f"- **语言**: {repo.get('language', 'N/A')}\n"
|
||||
md_content += f"- **星标**: {repo.get('stargazers_count', 'N/A')}\n"
|
||||
md_content += f"- **Fork数**: {repo.get('forks_count', 'N/A')}\n"
|
||||
md_content += f"- **最近更新**: {repo.get('updated_at', 'N/A')[:10]}\n"
|
||||
md_content += f"- **创建时间**: {repo.get('created_at', 'N/A')[:10]}\n"
|
||||
md_content += f"- **开源许可**: {repo.get('license', {}).get('name', 'N/A') if repo.get('license') is not None else 'N/A'}\n"
|
||||
if repo.get('topics'):
|
||||
md_content += f"- **主题标签**: {', '.join(repo.get('topics', []))}\n"
|
||||
if repo.get('homepage'):
|
||||
md_content += f"- **项目主页**: [{repo.get('homepage')}]({repo.get('homepage')})\n"
|
||||
md_content += "\n"
|
||||
|
||||
# 添加查询信息和元数据
|
||||
md_content += "## 查询元数据\n\n"
|
||||
md_content += f"- **原始查询**: {txt}\n"
|
||||
md_content += f"- **查询类型**: {search_criteria.query_type}\n"
|
||||
md_content += f"- **关键词**: {', '.join(search_criteria.keywords) if hasattr(search_criteria, 'keywords') and search_criteria.keywords else 'N/A'}\n"
|
||||
md_content += f"- **搜索日期**: {datetime.now().strftime('%Y-%m-%d')}\n\n"
|
||||
|
||||
# 保存为多种格式
|
||||
saved_files = []
|
||||
failed_files = []
|
||||
|
||||
# 1. 保存为TXT
|
||||
try:
|
||||
txt_formatter = TxtFormatter()
|
||||
txt_content = txt_formatter.create_document(md_content)
|
||||
txt_file = os.path.join(save_dir, f"github_results_{base_filename}.txt")
|
||||
with open(txt_file, 'w', encoding='utf-8') as f:
|
||||
f.write(txt_content)
|
||||
promote_file_to_downloadzone(txt_file, chatbot=chatbot)
|
||||
saved_files.append("TXT")
|
||||
except Exception as e:
|
||||
failed_files.append(f"TXT (错误: {str(e)})")
|
||||
|
||||
# 2. 保存为Markdown
|
||||
try:
|
||||
md_formatter = MarkdownFormatter()
|
||||
formatted_md_content = md_formatter.create_document(md_content, "GitHub项目搜索")
|
||||
md_file = os.path.join(save_dir, f"github_results_{base_filename}.md")
|
||||
with open(md_file, 'w', encoding='utf-8') as f:
|
||||
f.write(formatted_md_content)
|
||||
promote_file_to_downloadzone(md_file, chatbot=chatbot)
|
||||
saved_files.append("Markdown")
|
||||
except Exception as e:
|
||||
failed_files.append(f"Markdown (错误: {str(e)})")
|
||||
|
||||
# 3. 保存为HTML
|
||||
try:
|
||||
html_formatter = HtmlFormatter(processing_type="GitHub项目搜索")
|
||||
html_content = html_formatter.create_document(md_content)
|
||||
html_file = os.path.join(save_dir, f"github_results_{base_filename}.html")
|
||||
with open(html_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
promote_file_to_downloadzone(html_file, chatbot=chatbot)
|
||||
saved_files.append("HTML")
|
||||
except Exception as e:
|
||||
failed_files.append(f"HTML (错误: {str(e)})")
|
||||
|
||||
# 4. 保存为Word
|
||||
word_file = None
|
||||
try:
|
||||
word_formatter = WordFormatter()
|
||||
doc = word_formatter.create_document(md_content, "GitHub项目搜索")
|
||||
word_file = os.path.join(save_dir, f"github_results_{base_filename}.docx")
|
||||
doc.save(word_file)
|
||||
promote_file_to_downloadzone(word_file, chatbot=chatbot)
|
||||
saved_files.append("Word")
|
||||
except Exception as e:
|
||||
failed_files.append(f"Word (错误: {str(e)})")
|
||||
word_file = None
|
||||
|
||||
# 5. 保存为PDF (仅当Word保存成功时)
|
||||
if word_file and os.path.exists(word_file):
|
||||
try:
|
||||
pdf_file = WordToPdfConverter.convert_to_pdf(word_file)
|
||||
promote_file_to_downloadzone(pdf_file, chatbot=chatbot)
|
||||
saved_files.append("PDF")
|
||||
except Exception as e:
|
||||
failed_files.append(f"PDF (错误: {str(e)})")
|
||||
|
||||
# 报告保存结果
|
||||
if saved_files:
|
||||
success_message = f"成功保存以下格式: {', '.join(saved_files)}"
|
||||
if failed_files:
|
||||
failure_message = f"以下格式保存失败: {', '.join(failed_files)}"
|
||||
chatbot.append(["部分格式保存成功", f"{success_message}。{failure_message}"])
|
||||
else:
|
||||
chatbot.append(["所有格式保存成功", success_message])
|
||||
else:
|
||||
chatbot.append(["保存失败", f"所有格式均保存失败: {', '.join(failed_files)}"])
|
||||
else:
|
||||
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
635
crazy_functions/paper_fns/journal_paper_recom.py
Normal file
635
crazy_functions/paper_fns/journal_paper_recom.py
Normal file
@@ -0,0 +1,635 @@
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
from typing import Dict, List, Generator, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from crazy_functions.pdf_fns.text_content_loader import TextContentLoader
|
||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||
from toolbox import update_ui, promote_file_to_downloadzone, write_history_to_file, CatchException, report_exception
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
# 导入论文下载相关函数
|
||||
from crazy_functions.论文下载 import extract_paper_id, extract_paper_ids, get_arxiv_paper, format_arxiv_id, SciHub
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
import calendar
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecommendationQuestion:
|
||||
"""期刊会议推荐分析问题类"""
|
||||
id: str # 问题ID
|
||||
question: str # 问题内容
|
||||
importance: int # 重要性 (1-5,5最高)
|
||||
description: str # 问题描述
|
||||
|
||||
|
||||
class JournalConferenceRecommender:
|
||||
"""论文期刊会议推荐器"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化推荐器"""
|
||||
self.llm_kwargs = llm_kwargs
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.paper_content = ""
|
||||
self.analysis_results = {}
|
||||
|
||||
# 定义论文分析问题库(针对期刊会议推荐)
|
||||
self.questions = [
|
||||
RecommendationQuestion(
|
||||
id="research_field_and_topic",
|
||||
question="请分析这篇论文的研究领域、主题和关键词。具体包括:1)论文属于哪个主要学科领域(如自然科学、工程技术、医学、社会科学、人文学科等);2)具体的研究子领域或方向;3)论文的核心主题和关键概念;4)重要的学术关键词和专业术语;5)研究的跨学科特征(如果有);6)研究的地域性特征(国际性研究还是特定地区研究)。",
|
||||
importance=5,
|
||||
description="研究领域与主题分析"
|
||||
),
|
||||
RecommendationQuestion(
|
||||
id="methodology_and_approach",
|
||||
question="请分析论文的研究方法和技术路线。包括:1)采用的主要研究方法(定量研究、定性研究、理论分析、实验研究、田野调查、文献综述、案例研究等);2)使用的技术手段、工具或分析方法;3)研究设计的严谨性和创新性;4)数据收集和分析方法的适当性;5)研究方法在该学科中的先进性或传统性;6)方法学上的贡献或局限性。",
|
||||
importance=4,
|
||||
description="研究方法与技术路线"
|
||||
),
|
||||
RecommendationQuestion(
|
||||
id="novelty_and_contribution",
|
||||
question="请评估论文的创新性和学术贡献。包括:1)研究的新颖性程度(理论创新、方法创新、应用创新等);2)对现有知识体系的贡献或突破;3)解决问题的重要性和学术价值;4)研究成果的理论意义和实践价值;5)在该学科领域的地位和影响潜力;6)与国际前沿研究的关系;7)对后续研究的启发意义。",
|
||||
importance=4,
|
||||
description="创新性与学术贡献"
|
||||
),
|
||||
RecommendationQuestion(
|
||||
id="target_audience_and_scope",
|
||||
question="请分析论文的目标受众和应用范围。包括:1)主要面向的学术群体(研究者、从业者、政策制定者等);2)研究成果的潜在应用领域和受益群体;3)对学术界和实践界的价值;4)研究的国际化程度和跨文化适用性;5)是否适合国际期刊还是区域性期刊;6)语言发表偏好(英文、中文或其他语言);7)开放获取的必要性和可行性。",
|
||||
importance=3,
|
||||
description="目标受众与应用范围"
|
||||
),
|
||||
]
|
||||
|
||||
# 按重要性排序
|
||||
self.questions.sort(key=lambda q: q.importance, reverse=True)
|
||||
|
||||
def _load_paper(self, paper_path: str) -> Generator:
|
||||
"""加载论文内容"""
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用TextContentLoader读取文件
|
||||
loader = TextContentLoader(self.chatbot, self.history)
|
||||
|
||||
yield from loader.execute_single_file(paper_path)
|
||||
|
||||
# 获取加载的内容
|
||||
if len(self.history) >= 2 and self.history[-2]:
|
||||
self.paper_content = self.history[-2]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return True
|
||||
else:
|
||||
self.chatbot.append(["错误", "无法读取论文内容,请检查文件是否有效"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return False
|
||||
|
||||
def _analyze_question(self, question: RecommendationQuestion) -> Generator:
|
||||
"""分析单个问题"""
|
||||
try:
|
||||
# 创建分析提示
|
||||
prompt = f"请基于以下论文内容回答问题:\n\n{self.paper_content}\n\n问题:{question.question}"
|
||||
|
||||
# 使用单线程版本的请求函数
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=prompt,
|
||||
inputs_show_user=question.question, # 显示问题本身
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[], # 空历史,确保每个问题独立分析
|
||||
sys_prompt="你是一个专业的学术期刊会议推荐专家,需要仔细分析论文内容并提供准确的分析。请保持客观、专业,并基于论文内容提供深入分析。"
|
||||
)
|
||||
|
||||
if response:
|
||||
self.analysis_results[question.id] = response
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"分析问题时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return False
|
||||
|
||||
def _generate_journal_recommendations(self) -> Generator:
|
||||
"""生成期刊推荐"""
|
||||
self.chatbot.append(["生成期刊推荐", "正在基于论文分析结果生成期刊推荐..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 构建期刊推荐提示
|
||||
journal_prompt = """请基于以下论文分析结果,为这篇论文推荐合适的学术期刊。
|
||||
|
||||
推荐要求:
|
||||
1. 根据论文的创新性和工作质量,分别推荐不同级别的期刊:
|
||||
- 顶级期刊(影响因子>8或该领域顶级期刊):2-3个
|
||||
- 高质量期刊(影响因子4-8或该领域知名期刊):3-4个
|
||||
- 中等期刊(影响因子1.5-4或该领域认可期刊):3-4个
|
||||
- 入门期刊(影响因子<1.5但声誉良好的期刊):2-3个
|
||||
|
||||
注意:不同学科的影响因子标准差异很大,请根据论文所属学科的实际情况调整标准。
|
||||
特别是医学领域,需要考虑:
|
||||
- 临床医学期刊通常影响因子较高(顶级期刊IF>20,高质量期刊IF>10)
|
||||
- 基础医学期刊影响因子相对较低但学术价值很高
|
||||
- 专科医学期刊在各自领域内具有权威性
|
||||
- 医学期刊的临床实用性和循证医学价值
|
||||
|
||||
2. 对每个期刊提供详细信息:
|
||||
- 期刊全名和缩写
|
||||
- 最新影响因子(如果知道)
|
||||
- 期刊级别分类(Q1/Q2/Q3/Q4或该学科的分类标准)
|
||||
- 主要研究领域和范围
|
||||
- 与论文内容的匹配度评分(1-10分)
|
||||
- 发表难度评估(容易/中等/困难/极难)
|
||||
- 平均审稿周期
|
||||
- 开放获取政策
|
||||
- 期刊的学科分类(如SCI、SSCI、A&HCI等)
|
||||
- 医学期刊特殊信息(如适用):
|
||||
* PubMed收录情况
|
||||
* 是否为核心临床期刊
|
||||
* 专科领域权威性
|
||||
* 循证医学等级要求
|
||||
* 临床试验注册要求
|
||||
* 伦理委员会批准要求
|
||||
|
||||
3. 按推荐优先级排序,并说明推荐理由
|
||||
4. 提供针对性的投稿建议,考虑该学科的特点
|
||||
|
||||
论文分析结果:"""
|
||||
|
||||
for q in self.questions:
|
||||
if q.id in self.analysis_results:
|
||||
journal_prompt += f"\n\n{q.description}:\n{self.analysis_results[q.id]}"
|
||||
|
||||
journal_prompt += "\n\n请提供详细的期刊推荐报告,重点关注期刊的层次性和适配性。请根据论文的具体学科领域,采用该领域通用的期刊评价标准和分类体系。"
|
||||
|
||||
try:
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=journal_prompt,
|
||||
inputs_show_user="生成期刊推荐报告",
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[],
|
||||
sys_prompt="你是一个资深的跨学科学术期刊推荐专家,熟悉各个学科领域不同层次的期刊。请根据论文的具体学科和创新性,推荐从顶级到入门级的各层次期刊。不同学科有不同的期刊评价标准:理工科重视影响因子和SCI收录,社会科学重视SSCI和学科声誉,人文学科重视A&HCI和同行评议,医学领域重视PubMed收录、临床实用性、循证医学价值和伦理规范。请根据论文所属学科采用相应的评价标准。"
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
return "期刊推荐生成失败"
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"生成期刊推荐时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return "期刊推荐生成失败: " + str(e)
|
||||
|
||||
def _generate_conference_recommendations(self) -> Generator:
|
||||
"""生成会议推荐"""
|
||||
self.chatbot.append(["生成会议推荐", "正在基于论文分析结果生成会议推荐..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 获取当前时间信息
|
||||
current_time = datetime.now()
|
||||
current_date_str = current_time.strftime("%Y年%m月%d日")
|
||||
current_year = current_time.year
|
||||
current_month = current_time.month
|
||||
|
||||
# 构建会议推荐提示
|
||||
conference_prompt = f"""请基于以下论文分析结果,为这篇论文推荐合适的学术会议。
|
||||
|
||||
**重要提示:当前时间是{current_date_str}({current_year}年{current_month}月),请基于这个时间点推断会议的举办时间和投稿截止时间。**
|
||||
|
||||
推荐要求:
|
||||
1. 根据论文的创新性和工作质量,分别推荐不同级别的会议:
|
||||
- 顶级会议(该领域最权威的国际会议):2-3个
|
||||
- 高质量会议(该领域知名的国际或区域会议):3-4个
|
||||
- 中等会议(该领域认可的专业会议):3-4个
|
||||
- 专业会议(该领域细分方向的专门会议):2-3个
|
||||
|
||||
注意:不同学科的会议评价标准不同:
|
||||
- 计算机科学:可参考CCF分类(A/B/C类)
|
||||
- 工程学:可参考EI收录和影响力
|
||||
- 医学:可参考会议的临床影响和同行认可度
|
||||
- 社会科学:可参考会议的学术声誉和参与度
|
||||
- 人文学科:可参考会议的历史和学术传统
|
||||
- 自然科学:可参考会议的国际影响力和发表质量
|
||||
|
||||
特别是医学会议,需要考虑:
|
||||
- 临床医学会议重视实用性和临床指导价值
|
||||
- 基础医学会议重视科学创新和机制研究
|
||||
- 专科医学会议在各自领域内具有权威性
|
||||
- 国际医学会议的CME学分认证情况
|
||||
|
||||
2. 对每个会议提供详细信息:
|
||||
- 会议全名和缩写
|
||||
- 会议级别分类(根据该学科的评价标准)
|
||||
- 主要研究领域和主题
|
||||
- 与论文内容的匹配度评分(1-10分)
|
||||
- 录用难度评估(容易/中等/困难/极难)
|
||||
- 会议举办周期(年会/双年会/不定期等)
|
||||
- **基于当前时间{current_date_str},推断{current_year}年和{current_year+1}年的举办时间和地点**(请根据往年的举办时间规律进行推断)
|
||||
- **基于推断的会议时间,估算论文提交截止时间**(通常在会议前3-6个月)
|
||||
- 会议的国际化程度和影响范围
|
||||
- 医学会议特殊信息(如适用):
|
||||
* 是否提供CME学分
|
||||
* 临床实践指导价值
|
||||
* 专科认证机构认可情况
|
||||
* 会议论文集的PubMed收录情况
|
||||
* 伦理和临床试验相关要求
|
||||
|
||||
3. 按推荐优先级排序,并说明推荐理由
|
||||
4. **基于当前时间{current_date_str},提供会议投稿的时间规划建议**
|
||||
- 哪些会议可以赶上{current_year}年的投稿截止时间
|
||||
- 哪些会议需要准备{current_year+1}年的投稿
|
||||
- 具体的时间安排建议
|
||||
|
||||
论文分析结果:"""
|
||||
|
||||
for q in self.questions:
|
||||
if q.id in self.analysis_results:
|
||||
conference_prompt += f"\n\n{q.description}:\n{self.analysis_results[q.id]}"
|
||||
|
||||
conference_prompt += f"\n\n请提供详细的会议推荐报告,重点关注会议的层次性和时效性。请根据论文的具体学科领域,采用该领域通用的会议评价标准。\n\n**特别注意:请根据当前时间{current_date_str}和各会议的历史举办时间规律,准确推断{current_year}年和{current_year+1}年的会议时间安排,不要使用虚构的时间。**"
|
||||
|
||||
try:
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=conference_prompt,
|
||||
inputs_show_user="生成会议推荐报告",
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[],
|
||||
sys_prompt="你是一个资深的跨学科学术会议推荐专家,熟悉各个学科领域不同层次的学术会议。请根据论文的具体学科和创新性,推荐从顶级到专业级的各层次会议。不同学科有不同的会议评价标准和文化:理工科重视技术创新和国际影响力,社会科学重视理论贡献和社会意义,人文学科重视学术深度和文化价值,医学领域重视临床实用性、CME学分认证、专科权威性和伦理规范。请根据论文所属学科采用相应的评价标准和推荐策略。"
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
return "会议推荐生成失败"
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"生成会议推荐时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return "会议推荐生成失败: " + str(e)
|
||||
|
||||
def _generate_priority_summary(self, journal_recommendations: str, conference_recommendations: str) -> Generator:
|
||||
"""生成优先级总结"""
|
||||
self.chatbot.append(["生成优先级总结", "正在生成投稿优先级总结..."])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 获取当前时间信息
|
||||
current_time = datetime.now()
|
||||
current_date_str = current_time.strftime("%Y年%m月%d日")
|
||||
current_month = current_time.strftime("%Y年%m月")
|
||||
|
||||
# 计算未来时间点
|
||||
def add_months(date, months):
|
||||
"""安全地添加月份"""
|
||||
month = date.month - 1 + months
|
||||
year = date.year + month // 12
|
||||
month = month % 12 + 1
|
||||
day = min(date.day, calendar.monthrange(year, month)[1])
|
||||
return date.replace(year=year, month=month, day=day)
|
||||
|
||||
future_6_months = add_months(current_time, 6).strftime('%Y年%m月')
|
||||
future_12_months = add_months(current_time, 12).strftime('%Y年%m月')
|
||||
future_year = (current_time.year + 1)
|
||||
|
||||
priority_prompt = f"""请基于以下期刊和会议推荐结果,生成一个综合的投稿优先级总结。
|
||||
|
||||
**重要提示:当前时间是{current_date_str}({current_month}),请基于这个时间点制定投稿计划。**
|
||||
|
||||
期刊推荐结果:
|
||||
{journal_recommendations}
|
||||
|
||||
会议推荐结果:
|
||||
{conference_recommendations}
|
||||
|
||||
请提供:
|
||||
1. 综合投稿策略建议(考虑该学科的发表文化和惯例)
|
||||
- 期刊优先还是会议优先(不同学科有不同偏好)
|
||||
- 国际期刊/会议 vs 国内期刊/会议的选择策略
|
||||
- 英文发表 vs 中文发表的考虑
|
||||
|
||||
2. 按时间线排列的投稿计划(**基于当前时间{current_date_str},考虑截止时间和审稿周期**)
|
||||
- 短期目标({current_month}起3-6个月内,即到{future_6_months})
|
||||
- 中期目标(6-12个月内,即到{future_12_months})
|
||||
- 长期目标(1年以上,即{future_year}年以后)
|
||||
|
||||
3. 风险分散策略
|
||||
- 同时投稿多个不同级别的目标
|
||||
- 考虑该学科的一稿多投政策
|
||||
- 备选方案和应急策略
|
||||
|
||||
4. 针对论文可能需要的改进建议
|
||||
- 根据目标期刊/会议的要求调整内容
|
||||
- 语言和格式的优化建议
|
||||
- 补充实验或分析的建议
|
||||
|
||||
5. 预期的发表时间线和成功概率评估(基于当前时间{current_date_str})
|
||||
|
||||
6. 该学科特有的发表注意事项
|
||||
- 伦理审查要求(如医学、心理学等)
|
||||
- 数据开放要求(如某些自然科学领域)
|
||||
- 利益冲突声明(如医学、工程等)
|
||||
- 医学领域特殊要求:
|
||||
* 临床试验注册要求(ClinicalTrials.gov、中国临床试验注册中心等)
|
||||
* 患者知情同意和隐私保护
|
||||
* 医学伦理委员会批准证明
|
||||
* CONSORT、STROBE、PRISMA等报告规范遵循
|
||||
* 药物/器械安全性数据要求
|
||||
* CME学分认证相关要求
|
||||
* 临床指南和循证医学等级要求
|
||||
- 其他学科特殊要求
|
||||
|
||||
请以表格形式总结前10个最推荐的投稿目标(期刊+会议),包括优先级排序、预期时间线和成功概率。
|
||||
|
||||
**注意:所有时间规划都应基于当前时间{current_date_str}进行计算,不要使用虚构的时间。**"""
|
||||
|
||||
try:
|
||||
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||
inputs=priority_prompt,
|
||||
inputs_show_user="生成投稿优先级总结",
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history=[],
|
||||
sys_prompt="你是一个资深的跨学科学术发表策略专家,熟悉各个学科的发表文化、惯例和要求。请综合考虑不同学科的特点:理工科通常重视期刊发表和影响因子,社会科学平衡期刊和专著,人文学科重视同行评议和学术声誉,医学重视临床意义和伦理规范。请为作者制定最适合其学科背景的投稿策略和时间规划。"
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
return "优先级总结生成失败"
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["错误", f"生成优先级总结时出错: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return "优先级总结生成失败: " + str(e)
|
||||
|
||||
def save_recommendations(self, journal_recommendations: str, conference_recommendations: str, priority_summary: str) -> Generator:
|
||||
"""保存推荐报告"""
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 保存为Markdown文件
|
||||
try:
|
||||
md_content = f"""# 论文期刊会议推荐报告
|
||||
|
||||
## 投稿优先级总结
|
||||
|
||||
{priority_summary}
|
||||
|
||||
## 期刊推荐
|
||||
|
||||
{journal_recommendations}
|
||||
|
||||
## 会议推荐
|
||||
|
||||
{conference_recommendations}
|
||||
|
||||
---
|
||||
|
||||
# 详细分析结果
|
||||
"""
|
||||
|
||||
# 添加详细分析结果
|
||||
for q in self.questions:
|
||||
if q.id in self.analysis_results:
|
||||
md_content += f"\n\n## {q.description}\n\n{self.analysis_results[q.id]}"
|
||||
|
||||
result_file = write_history_to_file(
|
||||
history=[md_content],
|
||||
file_basename=f"期刊会议推荐_{timestamp}.md"
|
||||
)
|
||||
|
||||
if result_file and os.path.exists(result_file):
|
||||
promote_file_to_downloadzone(result_file, chatbot=self.chatbot)
|
||||
self.chatbot.append(["保存成功", f"推荐报告已保存至: {os.path.basename(result_file)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
else:
|
||||
self.chatbot.append(["警告", "保存报告成功但找不到文件"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"保存报告失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
def recommend_venues(self, paper_path: str) -> Generator:
|
||||
"""推荐期刊会议主流程"""
|
||||
# 加载论文
|
||||
success = yield from self._load_paper(paper_path)
|
||||
if not success:
|
||||
return
|
||||
|
||||
# 分析关键问题
|
||||
for question in self.questions:
|
||||
yield from self._analyze_question(question)
|
||||
|
||||
# 分别生成期刊和会议推荐
|
||||
journal_recommendations = yield from self._generate_journal_recommendations()
|
||||
conference_recommendations = yield from self._generate_conference_recommendations()
|
||||
|
||||
# 生成优先级总结
|
||||
priority_summary = yield from self._generate_priority_summary(journal_recommendations, conference_recommendations)
|
||||
|
||||
# 显示结果
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 保存报告
|
||||
yield from self.save_recommendations(journal_recommendations, conference_recommendations, priority_summary)
|
||||
|
||||
# 将完整的分析结果和推荐内容添加到历史记录中,方便用户继续提问
|
||||
self._add_to_history(journal_recommendations, conference_recommendations, priority_summary)
|
||||
|
||||
def _add_to_history(self, journal_recommendations: str, conference_recommendations: str, priority_summary: str):
|
||||
"""将分析结果和推荐内容添加到历史记录中"""
|
||||
try:
|
||||
# 构建完整的内容摘要
|
||||
history_content = f"""# 论文期刊会议推荐分析完成
|
||||
|
||||
## 📊 投稿优先级总结
|
||||
{priority_summary}
|
||||
|
||||
## 📚 期刊推荐
|
||||
{journal_recommendations}
|
||||
|
||||
## 🏛️ 会议推荐
|
||||
{conference_recommendations}
|
||||
|
||||
## 📋 详细分析结果
|
||||
"""
|
||||
|
||||
# 添加详细分析结果
|
||||
for q in self.questions:
|
||||
if q.id in self.analysis_results:
|
||||
history_content += f"\n### {q.description}\n{self.analysis_results[q.id]}\n"
|
||||
|
||||
history_content += "\n---\n💡 您现在可以基于以上分析结果继续提问,比如询问特定期刊的详细信息、投稿策略建议、或者对推荐结果的进一步解释。"
|
||||
|
||||
# 添加到历史记录中
|
||||
self.history.append("论文期刊会议推荐分析")
|
||||
self.history.append(history_content)
|
||||
|
||||
self.chatbot.append(["✅ 分析完成", "所有分析结果和推荐内容已添加到对话历史中,您可以继续基于这些内容提问。"])
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"添加到历史记录时出错: {str(e)},但推荐报告已正常生成"])
|
||||
# 即使添加历史失败,也不影响主要功能
|
||||
|
||||
|
||||
def _find_paper_file(path: str) -> str:
|
||||
"""查找路径中的论文文件(简化版)"""
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
|
||||
# 支持的文件扩展名(按优先级排序)
|
||||
extensions = ["pdf", "docx", "doc", "txt", "md", "tex"]
|
||||
|
||||
# 简单地遍历目录
|
||||
if os.path.isdir(path):
|
||||
try:
|
||||
for ext in extensions:
|
||||
# 手动检查每个可能的文件,而不使用glob
|
||||
potential_file = os.path.join(path, f"paper.{ext}")
|
||||
if os.path.exists(potential_file) and os.path.isfile(potential_file):
|
||||
return potential_file
|
||||
|
||||
# 如果没找到特定命名的文件,检查目录中的所有文件
|
||||
for file in os.listdir(path):
|
||||
file_path = os.path.join(path, file)
|
||||
if os.path.isfile(file_path):
|
||||
file_ext = file.split('.')[-1].lower() if '.' in file else ""
|
||||
if file_ext in extensions:
|
||||
return file_path
|
||||
except Exception:
|
||||
pass # 忽略任何错误
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def download_paper_by_id(paper_info, chatbot, history) -> str:
|
||||
"""下载论文并返回保存路径
|
||||
|
||||
Args:
|
||||
paper_info: 元组,包含论文ID类型(arxiv或doi)和ID值
|
||||
chatbot: 聊天机器人对象
|
||||
history: 历史记录
|
||||
|
||||
Returns:
|
||||
str: 下载的论文路径或None
|
||||
"""
|
||||
id_type, paper_id = paper_info
|
||||
|
||||
# 创建保存目录 - 使用时间戳创建唯一文件夹
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
user_name = chatbot.get_user() if hasattr(chatbot, 'get_user') else "default"
|
||||
from toolbox import get_log_folder, get_user
|
||||
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
|
||||
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_path = Path(save_dir)
|
||||
|
||||
chatbot.append([f"下载论文", f"正在下载{'arXiv' if id_type == 'arxiv' else 'DOI'} {paper_id} 的论文..."])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
pdf_path = None
|
||||
|
||||
try:
|
||||
if id_type == 'arxiv':
|
||||
# 使用改进的arxiv查询方法
|
||||
formatted_id = format_arxiv_id(paper_id)
|
||||
paper_result = get_arxiv_paper(formatted_id)
|
||||
|
||||
if not paper_result:
|
||||
chatbot.append([f"下载失败", f"未找到arXiv论文: {paper_id}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
# 下载PDF
|
||||
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
|
||||
pdf_path = str(save_path / filename)
|
||||
paper_result.download_pdf(filename=pdf_path)
|
||||
|
||||
else: # doi
|
||||
# 下载DOI
|
||||
sci_hub = SciHub(
|
||||
doi=paper_id,
|
||||
path=save_path
|
||||
)
|
||||
pdf_path = sci_hub.fetch()
|
||||
|
||||
# 检查下载结果
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
|
||||
chatbot.append([f"下载成功", f"已成功下载论文: {os.path.basename(pdf_path)}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return pdf_path
|
||||
else:
|
||||
chatbot.append([f"下载失败", f"论文下载失败: {paper_id}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
chatbot.append([f"下载错误", f"下载论文时出错: {str(e)}"])
|
||||
update_ui(chatbot=chatbot, history=history)
|
||||
return None
|
||||
|
||||
|
||||
@CatchException
|
||||
def 论文期刊会议推荐(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 论文期刊会议推荐"""
|
||||
# 初始化推荐器
|
||||
chatbot.append(["函数插件功能及使用方式", "论文期刊会议推荐:基于论文内容分析,为您推荐合适的学术期刊和会议投稿目标。适用于各个学科专业(自然科学、工程技术、医学、社会科学、人文学科等),根据不同学科的评价标准和发表文化,提供分层次的期刊会议推荐、影响因子分析、发表难度评估、投稿策略建议等。<br><br>📋 使用方式:<br>1、直接上传PDF文件<br>2、输入DOI号或arXiv ID<br>3、点击插件开始分析"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
paper_file = None
|
||||
|
||||
# 检查输入是否为论文ID(arxiv或DOI)
|
||||
paper_info = extract_paper_id(txt)
|
||||
|
||||
if paper_info:
|
||||
# 如果是论文ID,下载论文
|
||||
chatbot.append(["检测到论文ID", f"检测到{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'} ID: {paper_info[1]},准备下载论文..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 下载论文
|
||||
paper_file = download_paper_by_id(paper_info, chatbot, history)
|
||||
|
||||
if not paper_file:
|
||||
report_exception(chatbot, history, a=f"下载论文失败", b=f"无法下载{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'}论文: {paper_info[1]}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
else:
|
||||
# 检查输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析论文: {txt}", b=f"找不到文件或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 验证路径安全性
|
||||
user_name = chatbot.get_user()
|
||||
validate_path_safety(txt, user_name)
|
||||
|
||||
# 查找论文文件
|
||||
paper_file = _find_paper_file(txt)
|
||||
|
||||
if not paper_file:
|
||||
report_exception(chatbot, history, a=f"解析论文", b=f"在路径 {txt} 中未找到支持的论文文件")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 确保paper_file是字符串
|
||||
if paper_file is not None and not isinstance(paper_file, str):
|
||||
# 尝试转换为字符串
|
||||
try:
|
||||
paper_file = str(paper_file)
|
||||
except:
|
||||
report_exception(chatbot, history, a=f"类型错误", b=f"论文路径不是有效的字符串: {type(paper_file)}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 开始推荐
|
||||
chatbot.append(["开始分析", f"正在分析论文并生成期刊会议推荐: {os.path.basename(paper_file)}"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
recommender = JournalConferenceRecommender(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
yield from recommender.recommend_venues(paper_file)
|
||||
295
crazy_functions/paper_fns/paper_download.py
Normal file
295
crazy_functions/paper_fns/paper_download.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import re
|
||||
import os
|
||||
import zipfile
|
||||
from toolbox import CatchException, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
|
||||
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
def extract_paper_id(txt):
|
||||
"""从输入文本中提取论文ID"""
|
||||
# 尝试匹配DOI(将DOI匹配提前,因为其格式更加明确)
|
||||
doi_patterns = [
|
||||
r'doi.org/([\w\./-]+)', # doi.org/10.1234/xxx
|
||||
r'doi:\s*([\w\./-]+)', # doi: 10.1234/xxx
|
||||
r'(10\.\d{4,}/[\w\.-]+)', # 直接输入DOI: 10.1234/xxx
|
||||
]
|
||||
|
||||
for pattern in doi_patterns:
|
||||
match = re.search(pattern, txt, re.IGNORECASE)
|
||||
if match:
|
||||
return ('doi', match.group(1))
|
||||
|
||||
# 尝试匹配arXiv ID
|
||||
arxiv_patterns = [
|
||||
r'arxiv.org/abs/(\d+\.\d+)', # arxiv.org/abs/2103.14030
|
||||
r'arxiv.org/pdf/(\d+\.\d+)', # arxiv.org/pdf/2103.14030
|
||||
r'arxiv/(\d+\.\d+)', # arxiv/2103.14030
|
||||
r'^(\d{4}\.\d{4,5})$', # 直接输入ID: 2103.14030
|
||||
# 添加对早期arXiv ID的支持
|
||||
r'arxiv.org/abs/([\w-]+/\d{7})', # arxiv.org/abs/math/0211159
|
||||
r'arxiv.org/pdf/([\w-]+/\d{7})', # arxiv.org/pdf/hep-th/9901001
|
||||
r'^([\w-]+/\d{7})$', # 直接输入: math/0211159
|
||||
]
|
||||
|
||||
for pattern in arxiv_patterns:
|
||||
match = re.search(pattern, txt, re.IGNORECASE)
|
||||
if match:
|
||||
paper_id = match.group(1)
|
||||
# 如果是新格式(YYMM.NNNNN)或旧格式(category/NNNNNNN),都直接返回
|
||||
if re.match(r'^\d{4}\.\d{4,5}$', paper_id) or re.match(r'^[\w-]+/\d{7}$', paper_id):
|
||||
return ('arxiv', paper_id)
|
||||
|
||||
return None
|
||||
|
||||
def extract_paper_ids(txt):
|
||||
"""从输入文本中提取多个论文ID"""
|
||||
paper_ids = []
|
||||
|
||||
# 首先按换行符分割
|
||||
for line in txt.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line: # 跳过空行
|
||||
continue
|
||||
|
||||
# 对每一行再按空格分割
|
||||
for item in line.split():
|
||||
item = item.strip()
|
||||
if not item: # 跳过空项
|
||||
continue
|
||||
paper_info = extract_paper_id(item)
|
||||
if paper_info:
|
||||
paper_ids.append(paper_info)
|
||||
|
||||
# 去除重复项,保持顺序
|
||||
unique_paper_ids = []
|
||||
seen = set()
|
||||
for paper_info in paper_ids:
|
||||
if paper_info not in seen:
|
||||
seen.add(paper_info)
|
||||
unique_paper_ids.append(paper_info)
|
||||
|
||||
return unique_paper_ids
|
||||
|
||||
def format_arxiv_id(paper_id):
|
||||
"""格式化arXiv ID,处理新旧两种格式"""
|
||||
# 如果是旧格式 (e.g. astro-ph/0404140),需要去掉arxiv:前缀
|
||||
if '/' in paper_id:
|
||||
return paper_id.replace('arxiv:', '') # 确保移除可能存在的arxiv:前缀
|
||||
return paper_id
|
||||
|
||||
def get_arxiv_paper(paper_id):
|
||||
"""获取arXiv论文,处理新旧两种格式"""
|
||||
import arxiv
|
||||
|
||||
# 尝试不同的查询方式
|
||||
query_formats = [
|
||||
paper_id, # 原始ID
|
||||
paper_id.replace('/', ''), # 移除斜杠
|
||||
f"id:{paper_id}", # 添加id:前缀
|
||||
]
|
||||
|
||||
for query in query_formats:
|
||||
try:
|
||||
# 使用Search查询
|
||||
search = arxiv.Search(
|
||||
query=query,
|
||||
max_results=1
|
||||
)
|
||||
result = next(arxiv.Client().results(search))
|
||||
if result:
|
||||
return result
|
||||
except:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用id_list查询
|
||||
search = arxiv.Search(id_list=[query])
|
||||
result = next(arxiv.Client().results(search))
|
||||
if result:
|
||||
return result
|
||||
except:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def create_zip_archive(files, save_path):
|
||||
"""将多个PDF文件打包成zip"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"papers_{timestamp}.zip"
|
||||
zip_path = str(save_path / zip_filename)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
||||
for file in files:
|
||||
if os.path.exists(file):
|
||||
# 只添加文件名,不包含路径
|
||||
zipf.write(file, os.path.basename(file))
|
||||
|
||||
return zip_path
|
||||
|
||||
@CatchException
|
||||
def 论文下载(txt: str, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||
"""
|
||||
txt: 用户输入,可以是DOI、arxiv ID或相关链接,支持多行输入进行批量下载
|
||||
"""
|
||||
from crazy_functions.doc_fns.text_content_loader import TextContentLoader
|
||||
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||
from crazy_functions.review_fns.data_sources.scihub_source import SciHub
|
||||
# 解析输入
|
||||
paper_infos = extract_paper_ids(txt)
|
||||
if not paper_infos:
|
||||
chatbot.append(["输入解析", "未能识别任何论文ID或DOI,请检查输入格式。支持以下格式:\n- arXiv ID (例如:2103.14030)\n- arXiv链接\n- DOI (例如:10.1234/xxx)\n- DOI链接\n\n多个论文ID请用换行分隔。"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 创建保存目录 - 使用时间戳创建唯一文件夹
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download')
|
||||
save_dir = os.path.join(base_save_dir, f"papers_{timestamp}")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_path = Path(save_dir)
|
||||
|
||||
# 记录下载结果
|
||||
success_count = 0
|
||||
failed_papers = []
|
||||
downloaded_files = [] # 记录成功下载的文件路径
|
||||
|
||||
chatbot.append([f"开始下载", f"支持多行输入下载多篇论文,共检测到 {len(paper_infos)} 篇论文,开始下载..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
for id_type, paper_id in paper_infos:
|
||||
try:
|
||||
if id_type == 'arxiv':
|
||||
chatbot.append([f"正在下载", f"从arXiv下载论文 {paper_id}..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 使用改进的arxiv查询方法
|
||||
formatted_id = format_arxiv_id(paper_id)
|
||||
paper_result = get_arxiv_paper(formatted_id)
|
||||
|
||||
if not paper_result:
|
||||
failed_papers.append((paper_id, "未找到论文"))
|
||||
continue
|
||||
|
||||
# 下载PDF
|
||||
try:
|
||||
filename = f"arxiv_{paper_id.replace('/', '_')}.pdf"
|
||||
pdf_path = str(save_path / filename)
|
||||
paper_result.download_pdf(filename=pdf_path)
|
||||
if os.path.exists(pdf_path):
|
||||
downloaded_files.append(pdf_path)
|
||||
except Exception as e:
|
||||
failed_papers.append((paper_id, f"PDF下载失败: {str(e)}"))
|
||||
continue
|
||||
|
||||
else: # doi
|
||||
chatbot.append([f"正在下载", f"从Sci-Hub下载论文 {paper_id}..."])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
sci_hub = SciHub(
|
||||
doi=paper_id,
|
||||
path=save_path
|
||||
)
|
||||
pdf_path = sci_hub.fetch()
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
downloaded_files.append(pdf_path)
|
||||
|
||||
# 检查下载结果
|
||||
if pdf_path and os.path.exists(pdf_path):
|
||||
promote_file_to_downloadzone(pdf_path, chatbot=chatbot)
|
||||
success_count += 1
|
||||
else:
|
||||
failed_papers.append((paper_id, "下载失败"))
|
||||
|
||||
except Exception as e:
|
||||
failed_papers.append((paper_id, str(e)))
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 创建ZIP压缩包
|
||||
if downloaded_files:
|
||||
try:
|
||||
zip_path = create_zip_archive(downloaded_files, Path(base_save_dir))
|
||||
promote_file_to_downloadzone(zip_path, chatbot=chatbot)
|
||||
chatbot.append([
|
||||
f"创建压缩包",
|
||||
f"已将所有下载的论文打包为: {os.path.basename(zip_path)}"
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
except Exception as e:
|
||||
chatbot.append([
|
||||
f"创建压缩包失败",
|
||||
f"打包文件时出现错误: {str(e)}"
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 生成最终报告
|
||||
summary = f"下载完成!成功下载 {success_count} 篇论文。\n"
|
||||
if failed_papers:
|
||||
summary += "\n以下论文下载失败:\n"
|
||||
for paper_id, reason in failed_papers:
|
||||
summary += f"- {paper_id}: {reason}\n"
|
||||
|
||||
if downloaded_files:
|
||||
summary += f"\n所有论文已存放在文件夹 '{save_dir}' 中,并打包到压缩文件中。您可以在下载区找到单个PDF文件和压缩包。"
|
||||
|
||||
chatbot.append([
|
||||
f"下载完成",
|
||||
summary
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 如果下载成功且用户想要直接阅读内容
|
||||
if downloaded_files:
|
||||
chatbot.append([
|
||||
"提示",
|
||||
"正在读取论文内容进行分析,请稍候..."
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 使用TextContentLoader加载整个文件夹的PDF文件内容
|
||||
loader = TextContentLoader(chatbot, history)
|
||||
|
||||
# 删除提示信息
|
||||
chatbot.pop()
|
||||
|
||||
# 加载PDF内容 - 传入文件夹路径而不是单个文件路径
|
||||
yield from loader.execute(save_dir)
|
||||
|
||||
# 添加提示信息
|
||||
chatbot.append([
|
||||
"提示",
|
||||
"论文内容已加载完毕,您可以直接向AI提问有关该论文的问题。"
|
||||
])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
import asyncio
|
||||
async def test():
|
||||
# 测试批量输入
|
||||
batch_inputs = [
|
||||
# 换行分隔的测试
|
||||
"""https://arxiv.org/abs/2103.14030
|
||||
math/0211159
|
||||
10.1038/s41586-021-03819-2""",
|
||||
|
||||
# 空格分隔的测试
|
||||
"https://arxiv.org/abs/2103.14030 math/0211159 10.1038/s41586-021-03819-2",
|
||||
|
||||
# 混合分隔的测试
|
||||
"""https://arxiv.org/abs/2103.14030 math/0211159
|
||||
10.1038/s41586-021-03819-2 https://doi.org/10.1038/s41586-021-03819-2
|
||||
2103.14030""",
|
||||
]
|
||||
|
||||
for i, test_input in enumerate(batch_inputs, 1):
|
||||
print(f"\n测试用例 {i}:")
|
||||
print(f"输入: {test_input}")
|
||||
results = extract_paper_ids(test_input)
|
||||
print(f"解析结果:")
|
||||
for result in results:
|
||||
print(f" {result}")
|
||||
|
||||
asyncio.run(test())
|
||||
867
crazy_functions/paper_fns/reduce_aigc.py
Normal file
867
crazy_functions/paper_fns/reduce_aigc.py
Normal file
@@ -0,0 +1,867 @@
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
import threading
|
||||
from typing import Dict, List, Generator, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||
from crazy_functions.rag_fns.rag_file_support import extract_text, convert_to_markdown
|
||||
from request_llms.bridge_all import model_info
|
||||
from toolbox import update_ui, CatchException, report_exception, promote_file_to_downloadzone, write_history_to_file
|
||||
from shared_utils.fastapi_server import validate_path_safety
|
||||
|
||||
# 新增:导入结构化论文提取器
|
||||
from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor, ExtractorConfig, StructuredPaper
|
||||
|
||||
# 导入格式化器
|
||||
from crazy_functions.paper_fns.file2file_doc import (
|
||||
TxtFormatter,
|
||||
MarkdownFormatter,
|
||||
HtmlFormatter,
|
||||
WordFormatter
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class TextFragment:
|
||||
"""文本片段数据类,用于组织处理单元"""
|
||||
content: str
|
||||
fragment_index: int
|
||||
total_fragments: int
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理器 - 处理单个文档并输出结果"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||
"""初始化处理器"""
|
||||
self.llm_kwargs = llm_kwargs.copy() # 创建原始llm_kwargs的副本
|
||||
# 固定使用deepseek-reasoner模型
|
||||
self.llm_kwargs['llm_model'] = "deepseek-reasoner"
|
||||
self.plugin_kwargs = plugin_kwargs
|
||||
self.chatbot = chatbot
|
||||
self.history = history
|
||||
self.system_prompt = system_prompt
|
||||
self.processed_results = []
|
||||
self.failed_fragments = []
|
||||
# 新增:初始化论文结构提取器
|
||||
self.paper_extractor = PaperStructureExtractor()
|
||||
# 修改:设置降重次数,默认为1次(原为3次)
|
||||
self.reduction_times = int(self.plugin_kwargs.get("reduction_times", 1))
|
||||
|
||||
def _get_token_limit(self) -> int:
|
||||
"""获取模型token限制,返回更小的值以确保更细粒度的分割"""
|
||||
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||
# 降低token限制,使每个片段更小
|
||||
return max_token // 4 # 从3/4降低到1/4
|
||||
|
||||
def _create_batch_inputs(self, fragments: List[TextFragment], current_iteration: int = 1) -> Tuple[List, List, List]:
|
||||
"""创建批处理输入"""
|
||||
inputs_array = []
|
||||
inputs_show_user_array = []
|
||||
history_array = []
|
||||
|
||||
# 基础指令 - 增强单次调用的效果
|
||||
user_instruction = self.plugin_kwargs.get("advanced_arg", """你是一个人类科研工作者,请对以下学术文本进行深度改写,以彻底消除AI生成特征。具体要求如下:
|
||||
|
||||
1. 保持学术写作的严谨性和专业性,但加入一些人类写作的不完美元素
|
||||
2. 维持原文的核心论述和逻辑框架,但打破过于完美的结构
|
||||
3. 彻底改变句式结构:
|
||||
- 灵活运用主动句与被动句,偏好使用主动语态
|
||||
- 适当使用不同长度的句子,包括一些较长的复杂句
|
||||
- 引入变化多样的句式,打破规律性
|
||||
- 完全避免AI常用的模板化句式和套路表达
|
||||
- 增加一些学术写作中常见的转折和连接方式
|
||||
4. 全面改善用词:
|
||||
- 使用更多学术语境下的专业词汇和同义词替换
|
||||
- 避免过于机械和规律性的连接词,使用更自然的过渡
|
||||
- 重构专业术语的表达方式,但保持准确性
|
||||
- 增加词汇多样性,减少重复用词
|
||||
- 偶尔使用一些不太常见但恰当的词汇
|
||||
5. 模拟真实学者的写作风格:
|
||||
- 注重论证的严密性,但允许存在一些微小的不对称性
|
||||
- 保持表达的客观性,同时适度体现个人学术见解
|
||||
- 在适当位置表达观点时更加自信和坚定
|
||||
- 避免过于完美和机械均衡的论述结构
|
||||
- 允许段落长度有所变化,不要过于均匀
|
||||
6. 引入人类学者常见的写作特点:
|
||||
- 段落之间的过渡更加自然流畅
|
||||
- 适当使用一些学术界常见的修辞手法,但不过度使用
|
||||
- 偶尔使用一些强调和限定性表达
|
||||
- 适当使用一些学术界认可的个人化表达
|
||||
7. 彻底消除AI痕迹:
|
||||
- 避免过于规整和均衡的段落结构
|
||||
- 避免机械性的句式变化和词汇替换模式
|
||||
- 避免过于完美的逻辑推导,适当增加一些转折
|
||||
- 减少公式化的表达方式""")
|
||||
|
||||
# 对于单次调用的场景,不需要迭代前缀,直接使用更强力的改写指令
|
||||
for frag in fragments:
|
||||
# 在单次调用时使用更强力的指令
|
||||
if self.reduction_times == 1:
|
||||
i_say = (f'请对以下学术文本进行彻底改写,完全消除AI特征,使其像真实人类学者撰写的内容。\n\n{user_instruction}\n\n'
|
||||
f'请记住以下几点:\n'
|
||||
f'1. 避免过于规整和均衡的结构\n'
|
||||
f'2. 引入一些人类写作的微小不完美之处\n'
|
||||
f'3. 使用多样化的句式和词汇\n'
|
||||
f'4. 打破可能的AI规律性表达模式\n'
|
||||
f'5. 适当使用一些专业领域内的表达习惯\n\n'
|
||||
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
|
||||
f'文本内容:\n```\n{frag.content}\n```')
|
||||
else:
|
||||
# 原有的迭代前缀逻辑
|
||||
iteration_prefix = ""
|
||||
if current_iteration > 1:
|
||||
iteration_prefix = f"这是第{current_iteration}次改写,请在保持学术性的基础上,采用更加人性化、不同的表达方式。"
|
||||
if current_iteration == 2:
|
||||
iteration_prefix += "在保持专业性的同时,进一步优化句式结构和用词,显著降低AI痕迹。"
|
||||
elif current_iteration >= 3:
|
||||
iteration_prefix += "请在确保不损失任何学术内容的前提下,彻底重构表达方式,并适当引入少量人类学者常用的表达技巧,避免过度使用比喻和类比。"
|
||||
|
||||
i_say = (f'请按照以下要求处理文本内容:{iteration_prefix}{user_instruction}\n\n'
|
||||
f'请将对文本的处理结果放在<decision>和</decision>标签之间。\n\n'
|
||||
f'文本内容:\n```\n{frag.content}\n```')
|
||||
|
||||
i_say_show_user = f'正在处理文本片段 {frag.fragment_index + 1}/{frag.total_fragments}'
|
||||
|
||||
inputs_array.append(i_say)
|
||||
inputs_show_user_array.append(i_say_show_user)
|
||||
history_array.append([])
|
||||
|
||||
return inputs_array, inputs_show_user_array, history_array
|
||||
|
||||
def _extract_decision(self, text: str) -> str:
|
||||
"""从LLM响应中提取<decision>标签内的内容"""
|
||||
import re
|
||||
pattern = r'<decision>(.*?)</decision>'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
return matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到标签,返回原始文本
|
||||
return text.strip()
|
||||
|
||||
def process_file(self, file_path: str) -> Generator:
|
||||
"""处理单个文件"""
|
||||
self.chatbot.append(["开始处理文件", f"文件路径: {file_path}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
# 首先尝试转换为Markdown
|
||||
file_path = convert_to_markdown(file_path)
|
||||
|
||||
# 1. 检查文件是否为支持的论文格式
|
||||
is_paper_format = any(file_path.lower().endswith(ext) for ext in self.paper_extractor.SUPPORTED_EXTENSIONS)
|
||||
|
||||
if is_paper_format:
|
||||
# 使用结构化提取器处理论文
|
||||
return (yield from self._process_structured_paper(file_path))
|
||||
else:
|
||||
# 使用原有方式处理普通文档
|
||||
return (yield from self._process_regular_file(file_path))
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["处理错误", f"文件处理失败: {str(e)}"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
def _process_structured_paper(self, file_path: str) -> Generator:
|
||||
"""处理结构化论文文件"""
|
||||
# 1. 提取论文结构
|
||||
self.chatbot[-1] = ["正在分析论文结构", f"文件路径: {file_path}"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
try:
|
||||
paper = self.paper_extractor.extract_paper_structure(file_path)
|
||||
|
||||
if not paper or not paper.sections:
|
||||
self.chatbot.append(["无法提取论文结构", "将使用全文内容进行处理"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用全文内容进行段落切分
|
||||
if paper and paper.full_text:
|
||||
# 使用增强的分割函数进行更细致的分割
|
||||
fragments = self._breakdown_section_content(paper.full_text)
|
||||
|
||||
# 创建文本片段对象
|
||||
text_fragments = []
|
||||
for i, frag in enumerate(fragments):
|
||||
if frag.strip():
|
||||
text_fragments.append(TextFragment(
|
||||
content=frag,
|
||||
fragment_index=i,
|
||||
total_fragments=len(fragments)
|
||||
))
|
||||
|
||||
# 多次降重处理
|
||||
if text_fragments:
|
||||
current_fragments = text_fragments
|
||||
|
||||
# 进行多轮降重处理
|
||||
for iteration in range(1, self.reduction_times + 1):
|
||||
# 处理当前片段
|
||||
processed_content = yield from self._process_text_fragments(current_fragments, iteration)
|
||||
|
||||
# 如果这是最后一次迭代,保存结果
|
||||
if iteration == self.reduction_times:
|
||||
final_content = processed_content
|
||||
break
|
||||
|
||||
# 否则,准备下一轮迭代的片段
|
||||
# 从处理结果中提取处理后的内容
|
||||
next_fragments = []
|
||||
for idx, item in enumerate(self.processed_results):
|
||||
next_fragments.append(TextFragment(
|
||||
content=item['content'],
|
||||
fragment_index=idx,
|
||||
total_fragments=len(self.processed_results)
|
||||
))
|
||||
|
||||
current_fragments = next_fragments
|
||||
|
||||
# 更新UI显示最终结果
|
||||
self.chatbot[-1] = ["处理完成", f"共完成 {self.reduction_times} 轮降重"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return final_content
|
||||
else:
|
||||
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
else:
|
||||
self.chatbot.append(["处理失败", "未能提取到论文内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 2. 准备处理章节内容(不处理标题)
|
||||
self.chatbot[-1] = ["已提取论文结构", f"共 {len(paper.sections)} 个主要章节"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 3. 收集所有需要处理的章节内容并分割为合适大小
|
||||
sections_to_process = []
|
||||
section_map = {} # 用于映射处理前后的内容
|
||||
|
||||
def collect_section_contents(sections, parent_path=""):
|
||||
"""递归收集章节内容,跳过参考文献部分"""
|
||||
for i, section in enumerate(sections):
|
||||
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
|
||||
|
||||
# 检查是否为参考文献部分,如果是则跳过
|
||||
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
|
||||
continue # 跳过参考文献部分
|
||||
|
||||
# 只处理内容非空的章节
|
||||
if section.content and section.content.strip():
|
||||
# 使用增强的分割函数进行更细致的分割
|
||||
fragments = self._breakdown_section_content(section.content)
|
||||
|
||||
for fragment_idx, fragment_content in enumerate(fragments):
|
||||
if fragment_content.strip():
|
||||
fragment_index = len(sections_to_process)
|
||||
sections_to_process.append(TextFragment(
|
||||
content=fragment_content,
|
||||
fragment_index=fragment_index,
|
||||
total_fragments=0 # 临时值,稍后更新
|
||||
))
|
||||
|
||||
# 保存映射关系,用于稍后更新章节内容
|
||||
# 为每个片段存储原始章节和片段索引信息
|
||||
section_map[fragment_index] = (current_path, section, fragment_idx, len(fragments))
|
||||
|
||||
# 递归处理子章节
|
||||
if section.subsections:
|
||||
collect_section_contents(section.subsections, current_path)
|
||||
|
||||
# 收集所有章节内容
|
||||
collect_section_contents(paper.sections)
|
||||
|
||||
# 更新总片段数
|
||||
total_fragments = len(sections_to_process)
|
||||
for frag in sections_to_process:
|
||||
frag.total_fragments = total_fragments
|
||||
|
||||
# 4. 如果没有内容需要处理,直接返回
|
||||
if not sections_to_process:
|
||||
self.chatbot.append(["处理完成", "未找到需要处理的内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 5. 批量处理章节内容
|
||||
self.chatbot[-1] = ["开始处理论文内容", f"共 {len(sections_to_process)} 个内容片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 一次性准备所有输入
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(sections_to_process)
|
||||
|
||||
# 使用系统提示
|
||||
instruction = self.plugin_kwargs.get("advanced_arg", """请对以下学术文本进行彻底改写,以显著降低AI生成特征。具体要求如下:
|
||||
|
||||
1. 保持学术写作的严谨性和专业性
|
||||
2. 维持原文的核心论述和逻辑框架
|
||||
3. 优化句式结构:
|
||||
- 灵活运用主动句与被动句
|
||||
- 适当拆分复杂句式,提高可读性
|
||||
- 注意句式的多样性,避免重复模式
|
||||
- 打破AI常用的句式模板
|
||||
4. 改善用词:
|
||||
- 使用更多学术语境下的同义词替换
|
||||
- 避免过于机械和规律性的连接词
|
||||
- 适当调整专业术语的表达方式
|
||||
- 增加词汇多样性,减少重复用词
|
||||
5. 增强文本的学术特征:
|
||||
- 注重论证的严密性
|
||||
- 保持表达的客观性
|
||||
- 适度体现作者的学术见解
|
||||
- 避免过于完美和均衡的论述结构
|
||||
6. 确保语言风格的一致性
|
||||
7. 减少AI生成文本常见的套路和模式""")
|
||||
sys_prompt_array = [f"""作为一位专业的学术写作顾问,请按照以下要求改写文本:
|
||||
|
||||
1. 严格保持学术写作规范
|
||||
2. 维持原文的核心论述和逻辑框架
|
||||
3. 通过优化句式结构和用词降低AI生成特征
|
||||
4. 确保语言风格的一致性和专业性
|
||||
5. 保持内容的客观性和准确性
|
||||
6. 避免AI常见的套路化表达和过于完美的结构"""] * len(sections_to_process)
|
||||
|
||||
# 调用LLM一次性处理所有片段
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应,重组章节内容
|
||||
section_contents = {} # 用于重组各章节的处理后内容
|
||||
|
||||
for j, frag in enumerate(sections_to_process):
|
||||
try:
|
||||
llm_response = response_collection[j * 2 + 1]
|
||||
processed_text = self._extract_decision(llm_response)
|
||||
|
||||
if processed_text and processed_text.strip():
|
||||
# 保存处理结果
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': processed_text
|
||||
})
|
||||
|
||||
# 存储处理后的文本片段,用于后续重组
|
||||
fragment_index = frag.fragment_index
|
||||
if fragment_index in section_map:
|
||||
path, section, fragment_idx, total_fragments = section_map[fragment_index]
|
||||
|
||||
# 初始化此章节的内容容器(如果尚未创建)
|
||||
if path not in section_contents:
|
||||
section_contents[path] = [""] * total_fragments
|
||||
|
||||
# 将处理后的片段放入正确位置
|
||||
section_contents[path][fragment_idx] = processed_text
|
||||
else:
|
||||
self.failed_fragments.append(frag)
|
||||
except Exception as e:
|
||||
self.failed_fragments.append(frag)
|
||||
|
||||
# 重组每个章节的内容
|
||||
for path, fragments in section_contents.items():
|
||||
section = None
|
||||
for idx in section_map:
|
||||
if section_map[idx][0] == path:
|
||||
section = section_map[idx][1]
|
||||
break
|
||||
|
||||
if section:
|
||||
# 合并该章节的所有处理后片段
|
||||
section.content = "\n".join(fragments)
|
||||
|
||||
# 6. 更新UI
|
||||
success_count = total_fragments - len(self.failed_fragments)
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 收集参考文献部分(不进行处理)
|
||||
references_sections = []
|
||||
def collect_references(sections, parent_path=""):
|
||||
"""递归收集参考文献部分"""
|
||||
for i, section in enumerate(sections):
|
||||
current_path = f"{parent_path}/{i}" if parent_path else f"{i}"
|
||||
|
||||
# 检查是否为参考文献部分
|
||||
if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']:
|
||||
references_sections.append((current_path, section))
|
||||
|
||||
# 递归检查子章节
|
||||
if section.subsections:
|
||||
collect_references(section.subsections, current_path)
|
||||
|
||||
# 收集参考文献
|
||||
collect_references(paper.sections)
|
||||
|
||||
# 7. 将处理后的结构化论文转换为Markdown
|
||||
markdown_content = self.paper_extractor.generate_markdown(paper)
|
||||
|
||||
# 8. 返回处理后的内容
|
||||
self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段,参考文献部分未处理"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
self.chatbot.append(["结构化处理失败", f"错误: {str(e)},将尝试作为普通文件处理"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return (yield from self._process_regular_file(file_path))
|
||||
|
||||
def _process_regular_file(self, file_path: str) -> Generator:
|
||||
"""使用原有方式处理普通文件"""
|
||||
# 原有的文件处理逻辑
|
||||
self.chatbot[-1] = ["正在读取文件", f"文件路径: {file_path}"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
content = extract_text(file_path)
|
||||
if not content or not content.strip():
|
||||
self.chatbot.append(["处理失败", "文件内容为空或无法提取内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 2. 分割文本
|
||||
self.chatbot[-1] = ["正在分析文件", "将文件内容分割为适当大小的片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 使用增强的分割函数
|
||||
fragments = self._breakdown_section_content(content)
|
||||
|
||||
# 3. 创建文本片段对象
|
||||
text_fragments = []
|
||||
for i, frag in enumerate(fragments):
|
||||
if frag.strip():
|
||||
text_fragments.append(TextFragment(
|
||||
content=frag,
|
||||
fragment_index=i,
|
||||
total_fragments=len(fragments)
|
||||
))
|
||||
|
||||
# 4. 多轮降重处理
|
||||
if not text_fragments:
|
||||
self.chatbot.append(["处理失败", "未能提取到有效的文本内容"])
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
return None
|
||||
|
||||
# 批处理大小
|
||||
batch_size = 8 # 每批处理的片段数
|
||||
|
||||
# 第一次迭代
|
||||
current_batches = []
|
||||
for i in range(0, len(text_fragments), batch_size):
|
||||
current_batches.append(text_fragments[i:i + batch_size])
|
||||
|
||||
all_processed_fragments = []
|
||||
|
||||
# 进行多轮降重处理
|
||||
for iteration in range(1, self.reduction_times + 1):
|
||||
self.chatbot[-1] = ["开始处理文本", f"第 {iteration}/{self.reduction_times} 次降重"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
next_batches = []
|
||||
all_processed_fragments = []
|
||||
|
||||
# 分批处理当前迭代的片段
|
||||
for batch in current_batches:
|
||||
# 处理当前批次
|
||||
_ = yield from self._process_text_fragments(batch, iteration)
|
||||
|
||||
# 收集处理结果
|
||||
processed_batch = []
|
||||
for item in self.processed_results:
|
||||
processed_batch.append(TextFragment(
|
||||
content=item['content'],
|
||||
fragment_index=len(all_processed_fragments) + len(processed_batch),
|
||||
total_fragments=0 # 临时值,稍后更新
|
||||
))
|
||||
|
||||
all_processed_fragments.extend(processed_batch)
|
||||
|
||||
# 如果不是最后一轮迭代,准备下一批次
|
||||
if iteration < self.reduction_times:
|
||||
for i in range(0, len(processed_batch), batch_size):
|
||||
next_batches.append(processed_batch[i:i + batch_size])
|
||||
|
||||
# 更新总片段数
|
||||
for frag in all_processed_fragments:
|
||||
frag.total_fragments = len(all_processed_fragments)
|
||||
|
||||
# 为下一轮迭代准备批次
|
||||
current_batches = next_batches
|
||||
|
||||
# 合并最终结果
|
||||
final_content = "\n\n".join([frag.content for frag in all_processed_fragments])
|
||||
|
||||
# 5. 更新UI显示最终结果
|
||||
self.chatbot[-1] = ["处理完成", f"共完成 {self.reduction_times} 轮降重"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return final_content
|
||||
|
||||
def save_results(self, content: str, original_file_path: str) -> List[str]:
|
||||
"""保存处理结果为TXT格式"""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
original_filename = os.path.basename(original_file_path)
|
||||
filename_without_ext = os.path.splitext(original_filename)[0]
|
||||
base_filename = f"{filename_without_ext}_processed_{timestamp}"
|
||||
|
||||
result_files = []
|
||||
|
||||
# 只保存为TXT
|
||||
try:
|
||||
txt_formatter = TxtFormatter()
|
||||
txt_content = txt_formatter.create_document(content)
|
||||
txt_file = write_history_to_file(
|
||||
history=[txt_content],
|
||||
file_basename=f"{base_filename}.txt"
|
||||
)
|
||||
result_files.append(txt_file)
|
||||
except Exception as e:
|
||||
self.chatbot.append(["警告", f"TXT格式保存失败: {str(e)}"])
|
||||
|
||||
# 添加到下载区
|
||||
for file in result_files:
|
||||
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||
|
||||
return result_files
|
||||
|
||||
def _breakdown_section_content(self, content: str) -> List[str]:
|
||||
"""对文本内容进行分割与合并
|
||||
|
||||
主要按段落进行组织,只合并较小的段落以减少片段数量
|
||||
保留原始段落结构,不对长段落进行强制分割
|
||||
针对中英文设置不同的阈值,因为字符密度不同
|
||||
"""
|
||||
# 先按段落分割文本
|
||||
paragraphs = content.split('\n\n')
|
||||
|
||||
# 检测语言类型
|
||||
chinese_char_count = sum(1 for char in content if '\u4e00' <= char <= '\u9fff')
|
||||
is_chinese_text = chinese_char_count / max(1, len(content)) > 0.3
|
||||
|
||||
# 根据语言类型设置不同的阈值(只用于合并小段落)
|
||||
if is_chinese_text:
|
||||
# 中文文本:一个汉字就是一个字符,信息密度高
|
||||
min_chunk_size = 300 # 段落合并的最小阈值
|
||||
target_size = 800 # 理想的段落大小
|
||||
else:
|
||||
# 英文文本:一个单词由多个字符组成,信息密度低
|
||||
min_chunk_size = 600 # 段落合并的最小阈值
|
||||
target_size = 1600 # 理想的段落大小
|
||||
|
||||
# 1. 只合并小段落,不对长段落进行分割
|
||||
result_fragments = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for para in paragraphs:
|
||||
# 如果段落太小且不会超过目标大小,则合并
|
||||
if len(para) < min_chunk_size and current_length + len(para) <= target_size:
|
||||
current_chunk.append(para)
|
||||
current_length += len(para)
|
||||
# 否则,创建新段落
|
||||
else:
|
||||
# 如果当前块非空且与当前段落无关,先保存它
|
||||
if current_chunk and current_length > 0:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
|
||||
# 当前段落作为新块
|
||||
current_chunk = [para]
|
||||
current_length = len(para)
|
||||
|
||||
# 如果当前块大小已接近目标大小,保存并开始新块
|
||||
if current_length >= target_size:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
# 保存最后一个块
|
||||
if current_chunk:
|
||||
result_fragments.append('\n\n'.join(current_chunk))
|
||||
|
||||
# 2. 处理可能过大的片段(确保不超过token限制)
|
||||
final_fragments = []
|
||||
max_token = self._get_token_limit()
|
||||
|
||||
for fragment in result_fragments:
|
||||
# 检查fragment是否可能超出token限制
|
||||
# 根据语言类型调整token估算
|
||||
if is_chinese_text:
|
||||
estimated_tokens = len(fragment) / 1.5 # 中文每个token约1-2个字符
|
||||
else:
|
||||
estimated_tokens = len(fragment) / 4 # 英文每个token约4个字符
|
||||
|
||||
if estimated_tokens > max_token:
|
||||
# 即使可能超出限制,也尽量保持段落的完整性
|
||||
# 使用breakdown_text但设置更大的限制来减少分割
|
||||
larger_limit = max_token * 0.95 # 使用95%的限制
|
||||
sub_fragments = breakdown_text_to_satisfy_token_limit(
|
||||
txt=fragment,
|
||||
limit=larger_limit,
|
||||
llm_model=self.llm_kwargs['llm_model']
|
||||
)
|
||||
final_fragments.extend(sub_fragments)
|
||||
else:
|
||||
final_fragments.append(fragment)
|
||||
|
||||
return final_fragments
|
||||
|
||||
def _process_text_fragments(self, text_fragments: List[TextFragment], current_iteration: int = 1) -> str:
|
||||
"""处理文本片段,支持多次降重
|
||||
|
||||
Args:
|
||||
text_fragments: 要处理的文本片段列表
|
||||
current_iteration: 当前是第几次降重迭代
|
||||
|
||||
Returns:
|
||||
处理后的文本内容
|
||||
"""
|
||||
self.chatbot[-1] = ["开始处理文本", f"第 {current_iteration}/{self.reduction_times} 次降重,共 {len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
# 重置处理结果,为当前迭代做准备
|
||||
self.processed_results = []
|
||||
self.failed_fragments = []
|
||||
|
||||
# 一次性准备所有输入
|
||||
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(text_fragments, current_iteration)
|
||||
|
||||
# 对于单次调用的特殊处理 - 使用更强力的系统提示词
|
||||
if self.reduction_times == 1:
|
||||
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。
|
||||
|
||||
你的任务是对学术文本进行深度改写,以彻底消除AI特征,使其具有真实人类学者的写作特点:
|
||||
|
||||
1. 打破AI常见的规整句式结构和词汇选择模式
|
||||
2. 引入学术写作中的自然变化和不规则性
|
||||
3. 使用多样化的句式和段落结构,避免明显的平衡性和规律性
|
||||
4. 加入人类写作中常见的微小不完美之处,如段落长度变化、论述强调点不均等
|
||||
5. 在维持核心论点的同时,完全重构表达方式
|
||||
6. 适当使用学科领域内专家常用的表达习惯和术语
|
||||
7. 确保文本读起来自然流畅,避免机械感
|
||||
8. 打破AI生成内容常见的逻辑结构和表达套路
|
||||
9. 根据文本内容,在保持专业性的基础上增加写作的特色和个性
|
||||
|
||||
对于每个文本片段,尽可能彻底地重构,使其无法被识别为AI生成。"""
|
||||
|
||||
sys_prompt_array = [sys_prompt] * len(text_fragments)
|
||||
|
||||
else:
|
||||
# 原有的多次迭代处理逻辑
|
||||
# 根据迭代次数调整系统提示词强度
|
||||
base_instruction = self.plugin_kwargs.get("advanced_arg", """你是一个人类科研工作者,请对以下学术文本进行彻底改写,以显著降低AI生成特征。具体要求如下:
|
||||
|
||||
1. 保持学术写作的严谨性和专业性
|
||||
2. 维持原文的核心论述和逻辑框架
|
||||
3. 优化句式结构:
|
||||
- 灵活运用主动句与被动句
|
||||
- 适当拆分复杂句式,提高可读性
|
||||
- 注意句式的多样性,避免重复模式
|
||||
- 打破AI常用的句式模板
|
||||
4. 改善用词:
|
||||
- 使用更多学术语境下的同义词替换
|
||||
- 避免过于机械和规律性的连接词
|
||||
- 适当调整专业术语的表达方式
|
||||
- 增加词汇多样性,减少重复用词
|
||||
5. 增强文本的学术特征:
|
||||
- 注重论证的严密性
|
||||
- 保持表达的客观性
|
||||
- 适度体现作者的学术见解
|
||||
- 避免过于完美和均衡的论述结构
|
||||
6. 确保语言风格的一致性
|
||||
7. 减少AI生成文本常见的套路和模式""")
|
||||
|
||||
# 根据迭代次数增强强度
|
||||
if current_iteration == 1:
|
||||
# 第一次迭代使用基础强度
|
||||
instruction = base_instruction
|
||||
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第{current_iteration}轮深度改写:
|
||||
|
||||
1. 严格保持学术写作规范
|
||||
2. 维持原文的核心论述和逻辑框架
|
||||
3. 通过优化句式结构和用词降低AI生成特征
|
||||
4. 确保语言风格的一致性和专业性
|
||||
5. 保持内容的客观性和准确性
|
||||
6. 避免AI常见的套路化表达和过于完美的结构"""
|
||||
elif current_iteration == 2:
|
||||
# 第二次迭代使用中等强度
|
||||
instruction = base_instruction + """这是第二次改写,请进一步优化:
|
||||
|
||||
1. 更深入地调整句式结构,增加表达多样性
|
||||
2. 在保持准确性的前提下,替换更多专业术语的表达方式
|
||||
3. 优化段落结构,增强逻辑连贯性
|
||||
4. 适当调整学术论证的表达方式
|
||||
5. 打破AI生成内容的常见模式,如过于均衡的结构和规律性的表达方式"""
|
||||
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第二轮改写:
|
||||
|
||||
1. 在保持严谨性的基础上,优化表达方式
|
||||
2. 改善文本结构,提升学术论述的清晰度
|
||||
3. 增强专业术语使用的灵活性
|
||||
4. 确保论证过程的完整性和逻辑性
|
||||
5. 进一步减少AI生成内容的特征和模式
|
||||
6. 使文本表现出更自然、更人性化的写作风格"""
|
||||
else:
|
||||
# 第三次或更多次迭代使用高强度
|
||||
instruction = base_instruction + f"""这是第{current_iteration}次改写,需要在保持学术性的前提下彻底重构文本:
|
||||
|
||||
1. 全面重构句式,采用不同的表达策略
|
||||
2. 系统性替换学术用语,保持专业准确性
|
||||
3. 重组段落结构,优化论证层次
|
||||
4. 深化学术观点的表达方式
|
||||
5. 增强文本的逻辑严密性
|
||||
6. 提升论述的专业深度
|
||||
7. 确保不遗漏任何学术观点和论证要素
|
||||
8. 适当使用学术表达技巧(如精准举例、恰当转折等),但不过分依赖比喻和类比
|
||||
9. 彻底消除AI生成内容的特征,使文本具有真实学者的写作风格"""
|
||||
sys_prompt = f"""你是一个伪装成人类科研学者的AI,能够将各种AI生成内容改写成人类学者常用的表达方式,让人类学者无法分辨。请进行第{current_iteration}轮深度改写:
|
||||
|
||||
1. 在保持专业水准的前提下,彻底重构表达方式,引入长难句
|
||||
2. 确保学术论证的严密性和完整性
|
||||
3. 优化专业术语的运用
|
||||
4. 提升文本的学术价值
|
||||
5. 保持论述的逻辑性和连贯性
|
||||
6. 适当使用少量学术表达技巧,提高文本说服力,但避免过度使用比喻和类比
|
||||
7. 消除所有明显的AI生成痕迹,使文本更接近真实学者的写作风格"""
|
||||
|
||||
sys_prompt_array = [sys_prompt] * len(text_fragments)
|
||||
|
||||
# 调用LLM一次性处理所有片段
|
||||
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||
inputs_array=inputs_array,
|
||||
inputs_show_user_array=inputs_show_user_array,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
chatbot=self.chatbot,
|
||||
history_array=history_array,
|
||||
sys_prompt_array=sys_prompt_array,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
for j, frag in enumerate(text_fragments):
|
||||
try:
|
||||
llm_response = response_collection[j * 2 + 1]
|
||||
processed_text = self._extract_decision(llm_response)
|
||||
|
||||
if processed_text and processed_text.strip():
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': processed_text
|
||||
})
|
||||
else:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content
|
||||
})
|
||||
except Exception as e:
|
||||
self.failed_fragments.append(frag)
|
||||
self.processed_results.append({
|
||||
'index': frag.fragment_index,
|
||||
'content': frag.content
|
||||
})
|
||||
|
||||
# 按原始顺序合并结果
|
||||
self.processed_results.sort(key=lambda x: x['index'])
|
||||
final_content = "\n".join([item['content'] for item in self.processed_results])
|
||||
|
||||
# 更新UI
|
||||
success_count = len(text_fragments) - len(self.failed_fragments)
|
||||
self.chatbot[-1] = ["当前阶段处理完成", f"第 {current_iteration}/{self.reduction_times} 次降重,成功处理 {success_count}/{len(text_fragments)} 个片段"]
|
||||
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||
|
||||
return final_content
|
||||
|
||||
|
||||
@CatchException
|
||||
def 学术降重(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||
history: List, system_prompt: str, user_request: str):
|
||||
"""主函数 - 文件到文件处理"""
|
||||
# 初始化
|
||||
# 从高级参数中提取降重次数
|
||||
if "advanced_arg" in plugin_kwargs and plugin_kwargs["advanced_arg"]:
|
||||
# 检查是否包含降重次数的设置
|
||||
match = re.search(r'reduction_times\s*=\s*(\d+)', plugin_kwargs["advanced_arg"])
|
||||
if match:
|
||||
reduction_times = int(match.group(1))
|
||||
# 替换掉高级参数中的reduction_times设置,但保留其他内容
|
||||
plugin_kwargs["advanced_arg"] = re.sub(r'reduction_times\s*=\s*\d+', '', plugin_kwargs["advanced_arg"]).strip()
|
||||
# 添加到plugin_kwargs中作为单独的参数
|
||||
plugin_kwargs["reduction_times"] = reduction_times
|
||||
|
||||
processor = DocumentProcessor(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||
chatbot.append(["函数插件功能", f"文件内容处理:将文档内容进行{processor.reduction_times}次降重处理"])
|
||||
|
||||
# 更新用户提示,提供关于降重策略的详细说明
|
||||
if processor.reduction_times == 1:
|
||||
chatbot.append(["降重策略", "将使用单次深度降重,这种方式能更有效地降低AI特征,减少查重率。我们采用特殊优化的提示词,通过一次性强力改写来实现降重效果。"])
|
||||
elif processor.reduction_times > 1:
|
||||
chatbot.append(["降重策略", f"将进行{processor.reduction_times}轮迭代降重,每轮降重都会基于上一轮的结果,并逐渐增加降重强度。请注意,多轮迭代可能会引入新的AI特征,单次强力降重通常效果更好。"])
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 验证输入路径
|
||||
if not os.path.exists(txt):
|
||||
report_exception(chatbot, history, a=f"解析路径: {txt}", b=f"找不到路径或无权访问: {txt}")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 验证路径安全性
|
||||
user_name = chatbot.get_user()
|
||||
validate_path_safety(txt, user_name)
|
||||
|
||||
# 获取文件列表
|
||||
if os.path.isfile(txt):
|
||||
# 单个文件处理
|
||||
file_paths = [txt]
|
||||
else:
|
||||
# 目录处理 - 类似批量文件询问插件
|
||||
project_folder = txt
|
||||
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||
|
||||
# 排除压缩文件
|
||||
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||
file_paths = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||
|
||||
# 过滤支持的文件格式
|
||||
file_paths = [f for f in file_paths if any(f.lower().endswith(ext) for ext in
|
||||
list(processor.paper_extractor.SUPPORTED_EXTENSIONS) + ['.json', '.csv', '.xlsx', '.xls'])]
|
||||
|
||||
if not file_paths:
|
||||
report_exception(chatbot, history, a=f"解析路径: {txt}", b="未找到支持的文件类型")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
return
|
||||
|
||||
# 处理文件
|
||||
if len(file_paths) > 1:
|
||||
chatbot.append(["发现多个文件", f"共找到 {len(file_paths)} 个文件,将处理第一个文件"])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 只处理第一个文件
|
||||
file_to_process = file_paths[0]
|
||||
processed_content = yield from processor.process_file(file_to_process)
|
||||
|
||||
if processed_content:
|
||||
# 保存结果
|
||||
result_files = processor.save_results(processed_content, file_to_process)
|
||||
|
||||
if result_files:
|
||||
chatbot.append(["处理完成", f"已生成 {len(result_files)} 个结果文件"])
|
||||
else:
|
||||
chatbot.append(["处理完成", "但未能保存任何结果文件"])
|
||||
else:
|
||||
chatbot.append(["处理失败", "未能生成有效的处理结果"])
|
||||
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
387
crazy_functions/paper_fns/wiki/wikipedia_api.py
Normal file
387
crazy_functions/paper_fns/wiki/wikipedia_api.py
Normal file
@@ -0,0 +1,387 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
import re
|
||||
import random
|
||||
import time
|
||||
|
||||
class WikipediaAPI:
|
||||
"""维基百科API调用实现"""
|
||||
|
||||
def __init__(self, language: str = "zh", user_agent: str = None,
|
||||
max_concurrent: int = 5, request_delay: float = 0.5):
|
||||
"""
|
||||
初始化维基百科API客户端
|
||||
|
||||
Args:
|
||||
language: 语言代码 (zh: 中文, en: 英文, ja: 日文等)
|
||||
user_agent: 用户代理信息,如果为None将使用默认值
|
||||
max_concurrent: 最大并发请求数
|
||||
request_delay: 请求间隔时间(秒)
|
||||
"""
|
||||
self.language = language
|
||||
self.base_url = f"https://{language}.wikipedia.org/w/api.php"
|
||||
self.user_agent = user_agent or "WikipediaAPIClient/1.0 (chatscholar@163.com)"
|
||||
self.headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "application/json"
|
||||
}
|
||||
# 添加并发控制
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.request_delay = request_delay
|
||||
self.last_request_time = 0
|
||||
|
||||
async def _make_request(self, url, params=None):
|
||||
"""
|
||||
发起API请求,包含并发控制和请求延迟
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
API响应数据
|
||||
"""
|
||||
# 使用信号量控制并发
|
||||
async with self.semaphore:
|
||||
# 添加请求间隔
|
||||
current_time = time.time()
|
||||
time_since_last_request = current_time - self.last_request_time
|
||||
if time_since_last_request < self.request_delay:
|
||||
await asyncio.sleep(self.request_delay - time_since_last_request)
|
||||
|
||||
# 设置随机延迟,避免规律性请求
|
||||
jitter = random.uniform(0, 0.2)
|
||||
await asyncio.sleep(jitter)
|
||||
|
||||
# 记录本次请求时间
|
||||
self.last_request_time = time.time()
|
||||
|
||||
# 发起请求
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 429: # Too Many Requests
|
||||
retry_after = int(response.headers.get('Retry-After', 5))
|
||||
print(f"达到请求限制,等待 {retry_after} 秒后重试...")
|
||||
await asyncio.sleep(retry_after)
|
||||
return await self._make_request(url, params)
|
||||
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return None
|
||||
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"请求错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(self, query: str, limit: int = 10, namespace: int = 0) -> List[Dict]:
|
||||
"""
|
||||
搜索维基百科文章
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量
|
||||
namespace: 命名空间 (0表示文章, 14表示分类等)
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
params = {
|
||||
"action": "query",
|
||||
"list": "search",
|
||||
"srsearch": query,
|
||||
"format": "json",
|
||||
"srlimit": limit,
|
||||
"srnamespace": namespace,
|
||||
"srprop": "snippet|titlesnippet|sectiontitle|categorysnippet|size|wordcount|timestamp|redirecttitle"
|
||||
}
|
||||
|
||||
data = await self._make_request(self.base_url, params)
|
||||
if not data:
|
||||
return []
|
||||
|
||||
search_results = data.get("query", {}).get("search", [])
|
||||
return search_results
|
||||
|
||||
async def get_page_content(self, title: str, section: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
获取维基百科页面内容
|
||||
|
||||
Args:
|
||||
title: 页面标题
|
||||
section: 特定章节编号(可选)
|
||||
|
||||
Returns:
|
||||
页面内容字典
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
params = {
|
||||
"action": "parse",
|
||||
"page": title,
|
||||
"format": "json",
|
||||
"prop": "text|langlinks|categories|links|templates|images|externallinks|sections|revid|displaytitle|iwlinks|properties"
|
||||
}
|
||||
|
||||
# 如果指定了章节,只获取该章节内容
|
||||
if section is not None:
|
||||
params["section"] = section
|
||||
|
||||
async with session.get(self.base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
return {}
|
||||
|
||||
data = await response.json()
|
||||
if "error" in data:
|
||||
print(f"API错误: {data['error'].get('info', '未知错误')}")
|
||||
return {}
|
||||
|
||||
return data.get("parse", {})
|
||||
|
||||
async def get_summary(self, title: str, sentences: int = 3) -> str:
|
||||
"""
|
||||
获取页面摘要
|
||||
|
||||
Args:
|
||||
title: 页面标题
|
||||
sentences: 返回的句子数量
|
||||
|
||||
Returns:
|
||||
页面摘要文本
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
params = {
|
||||
"action": "query",
|
||||
"prop": "extracts",
|
||||
"exintro": "1",
|
||||
"exsentences": sentences,
|
||||
"explaintext": "1",
|
||||
"titles": title,
|
||||
"format": "json"
|
||||
}
|
||||
|
||||
async with session.get(self.base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
return ""
|
||||
|
||||
data = await response.json()
|
||||
pages = data.get("query", {}).get("pages", {})
|
||||
# 获取第一个页面ID的内容
|
||||
for page_id in pages:
|
||||
return pages[page_id].get("extract", "")
|
||||
return ""
|
||||
|
||||
async def get_random_articles(self, count: int = 1, namespace: int = 0) -> List[Dict]:
|
||||
"""
|
||||
获取随机文章
|
||||
|
||||
Args:
|
||||
count: 需要的随机文章数量
|
||||
namespace: 命名空间
|
||||
|
||||
Returns:
|
||||
随机文章列表
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
params = {
|
||||
"action": "query",
|
||||
"list": "random",
|
||||
"rnlimit": count,
|
||||
"rnnamespace": namespace,
|
||||
"format": "json"
|
||||
}
|
||||
|
||||
async with session.get(self.base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
return data.get("query", {}).get("random", [])
|
||||
|
||||
async def login(self, username: str, password: str) -> bool:
|
||||
"""
|
||||
使用维基百科账户登录
|
||||
|
||||
Args:
|
||||
username: 维基百科用户名
|
||||
password: 维基百科密码
|
||||
|
||||
Returns:
|
||||
登录是否成功
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
# 获取登录令牌
|
||||
params = {
|
||||
"action": "query",
|
||||
"meta": "tokens",
|
||||
"type": "login",
|
||||
"format": "json"
|
||||
}
|
||||
|
||||
async with session.get(self.base_url, params=params) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取登录令牌失败: HTTP {response.status}")
|
||||
return False
|
||||
|
||||
data = await response.json()
|
||||
login_token = data.get("query", {}).get("tokens", {}).get("logintoken")
|
||||
|
||||
if not login_token:
|
||||
print("获取登录令牌失败")
|
||||
return False
|
||||
|
||||
# 使用令牌登录
|
||||
login_params = {
|
||||
"action": "login",
|
||||
"lgname": username,
|
||||
"lgpassword": password,
|
||||
"lgtoken": login_token,
|
||||
"format": "json"
|
||||
}
|
||||
|
||||
async with session.post(self.base_url, data=login_params) as login_response:
|
||||
login_data = await login_response.json()
|
||||
|
||||
if login_data.get("login", {}).get("result") == "Success":
|
||||
print(f"登录成功: {username}")
|
||||
return True
|
||||
else:
|
||||
print(f"登录失败: {login_data.get('login', {}).get('reason', '未知原因')}")
|
||||
return False
|
||||
|
||||
async def setup_oauth(self, consumer_token: str, consumer_secret: str,
|
||||
access_token: str = None, access_secret: str = None) -> bool:
|
||||
"""
|
||||
设置OAuth认证
|
||||
|
||||
Args:
|
||||
consumer_token: 消费者令牌
|
||||
consumer_secret: 消费者密钥
|
||||
access_token: 访问令牌(可选)
|
||||
access_secret: 访问密钥(可选)
|
||||
|
||||
Returns:
|
||||
设置是否成功
|
||||
"""
|
||||
try:
|
||||
# 需要安装 mwoauth 库: pip install mwoauth
|
||||
import mwoauth
|
||||
import requests_oauthlib
|
||||
|
||||
# 设置OAuth
|
||||
self.consumer_token = consumer_token
|
||||
self.consumer_secret = consumer_secret
|
||||
|
||||
if access_token and access_secret:
|
||||
# 如果已有访问令牌
|
||||
self.auth = requests_oauthlib.OAuth1(
|
||||
consumer_token,
|
||||
consumer_secret,
|
||||
access_token,
|
||||
access_secret
|
||||
)
|
||||
print("OAuth设置成功")
|
||||
return True
|
||||
else:
|
||||
# 需要获取访问令牌(这通常需要用户在网页上授权)
|
||||
print("请在开发环境中完成以下OAuth授权流程:")
|
||||
|
||||
# 创建消费者
|
||||
consumer = mwoauth.Consumer(
|
||||
consumer_token, consumer_secret
|
||||
)
|
||||
|
||||
# 初始化握手
|
||||
redirect, request_token = mwoauth.initiate(
|
||||
f"https://{self.language}.wikipedia.org/w/index.php",
|
||||
consumer
|
||||
)
|
||||
|
||||
print(f"请访问此URL授权应用: {redirect}")
|
||||
# 这里通常会提示用户访问URL并输入授权码
|
||||
# 实际应用中需要实现适当的授权流程
|
||||
return False
|
||||
except ImportError:
|
||||
print("请安装 mwoauth 库: pip install mwoauth")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"设置OAuth时发生错误: {str(e)}")
|
||||
return False
|
||||
|
||||
async def example_usage():
|
||||
"""演示WikipediaAPI的使用方法"""
|
||||
# 创建默认中文维基百科API客户端
|
||||
wiki_zh = WikipediaAPI(language="zh")
|
||||
|
||||
try:
|
||||
# 示例1: 基本搜索
|
||||
print("\n=== 示例1: 搜索维基百科 ===")
|
||||
results = await wiki_zh.search("人工智能", limit=3)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n--- 结果 {i} ---")
|
||||
print(f"标题: {result.get('title')}")
|
||||
snippet = result.get('snippet', '')
|
||||
# 清理HTML标签
|
||||
snippet = re.sub(r'<.*?>', '', snippet)
|
||||
print(f"摘要: {snippet}")
|
||||
print(f"字数: {result.get('wordcount')}")
|
||||
print(f"大小: {result.get('size')} 字节")
|
||||
|
||||
# 示例2: 获取页面摘要
|
||||
print("\n=== 示例2: 获取页面摘要 ===")
|
||||
summary = await wiki_zh.get_summary("深度学习", sentences=2)
|
||||
print(f"深度学习摘要: {summary}")
|
||||
|
||||
# 示例3: 获取页面内容
|
||||
print("\n=== 示例3: 获取页面内容 ===")
|
||||
content = await wiki_zh.get_page_content("机器学习")
|
||||
if content and "text" in content:
|
||||
text = content["text"].get("*", "")
|
||||
# 移除HTML标签以便控制台显示
|
||||
clean_text = re.sub(r'<.*?>', '', text)
|
||||
print(f"机器学习页面内容片段: {clean_text[:200]}...")
|
||||
|
||||
# 显示页面包含的分类数量
|
||||
categories = content.get("categories", [])
|
||||
print(f"分类数量: {len(categories)}")
|
||||
|
||||
# 显示页面包含的链接数量
|
||||
links = content.get("links", [])
|
||||
print(f"链接数量: {len(links)}")
|
||||
|
||||
# 示例4: 获取特定章节内容
|
||||
print("\n=== 示例4: 获取特定章节内容 ===")
|
||||
# 获取引言部分(通常是0号章节)
|
||||
intro_content = await wiki_zh.get_page_content("人工智能", section=0)
|
||||
if intro_content and "text" in intro_content:
|
||||
intro_text = intro_content["text"].get("*", "")
|
||||
clean_intro = re.sub(r'<.*?>', '', intro_text)
|
||||
print(f"人工智能引言内容片段: {clean_intro[:200]}...")
|
||||
|
||||
# 示例5: 获取随机文章
|
||||
print("\n=== 示例5: 获取随机文章 ===")
|
||||
random_articles = await wiki_zh.get_random_articles(count=2)
|
||||
print("随机文章:")
|
||||
for i, article in enumerate(random_articles, 1):
|
||||
print(f"{i}. {article.get('title')}")
|
||||
|
||||
# 显示随机文章的简短摘要
|
||||
article_summary = await wiki_zh.get_summary(article.get('title'), sentences=1)
|
||||
print(f" 摘要: {article_summary[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
275
crazy_functions/pdf_fns/breakdown_pdf_txt.py
Normal file
275
crazy_functions/pdf_fns/breakdown_pdf_txt.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
|
||||
from loguru import logger
|
||||
import time
|
||||
import re
|
||||
|
||||
def force_breakdown(txt, limit, get_token_fn):
|
||||
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
|
||||
"""
|
||||
for i in reversed(range(len(txt))):
|
||||
if get_token_fn(txt[:i]) < limit:
|
||||
return txt[:i], txt[i:]
|
||||
return "Tiktoken未知错误", "Tiktoken未知错误"
|
||||
|
||||
|
||||
def maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage):
|
||||
""" 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||
当 remain_txt_to_cut < `_min` 时,我们再把 remain_txt_to_cut_storage 中的部分文字取出
|
||||
"""
|
||||
_min = int(5e4)
|
||||
_max = int(1e5)
|
||||
# print(len(remain_txt_to_cut), len(remain_txt_to_cut_storage))
|
||||
if len(remain_txt_to_cut) < _min and len(remain_txt_to_cut_storage) > 0:
|
||||
remain_txt_to_cut = remain_txt_to_cut + remain_txt_to_cut_storage
|
||||
remain_txt_to_cut_storage = ""
|
||||
if len(remain_txt_to_cut) > _max:
|
||||
remain_txt_to_cut_storage = remain_txt_to_cut[_max:] + remain_txt_to_cut_storage
|
||||
remain_txt_to_cut = remain_txt_to_cut[:_max]
|
||||
return remain_txt_to_cut, remain_txt_to_cut_storage
|
||||
|
||||
|
||||
def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=False):
|
||||
""" 文本切分
|
||||
"""
|
||||
res = []
|
||||
total_len = len(txt_tocut)
|
||||
fin_len = 0
|
||||
remain_txt_to_cut = txt_tocut
|
||||
remain_txt_to_cut_storage = ""
|
||||
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||
|
||||
while True:
|
||||
if get_token_fn(remain_txt_to_cut) <= limit:
|
||||
# 如果剩余文本的token数小于限制,那么就不用切了
|
||||
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
|
||||
break
|
||||
else:
|
||||
# 如果剩余文本的token数大于限制,那么就切
|
||||
lines = remain_txt_to_cut.split('\n')
|
||||
|
||||
# 估计一个切分点
|
||||
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
|
||||
estimated_line_cut = int(estimated_line_cut)
|
||||
|
||||
# 开始查找合适切分点的偏移(cnt)
|
||||
cnt = 0
|
||||
for cnt in reversed(range(estimated_line_cut)):
|
||||
if must_break_at_empty_line:
|
||||
# 首先尝试用双空行(\n\n)作为切分点
|
||||
if lines[cnt] != "":
|
||||
continue
|
||||
prev = "\n".join(lines[:cnt])
|
||||
post = "\n".join(lines[cnt:])
|
||||
if get_token_fn(prev) < limit:
|
||||
break
|
||||
|
||||
if cnt == 0:
|
||||
# 如果没有找到合适的切分点
|
||||
if break_anyway:
|
||||
# 是否允许暴力切分
|
||||
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
|
||||
else:
|
||||
# 不允许直接报错
|
||||
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
|
||||
|
||||
# 追加列表
|
||||
res.append(prev); fin_len+=len(prev)
|
||||
# 准备下一次迭代
|
||||
remain_txt_to_cut = post
|
||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||
process = fin_len/total_len
|
||||
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||
if len(remain_txt_to_cut.strip()) == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
def breakdown_text_to_satisfy_token_limit_(txt, limit, llm_model="gpt-3.5-turbo"):
|
||||
""" 使用多种方式尝试切分文本,以满足 token 限制
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
enc = model_info[llm_model]['tokenizer']
|
||||
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
try:
|
||||
# 第1次尝试,将双空行(\n\n)作为切分点
|
||||
return cut(limit, get_token_fn, txt, must_break_at_empty_line=True)
|
||||
except RuntimeError:
|
||||
try:
|
||||
# 第2次尝试,将单空行(\n)作为切分点
|
||||
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False)
|
||||
except RuntimeError:
|
||||
try:
|
||||
# 第3次尝试,将英文句号(.)作为切分点
|
||||
res = cut(limit, get_token_fn, txt.replace('.', '。\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
|
||||
return [r.replace('。\n', '.') for r in res]
|
||||
except RuntimeError as e:
|
||||
try:
|
||||
# 第4次尝试,将中文句号(。)作为切分点
|
||||
res = cut(limit, get_token_fn, txt.replace('。', '。。\n'), must_break_at_empty_line=False)
|
||||
return [r.replace('。。\n', '。') for r in res]
|
||||
except RuntimeError as e:
|
||||
# 第5次尝试,没办法了,随便切一下吧
|
||||
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
|
||||
|
||||
breakdown_text_to_satisfy_token_limit = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_, timeout=60)
|
||||
|
||||
def cut_new(limit, get_token_fn, txt_tocut, must_break_at_empty_line, must_break_at_one_empty_line=False, break_anyway=False):
|
||||
""" 文本切分
|
||||
"""
|
||||
res = []
|
||||
res_empty_line = []
|
||||
total_len = len(txt_tocut)
|
||||
fin_len = 0
|
||||
remain_txt_to_cut = txt_tocut
|
||||
remain_txt_to_cut_storage = ""
|
||||
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||
empty=0
|
||||
|
||||
while True:
|
||||
if get_token_fn(remain_txt_to_cut) <= limit:
|
||||
# 如果剩余文本的token数小于限制,那么就不用切了
|
||||
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
|
||||
res_empty_line.append(empty)
|
||||
break
|
||||
else:
|
||||
# 如果剩余文本的token数大于限制,那么就切
|
||||
lines = remain_txt_to_cut.split('\n')
|
||||
|
||||
# 估计一个切分点
|
||||
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
|
||||
estimated_line_cut = int(estimated_line_cut)
|
||||
|
||||
# 开始查找合适切分点的偏移(cnt)
|
||||
cnt = 0
|
||||
for cnt in reversed(range(estimated_line_cut)):
|
||||
if must_break_at_empty_line:
|
||||
# 首先尝试用双空行(\n\n)作为切分点
|
||||
if lines[cnt] != "":
|
||||
continue
|
||||
if must_break_at_empty_line or must_break_at_one_empty_line:
|
||||
empty=1
|
||||
prev = "\n".join(lines[:cnt])
|
||||
post = "\n".join(lines[cnt:])
|
||||
if get_token_fn(prev) < limit :
|
||||
break
|
||||
# empty=0
|
||||
if get_token_fn(prev)>limit:
|
||||
if '.' not in prev or '。' not in prev:
|
||||
# empty = 0
|
||||
break
|
||||
|
||||
# if cnt
|
||||
if cnt == 0:
|
||||
# 如果没有找到合适的切分点
|
||||
if break_anyway:
|
||||
# 是否允许暴力切分
|
||||
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
|
||||
empty =0
|
||||
else:
|
||||
# 不允许直接报错
|
||||
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
|
||||
|
||||
# 追加列表
|
||||
res.append(prev); fin_len+=len(prev)
|
||||
res_empty_line.append(empty)
|
||||
# 准备下一次迭代
|
||||
remain_txt_to_cut = post
|
||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||
process = fin_len/total_len
|
||||
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||
if len(remain_txt_to_cut.strip()) == 0:
|
||||
break
|
||||
return res,res_empty_line
|
||||
|
||||
|
||||
def breakdown_text_to_satisfy_token_limit_new_(txt, limit, llm_model="gpt-3.5-turbo"):
|
||||
""" 使用多种方式尝试切分文本,以满足 token 限制
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
enc = model_info[llm_model]['tokenizer']
|
||||
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
try:
|
||||
# 第1次尝试,将双空行(\n\n)作为切分点
|
||||
res, empty_line =cut_new(limit, get_token_fn, txt, must_break_at_empty_line=True)
|
||||
return res,empty_line
|
||||
except RuntimeError:
|
||||
try:
|
||||
# 第2次尝试,将单空行(\n)作为切分点
|
||||
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False,must_break_at_one_empty_line=True)
|
||||
return res, _
|
||||
except RuntimeError:
|
||||
try:
|
||||
# 第3次尝试,将英文句号(.)作为切分点
|
||||
res, _ = cut_new(limit, get_token_fn, txt.replace('.', '。\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
|
||||
return [r.replace('。\n', '.') for r in res],_
|
||||
|
||||
except RuntimeError as e:
|
||||
try:
|
||||
# 第4次尝试,将中文句号(。)作为切分点
|
||||
res,_ = cut_new(limit, get_token_fn, txt.replace('。', '。。\n'), must_break_at_empty_line=False)
|
||||
return [r.replace('。。\n', '。') for r in res], _
|
||||
except RuntimeError as e:
|
||||
# 第5次尝试,没办法了,随便切一下吧
|
||||
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
|
||||
return res,_
|
||||
breakdown_text_to_satisfy_token_limit_new = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_new_, timeout=60)
|
||||
|
||||
def cut_from_end_to_satisfy_token_limit_(txt, limit, reserve_token=500, llm_model="gpt-3.5-turbo"):
|
||||
"""从后往前裁剪文本,以论文为单位进行裁剪
|
||||
|
||||
参数:
|
||||
txt: 要处理的文本(格式化后的论文列表字符串)
|
||||
limit: token数量上限
|
||||
reserve_token: 需要预留的token数量,默认500
|
||||
llm_model: 使用的模型名称
|
||||
返回:
|
||||
裁剪后的文本
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
enc = model_info[llm_model]['tokenizer']
|
||||
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||
|
||||
# 计算当前文本的token数
|
||||
current_tokens = get_token_fn(txt)
|
||||
target_limit = limit - reserve_token
|
||||
|
||||
# 如果当前token数已经在限制范围内,直接返回
|
||||
if current_tokens <= target_limit:
|
||||
return txt
|
||||
|
||||
# 按论文编号分割文本
|
||||
papers = re.split(r'\n(?=\d+\. \*\*)', txt)
|
||||
if not papers:
|
||||
return txt
|
||||
|
||||
# 从前往后累加论文,直到达到token限制
|
||||
result = papers[0] # 保留第一篇
|
||||
current_tokens = get_token_fn(result)
|
||||
|
||||
for paper in papers[1:]:
|
||||
paper_tokens = get_token_fn(paper)
|
||||
if current_tokens + paper_tokens <= target_limit:
|
||||
result += "\n" + paper
|
||||
current_tokens += paper_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
# 添加超时保护
|
||||
cut_from_end_to_satisfy_token_limit = run_in_subprocess_with_timeout(cut_from_end_to_satisfy_token_limit_, timeout=20)
|
||||
|
||||
if __name__ == '__main__':
|
||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||
file_content, page_one = read_and_clean_pdf_text("build/assets/at.pdf")
|
||||
|
||||
from request_llms.bridge_all import model_info
|
||||
for i in range(5):
|
||||
file_content += file_content
|
||||
|
||||
logger.info(len(file_content))
|
||||
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
||||
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)
|
||||
|
||||
@@ -1,22 +1,48 @@
|
||||
import subprocess
|
||||
import os
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||
'.pptm', '.pptx']
|
||||
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt', '.pptm', '.pptx', '.bat']
|
||||
|
||||
def convert_to_markdown(file_path: str) -> str:
|
||||
"""
|
||||
将支持的文件格式转换为Markdown格式
|
||||
Args:
|
||||
file_path: 输入文件路径
|
||||
Returns:
|
||||
str: 转换后的Markdown文件路径,如果转换失败则返回原始文件路径
|
||||
"""
|
||||
_, ext = os.path.splitext(file_path.lower())
|
||||
|
||||
if ext in ['.docx', '.doc', '.pptx', '.ppt', '.pptm', '.xls', '.xlsx', '.csv', 'pdf']:
|
||||
try:
|
||||
# 创建输出Markdown文件路径
|
||||
md_path = os.path.splitext(file_path)[0] + '.md'
|
||||
# 使用markitdown工具将文件转换为Markdown
|
||||
command = f"markitdown {file_path} > {md_path}"
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
print(f"已将{ext}文件转换为Markdown: {md_path}")
|
||||
return md_path
|
||||
except Exception as e:
|
||||
print(f"{ext}转Markdown失败: {str(e)},将继续处理原文件")
|
||||
return file_path
|
||||
|
||||
return file_path
|
||||
|
||||
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||
def extract_text(file_path):
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
_, ext = os.path.splitext(file_path.lower())
|
||||
|
||||
# 使用 SimpleDirectoryReader 处理它支持的文件格式
|
||||
if ext in supports_format:
|
||||
try:
|
||||
reader = SimpleDirectoryReader(input_files=[file_path])
|
||||
print(f"Extracting text from {file_path} using SimpleDirectoryReader")
|
||||
documents = reader.load_data()
|
||||
if len(documents) > 0:
|
||||
return documents[0].text
|
||||
print(f"Complete: Extracting text from {file_path} using SimpleDirectoryReader")
|
||||
buffer = [ doc.text for doc in documents ]
|
||||
return '\n'.join(buffer)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return None
|
||||
else:
|
||||
return '格式不支持'
|
||||
|
||||
0
crazy_functions/review_fns/__init__.py
Normal file
0
crazy_functions/review_fns/__init__.py
Normal file
68
crazy_functions/review_fns/conversation_doc/endnote_doc.py
Normal file
68
crazy_functions/review_fns/conversation_doc/endnote_doc.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import List
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
|
||||
class EndNoteFormatter:
|
||||
"""EndNote参考文献格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||
"""生成EndNote格式的参考文献文本
|
||||
|
||||
Args:
|
||||
papers: 论文列表
|
||||
|
||||
Returns:
|
||||
str: EndNote格式的参考文献文本
|
||||
"""
|
||||
endnote_text = ""
|
||||
|
||||
for paper in papers:
|
||||
# 开始一个新条目
|
||||
endnote_text += "%0 Journal Article\n" # 默认类型为期刊文章
|
||||
|
||||
# 根据venue_type调整条目类型
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||
if paper.venue_type.lower() == 'conference':
|
||||
endnote_text = endnote_text.replace("Journal Article", "Conference Paper")
|
||||
elif paper.venue_type.lower() == 'preprint':
|
||||
endnote_text = endnote_text.replace("Journal Article", "Electronic Article")
|
||||
|
||||
# 添加标题
|
||||
endnote_text += f"%T {paper.title}\n"
|
||||
|
||||
# 添加作者
|
||||
for author in paper.authors:
|
||||
endnote_text += f"%A {author}\n"
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
endnote_text += f"%D {paper.year}\n"
|
||||
|
||||
# 添加期刊/会议名称
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
endnote_text += f"%J {paper.venue_name}\n"
|
||||
elif paper.venue:
|
||||
endnote_text += f"%J {paper.venue}\n"
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
endnote_text += f"%R {paper.doi}\n"
|
||||
endnote_text += f"%U https://doi.org/{paper.doi}\n"
|
||||
elif paper.url:
|
||||
endnote_text += f"%U {paper.url}\n"
|
||||
|
||||
# 添加摘要
|
||||
if paper.abstract:
|
||||
endnote_text += f"%X {paper.abstract}\n"
|
||||
|
||||
# 添加机构
|
||||
if hasattr(paper, 'institutions'):
|
||||
for institution in paper.institutions:
|
||||
endnote_text += f"%I {institution}\n"
|
||||
|
||||
# 条目之间添加空行
|
||||
endnote_text += "\n"
|
||||
|
||||
return endnote_text
|
||||
211
crazy_functions/review_fns/conversation_doc/excel_doc.py
Normal file
211
crazy_functions/review_fns/conversation_doc/excel_doc.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import re
|
||||
import os
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ExcelTableFormatter:
|
||||
"""聊天记录中Markdown表格转Excel生成器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Excel文档对象"""
|
||||
from openpyxl import Workbook
|
||||
self.workbook = Workbook()
|
||||
self._table_count = 0
|
||||
self._current_sheet = None
|
||||
|
||||
def _normalize_table_row(self, row):
|
||||
"""标准化表格行,处理不同的分隔符情况"""
|
||||
row = row.strip()
|
||||
if row.startswith('|'):
|
||||
row = row[1:]
|
||||
if row.endswith('|'):
|
||||
row = row[:-1]
|
||||
return [cell.strip() for cell in row.split('|')]
|
||||
|
||||
def _is_separator_row(self, row):
|
||||
"""检查是否是分隔行(由 - 或 : 组成)"""
|
||||
clean_row = re.sub(r'[\s|]', '', row)
|
||||
return bool(re.match(r'^[-:]+$', clean_row))
|
||||
|
||||
def _extract_tables_from_text(self, text):
|
||||
"""从文本中提取所有表格内容"""
|
||||
if not isinstance(text, str):
|
||||
return []
|
||||
|
||||
tables = []
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
if is_in_table and current_table:
|
||||
if len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
continue
|
||||
|
||||
if '|' in line:
|
||||
if not is_in_table:
|
||||
is_in_table = True
|
||||
current_table.append(line)
|
||||
else:
|
||||
if is_in_table and current_table:
|
||||
if len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
|
||||
if is_in_table and current_table and len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
|
||||
return tables
|
||||
|
||||
def _parse_table(self, table_lines):
|
||||
"""解析表格内容为结构化数据"""
|
||||
try:
|
||||
headers = self._normalize_table_row(table_lines[0])
|
||||
|
||||
separator_index = next(
|
||||
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
|
||||
1
|
||||
)
|
||||
|
||||
data_rows = []
|
||||
for line in table_lines[separator_index + 1:]:
|
||||
cells = self._normalize_table_row(line)
|
||||
# 确保单元格数量与表头一致
|
||||
while len(cells) < len(headers):
|
||||
cells.append('')
|
||||
cells = cells[:len(headers)]
|
||||
data_rows.append(cells)
|
||||
|
||||
if headers and data_rows:
|
||||
return {
|
||||
'headers': headers,
|
||||
'data': data_rows
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"解析表格时发生错误: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _create_sheet(self, question_num, table_num):
|
||||
"""创建新的工作表"""
|
||||
sheet_name = f'Q{question_num}_T{table_num}'
|
||||
if len(sheet_name) > 31:
|
||||
sheet_name = f'Table{self._table_count}'
|
||||
|
||||
if sheet_name in self.workbook.sheetnames:
|
||||
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
|
||||
|
||||
return self.workbook.create_sheet(title=sheet_name)
|
||||
|
||||
def create_document(self, history):
|
||||
"""
|
||||
处理聊天历史中的所有表格并创建Excel文档
|
||||
|
||||
Args:
|
||||
history: 聊天历史列表
|
||||
|
||||
Returns:
|
||||
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
|
||||
"""
|
||||
has_tables = False
|
||||
|
||||
# 删除默认创建的工作表
|
||||
default_sheet = self.workbook['Sheet']
|
||||
self.workbook.remove(default_sheet)
|
||||
|
||||
# 遍历所有回答
|
||||
for i in range(1, len(history), 2):
|
||||
answer = history[i]
|
||||
tables = self._extract_tables_from_text(answer)
|
||||
|
||||
for table_lines in tables:
|
||||
parsed_table = self._parse_table(table_lines)
|
||||
if parsed_table:
|
||||
self._table_count += 1
|
||||
sheet = self._create_sheet(i // 2 + 1, self._table_count)
|
||||
|
||||
# 写入表头
|
||||
for col, header in enumerate(parsed_table['headers'], 1):
|
||||
sheet.cell(row=1, column=col, value=header)
|
||||
|
||||
# 写入数据
|
||||
for row_idx, row_data in enumerate(parsed_table['data'], 2):
|
||||
for col_idx, value in enumerate(row_data, 1):
|
||||
sheet.cell(row=row_idx, column=col_idx, value=value)
|
||||
|
||||
has_tables = True
|
||||
|
||||
return self.workbook if has_tables else None
|
||||
|
||||
|
||||
def save_chat_tables(history, save_dir, base_name):
|
||||
"""
|
||||
保存聊天历史中的表格到Excel文件
|
||||
|
||||
Args:
|
||||
history: 聊天历史列表
|
||||
save_dir: 保存目录
|
||||
base_name: 基础文件名
|
||||
|
||||
Returns:
|
||||
list: 保存的文件路径列表
|
||||
"""
|
||||
result_files = []
|
||||
|
||||
try:
|
||||
# 创建Excel格式
|
||||
excel_formatter = ExcelTableFormatter()
|
||||
workbook = excel_formatter.create_document(history)
|
||||
|
||||
if workbook is not None:
|
||||
# 确保保存目录存在
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 生成Excel文件路径
|
||||
excel_file = os.path.join(save_dir, base_name + '.xlsx')
|
||||
|
||||
# 保存Excel文件
|
||||
workbook.save(excel_file)
|
||||
result_files.append(excel_file)
|
||||
print(f"已保存表格到Excel文件: {excel_file}")
|
||||
except Exception as e:
|
||||
print(f"保存Excel格式失败: {str(e)}")
|
||||
|
||||
return result_files
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 示例聊天历史
|
||||
history = [
|
||||
"问题1",
|
||||
"""这是第一个表格:
|
||||
| A | B | C |
|
||||
|---|---|---|
|
||||
| 1 | 2 | 3 |""",
|
||||
|
||||
"问题2",
|
||||
"这是没有表格的回答",
|
||||
|
||||
"问题3",
|
||||
"""回答包含多个表格:
|
||||
| Name | Age |
|
||||
|------|-----|
|
||||
| Tom | 20 |
|
||||
|
||||
第二个表格:
|
||||
| X | Y |
|
||||
|---|---|
|
||||
| 1 | 2 |"""
|
||||
]
|
||||
|
||||
# 保存表格
|
||||
save_dir = "output"
|
||||
base_name = "chat_tables"
|
||||
saved_files = save_chat_tables(history, save_dir, base_name)
|
||||
472
crazy_functions/review_fns/conversation_doc/html_doc.py
Normal file
472
crazy_functions/review_fns/conversation_doc/html_doc.py
Normal file
@@ -0,0 +1,472 @@
|
||||
class HtmlFormatter:
|
||||
"""聊天记录HTML格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
self.css_styles = """
|
||||
:root {
|
||||
--primary-color: #2563eb;
|
||||
--primary-light: #eff6ff;
|
||||
--secondary-color: #1e293b;
|
||||
--background-color: #f8fafc;
|
||||
--text-color: #334155;
|
||||
--border-color: #e2e8f0;
|
||||
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
line-height: 1.8;
|
||||
margin: 0;
|
||||
padding: 2rem;
|
||||
color: var(--text-color);
|
||||
background-color: var(--background-color);
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 16px;
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
::selection {
|
||||
background: var(--primary-light);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from { transform: translateX(-20px); opacity: 0; }
|
||||
to { transform: translateX(0); opacity: 1; }
|
||||
}
|
||||
|
||||
.container {
|
||||
animation: fadeIn 0.6s ease-out;
|
||||
}
|
||||
|
||||
.QaBox {
|
||||
animation: slideIn 0.5s ease-out;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.QaBox:hover {
|
||||
transform: translateX(5px);
|
||||
}
|
||||
.Question, .Answer, .historyBox {
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.chat-title {
|
||||
color: var(--primary-color);
|
||||
font-size: 2em;
|
||||
text-align: center;
|
||||
margin: 1rem 0 2rem;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: 2px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.chat-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
margin: 2rem 0;
|
||||
}
|
||||
|
||||
.QaBox {
|
||||
background: white;
|
||||
padding: 1.5rem;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid var(--primary-color);
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.Question {
|
||||
color: var(--secondary-color);
|
||||
font-weight: 500;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.Answer {
|
||||
color: var(--text-color);
|
||||
background: var(--primary-light);
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.history-section {
|
||||
margin-top: 3rem;
|
||||
padding-top: 2rem;
|
||||
border-top: 2px solid var(--border-color);
|
||||
}
|
||||
|
||||
.history-title {
|
||||
color: var(--secondary-color);
|
||||
font-size: 1.5em;
|
||||
margin-bottom: 1.5rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.historyBox {
|
||||
background: white;
|
||||
padding: 1rem;
|
||||
margin: 0.5rem 0;
|
||||
border-radius: 6px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background-color: #0f172a;
|
||||
--text-color: #e2e8f0;
|
||||
--border-color: #1e293b;
|
||||
}
|
||||
|
||||
.container, .QaBox {
|
||||
background: #1e293b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||
"""生成完整的HTML文档
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
Returns:
|
||||
str: 完整的HTML文档字符串
|
||||
"""
|
||||
chat_content = f'''
|
||||
<div class="QaBox">
|
||||
<div class="Question">{question}</div>
|
||||
<div class="Answer markdown-body" id="answer-content">{answer}</div>
|
||||
</div>
|
||||
'''
|
||||
|
||||
references_content = ""
|
||||
if ranked_papers:
|
||||
references_content = '<div class="history-section"><h2 class="history-title">参考文献</h2>'
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
authors = ', '.join(paper.authors)
|
||||
|
||||
# 构建引用信息
|
||||
citations_info = f"被引用次数:{paper.citations}" if paper.citations is not None else "引用信息未知"
|
||||
|
||||
# 构建下载链接
|
||||
download_links = []
|
||||
if paper.doi:
|
||||
# 检查是否是arXiv链接
|
||||
if 'arxiv.org' in paper.doi:
|
||||
# 如果DOI中包含完整的arXiv URL,直接使用
|
||||
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||
download_links.append(f'<a href="{arxiv_url}">arXiv链接</a>')
|
||||
# 提取arXiv ID并添加PDF链接
|
||||
arxiv_id = arxiv_url.split('abs/')[-1].split('v')[0]
|
||||
download_links.append(f'<a href="https://arxiv.org/pdf/{arxiv_id}.pdf">PDF下载</a>')
|
||||
else:
|
||||
# 非arXiv的DOI使用标准格式
|
||||
download_links.append(f'<a href="https://doi.org/{paper.doi}">DOI: {paper.doi}</a>')
|
||||
|
||||
if hasattr(paper, 'url') and paper.url and 'arxiv.org' not in str(paper.url):
|
||||
# 只有当URL不是arXiv链接时才添加
|
||||
download_links.append(f'<a href="{paper.url}">原文链接</a>')
|
||||
download_section = ' | '.join(download_links) if download_links else "无直接下载链接"
|
||||
|
||||
# 构建来源信息
|
||||
source_info = []
|
||||
if paper.venue_type:
|
||||
source_info.append(f"类型:{paper.venue_type}")
|
||||
if paper.venue_name:
|
||||
source_info.append(f"来源:{paper.venue_name}")
|
||||
|
||||
# 添加期刊指标信息
|
||||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||
source_info.append(f"<span class='journal-metric'>IF: {paper.if_factor}</span>")
|
||||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||
source_info.append(f"<span class='journal-metric'>JCR分区: {paper.jcr_division}</span>")
|
||||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||
source_info.append(f"<span class='journal-metric'>中科院分区: {paper.cas_division}</span>")
|
||||
|
||||
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||
if paper.venue_info.get('journal_ref'):
|
||||
source_info.append(f"期刊参考:{paper.venue_info['journal_ref']}")
|
||||
if paper.venue_info.get('publisher'):
|
||||
source_info.append(f"出版商:{paper.venue_info['publisher']}")
|
||||
source_section = ' | '.join(source_info) if source_info else ""
|
||||
|
||||
# 构建标准引用格式
|
||||
standard_citation = f"[{idx}] "
|
||||
# 添加作者(最多3个,超过则添加et al.)
|
||||
author_list = paper.authors[:3]
|
||||
if len(paper.authors) > 3:
|
||||
author_list.append("et al.")
|
||||
standard_citation += ", ".join(author_list) + ". "
|
||||
# 添加标题
|
||||
standard_citation += f"<i>{paper.title}</i>"
|
||||
# 添加期刊/会议名称
|
||||
if paper.venue_name:
|
||||
standard_citation += f". {paper.venue_name}"
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
standard_citation += f", {paper.year}"
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
if 'arxiv.org' in paper.doi:
|
||||
# 如果是arXiv链接,直接使用arXiv URL
|
||||
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||
standard_citation += f". {arxiv_url}"
|
||||
else:
|
||||
# 非arXiv的DOI使用标准格式
|
||||
standard_citation += f". DOI: {paper.doi}"
|
||||
standard_citation += "."
|
||||
|
||||
references_content += f'''
|
||||
<div class="historyBox">
|
||||
<div class="entry">
|
||||
<p class="paper-title"><b>[{idx}]</b> <i>{paper.title}</i></p>
|
||||
<p class="paper-authors">作者:{authors}</p>
|
||||
<p class="paper-year">发表年份:{paper.year if paper.year else "未知"}</p>
|
||||
<p class="paper-citations">{citations_info}</p>
|
||||
{f'<p class="paper-source">{source_section}</p>' if source_section else ""}
|
||||
<p class="paper-abstract">摘要:{paper.abstract if paper.abstract else "无摘要"}</p>
|
||||
<p class="paper-links">链接:{download_section}</p>
|
||||
<div class="standard-citation">
|
||||
<p class="citation-title">标准引用格式:</p>
|
||||
<p class="citation-text">{standard_citation}</p>
|
||||
<button class="copy-btn" onclick="copyToClipboard(this.previousElementSibling)">复制引用格式</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
'''
|
||||
references_content += '</div>'
|
||||
|
||||
# 添加新的CSS样式
|
||||
css_additions = """
|
||||
.paper-title {
|
||||
font-size: 1.1em;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
.paper-authors {
|
||||
color: var(--secondary-color);
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-year, .paper-citations {
|
||||
color: var(--text-color);
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-source {
|
||||
color: var(--text-color);
|
||||
font-style: italic;
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-abstract {
|
||||
margin: 0.8em 0;
|
||||
padding: 0.8em;
|
||||
background: var(--primary-light);
|
||||
border-radius: 4px;
|
||||
}
|
||||
.paper-links {
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
.paper-links a {
|
||||
color: var(--primary-color);
|
||||
text-decoration: none;
|
||||
margin-right: 1em;
|
||||
}
|
||||
.paper-links a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.standard-citation {
|
||||
margin-top: 1em;
|
||||
padding: 1em;
|
||||
background: #f8fafc;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.citation-title {
|
||||
font-weight: bold;
|
||||
margin-bottom: 0.5em;
|
||||
color: var(--secondary-color);
|
||||
}
|
||||
|
||||
.citation-text {
|
||||
font-family: 'Times New Roman', Times, serif;
|
||||
line-height: 1.6;
|
||||
margin-bottom: 0.5em;
|
||||
padding: 0.5em;
|
||||
background: white;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.copy-btn {
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5em 1em;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.copy-btn:hover {
|
||||
background: #1e40af;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.standard-citation {
|
||||
background: #1e293b;
|
||||
}
|
||||
.citation-text {
|
||||
background: #0f172a;
|
||||
}
|
||||
}
|
||||
|
||||
/* 添加期刊指标样式 */
|
||||
.journal-metric {
|
||||
display: inline-block;
|
||||
padding: 0.2em 0.6em;
|
||||
margin: 0 0.3em;
|
||||
background: var(--primary-light);
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.journal-metric {
|
||||
background: #1e293b;
|
||||
color: #60a5fa;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# 修改 js_code 部分,添加 markdown 解析功能
|
||||
js_code = """
|
||||
<script>
|
||||
// 复制功能
|
||||
function copyToClipboard(element) {
|
||||
const text = element.innerText;
|
||||
navigator.clipboard.writeText(text).then(function() {
|
||||
const btn = element.nextElementSibling;
|
||||
const originalText = btn.innerText;
|
||||
btn.innerText = '已复制!';
|
||||
setTimeout(() => {
|
||||
btn.innerText = originalText;
|
||||
}, 2000);
|
||||
}).catch(function(err) {
|
||||
console.error('复制失败:', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Markdown解析
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const answerContent = document.getElementById('answer-content');
|
||||
if (answerContent) {
|
||||
const markdown = answerContent.textContent;
|
||||
answerContent.innerHTML = marked.parse(markdown);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
"""
|
||||
|
||||
# 将新的CSS样式添加到现有样式中
|
||||
self.css_styles += css_additions
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>学术对话存档</title>
|
||||
<!-- 添加 marked.js -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
<!-- 添加 GitHub Markdown CSS -->
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/sindresorhus/github-markdown-css@4.0.0/github-markdown.min.css">
|
||||
<style>
|
||||
{self.css_styles}
|
||||
/* 添加 Markdown 相关样式 */
|
||||
.markdown-body {{
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
padding: 1rem;
|
||||
background: var(--primary-light);
|
||||
border-radius: 6px;
|
||||
}}
|
||||
.markdown-body pre {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
}}
|
||||
.markdown-body code {{
|
||||
background-color: rgba(175,184,193,0.2);
|
||||
border-radius: 6px;
|
||||
padding: 0.2em 0.4em;
|
||||
font-size: 85%;
|
||||
}}
|
||||
.markdown-body pre code {{
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
}}
|
||||
.markdown-body blockquote {{
|
||||
border-left: 0.25em solid #d0d7de;
|
||||
padding: 0 1em;
|
||||
color: #656d76;
|
||||
}}
|
||||
.markdown-body table {{
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin: 1em 0;
|
||||
}}
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {{
|
||||
border: 1px solid #d0d7de;
|
||||
padding: 6px 13px;
|
||||
}}
|
||||
.markdown-body table tr:nth-child(2n) {{
|
||||
background-color: #f6f8fa;
|
||||
}}
|
||||
@media (prefers-color-scheme: dark) {{
|
||||
.markdown-body {{
|
||||
background: #1e293b;
|
||||
color: #e2e8f0;
|
||||
}}
|
||||
.markdown-body pre {{
|
||||
background-color: #0f172a;
|
||||
}}
|
||||
.markdown-body code {{
|
||||
background-color: rgba(99,110,123,0.4);
|
||||
}}
|
||||
.markdown-body blockquote {{
|
||||
border-left-color: #30363d;
|
||||
color: #8b949e;
|
||||
}}
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {{
|
||||
border-color: #30363d;
|
||||
}}
|
||||
.markdown-body table tr:nth-child(2n) {{
|
||||
background-color: #0f172a;
|
||||
}}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 class="chat-title">学术对话存档</h1>
|
||||
<div class="chat-body">
|
||||
{chat_content}
|
||||
{references_content}
|
||||
</div>
|
||||
</div>
|
||||
{js_code}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
47
crazy_functions/review_fns/conversation_doc/markdown_doc.py
Normal file
47
crazy_functions/review_fns/conversation_doc/markdown_doc.py
Normal file
@@ -0,0 +1,47 @@
|
||||
class MarkdownFormatter:
|
||||
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
|
||||
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
|
||||
def _add_content(self, text: str):
|
||||
"""添加正文内容"""
|
||||
if text:
|
||||
self.content.append(f"\n{text}\n")
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||
"""创建完整的Markdown文档
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
Returns:
|
||||
str: 生成的Markdown文本
|
||||
"""
|
||||
content = []
|
||||
|
||||
# 添加问答部分
|
||||
content.append("## 问题")
|
||||
content.append(question)
|
||||
content.append("\n## 回答")
|
||||
content.append(answer)
|
||||
|
||||
# 添加参考文献
|
||||
if ranked_papers:
|
||||
content.append("\n## 参考文献")
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
authors = ', '.join(paper.authors[:3])
|
||||
if len(paper.authors) > 3:
|
||||
authors += ' et al.'
|
||||
|
||||
ref = f"[{idx}] {authors}. *{paper.title}*"
|
||||
if paper.venue_name:
|
||||
ref += f". {paper.venue_name}"
|
||||
if paper.year:
|
||||
ref += f", {paper.year}"
|
||||
if paper.doi:
|
||||
ref += f". [DOI: {paper.doi}](https://doi.org/{paper.doi})"
|
||||
|
||||
content.append(ref)
|
||||
|
||||
return "\n\n".join(content)
|
||||
@@ -0,0 +1,174 @@
|
||||
from typing import List
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
import re
|
||||
|
||||
class ReferenceFormatter:
|
||||
"""通用参考文献格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _sanitize_bibtex(self, text: str) -> str:
|
||||
"""清理BibTeX字符串,处理特殊字符"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 替换特殊字符
|
||||
replacements = {
|
||||
'&': '\\&',
|
||||
'%': '\\%',
|
||||
'$': '\\$',
|
||||
'#': '\\#',
|
||||
'_': '\\_',
|
||||
'{': '\\{',
|
||||
'}': '\\}',
|
||||
'~': '\\textasciitilde{}',
|
||||
'^': '\\textasciicircum{}',
|
||||
'\\': '\\textbackslash{}',
|
||||
'<': '\\textless{}',
|
||||
'>': '\\textgreater{}',
|
||||
'"': '``',
|
||||
"'": "'",
|
||||
'-': '--',
|
||||
'—': '---',
|
||||
}
|
||||
|
||||
for char, replacement in replacements.items():
|
||||
text = text.replace(char, replacement)
|
||||
|
||||
return text
|
||||
|
||||
def _generate_cite_key(self, paper: PaperMetadata) -> str:
|
||||
"""生成引用键
|
||||
格式: 第一作者姓氏_年份_第一个实词
|
||||
"""
|
||||
# 获取第一作者姓氏
|
||||
first_author = ""
|
||||
if paper.authors and len(paper.authors) > 0:
|
||||
first_author = paper.authors[0].split()[-1].lower()
|
||||
|
||||
# 获取年份
|
||||
year = str(paper.year) if paper.year else "0000"
|
||||
|
||||
# 从标题中获取第一个实词
|
||||
title_word = ""
|
||||
if paper.title:
|
||||
# 移除特殊字符,分割成单词
|
||||
words = re.findall(r'\w+', paper.title.lower())
|
||||
# 过滤掉常见的停用词
|
||||
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
|
||||
for word in words:
|
||||
if word not in stop_words and len(word) > 2:
|
||||
title_word = word
|
||||
break
|
||||
|
||||
# 组合cite key
|
||||
cite_key = f"{first_author}{year}{title_word}"
|
||||
|
||||
# 确保cite key只包含合法字符
|
||||
cite_key = re.sub(r'[^a-z0-9]', '', cite_key.lower())
|
||||
|
||||
return cite_key
|
||||
|
||||
def _get_entry_type(self, paper: PaperMetadata) -> str:
|
||||
"""确定BibTeX条目类型"""
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||
venue_type = paper.venue_type.lower()
|
||||
if venue_type == 'conference':
|
||||
return 'inproceedings'
|
||||
elif venue_type == 'preprint':
|
||||
return 'unpublished'
|
||||
elif venue_type == 'journal':
|
||||
return 'article'
|
||||
elif venue_type == 'book':
|
||||
return 'book'
|
||||
elif venue_type == 'thesis':
|
||||
return 'phdthesis'
|
||||
return 'article' # 默认为期刊文章
|
||||
|
||||
|
||||
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||
"""生成BibTeX格式的参考文献文本"""
|
||||
bibtex_text = "% This file was automatically generated by GPT-Academic\n"
|
||||
bibtex_text += "% Compatible with: EndNote, Zotero, JabRef, and LaTeX\n\n"
|
||||
|
||||
for paper in papers:
|
||||
entry_type = self._get_entry_type(paper)
|
||||
cite_key = self._generate_cite_key(paper)
|
||||
|
||||
bibtex_text += f"@{entry_type}{{{cite_key},\n"
|
||||
|
||||
# 添加标题
|
||||
if paper.title:
|
||||
bibtex_text += f" title = {{{self._sanitize_bibtex(paper.title)}}},\n"
|
||||
|
||||
# 添加作者
|
||||
if paper.authors:
|
||||
# 确保每个作者的姓和名正确分隔
|
||||
processed_authors = []
|
||||
for author in paper.authors:
|
||||
names = author.split()
|
||||
if len(names) > 1:
|
||||
# 假设最后一个词是姓,其他的是名
|
||||
surname = names[-1]
|
||||
given_names = ' '.join(names[:-1])
|
||||
processed_authors.append(f"{surname}, {given_names}")
|
||||
else:
|
||||
processed_authors.append(author)
|
||||
|
||||
authors = " and ".join([self._sanitize_bibtex(author) for author in processed_authors])
|
||||
bibtex_text += f" author = {{{authors}}},\n"
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
bibtex_text += f" year = {{{paper.year}}},\n"
|
||||
|
||||
# 添加期刊/会议名称
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
if entry_type == 'inproceedings':
|
||||
bibtex_text += f" booktitle = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||
elif entry_type == 'article':
|
||||
bibtex_text += f" journal = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||
# 添加期刊相关信息
|
||||
if hasattr(paper, 'venue_info'):
|
||||
if 'volume' in paper.venue_info:
|
||||
bibtex_text += f" volume = {{{paper.venue_info['volume']}}},\n"
|
||||
if 'number' in paper.venue_info:
|
||||
bibtex_text += f" number = {{{paper.venue_info['number']}}},\n"
|
||||
if 'pages' in paper.venue_info:
|
||||
bibtex_text += f" pages = {{{paper.venue_info['pages']}}},\n"
|
||||
elif paper.venue:
|
||||
venue_field = "booktitle" if entry_type == "inproceedings" else "journal"
|
||||
bibtex_text += f" {venue_field} = {{{self._sanitize_bibtex(paper.venue)}}},\n"
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
bibtex_text += f" doi = {{{paper.doi}}},\n"
|
||||
|
||||
# 添加URL
|
||||
if paper.url:
|
||||
bibtex_text += f" url = {{{paper.url}}},\n"
|
||||
elif paper.doi:
|
||||
bibtex_text += f" url = {{https://doi.org/{paper.doi}}},\n"
|
||||
|
||||
# 添加摘要
|
||||
if paper.abstract:
|
||||
bibtex_text += f" abstract = {{{self._sanitize_bibtex(paper.abstract)}}},\n"
|
||||
|
||||
# 添加机构
|
||||
if hasattr(paper, 'institutions') and paper.institutions:
|
||||
institutions = " and ".join([self._sanitize_bibtex(inst) for inst in paper.institutions])
|
||||
bibtex_text += f" institution = {{{institutions}}},\n"
|
||||
|
||||
# 添加月份
|
||||
if hasattr(paper, 'month'):
|
||||
bibtex_text += f" month = {{{paper.month}}},\n"
|
||||
|
||||
# 添加注释字段
|
||||
if hasattr(paper, 'note'):
|
||||
bibtex_text += f" note = {{{self._sanitize_bibtex(paper.note)}}},\n"
|
||||
|
||||
# 移除最后一个逗号并关闭条目
|
||||
bibtex_text = bibtex_text.rstrip(',\n') + "\n}\n\n"
|
||||
|
||||
return bibtex_text
|
||||
138
crazy_functions/review_fns/conversation_doc/word2pdf.py
Normal file
138
crazy_functions/review_fns/conversation_doc/word2pdf.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from docx2pdf import convert
|
||||
import os
|
||||
import platform
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
class WordToPdfConverter:
|
||||
"""Word文档转PDF转换器"""
|
||||
|
||||
@staticmethod
|
||||
def _replace_docx_in_filename(filename: Union[str, Path]) -> Path:
|
||||
"""
|
||||
将文件名中的'docx'替换为'pdf'
|
||||
例如: 'docx_test.pdf' -> 'pdf_test.pdf'
|
||||
"""
|
||||
path = Path(filename)
|
||||
new_name = path.stem.replace('docx', 'pdf')
|
||||
return path.parent / f"{new_name}{path.suffix}"
|
||||
|
||||
@staticmethod
|
||||
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将Word文档转换为PDF
|
||||
|
||||
参数:
|
||||
word_path: Word文档的路径
|
||||
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
|
||||
异常:
|
||||
如果转换失败,将抛出相应异常
|
||||
"""
|
||||
try:
|
||||
word_path = Path(word_path)
|
||||
|
||||
if pdf_path is None:
|
||||
# 创建新的pdf路径,同时替换文件名中的docx
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(word_path).with_suffix('.pdf')
|
||||
else:
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(Path(pdf_path))
|
||||
|
||||
# 检查操作系统
|
||||
if platform.system() == 'Linux':
|
||||
# Linux系统需要安装libreoffice
|
||||
if not os.system('which libreoffice') == 0:
|
||||
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
|
||||
|
||||
# 使用libreoffice进行转换
|
||||
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
|
||||
|
||||
# 如果输出路径与默认生成的不同,则重命名
|
||||
default_pdf = word_path.with_suffix('.pdf')
|
||||
if default_pdf != pdf_path:
|
||||
os.rename(default_pdf, pdf_path)
|
||||
else:
|
||||
# Windows和MacOS使用 docx2pdf
|
||||
convert(word_path, pdf_path)
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
|
||||
"""
|
||||
批量转换目录下的所有Word文档
|
||||
|
||||
参数:
|
||||
word_dir: 包含Word文档的目录路径
|
||||
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径列表
|
||||
"""
|
||||
word_dir = Path(word_dir)
|
||||
if pdf_dir:
|
||||
pdf_dir = Path(pdf_dir)
|
||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
converted_files = []
|
||||
|
||||
for word_file in word_dir.glob("*.docx"):
|
||||
try:
|
||||
if pdf_dir:
|
||||
pdf_path = pdf_dir / WordToPdfConverter._replace_docx_in_filename(
|
||||
word_file.with_suffix('.pdf')
|
||||
).name
|
||||
else:
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(
|
||||
word_file.with_suffix('.pdf')
|
||||
)
|
||||
|
||||
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
|
||||
converted_files.append(pdf_file)
|
||||
|
||||
except Exception as e:
|
||||
print(f"转换 {word_file} 失败: {str(e)}")
|
||||
|
||||
return converted_files
|
||||
|
||||
@staticmethod
|
||||
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将docx对象直接转换为PDF
|
||||
|
||||
参数:
|
||||
doc: python-docx的Document对象
|
||||
output_dir: 可选,输出目录。如果未指定,将使用当前目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
"""
|
||||
try:
|
||||
# 设置临时文件路径和输出路径
|
||||
output_dir = Path(output_dir) if output_dir else Path.cwd()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成临时word文件
|
||||
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
|
||||
doc.save(temp_docx)
|
||||
|
||||
# 转换为PDF
|
||||
pdf_path = temp_docx.with_suffix('.pdf')
|
||||
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
|
||||
|
||||
# 删除临时word文件
|
||||
temp_docx.unlink()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
if temp_docx.exists():
|
||||
temp_docx.unlink()
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
246
crazy_functions/review_fns/conversation_doc/word_doc.py
Normal file
246
crazy_functions/review_fns/conversation_doc/word_doc.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import re
|
||||
from docx import Document
|
||||
from docx.shared import Cm, Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||
from docx.enum.style import WD_STYLE_TYPE
|
||||
from docx.oxml.ns import qn
|
||||
from datetime import datetime
|
||||
import docx
|
||||
from docx.oxml import shared
|
||||
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
|
||||
|
||||
|
||||
class WordFormatter:
|
||||
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
|
||||
|
||||
def __init__(self):
|
||||
self.doc = Document()
|
||||
self._setup_document()
|
||||
self._create_styles()
|
||||
|
||||
def _setup_document(self):
|
||||
"""设置文档基本格式,包括页面设置和页眉"""
|
||||
sections = self.doc.sections
|
||||
for section in sections:
|
||||
# 设置页面大小为A4
|
||||
section.page_width = Cm(21)
|
||||
section.page_height = Cm(29.7)
|
||||
# 设置页边距
|
||||
section.top_margin = Cm(3.7) # 上边距37mm
|
||||
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||
section.left_margin = Cm(2.8) # 左边距28mm
|
||||
section.right_margin = Cm(2.6) # 右边距26mm
|
||||
# 设置页眉页脚距离
|
||||
section.header_distance = Cm(2.0)
|
||||
section.footer_distance = Cm(2.0)
|
||||
|
||||
# 修改页眉
|
||||
header = section.header
|
||||
header_para = header.paragraphs[0]
|
||||
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
header_run = header_para.add_run("GPT-Academic学术对话 (体验地址:https://auth.gpt-academic.top/)")
|
||||
header_run.font.name = '仿宋'
|
||||
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
header_run.font.size = Pt(9)
|
||||
|
||||
def _create_styles(self):
|
||||
"""创建文档样式"""
|
||||
# 创建正文样式
|
||||
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
style.font.name = '仿宋'
|
||||
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
style.font.size = Pt(12)
|
||||
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
style.paragraph_format.space_after = Pt(0)
|
||||
|
||||
# 创建问题样式
|
||||
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
question_style.font.name = '黑体'
|
||||
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
question_style.font.size = Pt(14) # 调整为14磅
|
||||
question_style.font.bold = True
|
||||
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
|
||||
question_style.paragraph_format.space_after = Pt(6)
|
||||
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||
|
||||
# 创建回答样式
|
||||
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
answer_style.font.name = '仿宋'
|
||||
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
answer_style.font.size = Pt(12) # 调整为12磅
|
||||
answer_style.paragraph_format.space_before = Pt(6)
|
||||
answer_style.paragraph_format.space_after = Pt(12)
|
||||
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||
|
||||
# 创建标题样式
|
||||
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
title_style.font.name = '黑体' # 改用黑体
|
||||
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
title_style.font.size = Pt(22) # 调整为22磅
|
||||
title_style.font.bold = True
|
||||
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
title_style.paragraph_format.space_before = Pt(0)
|
||||
title_style.paragraph_format.space_after = Pt(24)
|
||||
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
|
||||
# 添加参考文献样式
|
||||
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
ref_style.font.name = '宋体'
|
||||
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
|
||||
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
|
||||
ref_style.paragraph_format.space_before = Pt(3)
|
||||
ref_style.paragraph_format.space_after = Pt(3)
|
||||
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
|
||||
ref_style.paragraph_format.left_indent = Pt(21)
|
||||
ref_style.paragraph_format.first_line_indent = Pt(-21)
|
||||
|
||||
# 添加参考文献标题样式
|
||||
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
ref_title_style.font.name = '黑体'
|
||||
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
ref_title_style.font.size = Pt(16) # 参考文献标题与问题同样大小
|
||||
ref_title_style.font.bold = True
|
||||
ref_title_style.paragraph_format.space_before = Pt(24) # 增加段前距
|
||||
ref_title_style.paragraph_format.space_after = Pt(12)
|
||||
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None):
|
||||
"""写入聊天历史
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
"""
|
||||
try:
|
||||
# 添加标题
|
||||
title_para = self.doc.add_paragraph(style='Title_Custom')
|
||||
title_run = title_para.add_run('GPT-Academic 对话记录')
|
||||
|
||||
# 添加日期
|
||||
try:
|
||||
date_para = self.doc.add_paragraph()
|
||||
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||
date_run.font.name = '仿宋'
|
||||
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
date_run.font.size = Pt(16)
|
||||
except Exception as e:
|
||||
print(f"添加日期失败: {str(e)}")
|
||||
raise
|
||||
|
||||
self.doc.add_paragraph() # 添加空行
|
||||
|
||||
# 添加问答对话
|
||||
try:
|
||||
q_para = self.doc.add_paragraph(style='Question_Style')
|
||||
q_para.add_run('问题:').bold = True
|
||||
q_para.add_run(str(question))
|
||||
|
||||
a_para = self.doc.add_paragraph(style='Answer_Style')
|
||||
a_para.add_run('回答:').bold = True
|
||||
a_para.add_run(convert_markdown_to_word(str(answer)))
|
||||
except Exception as e:
|
||||
print(f"添加问答对话失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 添加参考文献部分
|
||||
if ranked_papers:
|
||||
try:
|
||||
ref_title = self.doc.add_paragraph(style='Reference_Title_Style')
|
||||
ref_title.add_run("参考文献")
|
||||
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
try:
|
||||
ref_para = self.doc.add_paragraph(style='Reference_Style')
|
||||
ref_para.add_run(f'[{idx}] ').bold = True
|
||||
|
||||
# 添加作者
|
||||
authors = ', '.join(paper.authors[:3])
|
||||
if len(paper.authors) > 3:
|
||||
authors += ' et al.'
|
||||
ref_para.add_run(f'{authors}. ')
|
||||
|
||||
# 添加标题
|
||||
title_run = ref_para.add_run(paper.title)
|
||||
title_run.italic = True
|
||||
if hasattr(paper, 'url') and paper.url:
|
||||
try:
|
||||
title_run._element.rPr.rStyle = self._create_hyperlink_style()
|
||||
self._add_hyperlink(ref_para, paper.title, paper.url)
|
||||
except Exception as e:
|
||||
print(f"添加超链接失败: {str(e)}")
|
||||
|
||||
# 添加期刊/会议信息
|
||||
if paper.venue_name:
|
||||
ref_para.add_run(f'. {paper.venue_name}')
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
ref_para.add_run(f', {paper.year}')
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
ref_para.add_run('. ')
|
||||
if "arxiv" in paper.url:
|
||||
doi_url = paper.doi
|
||||
else:
|
||||
doi_url = f'https://doi.org/{paper.doi}'
|
||||
self._add_hyperlink(ref_para, f'DOI: {paper.doi}', doi_url)
|
||||
|
||||
ref_para.add_run('.')
|
||||
except Exception as e:
|
||||
print(f"添加第 {idx} 篇参考文献失败: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"添加参考文献部分失败: {str(e)}")
|
||||
raise
|
||||
|
||||
return self.doc
|
||||
|
||||
except Exception as e:
|
||||
print(f"Word文档创建失败: {str(e)}")
|
||||
import traceback
|
||||
print(f"详细错误信息: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
def _create_hyperlink_style(self):
|
||||
"""创建超链接样式"""
|
||||
styles = self.doc.styles
|
||||
if 'Hyperlink' not in styles:
|
||||
hyperlink_style = styles.add_style('Hyperlink', WD_STYLE_TYPE.CHARACTER)
|
||||
# 使用科技蓝 (#0066CC)
|
||||
hyperlink_style.font.color.rgb = 0x0066CC # 科技蓝
|
||||
hyperlink_style.font.underline = True
|
||||
return styles['Hyperlink']
|
||||
|
||||
def _add_hyperlink(self, paragraph, text, url):
|
||||
"""添加超链接到段落"""
|
||||
# 这个是在XML级别添加超链接
|
||||
part = paragraph.part
|
||||
r_id = part.relate_to(url, docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, is_external=True)
|
||||
|
||||
# 创建超链接XML元素
|
||||
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
|
||||
hyperlink.set(docx.oxml.shared.qn('r:id'), r_id)
|
||||
|
||||
# 创建文本运行
|
||||
new_run = docx.oxml.shared.OxmlElement('w:r')
|
||||
rPr = docx.oxml.shared.OxmlElement('w:rPr')
|
||||
|
||||
# 应用超链接样式
|
||||
rStyle = docx.oxml.shared.OxmlElement('w:rStyle')
|
||||
rStyle.set(docx.oxml.shared.qn('w:val'), 'Hyperlink')
|
||||
rPr.append(rStyle)
|
||||
|
||||
# 添加文本
|
||||
t = docx.oxml.shared.OxmlElement('w:t')
|
||||
t.text = text
|
||||
new_run.append(rPr)
|
||||
new_run.append(t)
|
||||
hyperlink.append(new_run)
|
||||
|
||||
# 将超链接添加到段落
|
||||
paragraph._p.append(hyperlink)
|
||||
|
||||
0
crazy_functions/review_fns/data_sources/__init__.py
Normal file
0
crazy_functions/review_fns/data_sources/__init__.py
Normal file
279
crazy_functions/review_fns/data_sources/adsabs_source.py
Normal file
279
crazy_functions/review_fns/data_sources/adsabs_source.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class AdsabsSource(DataSource):
|
||||
"""ADS (Astrophysics Data System) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.adsabs.harvard.edu/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
method: HTTP方法
|
||||
data: POST请求数据
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
if method == "GET":
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif method == "POST":
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper(self, doc: dict) -> PaperMetadata:
|
||||
"""解析ADS文献数据
|
||||
|
||||
Args:
|
||||
doc: ADS文献数据
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
return PaperMetadata(
|
||||
title=doc.get('title', [''])[0] if doc.get('title') else '',
|
||||
authors=doc.get('author', []),
|
||||
abstract=doc.get('abstract', ''),
|
||||
year=doc.get('year'),
|
||||
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
|
||||
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
|
||||
citations=doc.get('citation_count'),
|
||||
venue=doc.get('pub', ''),
|
||||
institutions=doc.get('aff', []),
|
||||
venue_type="journal",
|
||||
venue_name=doc.get('pub', ''),
|
||||
venue_info={
|
||||
'volume': doc.get('volume'),
|
||||
'issue': doc.get('issue'),
|
||||
'pub_date': doc.get('pubdate', '')
|
||||
},
|
||||
source='adsabs'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
if start_year:
|
||||
query = f"{query} year:{start_year}-"
|
||||
|
||||
# 设置排序
|
||||
sort_mapping = {
|
||||
'relevance': 'score desc',
|
||||
'date': 'date desc',
|
||||
'citations': 'citation_count desc'
|
||||
}
|
||||
sort = sort_mapping.get(sort_by, 'score desc')
|
||||
|
||||
# 构建搜索请求
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": query,
|
||||
"rows": limit,
|
||||
"sort": sort,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _build_query_string(self, params: dict) -> str:
|
||||
"""构建查询字符串"""
|
||||
return "&".join([f"{k}={v}" for k, v in params.items()])
|
||||
|
||||
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定bibcode的论文详情"""
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"identifier:{bibcode}",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if response and 'response' in response and response['response']['docs']:
|
||||
return self._parse_paper(response['response']['docs'][0])
|
||||
return None
|
||||
|
||||
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
|
||||
"rows": limit,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"author:\"{author}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"pub:\"{journal}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"entdate:[NOW-{days}DAYS TO NOW]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"references(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def example_usage():
|
||||
"""AdsabsSource使用示例"""
|
||||
ads = AdsabsSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索黑洞相关论文 ===")
|
||||
papers = await ads.search("black hole", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 其他示例...
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# python -m crazy_functions.review_fns.data_sources.adsabs_source
|
||||
asyncio.run(example_usage())
|
||||
636
crazy_functions/review_fns/data_sources/arxiv_source.py
Normal file
636
crazy_functions/review_fns/data_sources/arxiv_source.py
Normal file
@@ -0,0 +1,636 @@
|
||||
import arxiv
|
||||
from typing import List, Optional, Union, Literal, Dict
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.request import urlretrieve
|
||||
import feedparser
|
||||
from tqdm import tqdm
|
||||
|
||||
class ArxivSource(DataSource):
|
||||
"""arXiv API实现"""
|
||||
|
||||
CATEGORIES = {
|
||||
# 物理学
|
||||
"Physics": {
|
||||
"astro-ph": "天体物理学",
|
||||
"cond-mat": "凝聚态物理",
|
||||
"gr-qc": "广义相对论与量子宇宙学",
|
||||
"hep-ex": "高能物理实验",
|
||||
"hep-lat": "格点场论",
|
||||
"hep-ph": "高能物理理论",
|
||||
"hep-th": "高能物理理论",
|
||||
"math-ph": "数学物理",
|
||||
"nlin": "非线性科学",
|
||||
"nucl-ex": "核实验",
|
||||
"nucl-th": "核理论",
|
||||
"physics": "物理学",
|
||||
"quant-ph": "量子物理",
|
||||
},
|
||||
|
||||
# 数学
|
||||
"Mathematics": {
|
||||
"math.AG": "代数几何",
|
||||
"math.AT": "代数拓扑",
|
||||
"math.AP": "分析与偏微分方程",
|
||||
"math.CT": "范畴论",
|
||||
"math.CA": "复分析",
|
||||
"math.CO": "组合数学",
|
||||
"math.AC": "交换代数",
|
||||
"math.CV": "复变函数",
|
||||
"math.DG": "微分几何",
|
||||
"math.DS": "动力系统",
|
||||
"math.FA": "泛函分析",
|
||||
"math.GM": "一般数学",
|
||||
"math.GN": "一般拓扑",
|
||||
"math.GT": "几何拓扑",
|
||||
"math.GR": "群论",
|
||||
"math.HO": "数学史与数学概述",
|
||||
"math.IT": "信息论",
|
||||
"math.KT": "K理论与同调",
|
||||
"math.LO": "逻辑",
|
||||
"math.MP": "数学物理",
|
||||
"math.MG": "度量几何",
|
||||
"math.NT": "数论",
|
||||
"math.NA": "数值分析",
|
||||
"math.OA": "算子代数",
|
||||
"math.OC": "最优化与控制",
|
||||
"math.PR": "概率论",
|
||||
"math.QA": "量子代数",
|
||||
"math.RT": "表示论",
|
||||
"math.RA": "环与代数",
|
||||
"math.SP": "谱理论",
|
||||
"math.ST": "统计理论",
|
||||
"math.SG": "辛几何",
|
||||
},
|
||||
|
||||
# 计算机科学
|
||||
"Computer Science": {
|
||||
"cs.AI": "人工智能",
|
||||
"cs.CL": "计算语言学",
|
||||
"cs.CC": "计算复杂性",
|
||||
"cs.CE": "计算工程",
|
||||
"cs.CG": "计算几何",
|
||||
"cs.GT": "计算机博弈论",
|
||||
"cs.CV": "计算机视觉",
|
||||
"cs.CY": "计算机与社会",
|
||||
"cs.CR": "密码学与安全",
|
||||
"cs.DS": "数据结构与算法",
|
||||
"cs.DB": "数据库",
|
||||
"cs.DL": "数字图书馆",
|
||||
"cs.DM": "离散数学",
|
||||
"cs.DC": "分布式计算",
|
||||
"cs.ET": "新兴技术",
|
||||
"cs.FL": "形式语言与自动机理论",
|
||||
"cs.GL": "一般文献",
|
||||
"cs.GR": "图形学",
|
||||
"cs.AR": "硬件架构",
|
||||
"cs.HC": "人机交互",
|
||||
"cs.IR": "信息检索",
|
||||
"cs.IT": "信息论",
|
||||
"cs.LG": "机器学习",
|
||||
"cs.LO": "逻辑与计算机",
|
||||
"cs.MS": "数学软件",
|
||||
"cs.MA": "多智能体系统",
|
||||
"cs.MM": "多媒体",
|
||||
"cs.NI": "网络与互联网架构",
|
||||
"cs.NE": "神经与进化计算",
|
||||
"cs.NA": "数值分析",
|
||||
"cs.OS": "操作系统",
|
||||
"cs.OH": "其他计算机科学",
|
||||
"cs.PF": "性能评估",
|
||||
"cs.PL": "编程语言",
|
||||
"cs.RO": "机器人学",
|
||||
"cs.SI": "社会与信息网络",
|
||||
"cs.SE": "软件工程",
|
||||
"cs.SD": "声音",
|
||||
"cs.SC": "符号计算",
|
||||
"cs.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 定量生物学
|
||||
"Quantitative Biology": {
|
||||
"q-bio.BM": "生物分子",
|
||||
"q-bio.CB": "细胞行为",
|
||||
"q-bio.GN": "基因组学",
|
||||
"q-bio.MN": "分子网络",
|
||||
"q-bio.NC": "神经计算",
|
||||
"q-bio.OT": "其他",
|
||||
"q-bio.PE": "群体与进化",
|
||||
"q-bio.QM": "定量方法",
|
||||
"q-bio.SC": "亚细胞过程",
|
||||
"q-bio.TO": "组织与器官",
|
||||
},
|
||||
|
||||
# 定量金融
|
||||
"Quantitative Finance": {
|
||||
"q-fin.CP": "计算金融",
|
||||
"q-fin.EC": "经济学",
|
||||
"q-fin.GN": "一般金融",
|
||||
"q-fin.MF": "数学金融",
|
||||
"q-fin.PM": "投资组合管理",
|
||||
"q-fin.PR": "定价理论",
|
||||
"q-fin.RM": "风险管理",
|
||||
"q-fin.ST": "统计金融",
|
||||
"q-fin.TR": "交易与市场微观结构",
|
||||
},
|
||||
|
||||
# 统计学
|
||||
"Statistics": {
|
||||
"stat.AP": "应用统计",
|
||||
"stat.CO": "计算统计",
|
||||
"stat.ML": "机器学习",
|
||||
"stat.ME": "方法论",
|
||||
"stat.OT": "其他统计",
|
||||
"stat.TH": "统计理论",
|
||||
},
|
||||
|
||||
# 电气工程与系统科学
|
||||
"Electrical Engineering and Systems Science": {
|
||||
"eess.AS": "音频与语音处理",
|
||||
"eess.IV": "图像与视频处理",
|
||||
"eess.SP": "信号处理",
|
||||
"eess.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 经济学
|
||||
"Economics": {
|
||||
"econ.EM": "计量经济学",
|
||||
"econ.GN": "一般经济学",
|
||||
"econ.TH": "理论经济学",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化"""
|
||||
self._initialize() # 调用初始化方法
|
||||
# 修改排序选项映射
|
||||
self.sort_options = {
|
||||
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
|
||||
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
|
||||
}
|
||||
|
||||
self.sort_order_options = {
|
||||
'ascending': arxiv.SortOrder.Ascending,
|
||||
'descending': arxiv.SortOrder.Descending
|
||||
}
|
||||
|
||||
self.default_sort = 'lastUpdatedDate'
|
||||
self.default_order = 'descending'
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.client = arxiv.Client()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[Dict]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
# 使用默认排序如果提供的排序选项无效
|
||||
if not sort_by or sort_by not in self.sort_options:
|
||||
sort_by = self.default_sort
|
||||
|
||||
# 使用默认排序顺序如果提供的顺序无效
|
||||
if not sort_order or sort_order not in self.sort_order_options:
|
||||
sort_order = self.default_order
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
search = arxiv.Search(
|
||||
query=query,
|
||||
max_results=limit,
|
||||
sort_by=self.sort_options[sort_by],
|
||||
sort_order=self.sort_order_options[sort_order]
|
||||
)
|
||||
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
|
||||
"""按ID搜索论文
|
||||
|
||||
Args:
|
||||
paper_id: 单个arXiv ID或ID列表,例如:'2005.14165' 或 ['2005.14165', '2103.14030']
|
||||
"""
|
||||
if isinstance(paper_id, str):
|
||||
paper_id = [paper_id]
|
||||
|
||||
search = arxiv.Search(
|
||||
id_list=paper_id,
|
||||
max_results=len(paper_id)
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
async def search_by_category(
|
||||
self,
|
||||
category: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
sort_order: str = 'descending',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按类别搜索论文"""
|
||||
query = f"cat:{category}"
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " AND ".join([f"au:\"{author}\"" for author in authors])
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
|
||||
sort_order: Literal['ascending', 'descending'] = 'descending'
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
|
||||
return await self.search(
|
||||
query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文PDF
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
# 清理标题中的非法字符
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.pdf"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
urlretrieve(paper.url, filepath)
|
||||
return filepath
|
||||
|
||||
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文源文件(通常是LaTeX源码)
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.tar.gz"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
source_url = paper.url.replace("/pdf/", "/src/")
|
||||
urlretrieve(source_url, filepath)
|
||||
return filepath
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID 或 DOI
|
||||
|
||||
Returns:
|
||||
论文详细信息,如果未找到返回 None
|
||||
"""
|
||||
try:
|
||||
# 如果是完整的 arXiv URL,提取 ID
|
||||
if "arxiv.org" in paper_id:
|
||||
paper_id = paper_id.split("/")[-1]
|
||||
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
|
||||
elif paper_id.startswith("10.48550/arXiv."):
|
||||
paper_id = paper_id.split(".")[-1]
|
||||
|
||||
papers = await self.search_by_id(paper_id)
|
||||
return papers[0] if papers else None
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
|
||||
"""解析arXiv API返回的数据"""
|
||||
# 解析主要类别和次要类别
|
||||
primary_category = result.primary_category
|
||||
categories = result.categories
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
'primary_category': primary_category,
|
||||
'categories': categories,
|
||||
'comments': getattr(result, 'comment', None),
|
||||
'journal_ref': getattr(result, 'journal_ref', None)
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=result.title,
|
||||
authors=[author.name for author in result.authors],
|
||||
abstract=result.summary,
|
||||
year=result.published.year,
|
||||
doi=result.entry_id,
|
||||
url=result.pdf_url,
|
||||
citations=None,
|
||||
venue=f"arXiv:{primary_category}",
|
||||
institutions=[],
|
||||
venue_type='preprint', # arXiv论文都是预印本
|
||||
venue_name='arXiv',
|
||||
venue_info=venue_info,
|
||||
source='arxiv' # 添加来源标记
|
||||
)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
category: str,
|
||||
debug: bool = False,
|
||||
batch_size: int = 50
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定类别的最新论文
|
||||
|
||||
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
|
||||
|
||||
Args:
|
||||
category: arXiv类别,例如:
|
||||
- 整个领域: 'cs'
|
||||
- 具体方向: 'cs.AI'
|
||||
- 多个类别: 'cs.AI+q-bio.NC'
|
||||
debug: 是否为调试模式,如果为True则只返回5篇最新论文
|
||||
batch_size: 批量获取论文的数量,默认50
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果类别无效
|
||||
"""
|
||||
try:
|
||||
# 处理类别格式
|
||||
# 1. 转换为小写
|
||||
# 2. 确保多个类别之间使用+连接
|
||||
category = category.lower().replace(' ', '+')
|
||||
|
||||
# 构建RSS feed URL
|
||||
feed_url = f"https://rss.arxiv.org/rss/{category}"
|
||||
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
|
||||
|
||||
feed = feedparser.parse(feed_url)
|
||||
|
||||
# 检查feed是否有效
|
||||
if hasattr(feed, 'status') and feed.status != 200:
|
||||
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
|
||||
|
||||
if not feed.entries:
|
||||
print(f"警告:未在feed中找到任何条目") # 添加调试信息
|
||||
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
|
||||
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
|
||||
|
||||
if debug:
|
||||
# 调试模式:只获取5篇最新论文
|
||||
search = arxiv.Search(
|
||||
query=f'cat:{category}',
|
||||
sort_by=arxiv.SortCriterion.SubmittedDate,
|
||||
sort_order=arxiv.SortOrder.Descending,
|
||||
max_results=5
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
# 正常模式:获取所有新论文
|
||||
# 从RSS条目中提取arXiv ID
|
||||
paper_ids = []
|
||||
for entry in feed.entries:
|
||||
try:
|
||||
# RSS链接格式可能是以下几种:
|
||||
# - http://arxiv.org/abs/2403.xxxxx
|
||||
# - http://arxiv.org/pdf/2403.xxxxx
|
||||
# - https://arxiv.org/abs/2403.xxxxx
|
||||
link = entry.link or entry.id
|
||||
arxiv_id = link.split('/')[-1].replace('.pdf', '')
|
||||
if arxiv_id:
|
||||
paper_ids.append(arxiv_id)
|
||||
except Exception as e:
|
||||
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
|
||||
continue
|
||||
|
||||
if not paper_ids:
|
||||
print("未能从feed中提取到任何论文ID") # 添加调试信息
|
||||
return []
|
||||
|
||||
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
|
||||
for i in range(0, len(paper_ids), batch_size):
|
||||
batch_ids = paper_ids[i:i + batch_size]
|
||||
search = arxiv.Search(
|
||||
id_list=batch_ids,
|
||||
max_results=len(batch_ids)
|
||||
)
|
||||
batch_results = list(self.client.results(search))
|
||||
papers.extend([self._parse_paper_data(result) for result in batch_results])
|
||||
pbar.update(len(batch_results))
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取最新论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc()) # 添加完整的错误追踪
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""ArxivSource使用示例"""
|
||||
arxiv_source = ArxivSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
# print("\n=== 示例1:搜索最新的机器学习论文(按提交时间排序)===")
|
||||
# papers = await arxiv_source.search(
|
||||
# "ti:\"machine learning\"",
|
||||
# limit=3,
|
||||
# sort_by='submitted',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# for i, paper in enumerate(papers, 1):
|
||||
# print(f"\n--- 论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
# print(f"arXiv ID: {paper.doi}")
|
||||
# print(f"PDF URL: {paper.url}")
|
||||
# if paper.abstract:
|
||||
# print(f"\n摘要:")
|
||||
# print(paper.abstract)
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例2:按ID搜索
|
||||
# print("\n=== 示例2:按ID搜索论文 ===")
|
||||
# paper_id = "2005.14165" # GPT-3论文
|
||||
# papers = await arxiv_source.search_by_id(paper_id)
|
||||
# if papers:
|
||||
# paper = papers[0]
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例3:按类别搜索
|
||||
# print("\n=== 示例3:搜索人工智能领域最新论文 ===")
|
||||
# ai_papers = await arxiv_source.search_by_category(
|
||||
# "cs.AI",
|
||||
# limit=2,
|
||||
# sort_by='updated',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# for i, paper in enumerate(ai_papers, 1):
|
||||
# print(f"\n--- AI论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例4:按作者搜索
|
||||
# print("\n=== 示例4:搜索特定作者的论文 ===")
|
||||
# author_papers = await arxiv_source.search_by_authors(
|
||||
# ["Bengio"],
|
||||
# limit=2,
|
||||
# sort_by='relevance'
|
||||
# )
|
||||
# for i, paper in enumerate(author_papers, 1):
|
||||
# print(f"\n--- Bengio的论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例5:按日期范围搜索
|
||||
# print("\n=== 示例5:搜索特定日期范围的论文 ===")
|
||||
# from datetime import datetime, timedelta
|
||||
# end_date = datetime.now()
|
||||
# start_date = end_date - timedelta(days=7) # 最近一周
|
||||
# recent_papers = await arxiv_source.search_by_date_range(
|
||||
# start_date,
|
||||
# end_date,
|
||||
# limit=2
|
||||
# )
|
||||
# for i, paper in enumerate(recent_papers, 1):
|
||||
# print(f"\n--- 最近论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例6:下载PDF
|
||||
# print("\n=== 示例6:下载论文PDF ===")
|
||||
# if papers: # 使用之前搜索到的GPT-3论文
|
||||
# pdf_path = await arxiv_source.download_pdf(paper_id)
|
||||
# print(f"PDF已下载到: {pdf_path}")
|
||||
|
||||
# # 示例7:下载源文件
|
||||
# print("\n=== 示例7:下载论文源文件 ===")
|
||||
# if papers:
|
||||
# source_path = await arxiv_source.download_source(paper_id)
|
||||
# print(f"源文件已下载到: {source_path}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例8:获取最新论文 ===")
|
||||
|
||||
# 获取CS.AI领域的最新论文
|
||||
print("\n--- 获取AI领域最新论文 ---")
|
||||
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
|
||||
for i, paper in enumerate(ai_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取整个计算机科学领域的最新论文
|
||||
print("\n--- 获取整个CS领域最新论文 ---")
|
||||
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
|
||||
for i, paper in enumerate(cs_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取多个类别的最新论文
|
||||
print("\n--- 获取AI和机器学习领域最新论文 ---")
|
||||
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
|
||||
for i, paper in enumerate(multi_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
102
crazy_functions/review_fns/data_sources/base_source.py
Normal file
102
crazy_functions/review_fns/data_sources/base_source.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
class PaperMetadata:
|
||||
"""论文元数据"""
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
authors: List[str],
|
||||
abstract: str,
|
||||
year: int,
|
||||
doi: str = None,
|
||||
url: str = None,
|
||||
citations: int = None,
|
||||
venue: str = None,
|
||||
institutions: List[str] = None,
|
||||
venue_type: str = None, # 来源类型(journal/conference/preprint等)
|
||||
venue_name: str = None, # 具体的期刊/会议名称
|
||||
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
|
||||
source: str = None # 新增: 论文来源标记
|
||||
):
|
||||
self.title = title
|
||||
self.authors = authors
|
||||
self.abstract = abstract
|
||||
self.year = year
|
||||
self.doi = doi
|
||||
self.url = url
|
||||
self.citations = citations
|
||||
self.venue = venue
|
||||
self.institutions = institutions or []
|
||||
self.venue_type = venue_type # 新增
|
||||
self.venue_name = venue_name # 新增
|
||||
self.venue_info = venue_info or {} # 新增
|
||||
self.source = source # 新增: 存储论文来源
|
||||
|
||||
# 新增:影响因子和分区信息,初始化为None
|
||||
self._if_factor = None
|
||||
self._cas_division = None
|
||||
self._jcr_division = None
|
||||
|
||||
@property
|
||||
def if_factor(self) -> Optional[float]:
|
||||
"""获取影响因子"""
|
||||
return self._if_factor
|
||||
|
||||
@if_factor.setter
|
||||
def if_factor(self, value: float):
|
||||
"""设置影响因子"""
|
||||
self._if_factor = value
|
||||
|
||||
@property
|
||||
def cas_division(self) -> Optional[str]:
|
||||
"""获取中科院分区"""
|
||||
return self._cas_division
|
||||
|
||||
@cas_division.setter
|
||||
def cas_division(self, value: str):
|
||||
"""设置中科院分区"""
|
||||
self._cas_division = value
|
||||
|
||||
@property
|
||||
def jcr_division(self) -> Optional[str]:
|
||||
"""获取JCR分区"""
|
||||
return self._jcr_division
|
||||
|
||||
@jcr_division.setter
|
||||
def jcr_division(self, value: str):
|
||||
"""设置JCR分区"""
|
||||
self._jcr_division = value
|
||||
|
||||
class DataSource(ABC):
|
||||
"""数据源基类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key
|
||||
self._initialize()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize(self) -> None:
|
||||
"""初始化数据源"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
|
||||
"""获取论文详细信息"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
pass
|
||||
1
crazy_functions/review_fns/data_sources/cas_if.json
Normal file
1
crazy_functions/review_fns/data_sources/cas_if.json
Normal file
File diff suppressed because one or more lines are too long
400
crazy_functions/review_fns/data_sources/crossref_source.py
Normal file
400
crazy_functions/review_fns/data_sources/crossref_source.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class CrossrefSource(DataSource):
|
||||
"""Crossref API实现"""
|
||||
|
||||
CONTACT_EMAILS = [
|
||||
"gpt_abc_academic@163.com",
|
||||
"gpt_abc_newapi@163.com",
|
||||
"gpt_abc_academic_pwd@163.com"
|
||||
]
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.crossref.org"
|
||||
# 随机选择一个邮箱
|
||||
contact_email = random.choice(self.CONTACT_EMAILS)
|
||||
self.headers = {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
|
||||
}
|
||||
if self.api_key:
|
||||
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序字段
|
||||
sort_order: 排序顺序
|
||||
start_year: 起始年份
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
# 请求更多的结果以补偿可能被过滤掉的文章
|
||||
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
|
||||
params = {
|
||||
"query": query,
|
||||
"rows": adjusted_limit,
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["filter"] = f"from-pub-date:{start_year}"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
params["sort"] = sort_by
|
||||
if sort_order:
|
||||
params["order"] = sort_order
|
||||
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
items = data.get("message", {}).get("items", [])
|
||||
if not items:
|
||||
print(f"未找到相关论文")
|
||||
return []
|
||||
|
||||
# 过滤掉没有摘要的文章
|
||||
papers = []
|
||||
filtered_count = 0
|
||||
for work in items:
|
||||
paper = self._parse_work(work)
|
||||
if paper.abstract and paper.abstract.strip():
|
||||
papers.append(paper)
|
||||
if len(papers) >= limit: # 达到原始请求的限制后停止
|
||||
break
|
||||
else:
|
||||
filtered_count += 1
|
||||
|
||||
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
|
||||
print(f"返回 {len(papers)} 篇包含摘要的论文")
|
||||
return papers
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取论文详情失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
return self._parse_work(data.get("message", {}))
|
||||
except Exception as e:
|
||||
print(f"解析论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={"select": "reference"}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取参考文献失败: HTTP {response.status}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 确保我们正确处理返回的数据结构
|
||||
if not isinstance(data, dict):
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
|
||||
references = data.get("message", {}).get("reference", [])
|
||||
if not references:
|
||||
print(f"未找到参考文献")
|
||||
return []
|
||||
|
||||
return [
|
||||
PaperMetadata(
|
||||
title=ref.get("article-title", ""),
|
||||
authors=[ref.get("author", "")],
|
||||
year=ref.get("year"),
|
||||
doi=ref.get("DOI"),
|
||||
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
|
||||
abstract="",
|
||||
citations=None,
|
||||
venue=ref.get("journal-title", ""),
|
||||
institutions=[]
|
||||
)
|
||||
for ref in references
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"解析参考文献数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params={
|
||||
"filter": f"reference.DOI:{doi}",
|
||||
"select": "DOI,title,author,published-print,abstract"
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取引用信息失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 检查返回的数据结构
|
||||
if isinstance(data, dict):
|
||||
items = data.get("message", {}).get("items", [])
|
||||
return [self._parse_work(work) for work in items]
|
||||
else:
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"解析引用数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析Crossref返回的数据"""
|
||||
# 获取摘要 - 处理可能的不同格式
|
||||
abstract = ""
|
||||
if isinstance(work.get("abstract"), str):
|
||||
abstract = work.get("abstract", "")
|
||||
elif isinstance(work.get("abstract"), dict):
|
||||
abstract = work.get("abstract", {}).get("value", "")
|
||||
|
||||
if not abstract:
|
||||
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
for author in work.get("author", []):
|
||||
if "affiliation" in author:
|
||||
for affiliation in author["affiliation"]:
|
||||
if "name" in affiliation and affiliation["name"] not in institutions:
|
||||
institutions.append(affiliation["name"])
|
||||
|
||||
# 获取venue信息
|
||||
venue_name = work.get("container-title", [None])[0]
|
||||
venue_type = work.get("type", "unknown") # 文献类型
|
||||
venue_info = {
|
||||
"publisher": work.get("publisher"),
|
||||
"issn": work.get("ISSN", []),
|
||||
"isbn": work.get("ISBN", []),
|
||||
"issue": work.get("issue"),
|
||||
"volume": work.get("volume"),
|
||||
"page": work.get("page")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", [None])[0] or "",
|
||||
authors=[
|
||||
author.get("given", "") + " " + author.get("family", "")
|
||||
for author in work.get("author", [])
|
||||
],
|
||||
institutions=institutions, # 添加机构信息
|
||||
abstract=abstract,
|
||||
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
|
||||
doi=work.get("DOI"),
|
||||
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
|
||||
citations=work.get("is-referenced-by-count"),
|
||||
venue=venue_name,
|
||||
venue_type=venue_type, # 添加venue类型
|
||||
venue_name=venue_name, # 添加venue名称
|
||||
venue_info=venue_info, # 添加venue详细信息
|
||||
source='crossref' # 添加来源标记
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " ".join([f"author:\"{author}\"" for author in authors])
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
start_year=start_year
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""CrossrefSource使用示例"""
|
||||
crossref = CrossrefSource(api_key=None)
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
print("\n=== 示例1:搜索最新的机器学习论文 ===")
|
||||
papers = await crossref.search(
|
||||
query="machine learning",
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc",
|
||||
start_year=2023
|
||||
)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
print(f"venue类型: {paper.venue_type}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:按DOI获取论文详情
|
||||
print("\n=== 示例2:获取特定论文详情 ===")
|
||||
# 使用BERT论文的DOI
|
||||
doi = "10.18653/v1/N19-1423"
|
||||
paper = await crossref.get_paper_details(doi)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:按作者搜索
|
||||
print("\n=== 示例3:搜索特定作者的论文 ===")
|
||||
author_papers = await crossref.search_by_authors(
|
||||
authors=["Yoshua Bengio"],
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
start_year=2020
|
||||
)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- {i}. {paper.title} ---")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例4:按日期范围搜索
|
||||
print("\n=== 示例4:搜索特定日期范围的论文 ===")
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=30) # 最近一个月
|
||||
recent_papers = await crossref.search_by_date_range(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc"
|
||||
)
|
||||
for i, paper in enumerate(recent_papers, 1):
|
||||
print(f"\n--- 最近发表的论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 示例5:获取论文引用信息
|
||||
print("\n=== 示例5:获取论文引用信息 ===")
|
||||
if paper: # 使用之前获取的BERT论文
|
||||
print("\n获取引用该论文的文献:")
|
||||
citations = await crossref.get_citations(paper.doi)
|
||||
for i, citing_paper in enumerate(citations[:3], 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {citing_paper.title}")
|
||||
print(f"作者: {', '.join(citing_paper.authors)}")
|
||||
print(f"发表年份: {citing_paper.year}")
|
||||
|
||||
print("\n获取该论文引用的参考文献:")
|
||||
references = await crossref.get_references(paper.doi)
|
||||
for i, ref_paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {ref_paper.title}")
|
||||
print(f"作者: {', '.join(ref_paper.authors)}")
|
||||
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
|
||||
|
||||
# 示例6:展示venue信息的使用
|
||||
print("\n=== 示例6:展示期刊/会议详细信息 ===")
|
||||
if papers:
|
||||
paper = papers[0]
|
||||
print(f"文献类型: {paper.venue_type}")
|
||||
print(f"发表venue: {paper.venue_name}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
449
crazy_functions/review_fns/data_sources/elsevier_source.py
Normal file
449
crazy_functions/review_fns/data_sources/elsevier_source.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class ElsevierSource(DataSource):
|
||||
"""Elsevier (Scopus) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
# 添加更多必要的头部信息
|
||||
"X-ELS-Insttoken": "", # 如果有机构令牌
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
JSON响应
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
# 添加更详细的错误信息
|
||||
error_text = await response.text()
|
||||
print(f"请求失败: {response.status}")
|
||||
print(f"错误详情: {error_text}")
|
||||
if response.status == 401:
|
||||
print(f"使用的API密钥: {self.api_key}")
|
||||
# 尝试切换到另一个API密钥
|
||||
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
|
||||
print(f"尝试切换到新的API密钥: {new_key}")
|
||||
self.api_key = new_key
|
||||
self.headers["X-ELS-APIKey"] = new_key
|
||||
# 重试请求
|
||||
return await self._make_request(url, params)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD",
|
||||
# 移除dc:description字段,因为它在STANDARD视图中不可用
|
||||
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by == "date":
|
||||
params["sort"] = "-coverDate"
|
||||
elif sort_by == "cited":
|
||||
params["sort"] = "-citedby-count"
|
||||
|
||||
# 发送搜索请求
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析搜索结果
|
||||
entries = response["search-results"].get("entry", [])
|
||||
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
|
||||
|
||||
# 尝试为每篇论文获取摘要
|
||||
for paper in papers:
|
||||
if paper.doi:
|
||||
paper.abstract = await self.fetch_abstract(paper.doi) or ""
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析Scopus API返回的条目"""
|
||||
try:
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
creator = entry.get("dc:creator")
|
||||
if creator:
|
||||
authors = [creator]
|
||||
|
||||
# 获取发表年份
|
||||
year = None
|
||||
if "prism:coverDate" in entry:
|
||||
try:
|
||||
year = int(entry["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 简化venue信息
|
||||
venue_info = {
|
||||
'source_id': entry.get("source-id"),
|
||||
'issn': entry.get("prism:issn")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=entry.get("dc:title", ""),
|
||||
authors=authors,
|
||||
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
|
||||
year=year,
|
||||
doi=entry.get("prism:doi"),
|
||||
url=entry.get("prism:url"),
|
||||
citations=int(entry.get("citedby-count", 0)),
|
||||
venue=entry.get("prism:publicationName"),
|
||||
institutions=[], # 移除机构信息
|
||||
venue_type="",
|
||||
venue_name=entry.get("prism:publicationName"),
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析条目时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
try:
|
||||
params = {
|
||||
"query": f"REF({doi})",
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD"
|
||||
}
|
||||
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
entries = response["search-results"].get("entry", [])
|
||||
return [self._parse_entry(entry) for entry in entries]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
try:
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}/references",
|
||||
params={"view": "STANDARD"}
|
||||
)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
references = response["references"].get("reference", [])
|
||||
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析参考文献数据"""
|
||||
try:
|
||||
authors = []
|
||||
if "author-list" in ref:
|
||||
author_list = ref["author-list"].get("author", [])
|
||||
if isinstance(author_list, list):
|
||||
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
|
||||
for author in author_list]
|
||||
else:
|
||||
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
|
||||
|
||||
year = None
|
||||
if "prism:coverDate" in ref:
|
||||
try:
|
||||
year = int(ref["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
return PaperMetadata(
|
||||
title=ref.get("ce:title", ""),
|
||||
authors=authors,
|
||||
abstract="", # 参考文献通常不包含摘要
|
||||
year=year,
|
||||
doi=ref.get("prism:doi"),
|
||||
url=None,
|
||||
citations=None,
|
||||
venue=ref.get("prism:publicationName"),
|
||||
institutions=[],
|
||||
venue_type="unknown",
|
||||
venue_name=ref.get("prism:publicationName"),
|
||||
venue_info={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析参考文献时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_affiliation(
|
||||
self,
|
||||
affiliation: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按机构搜索论文"""
|
||||
query = f"AF-ID({affiliation})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_venue(
|
||||
self,
|
||||
venue: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊/会议搜索论文"""
|
||||
query = f"SRCTITLE({venue})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def test_api_access(self):
|
||||
"""测试API访问权限"""
|
||||
print(f"\n测试API密钥: {self.api_key}")
|
||||
|
||||
# 测试1: 基础搜索
|
||||
basic_params = {
|
||||
"query": "test",
|
||||
"count": 1,
|
||||
"view": "STANDARD"
|
||||
}
|
||||
print("\n1. 测试基础搜索...")
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=basic_params
|
||||
)
|
||||
if response:
|
||||
print("基础搜索成功")
|
||||
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
|
||||
|
||||
# 测试2: 测试单篇文章访问
|
||||
print("\n2. 测试文章详情访问...")
|
||||
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{test_doi}",
|
||||
params={"view": "STANDARD"} # 改为STANDARD视图
|
||||
)
|
||||
if response:
|
||||
print("文章详情访问成功")
|
||||
else:
|
||||
print("文章详情访问失败")
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详细信息
|
||||
|
||||
注意:当前API权限不支持获取详细信息,返回None
|
||||
|
||||
Args:
|
||||
paper_id: 论文ID
|
||||
|
||||
Returns:
|
||||
None,因为当前API权限不支持此功能
|
||||
"""
|
||||
return None
|
||||
|
||||
async def fetch_abstract(self, doi: str) -> Optional[str]:
|
||||
"""获取论文摘要
|
||||
|
||||
使用Scopus Abstract API获取论文摘要
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
|
||||
Returns:
|
||||
摘要文本,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 使用Abstract API而不是Search API
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}",
|
||||
params={
|
||||
"view": "FULL" # 使用FULL视图
|
||||
}
|
||||
)
|
||||
|
||||
if response and "abstracts-retrieval-response" in response:
|
||||
# 从coredata中获取摘要
|
||||
coredata = response["abstracts-retrieval-response"].get("coredata", {})
|
||||
return coredata.get("dc:description", "")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取摘要时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def example_usage():
|
||||
"""ElsevierSource使用示例"""
|
||||
elsevier = ElsevierSource()
|
||||
|
||||
try:
|
||||
# 首先测试API访问权限
|
||||
print("\n=== 测试API访问权限 ===")
|
||||
await elsevier.test_api_access()
|
||||
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await elsevier.search("machine learning", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
print("期刊信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value: # 只打印非空值
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:获取引用信息
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例2:获取引用该论文的文献 ===")
|
||||
citations = await elsevier.get_citations(papers[0].doi, limit=3)
|
||||
for i, paper in enumerate(citations, 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例3:获取参考文献
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例3:获取论文的参考文献 ===")
|
||||
references = await elsevier.get_references(papers[0].doi)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例5:按机构搜索
|
||||
print("\n=== 示例5:按机构搜索 ===")
|
||||
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
|
||||
for i, paper in enumerate(affiliation_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例6:获取论文摘要
|
||||
print("\n=== 示例6:获取论文摘要 ===")
|
||||
test_doi = "10.1016/j.artint.2021.103535"
|
||||
abstract = await elsevier.fetch_abstract(test_doi)
|
||||
if abstract:
|
||||
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
|
||||
else:
|
||||
print("无法获取摘要")
|
||||
|
||||
# 在搜索结果中显示摘要
|
||||
print("\n=== 示例7:搜索结果中的摘要 ===")
|
||||
papers = await elsevier.search("machine learning", limit=1)
|
||||
for paper in papers:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
698
crazy_functions/review_fns/data_sources/github_source.py
Normal file
698
crazy_functions/review_fns/data_sources/github_source.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
|
||||
class GitHubSource:
|
||||
"""GitHub API实现"""
|
||||
|
||||
# 默认API密钥列表 - 可以放置多个GitHub令牌
|
||||
API_KEYS = [
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
|
||||
"""初始化GitHub API客户端
|
||||
|
||||
Args:
|
||||
api_key: GitHub个人访问令牌或令牌列表
|
||||
"""
|
||||
if api_key is None:
|
||||
self.api_keys = self.API_KEYS
|
||||
elif isinstance(api_key, str):
|
||||
self.api_keys = [api_key]
|
||||
else:
|
||||
self.api_keys = api_key
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.github.com"
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
"User-Agent": "GitHub-API-Python-Client"
|
||||
}
|
||||
|
||||
# 如果有可用的API密钥,随机选择一个
|
||||
if self.api_keys:
|
||||
selected_key = random.choice(self.api_keys)
|
||||
self.headers["Authorization"] = f"Bearer {selected_key}"
|
||||
print(f"已随机选择API密钥进行认证")
|
||||
else:
|
||||
print("警告: 未提供API密钥,将受到GitHub API请求限制")
|
||||
|
||||
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
|
||||
"""发送API请求
|
||||
|
||||
Args:
|
||||
method: HTTP方法 (GET, POST, PUT, DELETE等)
|
||||
endpoint: API端点
|
||||
params: URL参数
|
||||
data: 请求体数据
|
||||
|
||||
Returns:
|
||||
解析后的响应JSON
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# 为调试目的打印请求信息
|
||||
print(f"请求: {method} {url}")
|
||||
if params:
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 发送请求
|
||||
request_kwargs = {}
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if data:
|
||||
request_kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **request_kwargs) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
# 检查HTTP状态码
|
||||
if response.status >= 400:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {response_text}")
|
||||
return None
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
print(f"JSON解析错误: {response_text}")
|
||||
return None
|
||||
|
||||
# ===== 用户相关方法 =====
|
||||
|
||||
async def get_user(self, username: Optional[str] = None) -> Dict:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
|
||||
Returns:
|
||||
用户信息字典
|
||||
"""
|
||||
endpoint = "/user" if username is None else f"/users/{username}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
|
||||
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户的仓库列表
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
|
||||
params = {
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_user_starred(self, username: Optional[str] = None,
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户星标的仓库
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
星标仓库列表
|
||||
"""
|
||||
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 仓库相关方法 =====
|
||||
|
||||
async def get_repo(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的分支列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
分支列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/branches"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
|
||||
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的提交历史
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
sha: 特定提交SHA或分支名
|
||||
path: 文件路径筛选
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
提交列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
if sha:
|
||||
params["sha"] = sha
|
||||
if path:
|
||||
params["path"] = path
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
|
||||
"""获取特定提交的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
commit_sha: 提交SHA
|
||||
|
||||
Returns:
|
||||
提交详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 内容相关方法 =====
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
|
||||
"""获取文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 文件路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
文件内容信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
response = await self._request("GET", endpoint, params=params)
|
||||
if response and isinstance(response, dict) and "content" in response:
|
||||
try:
|
||||
# 解码Base64编码的文件内容
|
||||
content = base64.b64decode(response["content"].encode()).decode()
|
||||
response["decoded_content"] = content
|
||||
except Exception as e:
|
||||
print(f"解码文件内容时出错: {str(e)}")
|
||||
|
||||
return response
|
||||
|
||||
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
|
||||
"""获取目录内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 目录路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
# 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== Issues相关方法 =====
|
||||
|
||||
async def get_issues(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Issues列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: Issue状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Issues列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
|
||||
"""获取特定Issue的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
Issue详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
|
||||
"""获取Issue的评论
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
评论列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== Pull Requests相关方法 =====
|
||||
|
||||
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Pull Request列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: PR状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, popularity, long-running)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Pull Request列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
|
||||
"""获取特定Pull Request的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
Pull Request详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
|
||||
"""获取Pull Request中修改的文件
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
修改文件列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 搜索相关方法 =====
|
||||
|
||||
async def search_repositories(self, query: str, sort: str = "stars",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索仓库
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (stars, forks, updated)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/repositories"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_code(self, query: str, sort: str = "indexed",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索代码
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (indexed)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/code"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_issues(self, query: str, sort: str = "created",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索Issues和Pull Requests
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/issues"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_users(self, query: str, sort: str = "followers",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索用户
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (followers, repositories, joined)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/users"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 组织相关方法 =====
|
||||
|
||||
async def get_organization(self, org: str) -> Dict:
|
||||
"""获取组织信息
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
|
||||
Returns:
|
||||
组织信息
|
||||
"""
|
||||
endpoint = f"/orgs/{org}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_organization_repos(self, org: str, type: str = "all",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织的仓库列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
type: 仓库类型 (all, public, private, forks, sources, member, internal)
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/repos"
|
||||
params = {
|
||||
"type": type,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织成员列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
成员列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/members"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 更复杂的操作 =====
|
||||
|
||||
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库使用的编程语言及其比例
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
语言使用情况
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/languages"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的贡献者统计
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
贡献者统计信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的提交活动
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
提交活动统计
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def example_usage():
|
||||
"""GitHubSource使用示例"""
|
||||
# 创建客户端实例(可选传入API令牌)
|
||||
# github = GitHubSource(api_key="your_github_token")
|
||||
github = GitHubSource()
|
||||
|
||||
try:
|
||||
# 示例1:搜索热门Python仓库
|
||||
print("\n=== 示例1:搜索热门Python仓库 ===")
|
||||
repos = await github.search_repositories(
|
||||
query="language:python stars:>1000",
|
||||
sort="stars",
|
||||
order="desc",
|
||||
per_page=5
|
||||
)
|
||||
|
||||
if repos and "items" in repos:
|
||||
for i, repo in enumerate(repos["items"], 1):
|
||||
print(f"\n--- 仓库 {i} ---")
|
||||
print(f"名称: {repo['full_name']}")
|
||||
print(f"描述: {repo['description']}")
|
||||
print(f"星标数: {repo['stargazers_count']}")
|
||||
print(f"Fork数: {repo['forks_count']}")
|
||||
print(f"最近更新: {repo['updated_at']}")
|
||||
print(f"URL: {repo['html_url']}")
|
||||
|
||||
# 示例2:获取特定仓库的详情
|
||||
print("\n=== 示例2:获取特定仓库的详情 ===")
|
||||
repo_details = await github.get_repo("microsoft", "vscode")
|
||||
if repo_details:
|
||||
print(f"名称: {repo_details['full_name']}")
|
||||
print(f"描述: {repo_details['description']}")
|
||||
print(f"星标数: {repo_details['stargazers_count']}")
|
||||
print(f"Fork数: {repo_details['forks_count']}")
|
||||
print(f"默认分支: {repo_details['default_branch']}")
|
||||
print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}")
|
||||
print(f"语言: {repo_details['language']}")
|
||||
print(f"Open Issues数: {repo_details['open_issues_count']}")
|
||||
|
||||
# 示例3:获取仓库的提交历史
|
||||
print("\n=== 示例3:获取仓库的最近提交 ===")
|
||||
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
|
||||
if commits:
|
||||
for i, commit in enumerate(commits, 1):
|
||||
print(f"\n--- 提交 {i} ---")
|
||||
print(f"SHA: {commit['sha'][:7]}")
|
||||
print(f"作者: {commit['commit']['author']['name']}")
|
||||
print(f"日期: {commit['commit']['author']['date']}")
|
||||
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
|
||||
|
||||
# 示例4:搜索代码
|
||||
print("\n=== 示例4:搜索代码 ===")
|
||||
code_results = await github.search_code(
|
||||
query="filename:README.md language:markdown pytorch in:file",
|
||||
per_page=3
|
||||
)
|
||||
if code_results and "items" in code_results:
|
||||
print(f"共找到: {code_results['total_count']} 个结果")
|
||||
for i, item in enumerate(code_results["items"], 1):
|
||||
print(f"\n--- 代码 {i} ---")
|
||||
print(f"仓库: {item['repository']['full_name']}")
|
||||
print(f"文件: {item['path']}")
|
||||
print(f"URL: {item['html_url']}")
|
||||
|
||||
# 示例5:获取文件内容
|
||||
print("\n=== 示例5:获取文件内容 ===")
|
||||
file_content = await github.get_file_content("python", "cpython", "README.rst")
|
||||
if file_content and "decoded_content" in file_content:
|
||||
content = file_content["decoded_content"]
|
||||
print(f"文件名: {file_content['name']}")
|
||||
print(f"大小: {file_content['size']} 字节")
|
||||
print(f"内容预览: {content[:200]}...")
|
||||
|
||||
# 示例6:获取仓库使用的编程语言
|
||||
print("\n=== 示例6:获取仓库使用的编程语言 ===")
|
||||
languages = await github.get_repository_languages("facebook", "react")
|
||||
if languages:
|
||||
print(f"React仓库使用的编程语言:")
|
||||
for lang, bytes_of_code in languages.items():
|
||||
print(f"- {lang}: {bytes_of_code} 字节")
|
||||
|
||||
# 示例7:获取组织信息
|
||||
print("\n=== 示例7:获取组织信息 ===")
|
||||
org_info = await github.get_organization("google")
|
||||
if org_info:
|
||||
print(f"名称: {org_info['name']}")
|
||||
print(f"描述: {org_info.get('description', '无')}")
|
||||
print(f"位置: {org_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {org_info['public_repos']}")
|
||||
print(f"成员数: {org_info.get('public_members', 0)}")
|
||||
print(f"URL: {org_info['html_url']}")
|
||||
|
||||
# 示例8:获取用户信息
|
||||
print("\n=== 示例8:获取用户信息 ===")
|
||||
user_info = await github.get_user("torvalds")
|
||||
if user_info:
|
||||
print(f"名称: {user_info['name']}")
|
||||
print(f"公司: {user_info.get('company', '无')}")
|
||||
print(f"博客: {user_info.get('blog', '无')}")
|
||||
print(f"位置: {user_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {user_info['public_repos']}")
|
||||
print(f"关注者数: {user_info['followers']}")
|
||||
print(f"URL: {user_info['html_url']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
142
crazy_functions/review_fns/data_sources/journal_metrics.py
Normal file
142
crazy_functions/review_fns/data_sources/journal_metrics.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
class JournalMetrics:
|
||||
"""期刊指标管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.journal_data: Dict = {} # 期刊名称到指标的映射
|
||||
self.issn_map: Dict = {} # ISSN到指标的映射
|
||||
self.name_map: Dict = {} # 标准化名称到指标的映射
|
||||
self._load_journal_data()
|
||||
|
||||
def _normalize_journal_name(self, name: str) -> str:
|
||||
"""标准化期刊名称
|
||||
|
||||
Args:
|
||||
name: 原始期刊名称
|
||||
|
||||
Returns:
|
||||
标准化后的期刊名称
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
|
||||
# 转换为小写
|
||||
name = name.lower()
|
||||
|
||||
# 移除常见的前缀和后缀
|
||||
prefixes = ['the ', 'proceedings of ', 'journal of ']
|
||||
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
|
||||
|
||||
for prefix in prefixes:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix):]
|
||||
|
||||
for suffix in suffixes:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)]
|
||||
|
||||
# 移除特殊字符,保留字母、数字和空格
|
||||
name = ''.join(c for c in name if c.isalnum() or c.isspace())
|
||||
|
||||
# 移除多余的空格
|
||||
name = ' '.join(name.split())
|
||||
|
||||
return name
|
||||
|
||||
def _convert_if_value(self, if_str: str) -> Optional[float]:
|
||||
"""转换IF值为float,处理特殊情况"""
|
||||
try:
|
||||
if if_str.startswith('<'):
|
||||
# 对于<0.1这样的值,返回0.1
|
||||
return float(if_str.strip('<'))
|
||||
return float(if_str)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _load_journal_data(self):
|
||||
"""加载期刊数据"""
|
||||
try:
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 建立期刊名称到指标的映射
|
||||
for journal in data:
|
||||
# 准备指标数据
|
||||
metrics = {
|
||||
'if_factor': self._convert_if_value(journal.get('IF')),
|
||||
'jcr_division': journal.get('Q'),
|
||||
'cas_division': journal.get('B')
|
||||
}
|
||||
|
||||
# 存储期刊名称映射(使用标准化名称)
|
||||
if journal.get('journal'):
|
||||
normalized_name = self._normalize_journal_name(journal['journal'])
|
||||
self.journal_data[normalized_name] = metrics
|
||||
self.name_map[normalized_name] = metrics
|
||||
|
||||
# 存储期刊缩写映射
|
||||
if journal.get('jabb'):
|
||||
normalized_abbr = self._normalize_journal_name(journal['jabb'])
|
||||
self.journal_data[normalized_abbr] = metrics
|
||||
self.name_map[normalized_abbr] = metrics
|
||||
|
||||
# 存储ISSN映射
|
||||
if journal.get('issn'):
|
||||
self.issn_map[journal['issn']] = metrics
|
||||
if journal.get('eissn'):
|
||||
self.issn_map[journal['eissn']] = metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载期刊数据时出错: {str(e)}")
|
||||
self.journal_data = {}
|
||||
self.issn_map = {}
|
||||
self.name_map = {}
|
||||
|
||||
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
|
||||
"""获取期刊指标
|
||||
|
||||
Args:
|
||||
venue_name: 期刊名称
|
||||
venue_info: 期刊详细信息
|
||||
|
||||
Returns:
|
||||
包含期刊指标的字典
|
||||
"""
|
||||
try:
|
||||
metrics = {}
|
||||
|
||||
# 1. 首先尝试通过ISSN匹配
|
||||
if venue_info and 'issn' in venue_info:
|
||||
issn_value = venue_info['issn']
|
||||
# 处理ISSN可能是列表的情况
|
||||
if isinstance(issn_value, list):
|
||||
# 尝试每个ISSN
|
||||
for issn in issn_value:
|
||||
metrics = self.issn_map.get(issn, {})
|
||||
if metrics: # 如果找到匹配的指标,就停止搜索
|
||||
break
|
||||
else: # ISSN是字符串的情况
|
||||
metrics = self.issn_map.get(issn_value, {})
|
||||
|
||||
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
|
||||
if not metrics and venue_name:
|
||||
# 标准化期刊名称
|
||||
normalized_name = self._normalize_journal_name(venue_name)
|
||||
metrics = self.name_map.get(normalized_name, {})
|
||||
|
||||
# 如果完全匹配失败,尝试部分匹配
|
||||
# if not metrics:
|
||||
# for db_name, db_metrics in self.name_map.items():
|
||||
# if normalized_name in db_name:
|
||||
# metrics = db_metrics
|
||||
# break
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取期刊指标时出错: {str(e)}")
|
||||
return {}
|
||||
163
crazy_functions/review_fns/data_sources/openalex_source.py
Normal file
163
crazy_functions/review_fns/data_sources/openalex_source.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
class OpenAlexSource(DataSource):
|
||||
"""OpenAlex API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.openalex.org"
|
||||
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"title.search:{query}",
|
||||
"per-page": limit
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
results = data.get("results", [])
|
||||
return [self._parse_work(work) for work in results]
|
||||
except Exception as e:
|
||||
print(f"搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析OpenAlex返回的数据"""
|
||||
# 获取作者信息
|
||||
raw_author_names = [
|
||||
authorship.get("raw_author_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
if authorship
|
||||
]
|
||||
# 处理作者名字格式
|
||||
authors = [
|
||||
self._reformat_name(author)
|
||||
for author in raw_author_names
|
||||
]
|
||||
|
||||
# 获取机构信息
|
||||
institutions = [
|
||||
inst.get("display_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
for inst in authorship.get("institutions", [])
|
||||
if inst
|
||||
]
|
||||
|
||||
# 获取主要发表位置信息
|
||||
primary_location = work.get("primary_location") or {}
|
||||
source = primary_location.get("source") or {}
|
||||
venue = source.get("display_name")
|
||||
|
||||
# 获取发表日期
|
||||
year = work.get("publication_year")
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", ""),
|
||||
authors=authors,
|
||||
institutions=institutions,
|
||||
abstract=work.get("abstract", ""),
|
||||
year=year,
|
||||
doi=work.get("doi"),
|
||||
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
|
||||
citations=work.get("cited_by_count"),
|
||||
venue=venue
|
||||
)
|
||||
|
||||
def _reformat_name(self, name: str) -> str:
|
||||
"""重新格式化作者名字"""
|
||||
if "," not in name:
|
||||
return name
|
||||
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
|
||||
return f"{given_names} {family}"
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return self._parse_work(data)
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"cites:doi:{doi}",
|
||||
"per-page": 100
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def example_usage():
|
||||
"""OpenAlexSource使用示例"""
|
||||
# 初始化OpenAlexSource
|
||||
openalex = OpenAlexSource()
|
||||
|
||||
try:
|
||||
print("正在搜索论文...")
|
||||
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
|
||||
papers = await openalex.search(query="artificial intelligence", limit=5)
|
||||
|
||||
if not papers:
|
||||
print("未获取到任何论文信息")
|
||||
return
|
||||
|
||||
print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# 打印搜索结果
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"发表年份: {paper.year if paper.year else '未知'}")
|
||||
print(f"DOI: {paper.doi if paper.doi else '未知'}")
|
||||
print(f"URL: {paper.url if paper.url else '未知'}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
|
||||
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 如果直接运行此文件,执行示例代码
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
458
crazy_functions/review_fns/data_sources/pubmed_source.py
Normal file
458
crazy_functions/review_fns/data_sources/pubmed_source.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import quote
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class PubMedSource(DataSource):
|
||||
"""PubMed API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: PubMed API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str) -> Optional[str]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.text()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
query = f"{query} AND {start_year}:3000[dp]"
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = (
|
||||
f"{self.base_url}/esearch.fcgi?"
|
||||
f"db=pubmed&term={quote(query)}&retmax={limit}"
|
||||
f"&usehistory=y&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
if sort_by == "date":
|
||||
search_url += "&sort=date"
|
||||
|
||||
# 获取搜索结果
|
||||
response = await self._make_request(search_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
id_list = root.findall(".//Id")
|
||||
pmids = [id_elem.text for id_elem in id_list]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
batch_size = 50
|
||||
for i in range(0, len(pmids), batch_size):
|
||||
batch = pmids[i:i + batch_size]
|
||||
batch_papers = await self._fetch_papers_batch(batch)
|
||||
papers.extend(batch_papers)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
|
||||
"""批量获取论文详情
|
||||
|
||||
Args:
|
||||
pmids: PubMed ID列表
|
||||
|
||||
Returns:
|
||||
论文详情列表
|
||||
"""
|
||||
try:
|
||||
# 构建批量获取URL
|
||||
fetch_url = (
|
||||
f"{self.base_url}/efetch.fcgi?"
|
||||
f"db=pubmed&id={','.join(pmids)}"
|
||||
f"&retmode=xml&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(fetch_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
articles = root.findall(".//PubmedArticle")
|
||||
|
||||
return [self._parse_article(article) for article in articles]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文批次时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_article(self, article: ET.Element) -> PaperMetadata:
|
||||
"""解析PubMed文章XML
|
||||
|
||||
Args:
|
||||
article: XML元素
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
pmid = article.find(".//PMID").text
|
||||
article_meta = article.find(".//Article")
|
||||
|
||||
# 获取标题
|
||||
title = article_meta.find(".//ArticleTitle")
|
||||
title = title.text if title is not None else ""
|
||||
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
author_list = article_meta.findall(".//Author")
|
||||
for author in author_list:
|
||||
last_name = author.find("LastName")
|
||||
fore_name = author.find("ForeName")
|
||||
if last_name is not None and fore_name is not None:
|
||||
authors.append(f"{fore_name.text} {last_name.text}")
|
||||
elif last_name is not None:
|
||||
authors.append(last_name.text)
|
||||
|
||||
# 获取摘要
|
||||
abstract = article_meta.find(".//Abstract/AbstractText")
|
||||
abstract = abstract.text if abstract is not None else ""
|
||||
|
||||
# 获取发表年份
|
||||
pub_date = article_meta.find(".//PubDate/Year")
|
||||
year = int(pub_date.text) if pub_date is not None else None
|
||||
|
||||
# 获取DOI
|
||||
doi = article.find(".//ELocationID[@EIdType='doi']")
|
||||
doi = doi.text if doi is not None else None
|
||||
|
||||
# 获取期刊信息
|
||||
journal = article_meta.find(".//Journal")
|
||||
if journal is not None:
|
||||
journal_title = journal.find(".//Title")
|
||||
venue = journal_title.text if journal_title is not None else None
|
||||
|
||||
# 获取期刊详细信息
|
||||
venue_info = {
|
||||
'issn': journal.findtext(".//ISSN"),
|
||||
'volume': journal.findtext(".//Volume"),
|
||||
'issue': journal.findtext(".//Issue"),
|
||||
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
|
||||
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
|
||||
}
|
||||
else:
|
||||
venue = None
|
||||
venue_info = {}
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
affiliations = article_meta.findall(".//Affiliation")
|
||||
for affiliation in affiliations:
|
||||
if affiliation is not None and affiliation.text:
|
||||
institutions.append(affiliation.text)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
|
||||
citations=None, # PubMed API不直接提供引用数据
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info,
|
||||
source='pubmed' # 添加来源标记
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定PMID的论文详情"""
|
||||
papers = await self._fetch_papers_batch([pmid])
|
||||
return papers[0] if papers else None
|
||||
|
||||
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文
|
||||
|
||||
使用PubMed的相关文章功能
|
||||
|
||||
Args:
|
||||
pmid: PubMed ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
相关论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建相关文章URL
|
||||
link_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(link_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
related_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in related_ids][:limit]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取相关论文详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取相关论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"{author}[Author]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"{journal}[Journal]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文
|
||||
|
||||
Args:
|
||||
days: 最近几天的论文
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
最新论文列表
|
||||
"""
|
||||
query = f"last {days} days[dp]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
注意:PubMed API本身不提供引用数据,此方法将返回空列表
|
||||
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
空列表,因为PubMed不提供引用数据
|
||||
"""
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
从PubMed文章的参考文献列表获取引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
引用的文献列表
|
||||
"""
|
||||
try:
|
||||
# 构建获取参考文献的URL
|
||||
refs_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
|
||||
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
|
||||
f"&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(refs_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
ref_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in ref_ids]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取参考文献详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""PubMedSource使用示例"""
|
||||
pubmed = PubMedSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索COVID-19相关论文 ===")
|
||||
papers = await pubmed.search("COVID-19", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
|
||||
# 示例2:获取论文详情
|
||||
if papers:
|
||||
print("\n=== 示例2:获取论文详情 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
paper = await pubmed.get_paper_details(paper_id)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"期刊: {paper.venue}")
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
|
||||
# 示例3:获取相关论文
|
||||
if papers:
|
||||
print("\n=== 示例3:获取相关论文 ===")
|
||||
related = await pubmed.get_related_papers(paper_id, limit=3)
|
||||
for i, paper in enumerate(related, 1):
|
||||
print(f"\n--- 相关论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例5:按期刊搜索
|
||||
print("\n=== 示例5:按期刊搜索 ===")
|
||||
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
|
||||
for i, paper in enumerate(journal_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例6:获取最新论文 ===")
|
||||
latest = await pubmed.get_latest_papers(days=7, limit=3)
|
||||
for i, paper in enumerate(latest, 1):
|
||||
print(f"\n--- 最新论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表日期: {paper.venue_info.get('pub_date')}")
|
||||
|
||||
# 示例7:获取论文的参考文献
|
||||
if papers:
|
||||
print("\n=== 示例7:获取论文的参考文献 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
references = await pubmed.get_references(paper_id)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例8:尝试获取引用信息(将返回空列表)
|
||||
if papers:
|
||||
print("\n=== 示例8:获取论文的引用信息 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
citations = await pubmed.get_citations(paper_id)
|
||||
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
326
crazy_functions/review_fns/data_sources/scihub_source.py
Normal file
326
crazy_functions/review_fns/data_sources/scihub_source.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
from loguru import logger
|
||||
import PyPDF2
|
||||
import io
|
||||
|
||||
|
||||
class SciHub:
|
||||
# 更新的镜像列表,包含更多可用的镜像
|
||||
MIRRORS = [
|
||||
'https://sci-hub.se/',
|
||||
'https://sci-hub.st/',
|
||||
'https://sci-hub.ru/',
|
||||
'https://sci-hub.wf/',
|
||||
'https://sci-hub.ee/',
|
||||
'https://sci-hub.ren/',
|
||||
'https://sci-hub.tf/',
|
||||
'https://sci-hub.si/',
|
||||
'https://sci-hub.do/',
|
||||
'https://sci-hub.hkvisa.net/',
|
||||
'https://sci-hub.mksa.top/',
|
||||
'https://sci-hub.shop/',
|
||||
'https://sci-hub.yncjkj.com/',
|
||||
'https://sci-hub.41610.org/',
|
||||
'https://sci-hub.automic.us/',
|
||||
'https://sci-hub.et-fine.com/',
|
||||
'https://sci-hub.pooh.mu/',
|
||||
'https://sci-hub.bban.top/',
|
||||
'https://sci-hub.usualwant.com/',
|
||||
'https://sci-hub.unblockit.kim/'
|
||||
]
|
||||
|
||||
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
|
||||
self.timeout = timeout
|
||||
self.path = path
|
||||
self.doi = str(doi)
|
||||
self.use_proxy = use_proxy
|
||||
self.headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
}
|
||||
self.payload = {
|
||||
'sci-hub-plugin-check': '',
|
||||
'request': self.doi
|
||||
}
|
||||
self.url = url if url else self.MIRRORS[0]
|
||||
self.proxies = {
|
||||
"http": "socks5h://localhost:10880",
|
||||
"https": "socks5h://localhost:10880",
|
||||
} if use_proxy else None
|
||||
|
||||
def _test_proxy_connection(self):
|
||||
"""测试代理连接是否可用"""
|
||||
if not self.use_proxy:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 测试代理连接
|
||||
test_response = requests.get(
|
||||
'https://httpbin.org/ip',
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
logger.info("代理连接测试成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"代理连接测试失败: {str(e)}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def _check_pdf_validity(self, content):
|
||||
"""检查PDF文件是否有效"""
|
||||
try:
|
||||
# 使用PyPDF2检查PDF是否可以正常打开和读取
|
||||
pdf = PyPDF2.PdfReader(io.BytesIO(content))
|
||||
if len(pdf.pages) > 0:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"PDF文件无效: {str(e)}")
|
||||
return False
|
||||
|
||||
def _send_request(self):
|
||||
"""发送请求到Sci-Hub镜像站点"""
|
||||
# 首先测试代理连接
|
||||
if self.use_proxy and not self._test_proxy_connection():
|
||||
logger.warning("代理连接不可用,切换到直连模式")
|
||||
self.use_proxy = False
|
||||
self.proxies = None
|
||||
|
||||
last_exception = None
|
||||
working_mirrors = []
|
||||
|
||||
# 先测试哪些镜像可用
|
||||
logger.info("正在测试镜像站点可用性...")
|
||||
for mirror in self.MIRRORS:
|
||||
try:
|
||||
test_response = requests.get(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
working_mirrors.append(mirror)
|
||||
logger.info(f"镜像 {mirror} 可用")
|
||||
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
|
||||
continue
|
||||
|
||||
if not working_mirrors:
|
||||
raise Exception("没有找到可用的镜像站点")
|
||||
|
||||
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
|
||||
|
||||
# 使用可用的镜像进行下载
|
||||
for mirror in working_mirrors:
|
||||
try:
|
||||
res = requests.post(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
data=self.payload,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout
|
||||
)
|
||||
if res.ok:
|
||||
logger.info(f"成功使用镜像站点: {mirror}")
|
||||
self.url = mirror # 更新当前使用的镜像
|
||||
time.sleep(1) # 降低等待时间以提高效率
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise Exception("所有可用镜像站点均无法完成下载")
|
||||
|
||||
def _extract_url(self, response):
|
||||
"""从响应中提取PDF下载链接"""
|
||||
soup = BeautifulSoup(response.content, 'html.parser')
|
||||
try:
|
||||
# 尝试多种方式提取PDF链接
|
||||
pdf_element = soup.find(id='pdf')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 尝试其他可能的选择器
|
||||
pdf_element = soup.find('iframe')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 查找直接的PDF链接
|
||||
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
|
||||
if pdf_links:
|
||||
content_url = pdf_links[0].get('href')
|
||||
else:
|
||||
raise AttributeError("未找到PDF链接")
|
||||
|
||||
if content_url:
|
||||
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
|
||||
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
|
||||
raise AttributeError("找到的链接不是PDF文件")
|
||||
except AttributeError:
|
||||
logger.error(f"未找到论文 {self.doi}")
|
||||
return None
|
||||
|
||||
current_mirror = self.url.rstrip('/')
|
||||
if content_url.startswith('/'):
|
||||
return current_mirror + content_url
|
||||
elif content_url.startswith('http'):
|
||||
return content_url
|
||||
else:
|
||||
return 'https:/' + content_url
|
||||
|
||||
def _download_pdf(self, pdf_url):
|
||||
"""下载PDF文件并验证其完整性"""
|
||||
try:
|
||||
# 尝试不同的下载方式
|
||||
download_methods = [
|
||||
# 方法1:直接下载
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
|
||||
# 方法2:添加 Referer 头
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': self.url}),
|
||||
# 方法3:使用原始域名作为 Referer
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
|
||||
]
|
||||
|
||||
for i, download_method in enumerate(download_methods):
|
||||
try:
|
||||
logger.info(f"尝试下载方式 {i+1}/3...")
|
||||
response = download_method()
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
|
||||
logger.info(f"PDF下载成功,文件大小: {len(content)} bytes")
|
||||
return content
|
||||
else:
|
||||
logger.warning("下载的文件可能不是有效的PDF")
|
||||
elif response.status_code == 403:
|
||||
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"下载失败,状态码: {response.status_code}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 如果所有方法都失败,尝试构造替代URL
|
||||
try:
|
||||
logger.info("尝试使用替代镜像下载...")
|
||||
# 从原始URL提取关键信息
|
||||
if '/downloads/' in pdf_url:
|
||||
file_part = pdf_url.split('/downloads/')[-1]
|
||||
alternative_mirrors = [
|
||||
f"https://sci-hub.se/downloads/{file_part}",
|
||||
f"https://sci-hub.st/downloads/{file_part}",
|
||||
f"https://sci-hub.ru/downloads/{file_part}",
|
||||
f"https://sci-hub.wf/downloads/{file_part}",
|
||||
f"https://sci-hub.ee/downloads/{file_part}",
|
||||
f"https://sci-hub.ren/downloads/{file_part}",
|
||||
f"https://sci-hub.tf/downloads/{file_part}"
|
||||
]
|
||||
|
||||
for alt_url in alternative_mirrors:
|
||||
try:
|
||||
response = requests.get(
|
||||
alt_url,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content):
|
||||
logger.info(f"使用替代镜像成功下载: {alt_url}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"所有下载方式都失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载PDF文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def fetch(self):
|
||||
"""获取论文PDF,包含重试和验证机制"""
|
||||
for attempt in range(2): # 最多重试3次
|
||||
try:
|
||||
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
|
||||
|
||||
# 获取PDF下载链接
|
||||
response = self._send_request()
|
||||
pdf_url = self._extract_url(response)
|
||||
if pdf_url is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:未找到PDF下载链接")
|
||||
continue
|
||||
|
||||
logger.info(f"找到PDF下载链接: {pdf_url}")
|
||||
|
||||
# 下载并验证PDF
|
||||
pdf_content = self._download_pdf(pdf_url)
|
||||
if pdf_content is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:PDF下载失败")
|
||||
continue
|
||||
|
||||
# 保存PDF文件
|
||||
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
|
||||
pdf_path = self.path.joinpath(pdf_name)
|
||||
pdf_path.write_bytes(pdf_content)
|
||||
|
||||
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第 {attempt + 1} 次尝试失败: {str(e)}")
|
||||
if attempt < 2: # 不是最后一次尝试
|
||||
wait_time = (attempt + 1) * 3 # 递增等待时间
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
|
||||
|
||||
# Usage Example
|
||||
if __name__ == '__main__':
|
||||
# 创建一个用于保存PDF的目录
|
||||
save_path = Path('./downloaded_papers')
|
||||
save_path.mkdir(exist_ok=True)
|
||||
|
||||
# DOI示例
|
||||
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
|
||||
|
||||
try:
|
||||
# 初始化SciHub下载器,先尝试使用代理
|
||||
logger.info("尝试使用代理模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
|
||||
|
||||
# 开始下载
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用代理模式失败: {str(e)}")
|
||||
try:
|
||||
# 如果代理模式失败,尝试直连模式
|
||||
logger.info("尝试直连模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
except Exception as e2:
|
||||
print(f"直连模式也失败: {str(e2)}")
|
||||
print("建议检查网络连接或尝试其他DOI")
|
||||
400
crazy_functions/review_fns/data_sources/scopus_source.py
Normal file
400
crazy_functions/review_fns/data_sources/scopus_source.py
Normal file
@@ -0,0 +1,400 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import random
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
from tqdm import tqdm
|
||||
|
||||
class ScopusSource(DataSource):
|
||||
"""Scopus API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Scopus API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
响应JSON数据
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Scopus API返回的数据
|
||||
|
||||
Args:
|
||||
data: Scopus API返回的论文数据
|
||||
|
||||
Returns:
|
||||
解析后的论文元数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
title = data.get("dc:title", "")
|
||||
|
||||
# 提取作者信息
|
||||
authors = []
|
||||
if "author" in data:
|
||||
if isinstance(data["author"], list):
|
||||
for author in data["author"]:
|
||||
if "given-name" in author and "surname" in author:
|
||||
authors.append(f"{author['given-name']} {author['surname']}")
|
||||
elif "indexed-name" in author:
|
||||
authors.append(author["indexed-name"])
|
||||
elif isinstance(data["author"], dict):
|
||||
if "given-name" in data["author"] and "surname" in data["author"]:
|
||||
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
|
||||
elif "indexed-name" in data["author"]:
|
||||
authors.append(data["author"]["indexed-name"])
|
||||
|
||||
# 提取摘要
|
||||
abstract = data.get("dc:description", "")
|
||||
|
||||
# 提取年份
|
||||
year = None
|
||||
if "prism:coverDate" in data:
|
||||
try:
|
||||
year = int(data["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 提取DOI
|
||||
doi = data.get("prism:doi")
|
||||
|
||||
# 提取引用次数
|
||||
citations = data.get("citedby-count")
|
||||
if citations:
|
||||
try:
|
||||
citations = int(citations)
|
||||
except:
|
||||
citations = None
|
||||
|
||||
# 提取期刊信息
|
||||
venue = data.get("prism:publicationName")
|
||||
|
||||
# 提取机构信息
|
||||
institutions = []
|
||||
if "affiliation" in data:
|
||||
if isinstance(data["affiliation"], list):
|
||||
for aff in data["affiliation"]:
|
||||
if "affilname" in aff:
|
||||
institutions.append(aff["affilname"])
|
||||
elif isinstance(data["affiliation"], dict):
|
||||
if "affilname" in data["affiliation"]:
|
||||
institutions.append(data["affiliation"]["affilname"])
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
"issn": data.get("prism:issn"),
|
||||
"eissn": data.get("prism:eIssn"),
|
||||
"volume": data.get("prism:volume"),
|
||||
"issue": data.get("prism:issueIdentifier"),
|
||||
"page_range": data.get("prism:pageRange"),
|
||||
"article_number": data.get("article-number"),
|
||||
"publication_date": data.get("prism:coverDate")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=data.get("link", [{}])[0].get("@href"),
|
||||
citations=citations,
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析论文数据时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询参数
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100), # Scopus API单次请求限制
|
||||
"start": 0
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
sort_map = {
|
||||
"relevance": "-score",
|
||||
"date": "-coverDate",
|
||||
"citations": "-citedby-count"
|
||||
}
|
||||
if sort_by in sort_map:
|
||||
params["sort"] = sort_map[sort_by]
|
||||
|
||||
# 发送请求
|
||||
url = f"{self.base_url}/search/scopus"
|
||||
response = await self._make_request(url, params)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
results = response["search-results"].get("entry", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID或DOI
|
||||
|
||||
Returns:
|
||||
论文详情
|
||||
"""
|
||||
try:
|
||||
# 判断是否为DOI
|
||||
if "/" in paper_id:
|
||||
url = f"{self.base_url}/article/doi/{paper_id}"
|
||||
else:
|
||||
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
|
||||
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "abstracts-retrieval-response" not in response:
|
||||
return None
|
||||
|
||||
data = response["abstracts-retrieval-response"]
|
||||
return self._parse_paper_data(data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
引用论文列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/citations/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "citing-papers" not in response:
|
||||
return []
|
||||
|
||||
results = response["citing-papers"].get("papers", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用信息时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
参考文献列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/references/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
results = response["references"].get("reference", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"SRCTITLE({journal})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"LOAD-DATE > NOW() - {days}d"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def example_usage():
|
||||
"""ScopusSource使用示例"""
|
||||
scopus = ScopusSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await scopus.search("machine learning", limit=3)
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例2:按作者搜索
|
||||
print("\n=== 示例2:搜索特定作者的论文 ===")
|
||||
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
|
||||
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例3:根据关键词搜索相关论文
|
||||
print("\n=== 示例3:搜索人工智能相关论文 ===")
|
||||
keywords = "artificial intelligence AND deep learning"
|
||||
papers = await scopus.search(
|
||||
query=keywords,
|
||||
limit=5,
|
||||
sort_by="citations", # 按引用次数排序
|
||||
start_year=2020 # 只搜索2020年之后的论文
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
480
crazy_functions/review_fns/data_sources/semantic_source.py
Normal file
480
crazy_functions/review_fns/data_sources/semantic_source.py
Normal file
@@ -0,0 +1,480 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class SemanticScholarSource(DataSource):
|
||||
"""Semantic Scholar API实现,使用官方Python包"""
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Semantic Scholar API密钥(可选)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self._initialize() # 调用初始化方法
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化API客户端"""
|
||||
if not self.api_key:
|
||||
# 默认API密钥列表
|
||||
default_api_keys = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
self.api_key = random.choice(default_api_keys)
|
||||
|
||||
self.client = None # 延迟初始化
|
||||
self.fields = [
|
||||
"title",
|
||||
"authors",
|
||||
"abstract",
|
||||
"year",
|
||||
"externalIds",
|
||||
"citationCount",
|
||||
"venue",
|
||||
"openAccessPdf",
|
||||
"publicationVenue"
|
||||
]
|
||||
|
||||
async def _ensure_client(self):
|
||||
"""确保客户端已初始化"""
|
||||
if self.client is None:
|
||||
from semanticscholar import AsyncSemanticScholar
|
||||
self.client = AsyncSemanticScholar(api_key=self.api_key)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} year>={start_year}"
|
||||
|
||||
# 直接使用 search_paper 的结果
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
|
||||
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定DOI的论文详情"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
|
||||
return self._parse_paper_data(paper)
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
citations = response.get('data', [])
|
||||
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
|
||||
except Exception as e:
|
||||
print(f"获取引用列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
references = response.get('data', [])
|
||||
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
|
||||
except Exception as e:
|
||||
print(f"获取参考文献列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取论文推荐
|
||||
|
||||
根据一篇论文获取相关的推荐论文
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
papers = await self.client.get_recommended_papers(
|
||||
f"DOI:{doi}",
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers_from_lists(
|
||||
self,
|
||||
positive_dois: List[str],
|
||||
negative_dois: List[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""基于正负例论文列表获取推荐
|
||||
|
||||
Args:
|
||||
positive_dois: 正例论文DOI列表(想要获取类似的论文)
|
||||
negative_dois: 负例论文DOI列表(不想要类似的论文)
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
|
||||
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
|
||||
|
||||
papers = await self.client.get_recommended_papers_from_lists(
|
||||
positive_paper_ids=positive_ids,
|
||||
negative_paper_ids=negative_ids,
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
|
||||
"""搜索作者"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求而不是 search_author 方法
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
|
||||
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
authors = response.get('data', [])
|
||||
return [
|
||||
{
|
||||
'author_id': author.get('authorId'),
|
||||
'name': author.get('name'),
|
||||
'paper_count': author.get('paperCount'),
|
||||
'citation_count': author.get('citationCount'),
|
||||
}
|
||||
for author in authors
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"搜索作者时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_author_details(self, author_id: str) -> Optional[dict]:
|
||||
"""获取作者详细信息"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
|
||||
"fields=name,paperCount,citationCount,hIndex",
|
||||
self.client.auth_header
|
||||
)
|
||||
return {
|
||||
'author_id': response.get('authorId'),
|
||||
'name': response.get('name'),
|
||||
'paper_count': response.get('paperCount'),
|
||||
'citation_count': response.get('citationCount'),
|
||||
'h_index': response.get('hIndex'),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取作者详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取作者的论文列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
|
||||
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取作者论文列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_autocomplete(self, query: str) -> List[dict]:
|
||||
"""论文标题自动补全"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
|
||||
f"query={query}",
|
||||
self.client.auth_header
|
||||
)
|
||||
suggestions = response.get('matches', [])
|
||||
return [
|
||||
{
|
||||
'title': suggestion.get('title'),
|
||||
'paper_id': suggestion.get('paperId'),
|
||||
'year': suggestion.get('year'),
|
||||
'venue': suggestion.get('venue'),
|
||||
}
|
||||
for suggestion in suggestions
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"获取标题自动补全时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_paper_data(self, paper) -> PaperMetadata:
|
||||
"""解析论文数据"""
|
||||
# 获取DOI
|
||||
doi = None
|
||||
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
|
||||
if external_ids:
|
||||
if isinstance(external_ids, dict):
|
||||
doi = external_ids.get('DOI')
|
||||
if not doi and 'ArXiv' in external_ids:
|
||||
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
|
||||
else:
|
||||
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
|
||||
if not doi and hasattr(external_ids, 'ArXiv'):
|
||||
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
|
||||
|
||||
# 获取PDF URL
|
||||
pdf_url = None
|
||||
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
|
||||
if pdf_info:
|
||||
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
|
||||
|
||||
# 获取发表场所详细信息
|
||||
venue_type = None
|
||||
venue_name = None
|
||||
venue_info = {}
|
||||
|
||||
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
|
||||
if venue:
|
||||
if isinstance(venue, dict):
|
||||
venue_name = venue.get('name')
|
||||
venue_type = venue.get('type')
|
||||
# 提取更多venue信息
|
||||
venue_info = {
|
||||
'issn': venue.get('issn'),
|
||||
'publisher': venue.get('publisher'),
|
||||
'url': venue.get('url'),
|
||||
'alternate_names': venue.get('alternate_names', [])
|
||||
}
|
||||
else:
|
||||
venue_name = venue.name if hasattr(venue, 'name') else None
|
||||
venue_type = venue.type if hasattr(venue, 'type') else None
|
||||
venue_info = {
|
||||
'issn': getattr(venue, 'issn', None),
|
||||
'publisher': getattr(venue, 'publisher', None),
|
||||
'url': getattr(venue, 'url', None),
|
||||
'alternate_names': getattr(venue, 'alternate_names', [])
|
||||
}
|
||||
|
||||
# 获取标题
|
||||
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
|
||||
|
||||
# 获取作者
|
||||
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
|
||||
author_names = []
|
||||
for author in authors:
|
||||
if isinstance(author, dict):
|
||||
author_names.append(author.get('name', ''))
|
||||
else:
|
||||
author_names.append(author.name if hasattr(author, 'name') else str(author))
|
||||
|
||||
# 获取摘要
|
||||
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
|
||||
|
||||
# 获取年份
|
||||
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
|
||||
|
||||
# 获取引用次数
|
||||
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=author_names,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
|
||||
citations=citations,
|
||||
venue=venue_name,
|
||||
institutions=[],
|
||||
venue_type=venue_type,
|
||||
venue_name=venue_name,
|
||||
venue_info=venue_info,
|
||||
source='semantic' # 添加来源标记
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""SemanticScholarSource使用示例"""
|
||||
semantic = SemanticScholarSource()
|
||||
|
||||
try:
|
||||
# 示例1:使用DOI直接获取论文
|
||||
print("\n=== 示例1:通过DOI获取论文 ===")
|
||||
doi = "10.18653/v1/N19-1423" # BERT论文
|
||||
print(f"获取DOI为 {doi} 的论文信息...")
|
||||
|
||||
paper = await semantic.get_paper_details(doi)
|
||||
if paper:
|
||||
print("\n--- 论文信息 ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\n引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
|
||||
# 示例2:搜索论文
|
||||
print("\n=== 示例2:搜索论文 ===")
|
||||
query = "BERT pre-training"
|
||||
print(f"搜索关键词 '{query}' 相关的论文...")
|
||||
papers = await semantic.search(query=query, limit=3)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 搜索结果 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\nDOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:获取论文推荐
|
||||
print("\n=== 示例3:获取论文推荐 ===")
|
||||
print(f"获取与论文 {doi} 相关的推荐论文...")
|
||||
recommendations = await semantic.get_recommended_papers(doi, limit=3)
|
||||
for i, paper in enumerate(recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例4:基于多篇论文的推荐
|
||||
print("\n=== 示例4:基于多篇论文的推荐 ===")
|
||||
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
|
||||
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
|
||||
multi_recommendations = await semantic.get_recommended_papers_from_lists(
|
||||
positive_dois=positive_dois,
|
||||
limit=3
|
||||
)
|
||||
for i, paper in enumerate(multi_recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例5:搜索作者
|
||||
print("\n=== 示例5:搜索作者 ===")
|
||||
author_query = "Yann LeCun"
|
||||
print(f"搜索作者: '{author_query}'")
|
||||
authors = await semantic.search_author(author_query, limit=3)
|
||||
for i, author in enumerate(authors, 1):
|
||||
print(f"\n--- 作者 {i} ---")
|
||||
print(f"姓名: {author['name']}")
|
||||
print(f"论文数量: {author['paper_count']}")
|
||||
print(f"总引用次数: {author['citation_count']}")
|
||||
|
||||
# 示例6:获取作者详情
|
||||
print("\n=== 示例6:获取作者详情 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者ID {author_id} 的详细信息...")
|
||||
author_details = await semantic.get_author_details(author_id)
|
||||
if author_details:
|
||||
print(f"姓名: {author_details['name']}")
|
||||
print(f"H指数: {author_details['h_index']}")
|
||||
print(f"总引用次数: {author_details['citation_count']}")
|
||||
print(f"发表论文数: {author_details['paper_count']}")
|
||||
|
||||
# 示例7:获取作者论文
|
||||
print("\n=== 示例7:获取作者论文 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者 {authors[0]['name']} 的论文列表...")
|
||||
author_papers = await semantic.get_author_papers(author_id, limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例8:论文标题自动补全
|
||||
print("\n=== 示例8:论文标题自动补全 ===")
|
||||
title_query = "Attention is all"
|
||||
print(f"搜索标题: '{title_query}'")
|
||||
suggestions = await semantic.get_paper_autocomplete(title_query)
|
||||
for i, suggestion in enumerate(suggestions[:3], 1):
|
||||
print(f"\n--- 建议 {i} ---")
|
||||
print(f"标题: {suggestion['title']}")
|
||||
print(f"发表年份: {suggestion['year']}")
|
||||
print(f"发表venue: {suggestion['venue']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
46
crazy_functions/review_fns/data_sources/unpaywall_source.py
Normal file
46
crazy_functions/review_fns/data_sources/unpaywall_source.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
|
||||
class UnpaywallSource(DataSource):
|
||||
"""Unpaywall API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.unpaywall.org/v2"
|
||||
self.email = self.api_key # Unpaywall使用email作为API key
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/search",
|
||||
params={
|
||||
"query": query,
|
||||
"email": self.email,
|
||||
"limit": limit
|
||||
}
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_response(item.response)
|
||||
for item in data.get("results", [])]
|
||||
|
||||
def _parse_response(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Unpaywall返回的数据"""
|
||||
return PaperMetadata(
|
||||
title=data.get("title", ""),
|
||||
authors=[
|
||||
f"{author.get('given', '')} {author.get('family', '')}"
|
||||
for author in data.get("z_authors", [])
|
||||
],
|
||||
institutions=[
|
||||
aff.get("name", "")
|
||||
for author in data.get("z_authors", [])
|
||||
for aff in author.get("affiliation", [])
|
||||
],
|
||||
abstract="", # Unpaywall不提供摘要
|
||||
year=data.get("year"),
|
||||
doi=data.get("doi"),
|
||||
url=data.get("doi_url"),
|
||||
citations=None, # Unpaywall不提供引用计数
|
||||
venue=data.get("journal_name")
|
||||
)
|
||||
412
crazy_functions/review_fns/handlers/base_handler.py
Normal file
412
crazy_functions/review_fns/handlers/base_handler.py
Normal file
@@ -0,0 +1,412 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||||
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
|
||||
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
|
||||
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
|
||||
from request_llms.bridge_all import model_info
|
||||
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
|
||||
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
|
||||
from toolbox import get_conf
|
||||
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""处理器基类"""
|
||||
|
||||
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
|
||||
self.arxiv = arxiv
|
||||
self.semantic = semantic
|
||||
self.pubmed = PubMedSource()
|
||||
self.crossref = CrossrefSource() # 添加 Crossref 实例
|
||||
self.adsabs = AdsabsSource() # 添加 ADS 实例
|
||||
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
|
||||
self.ranked_papers = [] # 存储排序后的论文列表
|
||||
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
|
||||
|
||||
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
|
||||
"""获取搜索参数"""
|
||||
return {
|
||||
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
|
||||
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
|
||||
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> List[List[str]]:
|
||||
"""处理查询"""
|
||||
pass
|
||||
|
||||
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用arXiv专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 首先尝试基础搜索
|
||||
query = params.get("query", "")
|
||||
if query:
|
||||
papers = await self.arxiv.search(
|
||||
query,
|
||||
limit=params["limit"],
|
||||
sort_by=params.get("sort_by", "relevance"),
|
||||
sort_order=params.get("sort_order", "descending"),
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
# 如果基础搜索没有结果,尝试分类搜索
|
||||
if not papers:
|
||||
categories = params.get("categories", [])
|
||||
for category in categories:
|
||||
category_papers = await self.arxiv.search_by_category(
|
||||
category,
|
||||
limit=params["limit"],
|
||||
sort_by=params.get("sort_by", "relevance"),
|
||||
sort_order=params.get("sort_order", "descending"),
|
||||
)
|
||||
if category_papers:
|
||||
papers.extend(category_papers)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"arXiv搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用Semantic Scholar专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
|
||||
# 只使用基本的搜索参数
|
||||
papers = await self.semantic.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"]
|
||||
)
|
||||
|
||||
# 在内存中进行过滤
|
||||
if papers and min_year:
|
||||
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Semantic Scholar搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用PubMed专用参数搜索"""
|
||||
try:
|
||||
# 如果不需要PubMed搜索,直接返回空列表
|
||||
if params.get("search_type") == "none":
|
||||
return []
|
||||
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 根据搜索类型选择搜索方法
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.pubmed.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "author":
|
||||
papers = await self.pubmed.search_by_author(
|
||||
author=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "journal":
|
||||
papers = await self.pubmed.search_by_journal(
|
||||
journal=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"PubMed搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用Crossref专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 根据搜索类型选择搜索方法
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.crossref.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "author":
|
||||
papers = await self.crossref.search_by_authors(
|
||||
authors=[params.get("query", "")],
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "journal":
|
||||
# 实现期刊搜索逻辑
|
||||
pass
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Crossref搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用ADS专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 执行搜索
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.adsabs.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"ADS搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""从所有数据源搜索论文"""
|
||||
search_tasks = []
|
||||
|
||||
# # 检查是否需要执行PubMed搜索
|
||||
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
|
||||
is_using_pubmed = False # 开源版本不再搜索pubmed
|
||||
|
||||
# 如果使用PubMed,则只执行PubMed和Semantic Scholar搜索
|
||||
if is_using_pubmed:
|
||||
search_tasks.append(
|
||||
self._search_pubmed(
|
||||
criteria.pubmed_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
# Semantic Scholar总是执行搜索
|
||||
search_tasks.append(
|
||||
self._search_semantic(
|
||||
criteria.semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
# 如果不使用ADS,则执行Crossref搜索
|
||||
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
|
||||
search_tasks.append(
|
||||
self._search_crossref(
|
||||
criteria.crossref_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
search_tasks.append(
|
||||
self._search_arxiv(
|
||||
criteria.arxiv_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
if get_conf("SEMANTIC_SCHOLAR_KEY"):
|
||||
search_tasks.append(
|
||||
self._search_semantic(
|
||||
criteria.semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
# 执行所有需要的搜索任务
|
||||
papers = await asyncio.gather(*search_tasks)
|
||||
|
||||
# 合并所有来源的论文并统计各来源的数量
|
||||
all_papers = []
|
||||
source_counts = {
|
||||
'arxiv': 0,
|
||||
'semantic': 0,
|
||||
'pubmed': 0,
|
||||
'crossref': 0,
|
||||
'adsabs': 0
|
||||
}
|
||||
|
||||
for source_papers in papers:
|
||||
if source_papers:
|
||||
for paper in source_papers:
|
||||
source = getattr(paper, 'source', 'unknown')
|
||||
if source in source_counts:
|
||||
source_counts[source] += 1
|
||||
all_papers.extend(source_papers)
|
||||
|
||||
# 打印各来源的论文数量
|
||||
print("\n=== 各数据源找到的论文数量 ===")
|
||||
for source, count in source_counts.items():
|
||||
if count > 0: # 只打印有论文的来源
|
||||
print(f"{source.capitalize()}: {count} 篇")
|
||||
print(f"总计: {len(all_papers)} 篇")
|
||||
print("===========================\n")
|
||||
|
||||
return all_papers
|
||||
|
||||
def _format_paper_time(self, paper) -> str:
|
||||
"""格式化论文时间信息"""
|
||||
year = getattr(paper, 'year', None)
|
||||
if not year:
|
||||
return ""
|
||||
|
||||
# 如果有具体的发表日期,使用具体日期
|
||||
if hasattr(paper, 'published') and paper.published:
|
||||
return f"(发表于 {paper.published.strftime('%Y-%m')})"
|
||||
# 如果只有年份,只显示年份
|
||||
return f"({year})"
|
||||
|
||||
def _format_papers(self, papers: List) -> str:
|
||||
"""格式化论文列表,使用token限制控制长度"""
|
||||
formatted = []
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
# 只保留前三个作者
|
||||
authors = paper.authors[:3]
|
||||
if len(paper.authors) > 3:
|
||||
authors.append("et al.")
|
||||
|
||||
# 构建所有可能的下载链接
|
||||
download_links = []
|
||||
|
||||
# 添加arXiv链接
|
||||
if hasattr(paper, 'doi') and paper.doi:
|
||||
if paper.doi.startswith("10.48550/arXiv."):
|
||||
# 从DOI中提取完整的arXiv ID
|
||||
arxiv_id = paper.doi.split("arXiv.")[-1]
|
||||
# 移除多余的点号并确保格式正确
|
||||
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
|
||||
if arxiv_id.startswith("."): # 移除开头的点号
|
||||
arxiv_id = arxiv_id[1:]
|
||||
if arxiv_id.endswith("."): # 移除结尾的点号
|
||||
arxiv_id = arxiv_id[:-1]
|
||||
|
||||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||
elif "arxiv.org/abs/" in paper.doi:
|
||||
# 直接从URL中提取arXiv ID
|
||||
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
|
||||
if "v" in arxiv_id: # 移除版本号
|
||||
arxiv_id = arxiv_id.split("v")[0]
|
||||
|
||||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||
else:
|
||||
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
|
||||
|
||||
# 添加直接URL链接(如果存在且不同于前面的链接)
|
||||
if hasattr(paper, 'url') and paper.url:
|
||||
if not any(paper.url in link for link in download_links):
|
||||
download_links.append(f"[Source]({paper.url})")
|
||||
|
||||
# 构建下载链接字符串
|
||||
download_section = " | ".join(download_links) if download_links else "No direct download link available"
|
||||
|
||||
# 构建来源信息
|
||||
source_info = []
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
|
||||
source_info.append(f"Type: {paper.venue_type}")
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
source_info.append(f"Venue: {paper.venue_name}")
|
||||
|
||||
# 添加IF指数和分区信息
|
||||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||
source_info.append(f"IF: {paper.if_factor}")
|
||||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||
source_info.append(f"中科院分区: {paper.cas_division}")
|
||||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||
source_info.append(f"JCR分区: {paper.jcr_division}")
|
||||
|
||||
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||
if paper.venue_info.get('journal_ref'):
|
||||
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
|
||||
if paper.venue_info.get('publisher'):
|
||||
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
|
||||
|
||||
# 构建当前论文的格式化文本
|
||||
paper_text = (
|
||||
f"{i}. **{paper.title}**\n" +
|
||||
f" Authors: {', '.join(authors)}\n" +
|
||||
f" Year: {paper.year}\n" +
|
||||
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
|
||||
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
|
||||
# 添加PubMed特有信息
|
||||
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
|
||||
'mesh_terms') and paper.mesh_terms else "") +
|
||||
f" 📥 PDF Downloads: {download_section}\n" +
|
||||
f" Abstract: {paper.abstract}\n"
|
||||
)
|
||||
|
||||
formatted.append(paper_text)
|
||||
|
||||
full_text = "\n".join(formatted)
|
||||
|
||||
# 根据不同模型设置不同的token限制
|
||||
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
|
||||
|
||||
token_limit = model_info[model_name]['max_token'] * 3 // 4
|
||||
# 使用token限制控制长度
|
||||
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
|
||||
|
||||
def _get_current_time(self) -> str:
|
||||
"""获取当前时间信息"""
|
||||
now = datetime.now()
|
||||
return now.strftime("%Y年%m月%d日")
|
||||
|
||||
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成道歉提示"""
|
||||
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
|
||||
|
||||
可能的原因:
|
||||
1. 搜索词过于具体或专业
|
||||
2. 时间范围限制过严
|
||||
|
||||
建议解决方案:
|
||||
1. 尝试使用更通用的关键词
|
||||
2. 扩大搜索时间范围
|
||||
3. 使用同义词或相关术语
|
||||
请根据以上建议调整后重试。"""
|
||||
|
||||
def get_ranked_papers(self) -> str:
|
||||
"""获取排序后的论文列表的格式化字符串"""
|
||||
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
|
||||
|
||||
def _is_pubmed_paper(self, paper) -> bool:
|
||||
"""判断是否为PubMed论文"""
|
||||
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url)
|
||||
106
crazy_functions/review_fns/handlers/latest_handler.py
Normal file
106
crazy_functions/review_fns/handlers/latest_handler.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class Arxiv最新论文推荐功能(BaseHandler):
|
||||
"""最新论文推荐处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理最新论文推荐请求"""
|
||||
|
||||
# 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 获取最新论文
|
||||
papers = []
|
||||
for category in criteria.arxiv_params["categories"]:
|
||||
latest_papers = await self.arxiv.get_latest_papers(
|
||||
category=category,
|
||||
debug=False,
|
||||
batch_size=50
|
||||
)
|
||||
papers.extend(latest_papers)
|
||||
|
||||
if not papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 使用embedding模型对论文进行排序
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Based on your interest in {criteria.main_topic}, here are the latest papers from arXiv in relevant categories:
|
||||
{', '.join(criteria.arxiv_params["categories"])}
|
||||
|
||||
Latest papers available:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please provide:
|
||||
1. A clear list of latext papers, organized by themes or approaches
|
||||
|
||||
|
||||
2. Group papers by sub-topics or themes if applicable
|
||||
|
||||
3. For each paper:
|
||||
- Publication time
|
||||
- The key contributions and main findings
|
||||
- Why it's relevant to the user's interests
|
||||
- How it relates to other latest papers
|
||||
- The paper's citation count and citation impact
|
||||
- The paper's download link
|
||||
|
||||
4. A suggested reading order based on:
|
||||
- Paper relationships and dependencies
|
||||
- Difficulty level
|
||||
- Significance
|
||||
|
||||
5. Future Directions
|
||||
- Emerging venues and research streams
|
||||
- Novel methodological approaches
|
||||
- Cross-disciplinary opportunities
|
||||
- Research gaps by publication type
|
||||
|
||||
IMPORTANT:
|
||||
- Focus on explaining why each paper is interesting
|
||||
- Highlight the novelty and potential impact
|
||||
- Consider the credibility and stage of each publication
|
||||
- Use the provided paper titles with their links when referring to specific papers
|
||||
- Base recommendations ONLY on the explicitly provided paper information
|
||||
- Do not make ANY assumptions about papers beyond the given data
|
||||
- When information is missing or unclear, acknowledge the limitation
|
||||
- Never speculate about:
|
||||
* Paper quality or rigor not evidenced in the data
|
||||
* Research impact beyond citation counts
|
||||
* Implementation details not mentioned
|
||||
* Author expertise or background
|
||||
* Future research directions not stated
|
||||
- For each paper, cite only verifiable information
|
||||
- Clearly distinguish between facts and potential implications
|
||||
- Each paper includes download links in its 📥 PDF Downloads section
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
344
crazy_functions/review_fns/handlers/paper_handler.py
Normal file
344
crazy_functions/review_fns/handlers/paper_handler.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 单篇论文分析功能(BaseHandler):
|
||||
"""论文分析处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理论文分析请求,返回最终的prompt"""
|
||||
|
||||
# 1. 获取论文详情
|
||||
paper = await self._get_paper_details(criteria)
|
||||
if not paper:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 保存为ranked_papers以便统一接口
|
||||
self.ranked_papers = [paper]
|
||||
|
||||
# 2. 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
|
||||
# 获取论文信息
|
||||
title = getattr(paper, "title", "Unknown Title")
|
||||
authors = getattr(paper, "authors", [])
|
||||
year = getattr(paper, "year", "Unknown Year")
|
||||
abstract = getattr(paper, "abstract", "No abstract available")
|
||||
citations = getattr(paper, "citations", "N/A")
|
||||
|
||||
# 添加论文ID信息
|
||||
paper_id = ""
|
||||
if criteria.paper_source == "arxiv":
|
||||
paper_id = f"arXiv ID: {criteria.paper_id}\n"
|
||||
elif criteria.paper_source == "doi":
|
||||
paper_id = f"DOI: {criteria.paper_id}\n"
|
||||
|
||||
# 格式化作者列表
|
||||
authors_str = ', '.join(authors) if isinstance(authors, list) else authors
|
||||
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Please provide a comprehensive analysis of the following paper:
|
||||
|
||||
{paper_id}Title: {title}
|
||||
Authors: {authors_str}
|
||||
Year: {year}
|
||||
Citations: {citations}
|
||||
Publication Venue: {paper.venue_name} ({paper.venue_type})
|
||||
{f"Publisher: {paper.venue_info.get('publisher')}" if paper.venue_info.get('publisher') else ""}
|
||||
{f"Journal Reference: {paper.venue_info.get('journal_ref')}" if paper.venue_info.get('journal_ref') else ""}
|
||||
Abstract: {abstract}
|
||||
|
||||
Please provide:
|
||||
1. Publication Context
|
||||
- Publication venue analysis and impact factor (if available)
|
||||
- Paper type (journal article, conference paper, preprint)
|
||||
- Publication timeline and peer review status
|
||||
- Publisher reputation and venue prestige
|
||||
|
||||
2. Research Context
|
||||
- Field positioning and significance
|
||||
- Historical context and prior work
|
||||
- Related research streams
|
||||
- Cross-venue impact analysis
|
||||
|
||||
3. Technical Analysis
|
||||
- Detailed methodology review
|
||||
- Implementation details
|
||||
- Experimental setup and results
|
||||
- Technical innovations
|
||||
|
||||
4. Impact Analysis
|
||||
- Citation patterns and influence
|
||||
- Cross-venue recognition
|
||||
- Industry vs. academic impact
|
||||
- Practical applications
|
||||
|
||||
5. Critical Review
|
||||
- Methodological rigor assessment
|
||||
- Result reliability and reproducibility
|
||||
- Venue-appropriate evaluation standards
|
||||
- Limitations and potential improvements
|
||||
|
||||
IMPORTANT:
|
||||
- Strictly use ONLY the information provided above about the paper
|
||||
- Do not make ANY assumptions or inferences beyond the given data
|
||||
- If certain information is not provided, explicitly state that it is unknown
|
||||
- For any unclear or missing details, acknowledge the limitation rather than speculating
|
||||
- When discussing methodology or results, only describe what is explicitly stated in the abstract
|
||||
- Never fabricate or assume any details about:
|
||||
* Publication venues or status
|
||||
* Implementation details not mentioned
|
||||
* Results or findings not stated
|
||||
* Impact or influence not supported by the citation count
|
||||
* Authors' affiliations or backgrounds
|
||||
* Future work or implications not mentioned
|
||||
- You can find the paper's download options in the 📥 PDF Downloads section
|
||||
- Available download formats include arXiv PDF, DOI links, and source URLs
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _get_paper_details(self, criteria: SearchCriteria):
|
||||
"""获取论文详情"""
|
||||
try:
|
||||
if criteria.paper_source == "arxiv":
|
||||
# 使用 arxiv ID 搜索
|
||||
papers = await self.arxiv.search_by_id(criteria.paper_id)
|
||||
return papers[0] if papers else None
|
||||
|
||||
elif criteria.paper_source == "doi":
|
||||
# 尝试从所有来源获取
|
||||
paper = await self.semantic.get_paper_by_doi(criteria.paper_id)
|
||||
if not paper:
|
||||
# 如果Semantic Scholar没有找到,尝试PubMed
|
||||
papers = await self.pubmed.search(
|
||||
f"{criteria.paper_id}[doi]",
|
||||
limit=1
|
||||
)
|
||||
if papers:
|
||||
return papers[0]
|
||||
return paper
|
||||
|
||||
elif criteria.paper_source == "title":
|
||||
# 使用_search_all_sources搜索
|
||||
search_params = {
|
||||
'max_papers': 1,
|
||||
'min_year': 1900, # 不限制年份
|
||||
'search_multiplier': 1
|
||||
}
|
||||
|
||||
# 设置搜索参数
|
||||
criteria.arxiv_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'ti:"{criteria.paper_title}"',
|
||||
"limit": 1
|
||||
}
|
||||
criteria.semantic_params = {
|
||||
"query": criteria.paper_title,
|
||||
"limit": 1
|
||||
}
|
||||
criteria.pubmed_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'"{criteria.paper_title}"[Title]',
|
||||
"limit": 1
|
||||
}
|
||||
|
||||
papers = await self._search_all_sources(criteria, search_params)
|
||||
return papers[0] if papers else None
|
||||
|
||||
# 如果都没有找到,尝试使用 main_topic 作为标题搜索
|
||||
if not criteria.paper_title and not criteria.paper_id:
|
||||
search_params = {
|
||||
'max_papers': 1,
|
||||
'min_year': 1900,
|
||||
'search_multiplier': 1
|
||||
}
|
||||
|
||||
# 设置搜索参数
|
||||
criteria.arxiv_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'ti:"{criteria.main_topic}"',
|
||||
"limit": 1
|
||||
}
|
||||
criteria.semantic_params = {
|
||||
"query": criteria.main_topic,
|
||||
"limit": 1
|
||||
}
|
||||
criteria.pubmed_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'"{criteria.main_topic}"[Title]',
|
||||
"limit": 1
|
||||
}
|
||||
|
||||
papers = await self._search_all_sources(criteria, search_params)
|
||||
return papers[0] if papers else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_citation_context(self, paper: Dict, plugin_kwargs: Dict) -> Tuple[List, List]:
|
||||
"""获取引用上下文"""
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 使用论文标题构建搜索参数
|
||||
title_query = f'ti:"{getattr(paper, "title", "")}"'
|
||||
arxiv_params = {
|
||||
"query": title_query,
|
||||
"limit": search_params['max_papers'],
|
||||
"search_type": "basic",
|
||||
"sort_by": "relevance",
|
||||
"sort_order": "descending"
|
||||
}
|
||||
semantic_params = {
|
||||
"query": getattr(paper, "title", ""),
|
||||
"limit": search_params['max_papers']
|
||||
}
|
||||
|
||||
citations, references = await asyncio.gather(
|
||||
self._search_semantic(
|
||||
semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=search_params['min_year']
|
||||
),
|
||||
self._search_arxiv(
|
||||
arxiv_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=search_params['min_year']
|
||||
)
|
||||
)
|
||||
|
||||
return citations, references
|
||||
|
||||
async def _generate_analysis(
|
||||
self,
|
||||
paper: Dict,
|
||||
citations: List,
|
||||
references: List,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any]
|
||||
) -> List[List[str]]:
|
||||
"""生成论文分析"""
|
||||
|
||||
# 构建提示
|
||||
analysis_prompt = f"""Please provide a comprehensive analysis of the following paper:
|
||||
|
||||
Paper details:
|
||||
{self._format_paper(paper)}
|
||||
|
||||
Key references (papers cited by this paper):
|
||||
{self._format_papers(references)}
|
||||
|
||||
Important citations (papers that cite this paper):
|
||||
{self._format_papers(citations)}
|
||||
|
||||
Please provide:
|
||||
1. Paper Overview
|
||||
- Main research question/objective
|
||||
- Key methodology/approach
|
||||
- Main findings/contributions
|
||||
|
||||
2. Technical Analysis
|
||||
- Detailed methodology review
|
||||
- Technical innovations
|
||||
- Implementation details
|
||||
- Experimental setup and results
|
||||
|
||||
3. Impact Analysis
|
||||
- Significance in the field
|
||||
- Influence on subsequent research (based on citing papers)
|
||||
- Relationship to prior work (based on cited papers)
|
||||
- Practical applications
|
||||
|
||||
4. Critical Review
|
||||
- Strengths and limitations
|
||||
- Potential improvements
|
||||
- Open questions and future directions
|
||||
- Alternative approaches
|
||||
|
||||
5. Related Research Context
|
||||
- How it builds on previous work
|
||||
- How it has influenced subsequent research
|
||||
- Comparison with alternative approaches
|
||||
|
||||
Format your response in markdown with clear sections."""
|
||||
|
||||
# 并行生成概述和技术分析
|
||||
for response_chunk in request_gpt(
|
||||
inputs_array=[
|
||||
analysis_prompt,
|
||||
self._get_technical_prompt(paper)
|
||||
],
|
||||
inputs_show_user_array=[
|
||||
"Generating paper analysis...",
|
||||
"Analyzing technical details..."
|
||||
],
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[history, []],
|
||||
sys_prompt_array=[
|
||||
system_prompt,
|
||||
"You are an expert at analyzing technical details in research papers."
|
||||
]
|
||||
):
|
||||
pass # 等待生成完成
|
||||
|
||||
# 获取最后的两个回答
|
||||
if chatbot and len(chatbot[-2:]) == 2:
|
||||
analysis = chatbot[-2][1]
|
||||
technical = chatbot[-1][1]
|
||||
full_analysis = f"""# Paper Analysis: {paper.title}
|
||||
|
||||
## General Analysis
|
||||
{analysis}
|
||||
|
||||
## Technical Deep Dive
|
||||
{technical}
|
||||
"""
|
||||
chatbot.append(["Here is the paper analysis:", full_analysis])
|
||||
else:
|
||||
chatbot.append(["Here is the paper analysis:", "Failed to generate analysis."])
|
||||
|
||||
return chatbot
|
||||
|
||||
def _get_technical_prompt(self, paper: Dict) -> str:
|
||||
"""生成技术分析提示"""
|
||||
return f"""Please provide a detailed technical analysis of the following paper:
|
||||
|
||||
{self._format_paper(paper)}
|
||||
|
||||
Focus on:
|
||||
1. Mathematical formulations and their implications
|
||||
2. Algorithm design and complexity analysis
|
||||
3. Architecture details and design choices
|
||||
4. Implementation challenges and solutions
|
||||
5. Performance analysis and bottlenecks
|
||||
6. Technical limitations and potential improvements
|
||||
|
||||
Format your response in markdown, focusing purely on technical aspects."""
|
||||
|
||||
|
||||
147
crazy_functions/review_fns/handlers/qa_handler.py
Normal file
147
crazy_functions/review_fns/handlers/qa_handler.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from textwrap import dedent
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 学术问答功能(BaseHandler):
|
||||
"""学术问答处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理学术问答请求,返回最终的prompt"""
|
||||
|
||||
# 1. 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 2. 搜索相关论文
|
||||
papers = await self._search_relevant_papers(criteria, search_params)
|
||||
if not papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = dedent(f"""Current time: {current_time}
|
||||
|
||||
Based on the following paper abstracts, please answer this academic question: {criteria.original_query}
|
||||
|
||||
Available papers for reference:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please structure your response in the following format:
|
||||
|
||||
1. Core Answer (2-3 paragraphs)
|
||||
- Provide a clear, direct answer synthesizing key findings
|
||||
- Support main points with citations [1,2,etc.]
|
||||
- Focus on consensus and differences across papers
|
||||
|
||||
2. Key Evidence (2-3 paragraphs)
|
||||
- Present supporting evidence from abstracts
|
||||
- Compare methodologies and results
|
||||
- Highlight significant findings with citations
|
||||
|
||||
3. Research Context (1-2 paragraphs)
|
||||
- Discuss current trends and developments
|
||||
- Identify research gaps or limitations
|
||||
- Suggest potential future directions
|
||||
|
||||
Guidelines:
|
||||
- Base your answer ONLY on the provided abstracts
|
||||
- Use numbered citations [1], [2,3], etc. for every claim
|
||||
- Maintain academic tone and objectivity
|
||||
- Synthesize findings across multiple papers
|
||||
- Focus on the most relevant information to the question
|
||||
|
||||
Constraints:
|
||||
- Do not include information beyond the provided abstracts
|
||||
- Avoid speculation or personal opinions
|
||||
- Do not elaborate on technical details unless directly relevant
|
||||
- Keep citations concise and focused
|
||||
- Use [N] citations for every major claim or finding
|
||||
- Cite multiple papers [1,2,3] when showing consensus
|
||||
- Place citations immediately after the relevant statements
|
||||
|
||||
Note: Provide citations for every major claim to ensure traceability to source papers.
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language. Use Chinese to answer if no language is specified.
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
)
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _search_relevant_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""搜索相关论文"""
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return []
|
||||
|
||||
# 使用BGE重排序
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.main_topic,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
return self.ranked_papers or []
|
||||
|
||||
async def _generate_answer(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
papers: List,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any]
|
||||
) -> List[List[str]]:
|
||||
"""生成答案"""
|
||||
|
||||
# 构建提示
|
||||
qa_prompt = dedent(f"""Please answer the following academic question based on recent research papers.
|
||||
|
||||
Question: {criteria.main_topic}
|
||||
|
||||
Relevant papers:
|
||||
{self._format_papers(papers)}
|
||||
|
||||
Please provide:
|
||||
1. A direct answer to the question
|
||||
2. Supporting evidence from the papers
|
||||
3. Different perspectives or approaches if applicable
|
||||
4. Current limitations and open questions
|
||||
5. References to specific papers
|
||||
|
||||
Format your response in markdown with clear sections."""
|
||||
)
|
||||
# 调用LLM生成答案
|
||||
for response_chunk in request_gpt(
|
||||
inputs_array=[qa_prompt],
|
||||
inputs_show_user_array=["Generating answer..."],
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[history],
|
||||
sys_prompt_array=[system_prompt]
|
||||
):
|
||||
pass # 等待生成完成
|
||||
|
||||
# 获取最后的回答
|
||||
if chatbot and len(chatbot[-1]) >= 2:
|
||||
answer = chatbot[-1][1]
|
||||
chatbot.append(["Here is the answer:", answer])
|
||||
else:
|
||||
chatbot.append(["Here is the answer:", "Failed to generate answer."])
|
||||
|
||||
return chatbot
|
||||
|
||||
185
crazy_functions/review_fns/handlers/recommend_handler.py
Normal file
185
crazy_functions/review_fns/handlers/recommend_handler.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from textwrap import dedent
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 论文推荐功能(BaseHandler):
|
||||
"""论文推荐处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理论文推荐请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 1. 先搜索种子论文
|
||||
seed_papers = await self._search_seed_papers(criteria, search_params)
|
||||
if not seed_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 使用BGE重排序
|
||||
all_papers = seed_papers
|
||||
|
||||
if not all_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
if not self.ranked_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = dedent(f"""Current time: {current_time}
|
||||
|
||||
Based on the user's interest in {criteria.main_topic}, here are relevant papers.
|
||||
|
||||
Available papers for recommendation:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please provide:
|
||||
1. Group papers by sub-topics or themes if applicable
|
||||
|
||||
2. For each paper:
|
||||
- Publication time and venue (when available)
|
||||
- Journal metrics (when available):
|
||||
* Impact Factor (IF)
|
||||
* JCR Quartile
|
||||
* Chinese Academy of Sciences (CAS) Division
|
||||
- The key contributions and main findings
|
||||
- Why it's relevant to the user's interests
|
||||
- How it relates to other recommended papers
|
||||
- The paper's citation count and citation impact
|
||||
- The paper's download link
|
||||
|
||||
3. A suggested reading order based on:
|
||||
- Journal impact and quality metrics
|
||||
- Chronological development of ideas
|
||||
- Paper relationships and dependencies
|
||||
- Difficulty level
|
||||
- Impact and significance
|
||||
|
||||
4. Future Directions
|
||||
- Emerging venues and research streams
|
||||
- Novel methodological approaches
|
||||
- Cross-disciplinary opportunities
|
||||
- Research gaps by publication type
|
||||
|
||||
|
||||
IMPORTANT:
|
||||
- Focus on explaining why each paper is valuable
|
||||
- Highlight connections between papers
|
||||
- Consider both citation counts AND journal metrics when discussing impact
|
||||
- When available, use IF, JCR quartile, and CAS division to assess paper quality
|
||||
- Mention publication timing when discussing paper relationships
|
||||
- When referring to papers, use HTML links in this format:
|
||||
* For DOIs: <a href='https://doi.org/DOI_HERE' target='_blank'>DOI: DOI_HERE</a>
|
||||
* For titles: <a href='PAPER_URL' target='_blank'>PAPER_TITLE</a>
|
||||
- Present papers in a way that shows the evolution of ideas over time
|
||||
- Base recommendations ONLY on the explicitly provided paper information
|
||||
- Do not make ANY assumptions about papers beyond the given data
|
||||
- When information is missing or unclear, acknowledge the limitation
|
||||
- Never speculate about:
|
||||
* Paper quality or rigor not evidenced in the data
|
||||
* Research impact beyond citation counts and journal metrics
|
||||
* Implementation details not mentioned
|
||||
* Author expertise or background
|
||||
* Future research directions not stated
|
||||
- For each recommendation, cite only verifiable information
|
||||
- Clearly distinguish between facts and potential implications
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
)
|
||||
return final_prompt
|
||||
|
||||
async def _search_seed_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""搜索种子论文"""
|
||||
try:
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return []
|
||||
|
||||
return all_papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索种子论文时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _get_recommendations(self, seed_papers: List, multiplier: int = 1) -> List:
|
||||
"""获取推荐论文"""
|
||||
recommendations = []
|
||||
base_limit = 3 * multiplier
|
||||
|
||||
# 将种子论文添加到推荐列表中
|
||||
recommendations.extend(seed_papers)
|
||||
|
||||
# 只使用前5篇论文作为种子
|
||||
seed_papers = seed_papers[:5]
|
||||
|
||||
for paper in seed_papers:
|
||||
try:
|
||||
if paper.doi and paper.doi.startswith("10.48550/arXiv."):
|
||||
# arXiv论文
|
||||
arxiv_id = paper.doi.split(".")[-1]
|
||||
paper_details = await self.arxiv.get_paper_details(arxiv_id)
|
||||
if paper_details and hasattr(paper_details, 'venue'):
|
||||
category = paper_details.venue.split(":")[-1]
|
||||
similar_papers = await self.arxiv.search_by_category(
|
||||
category,
|
||||
limit=base_limit,
|
||||
sort_by='relevance'
|
||||
)
|
||||
recommendations.extend(similar_papers)
|
||||
elif paper.doi: # 只对有DOI的论文获取推荐
|
||||
# Semantic Scholar论文
|
||||
similar_papers = await self.semantic.get_recommended_papers(
|
||||
paper.doi,
|
||||
limit=base_limit
|
||||
)
|
||||
if similar_papers: # 只添加成功获取的推荐
|
||||
recommendations.extend(similar_papers)
|
||||
else:
|
||||
# 对于没有DOI的论文,使用标题进行相关搜索
|
||||
if paper.title:
|
||||
similar_papers = await self.semantic.search(
|
||||
query=paper.title,
|
||||
limit=base_limit
|
||||
)
|
||||
recommendations.extend(similar_papers)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文 '{paper.title}' 的推荐时发生错误: {str(e)}")
|
||||
continue
|
||||
|
||||
# 去重处理
|
||||
seen_dois = set()
|
||||
unique_recommendations = []
|
||||
for paper in recommendations:
|
||||
if paper.doi and paper.doi not in seen_dois:
|
||||
seen_dois.add(paper.doi)
|
||||
unique_recommendations.append(paper)
|
||||
elif not paper.doi and paper not in unique_recommendations:
|
||||
unique_recommendations.append(paper)
|
||||
|
||||
return unique_recommendations
|
||||
193
crazy_functions/review_fns/handlers/review_handler.py
Normal file
193
crazy_functions/review_fns/handlers/review_handler.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class 文献综述功能(BaseHandler):
|
||||
"""文献综述处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理文献综述请求,返回最终的prompt"""
|
||||
|
||||
# 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
# 检查排序后的论文数量
|
||||
if not self.ranked_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 检查是否包含PubMed论文
|
||||
has_pubmed_papers = any(paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url
|
||||
for paper in self.ranked_papers)
|
||||
|
||||
if has_pubmed_papers:
|
||||
return self._generate_medical_review_prompt(criteria)
|
||||
else:
|
||||
return self._generate_general_review_prompt(criteria)
|
||||
|
||||
def _generate_medical_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成医学文献综述prompt"""
|
||||
return f"""Current time: {self._get_current_time()}
|
||||
|
||||
Conduct a systematic medical literature review on {criteria.main_topic} based STRICTLY on the provided articles.
|
||||
|
||||
Available literature for review:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||
|
||||
Please structure your medical review following these guidelines:
|
||||
|
||||
1. Research Overview
|
||||
- Main research questions and objectives from the studies
|
||||
- Types of studies included (clinical trials, observational studies, etc.)
|
||||
- Study populations and settings
|
||||
- Time period of the research
|
||||
|
||||
2. Key Findings
|
||||
- Main outcomes and results reported in abstracts
|
||||
- Primary endpoints and their measurements
|
||||
- Statistical significance when reported
|
||||
- Observed trends across studies
|
||||
|
||||
3. Methods Summary
|
||||
- Study designs used
|
||||
- Major interventions or treatments studied
|
||||
- Key outcome measures
|
||||
- Patient populations studied
|
||||
|
||||
4. Clinical Relevance
|
||||
- Reported clinical implications
|
||||
- Main conclusions from authors
|
||||
- Reported benefits and risks
|
||||
- Treatment responses when available
|
||||
|
||||
5. Research Status
|
||||
- Current research focus areas
|
||||
- Reported limitations
|
||||
- Gaps identified in abstracts
|
||||
- Authors' suggested future directions
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
|
||||
Citation Rules (MANDATORY):
|
||||
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||
- Use ONLY the papers provided in the available literature list above
|
||||
- Citations must appear immediately after each statement
|
||||
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||
|
||||
Content Guidelines:
|
||||
- Present only information available in the provided papers
|
||||
- If certain information is not available, simply omit that aspect rather than explicitly stating its absence
|
||||
- Focus on synthesizing and presenting available findings
|
||||
- Maintain professional medical writing style
|
||||
- Present limitations and gaps as research opportunities rather than missing information
|
||||
|
||||
Writing Style:
|
||||
- Use precise medical terminology
|
||||
- Maintain objective reporting
|
||||
- Use consistent terminology throughout
|
||||
- Present a cohesive narrative without referencing data limitations
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
def _generate_general_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成通用文献综述prompt"""
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Conduct a comprehensive literature review on {criteria.main_topic} focusing on the following aspects:
|
||||
{', '.join(criteria.sub_topics)}
|
||||
|
||||
Available literature for review:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||
|
||||
Please structure your review following these guidelines:
|
||||
|
||||
1. Introduction and Research Background
|
||||
- Current state and significance of the research field
|
||||
- Key research problems and challenges
|
||||
- Research development timeline and evolution
|
||||
|
||||
2. Research Directions and Classifications
|
||||
- Major research directions and their relationships
|
||||
- Different technical approaches and their characteristics
|
||||
- Comparative analysis of various solutions
|
||||
|
||||
3. Core Technologies and Methods
|
||||
- Key technological breakthroughs
|
||||
- Advantages and limitations of different methods
|
||||
- Technical challenges and solutions
|
||||
|
||||
4. Applications and Impact
|
||||
- Real-world applications and use cases
|
||||
- Industry influence and practical value
|
||||
- Implementation challenges and solutions
|
||||
|
||||
5. Future Trends and Prospects
|
||||
- Emerging research directions
|
||||
- Unsolved problems and challenges
|
||||
- Potential breakthrough points
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
|
||||
Citation Rules (MANDATORY):
|
||||
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||
- Use ONLY the papers provided in the available literature list above
|
||||
- Citations must appear immediately after each statement
|
||||
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||
|
||||
Writing Style:
|
||||
- Maintain academic and professional tone
|
||||
- Focus on objective analysis with proper citations
|
||||
- Ensure logical flow and clear structure
|
||||
|
||||
Content Requirements:
|
||||
- Base ALL analysis STRICTLY on the provided papers with explicit citations
|
||||
- When introducing any concept, method, or finding, immediately follow with [N]
|
||||
- For each research direction or approach, cite the specific papers [N] that proposed or developed it
|
||||
- When discussing limitations or challenges, cite the papers [N] that identified them
|
||||
- DO NOT include information from sources outside the provided paper list
|
||||
- DO NOT make unsupported claims or statements
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
452
crazy_functions/review_fns/paper_processor/paper_llm_ranker.py
Normal file
452
crazy_functions/review_fns/paper_processor/paper_llm_ranker.py
Normal file
@@ -0,0 +1,452 @@
|
||||
from typing import List, Dict
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
from request_llms.embed_models.bge_llm import BGELLMRanker
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import random
|
||||
from crazy_functions.review_fns.data_sources.journal_metrics import JournalMetrics
|
||||
|
||||
class PaperLLMRanker:
|
||||
"""使用LLM进行论文重排序"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict = None):
|
||||
self.ranker = BGELLMRanker(llm_kwargs=llm_kwargs)
|
||||
self.journal_metrics = JournalMetrics()
|
||||
|
||||
def _update_paper_metrics(self, papers: List[PaperMetadata]) -> None:
|
||||
"""更新论文的期刊指标"""
|
||||
for paper in papers:
|
||||
# 跳过arXiv来源的论文
|
||||
if getattr(paper, 'source', '') == 'arxiv':
|
||||
continue
|
||||
|
||||
if hasattr(paper, 'venue_name') or hasattr(paper, 'venue_info'):
|
||||
# 获取venue_name和venue_info
|
||||
venue_name = getattr(paper, 'venue_name', '')
|
||||
venue_info = getattr(paper, 'venue_info', {})
|
||||
|
||||
# 使用改进的匹配逻辑获取指标
|
||||
metrics = self.journal_metrics.get_journal_metrics(venue_name, venue_info)
|
||||
|
||||
# 更新论文的指标
|
||||
paper.if_factor = metrics.get('if_factor')
|
||||
paper.jcr_division = metrics.get('jcr_division')
|
||||
paper.cas_division = metrics.get('cas_division')
|
||||
|
||||
def _get_year_as_int(self, paper) -> int:
|
||||
"""统一获取论文年份为整数格式
|
||||
|
||||
Args:
|
||||
paper: 论文对象或直接是年份值
|
||||
|
||||
Returns:
|
||||
整数格式的年份,如果无法转换则返回0
|
||||
"""
|
||||
try:
|
||||
# 如果输入直接是年份而不是论文对象
|
||||
if isinstance(paper, int):
|
||||
return paper
|
||||
elif isinstance(paper, str):
|
||||
try:
|
||||
return int(paper)
|
||||
except ValueError:
|
||||
import re
|
||||
year_match = re.search(r'\d{4}', paper)
|
||||
if year_match:
|
||||
return int(year_match.group())
|
||||
return 0
|
||||
elif isinstance(paper, float):
|
||||
return int(paper)
|
||||
|
||||
# 处理论文对象
|
||||
year = getattr(paper, 'year', None)
|
||||
if year is None:
|
||||
return 0
|
||||
|
||||
# 如果是字符串,尝试转换为整数
|
||||
if isinstance(year, str):
|
||||
# 首先尝试直接转换整个字符串
|
||||
try:
|
||||
return int(year)
|
||||
except ValueError:
|
||||
# 如果直接转换失败,尝试提取第一个数字序列
|
||||
import re
|
||||
year_match = re.search(r'\d{4}', year)
|
||||
if year_match:
|
||||
return int(year_match.group())
|
||||
return 0
|
||||
# 如果是浮点数,转换为整数
|
||||
elif isinstance(year, float):
|
||||
return int(year)
|
||||
# 如果已经是整数,直接返回
|
||||
elif isinstance(year, int):
|
||||
return year
|
||||
return 0
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
def rank_papers(
|
||||
self,
|
||||
query: str,
|
||||
papers: List[PaperMetadata],
|
||||
search_criteria: SearchCriteria = None,
|
||||
top_k: int = 40,
|
||||
use_rerank: bool = False,
|
||||
pre_filter_ratio: float = 0.5,
|
||||
max_papers: int = 150
|
||||
) -> List[PaperMetadata]:
|
||||
"""对论文进行重排序"""
|
||||
initial_count = len(papers) if papers else 0
|
||||
stats = {'initial': initial_count}
|
||||
|
||||
if not papers or not query:
|
||||
return []
|
||||
|
||||
# 更新论文的期刊指标
|
||||
self._update_paper_metrics(papers)
|
||||
|
||||
# 构建增强查询
|
||||
# enhanced_query = self._build_enhanced_query(query, search_criteria) if search_criteria else query
|
||||
enhanced_query = query
|
||||
# 首先过滤不满足年份要求的论文
|
||||
if search_criteria and search_criteria.start_year and search_criteria.end_year:
|
||||
before_year_filter = len(papers)
|
||||
filtered_papers = []
|
||||
start_year = int(search_criteria.start_year)
|
||||
end_year = int(search_criteria.end_year)
|
||||
|
||||
for paper in papers:
|
||||
paper_year = self._get_year_as_int(paper)
|
||||
if paper_year == 0 or start_year <= paper_year <= end_year:
|
||||
filtered_papers.append(paper)
|
||||
|
||||
papers = filtered_papers
|
||||
stats['after_year_filter'] = len(papers)
|
||||
|
||||
if not papers: # 如果过滤后没有论文,直接返回空列表
|
||||
return []
|
||||
|
||||
# 新增:对少量论文的快速处理
|
||||
SMALL_PAPER_THRESHOLD = 10 # 定义"少量"论文的阈值
|
||||
if len(papers) <= SMALL_PAPER_THRESHOLD:
|
||||
# 对于少量论文,直接根据查询类型进行简单排序
|
||||
if search_criteria:
|
||||
if search_criteria.query_type == "latest":
|
||||
papers.sort(key=lambda x: getattr(x, 'year', 0) or 0, reverse=True)
|
||||
elif search_criteria.query_type == "recommend":
|
||||
papers.sort(key=lambda x: getattr(x, 'citations', 0) or 0, reverse=True)
|
||||
elif search_criteria.query_type == "review":
|
||||
papers.sort(key=lambda x:
|
||||
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||
for keyword in ['review', 'survey', 'overview'])
|
||||
else 0,
|
||||
reverse=True
|
||||
)
|
||||
return papers[:top_k]
|
||||
|
||||
# 1. 优先处理最新的论文
|
||||
if search_criteria and search_criteria.query_type == "latest":
|
||||
papers = sorted(papers, key=lambda x: self._get_year_as_int(x), reverse=True)
|
||||
|
||||
# 2. 如果是综述类查询,优先处理可能的综述论文
|
||||
if search_criteria and search_criteria.query_type == "review":
|
||||
papers = sorted(papers, key=lambda x:
|
||||
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||
for keyword in ['review', 'survey', 'overview'])
|
||||
else 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 3. 如果论文数量超过限制,采用分层采样而不是完全随机
|
||||
if len(papers) > max_papers:
|
||||
before_max_limit = len(papers)
|
||||
papers = self._select_papers_strategically(papers, search_criteria, max_papers)
|
||||
stats['after_max_limit'] = len(papers)
|
||||
|
||||
try:
|
||||
paper_texts = []
|
||||
valid_papers = [] # 4. 跟踪有效论文
|
||||
|
||||
for paper in papers:
|
||||
if paper is None:
|
||||
continue
|
||||
# 5. 预先过滤明显不相关的论文
|
||||
if search_criteria and search_criteria.start_year:
|
||||
if getattr(paper, 'year', 0) and self._get_year_as_int(paper.year) < search_criteria.start_year:
|
||||
continue
|
||||
|
||||
doc = self._build_enhanced_document(paper, search_criteria)
|
||||
paper_texts.append(doc)
|
||||
valid_papers.append(paper) # 记录对应的论文
|
||||
|
||||
stats['after_valid_check'] = len(valid_papers)
|
||||
|
||||
if not paper_texts:
|
||||
return []
|
||||
|
||||
# 使用LLM判断相关性
|
||||
relevance_results = self.ranker.batch_check_relevance(
|
||||
query=enhanced_query, # 使用增强的查询
|
||||
paper_texts=paper_texts,
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
# 6. 优化相关论文的选择策略
|
||||
relevant_papers = []
|
||||
for paper, is_relevant in zip(valid_papers, relevance_results):
|
||||
if is_relevant:
|
||||
relevant_papers.append(paper)
|
||||
|
||||
stats['after_llm_filter'] = len(relevant_papers)
|
||||
|
||||
# 打印统计信息
|
||||
print(f"论文筛选统计: 初始数量={stats['initial']}, " +
|
||||
f"年份过滤后={stats.get('after_year_filter', stats['initial'])}, " +
|
||||
f"数量限制后={stats.get('after_max_limit', stats.get('after_year_filter', stats['initial']))}, " +
|
||||
f"有效性检查后={stats['after_valid_check']}, " +
|
||||
f"LLM筛选后={stats['after_llm_filter']}")
|
||||
|
||||
# 7. 改进回退策略
|
||||
if len(relevant_papers) < min(5, len(papers)):
|
||||
# 如果相关论文太少,返回按引用量排序的论文
|
||||
return sorted(
|
||||
papers[:top_k],
|
||||
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 8. 对最终结果进行排序
|
||||
if search_criteria:
|
||||
if search_criteria.query_type == "latest":
|
||||
# 最新论文优先,但同年份按IF排序
|
||||
relevant_papers.sort(key=lambda x: (
|
||||
self._get_year_as_int(x),
|
||||
getattr(x, 'if_factor', 0) or 0
|
||||
), reverse=True)
|
||||
elif search_criteria.query_type == "recommend":
|
||||
# IF指数优先,其次是引用量
|
||||
relevant_papers.sort(key=lambda x: (
|
||||
getattr(x, 'if_factor', 0) or 0,
|
||||
getattr(x, 'citations', 0) or 0
|
||||
), reverse=True)
|
||||
else:
|
||||
# 默认按IF指数排序
|
||||
relevant_papers.sort(key=lambda x: getattr(x, 'if_factor', 0) or 0, reverse=True)
|
||||
|
||||
return relevant_papers[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
print(f"论文排序时出错: {str(e)}")
|
||||
# 9. 改进错误处理的回退策略
|
||||
try:
|
||||
return sorted(
|
||||
papers[:top_k],
|
||||
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||
reverse=True
|
||||
)
|
||||
except:
|
||||
return papers[:top_k] if papers else []
|
||||
|
||||
def _build_enhanced_query(self, query: str, criteria: SearchCriteria) -> str:
|
||||
"""构建增强的查询文本"""
|
||||
components = []
|
||||
|
||||
# 强调这是用户的原始查询,是最重要的匹配依据
|
||||
components.append(f"Original user query that must be primarily matched: {query}")
|
||||
|
||||
if criteria:
|
||||
# 添加主题(如果与原始查询不同)
|
||||
if criteria.main_topic and criteria.main_topic != query:
|
||||
components.append(f"Additional context - The main topic is about: {criteria.main_topic}")
|
||||
|
||||
# 添加子主题
|
||||
if criteria.sub_topics:
|
||||
components.append(f"Secondary aspects to consider: {', '.join(criteria.sub_topics)}")
|
||||
|
||||
# 添加查询类型相关信息
|
||||
if criteria.query_type == "review":
|
||||
components.append("Paper type preference: Looking for comprehensive review papers, survey papers, or overview papers")
|
||||
elif criteria.query_type == "latest":
|
||||
components.append("Temporal preference: Focus on the most recent developments and latest papers")
|
||||
elif criteria.query_type == "recommend":
|
||||
components.append("Impact preference: Consider influential and fundamental papers")
|
||||
|
||||
# 直接连接所有组件,保持语序
|
||||
enhanced_query = ' '.join(components)
|
||||
|
||||
# 限制长度但不打乱顺序
|
||||
if len(enhanced_query) > 1000:
|
||||
enhanced_query = enhanced_query[:997] + "..."
|
||||
|
||||
return enhanced_query
|
||||
|
||||
def _build_enhanced_document(self, paper: PaperMetadata, criteria: SearchCriteria) -> str:
|
||||
"""构建增强的文档表示"""
|
||||
components = []
|
||||
|
||||
# 基本信息
|
||||
title = getattr(paper, 'title', '')
|
||||
authors = ', '.join(getattr(paper, 'authors', []))
|
||||
abstract = getattr(paper, 'abstract', '')
|
||||
year = getattr(paper, 'year', '')
|
||||
venue = getattr(paper, 'venue', '')
|
||||
|
||||
components.extend([
|
||||
f"Title: {title}",
|
||||
f"Authors: {authors}",
|
||||
f"Year: {year}",
|
||||
f"Venue: {venue}",
|
||||
f"Abstract: {abstract}"
|
||||
])
|
||||
|
||||
# 根据查询类型添加额外信息
|
||||
if criteria:
|
||||
if criteria.query_type == "review":
|
||||
# 对于综述类查询,强调论文的综述性质
|
||||
title_lower = (title or '').lower()
|
||||
abstract_lower = (abstract or '').lower()
|
||||
if any(keyword in title_lower or keyword in abstract_lower
|
||||
for keyword in ['review', 'survey', 'overview']):
|
||||
components.append("This is a review/survey paper")
|
||||
|
||||
elif criteria.query_type == "latest":
|
||||
# 对于最新论文查询,强调时间信息
|
||||
if year and int(year) >= criteria.start_year:
|
||||
components.append(f"This is a recent paper from {year}")
|
||||
|
||||
elif criteria.query_type == "recommend":
|
||||
# 对于推荐类查询,添加主题相关性信息
|
||||
if criteria.main_topic:
|
||||
title_lower = (title or '').lower()
|
||||
abstract_lower = (abstract or '').lower()
|
||||
topic_relevance = any(topic.lower() in title_lower or topic.lower() in abstract_lower
|
||||
for topic in [criteria.main_topic] + (criteria.sub_topics or []))
|
||||
if topic_relevance:
|
||||
components.append(f"This paper is directly related to {criteria.main_topic}")
|
||||
|
||||
return '\n'.join(components)
|
||||
|
||||
def _select_papers_strategically(
|
||||
self,
|
||||
papers: List[PaperMetadata],
|
||||
search_criteria: SearchCriteria,
|
||||
max_papers: int = 150
|
||||
) -> List[PaperMetadata]:
|
||||
"""战略性地选择论文子集,优先选择非Crossref来源的论文,
|
||||
当ADS论文充足时排除arXiv论文"""
|
||||
if len(papers) <= max_papers:
|
||||
return papers
|
||||
|
||||
# 1. 首先按来源分组
|
||||
papers_by_source = {
|
||||
'crossref': [],
|
||||
'adsabs': [],
|
||||
'arxiv': [],
|
||||
'others': [] # semantic, pubmed等其他来源
|
||||
}
|
||||
|
||||
for paper in papers:
|
||||
source = getattr(paper, 'source', '')
|
||||
if source == 'crossref':
|
||||
papers_by_source['crossref'].append(paper)
|
||||
elif source == 'adsabs':
|
||||
papers_by_source['adsabs'].append(paper)
|
||||
elif source == 'arxiv':
|
||||
papers_by_source['arxiv'].append(paper)
|
||||
else:
|
||||
papers_by_source['others'].append(paper)
|
||||
|
||||
# 2. 计算分数的通用函数
|
||||
def calculate_paper_score(paper):
|
||||
score = 0
|
||||
title = (getattr(paper, 'title', '') or '').lower()
|
||||
abstract = (getattr(paper, 'abstract', '') or '').lower()
|
||||
year = self._get_year_as_int(paper)
|
||||
citations = getattr(paper, 'citations', 0) or 0
|
||||
|
||||
# 安全地获取搜索条件
|
||||
main_topic = (getattr(search_criteria, 'main_topic', '') or '').lower()
|
||||
sub_topics = getattr(search_criteria, 'sub_topics', []) or []
|
||||
query_type = getattr(search_criteria, 'query_type', '')
|
||||
start_year = getattr(search_criteria, 'start_year', 0) or 0
|
||||
|
||||
# 主题相关性得分
|
||||
if main_topic and main_topic in title:
|
||||
score += 10
|
||||
if main_topic and main_topic in abstract:
|
||||
score += 5
|
||||
|
||||
# 子主题相关性得分
|
||||
for sub_topic in sub_topics:
|
||||
if sub_topic and sub_topic.lower() in title:
|
||||
score += 5
|
||||
if sub_topic and sub_topic.lower() in abstract:
|
||||
score += 2.5
|
||||
|
||||
# 根据查询类型调整分数
|
||||
if query_type == "review":
|
||||
review_keywords = ['review', 'survey', 'overview']
|
||||
if any(keyword in title for keyword in review_keywords):
|
||||
score *= 1.5
|
||||
if any(keyword in abstract for keyword in review_keywords):
|
||||
score *= 1.2
|
||||
elif query_type == "latest":
|
||||
if year and start_year:
|
||||
year_int = year if isinstance(year, int) else self._get_year_as_int(paper)
|
||||
start_year_int = start_year if isinstance(start_year, int) else int(start_year)
|
||||
if year_int >= start_year_int:
|
||||
recency_bonus = min(5, (year_int - start_year_int))
|
||||
score += recency_bonus * 2
|
||||
elif query_type == "recommend":
|
||||
citation_score = min(10, citations / 100)
|
||||
score += citation_score
|
||||
|
||||
return score
|
||||
|
||||
result = []
|
||||
|
||||
# 3. 处理ADS和arXiv论文
|
||||
non_crossref_papers = papers_by_source['others'] # 首先添加其他来源的论文
|
||||
|
||||
# 添加ADS论文
|
||||
if papers_by_source['adsabs']:
|
||||
non_crossref_papers.extend(papers_by_source['adsabs'])
|
||||
|
||||
# 只有当ADS论文不足20篇时,才添加arXiv论文
|
||||
if len(papers_by_source['adsabs']) <= 20:
|
||||
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||
elif not papers_by_source['adsabs'] and papers_by_source['arxiv']:
|
||||
# 如果没有ADS论文但有arXiv论文,也使用arXiv论文
|
||||
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||
|
||||
# 4. 对非Crossref论文评分和排序
|
||||
scored_non_crossref = [(p, calculate_paper_score(p)) for p in non_crossref_papers]
|
||||
scored_non_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 5. 先添加高分的非Crossref论文
|
||||
non_crossref_limit = max_papers * 0.9 # 90%的配额给非Crossref论文
|
||||
if len(scored_non_crossref) >= non_crossref_limit:
|
||||
result.extend([p[0] for p in scored_non_crossref[:int(non_crossref_limit)]])
|
||||
else:
|
||||
result.extend([p[0] for p in scored_non_crossref])
|
||||
|
||||
# 6. 如果还有剩余空间,考虑添加Crossref论文
|
||||
remaining_slots = max_papers - len(result)
|
||||
if remaining_slots > 0 and papers_by_source['crossref']:
|
||||
# 计算Crossref论文的最大数量(不超过总数的10%)
|
||||
max_crossref = min(remaining_slots, max_papers * 0.1)
|
||||
|
||||
# 对Crossref论文评分和排序
|
||||
scored_crossref = [(p, calculate_paper_score(p)) for p in papers_by_source['crossref']]
|
||||
scored_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 添加最高分的Crossref论文
|
||||
result.extend([p[0] for p in scored_crossref[:int(max_crossref)]])
|
||||
|
||||
# 7. 如果使用了Crossref论文后还有空位,继续使用非Crossref论文填充
|
||||
if len(result) < max_papers and len(scored_non_crossref) > len(result):
|
||||
remaining_non_crossref = [p[0] for p in scored_non_crossref[len(result):]]
|
||||
result.extend(remaining_non_crossref[:max_papers - len(result)])
|
||||
|
||||
return result
|
||||
76
crazy_functions/review_fns/prompts/adsabs_prompts.py
Normal file
76
crazy_functions/review_fns/prompts/adsabs_prompts.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# ADS query optimization prompt
|
||||
ADSABS_QUERY_PROMPT = """Analyze and optimize the following query for NASA ADS search.
|
||||
If the query is not related to astronomy, astrophysics, or physics, return <query>none</query>.
|
||||
If the query contains non-English terms, translate them to English first.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized ADS search query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Relevant research areas for ADS:
|
||||
- Astronomy and astrophysics
|
||||
- Physics (theoretical and experimental)
|
||||
- Space science and exploration
|
||||
- Planetary science
|
||||
- Cosmology
|
||||
- Astrobiology
|
||||
- Related instrumentation and methods
|
||||
|
||||
Available search fields and filters:
|
||||
1. Basic fields:
|
||||
- title: Search in title (title:"term")
|
||||
- abstract: Search in abstract (abstract:"term")
|
||||
- author: Search for author names (author:"lastname, firstname")
|
||||
- year: Filter by year (year:2020-2023)
|
||||
- bibstem: Search by journal abbreviation (bibstem:ApJ)
|
||||
|
||||
2. Boolean operators:
|
||||
- AND
|
||||
- OR
|
||||
- NOT
|
||||
- (): Group terms
|
||||
- "": Exact phrase match
|
||||
|
||||
3. Special filters:
|
||||
- citations(identifier:paper): Papers citing a specific paper
|
||||
- references(identifier:paper): References of a specific paper
|
||||
- citation_count: Filter by citation count
|
||||
- database: Filter by database (database:astronomy)
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Black holes in galaxy centers after 2020"
|
||||
<query>title:"black hole" AND abstract:"galaxy center" AND year:2020-</query>
|
||||
|
||||
2. Query: "Papers by Neil deGrasse Tyson about exoplanets"
|
||||
<query>author:"Tyson, Neil deGrasse" AND title:exoplanet</query>
|
||||
|
||||
3. Query: "Most cited papers about dark matter in ApJ"
|
||||
<query>title:"dark matter" AND bibstem:ApJ AND citation_count:[100 TO *]</query>
|
||||
|
||||
4. Query: "Latest research on diabetes treatment"
|
||||
<query>none</query>
|
||||
|
||||
5. Query: "Machine learning for galaxy classification"
|
||||
<query>title:("machine learning" OR "deep learning") AND (title:galaxy OR abstract:galaxy) AND abstract:classification</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized ADS search query using appropriate fields and operators, or "none" if not relevant</query>"""
|
||||
|
||||
# System prompt
|
||||
ADSABS_QUERY_SYSTEM_PROMPT = """You are an expert at crafting NASA ADS search queries.
|
||||
Your task is to:
|
||||
1. First determine if the query is relevant to astronomy, astrophysics, or physics research
|
||||
2. If relevant, optimize the natural language query for the ADS API
|
||||
3. If not relevant, return "none" to indicate the query should be handled by other databases
|
||||
|
||||
Focus on creating precise queries that will return relevant astronomical and physics literature.
|
||||
Always generate English search terms regardless of the input language.
|
||||
Consider using field-specific search terms and appropriate filters to improve search accuracy.
|
||||
|
||||
Remember: ADS is specifically for astronomy, astrophysics, and physics research.
|
||||
Medical, biological, or general research queries should return "none"."""
|
||||
341
crazy_functions/review_fns/prompts/arxiv_prompts.py
Normal file
341
crazy_functions/review_fns/prompts/arxiv_prompts.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# Basic type analysis prompt
|
||||
ARXIV_TYPE_PROMPT = """Analyze the research query and determine if arXiv search is needed and its type.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task 1: Determine if this query requires arXiv search
|
||||
- arXiv is suitable for:
|
||||
* Computer science and AI/ML
|
||||
* Physics and mathematics
|
||||
* Quantitative biology and finance
|
||||
* Electrical engineering
|
||||
* Recent preprints in these fields
|
||||
- arXiv is NOT needed for:
|
||||
* Medical research (unless ML/AI applications)
|
||||
* Social sciences
|
||||
* Business studies
|
||||
* Humanities
|
||||
* Industry reports
|
||||
|
||||
Task 2: If arXiv search is needed, determine the most appropriate search type
|
||||
Available types:
|
||||
1. basic: Keyword-based search across all fields
|
||||
- For specific technical queries
|
||||
- When looking for particular methods or applications
|
||||
2. category: Category-based search within specific fields
|
||||
- For broad topic exploration
|
||||
- When surveying a research area
|
||||
3. none: arXiv search not needed for this query
|
||||
- When topic is outside arXiv's scope
|
||||
- For non-technical or clinical research
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "BERT transformer architecture"
|
||||
<search_type>basic</search_type>
|
||||
|
||||
2. Query: "latest developments in machine learning"
|
||||
<search_type>category</search_type>
|
||||
|
||||
3. Query: "COVID-19 clinical trials"
|
||||
<search_type>none</search_type>
|
||||
|
||||
4. Query: "psychological effects of social media"
|
||||
<search_type>none</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose either 'basic', 'category', or 'none'</search_type>"""
|
||||
|
||||
# Query optimization prompt
|
||||
ARXIV_QUERY_PROMPT = """Optimize the following query for arXiv search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized arXiv search query using boolean operators and field tags.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Available field tags:
|
||||
- ti: Search in title
|
||||
- abs: Search in abstract
|
||||
- au: Search for author
|
||||
- all: Search in all fields (default)
|
||||
|
||||
Boolean operators:
|
||||
- AND: Both terms must appear
|
||||
- OR: Either term can appear
|
||||
- NOT: Exclude terms
|
||||
- (): Group terms
|
||||
- "": Exact phrase match
|
||||
|
||||
Examples:
|
||||
|
||||
1. Natural query: "Recent papers about transformer models by Vaswani"
|
||||
<query>ti:"transformer model" AND au:Vaswani AND year:[2017 TO 2024]</query>
|
||||
|
||||
2. Natural query: "Deep learning for computer vision, excluding surveys"
|
||||
<query>ti:(deep learning AND "computer vision") NOT (ti:survey OR ti:review)</query>
|
||||
|
||||
3. Natural query: "Attention mechanism in language models"
|
||||
<query>ti:(attention OR "attention mechanism") AND abs:"language model"</query>
|
||||
|
||||
4. Natural query: "GANs or generative adversarial networks for image generation"
|
||||
<query>(ti:GAN OR ti:"generative adversarial network") AND abs:"image generation"</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized search query using appropriate operators and tags</query>
|
||||
|
||||
Note:
|
||||
- Use quotes for exact phrases
|
||||
- Combine multiple conditions with boolean operators
|
||||
- Consider both title and abstract for important concepts
|
||||
- Include author names when relevant
|
||||
- Use parentheses for complex logical groupings"""
|
||||
|
||||
# Sort parameters prompt
|
||||
ARXIV_SORT_PROMPT = """Determine optimal sorting parameters for the research query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting parameters to help users find the most relevant papers.
|
||||
|
||||
Available sorting options:
|
||||
|
||||
1. Sort by:
|
||||
- relevance: Best match to query terms (default)
|
||||
- lastUpdatedDate: Most recently updated papers
|
||||
- submittedDate: Most recently submitted papers
|
||||
|
||||
2. Sort order:
|
||||
- descending: Newest/Most relevant first (default)
|
||||
- ascending: Oldest/Least relevant first
|
||||
|
||||
3. Result limit:
|
||||
- Minimum: 10 papers
|
||||
- Maximum: 50 papers
|
||||
- Recommended: 20-30 papers for most queries
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in transformer models"
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Foundational papers about neural networks"
|
||||
<sort_by>relevance</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>20</limit>
|
||||
|
||||
3. Query: "Evolution of deep learning since 2012"
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>ascending</sort_order>
|
||||
<limit>50</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, lastUpdatedDate, or submittedDate</sort_by>
|
||||
<sort_order>Choose: ascending or descending</sort_order>
|
||||
<limit>Suggest number between 10-50</limit>
|
||||
|
||||
Note:
|
||||
- Choose relevance for specific technical queries
|
||||
- Use lastUpdatedDate for tracking paper revisions
|
||||
- Use submittedDate for following recent developments
|
||||
- Consider query context when setting the limit"""
|
||||
|
||||
# System prompts for each task
|
||||
ARXIV_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine whether the query is better suited for keyword search or category-based search.
|
||||
Consider the query's specificity, scope, and intended search area when making your decision.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
ARXIV_QUERY_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||
Your task is to optimize natural language queries using boolean operators and field tags.
|
||||
Focus on creating precise, targeted queries that will return the most relevant results.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
ARXIV_CATEGORIES_SYSTEM_PROMPT = """You are an expert at arXiv category classification.
|
||||
Your task is to select the most relevant categories for the given research query.
|
||||
Consider both primary and related interdisciplinary categories, while maintaining focus on the main research area.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
ARXIV_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the user's likely intent and temporal aspects of the research topic.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
# 添加新的搜索提示词
|
||||
ARXIV_SEARCH_PROMPT = """Analyze and optimize the research query for arXiv search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized arXiv search query.
|
||||
|
||||
Available search options:
|
||||
1. Basic search with field tags:
|
||||
- ti: Search in title
|
||||
- abs: Search in abstract
|
||||
- au: Search for author
|
||||
Example: "ti:transformer AND abs:attention"
|
||||
|
||||
2. Category-based search:
|
||||
- Use specific arXiv categories
|
||||
Example: "cat:cs.AI AND neural networks"
|
||||
|
||||
3. Date range:
|
||||
- Specify date range using submittedDate
|
||||
Example: "deep learning AND submittedDate:[20200101 TO 20231231]"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent papers about transformer models by Vaswani"
|
||||
<search_criteria>
|
||||
<query>ti:"transformer model" AND au:Vaswani AND submittedDate:[20170101 TO 99991231]</query>
|
||||
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>30</limit>
|
||||
</search_criteria>
|
||||
|
||||
2. Query: "Latest developments in computer vision"
|
||||
<search_criteria>
|
||||
<query>cat:cs.CV AND submittedDate:[20220101 TO 99991231]</query>
|
||||
<categories>cs.CV, cs.AI, cs.LG</categories>
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>25</limit>
|
||||
</search_criteria>
|
||||
|
||||
Please analyze the query and respond with XML tags containing search criteria."""
|
||||
|
||||
ARXIV_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||
Your task is to analyze research queries and transform them into optimized arXiv search criteria.
|
||||
Consider query intent, relevant categories, and temporal aspects when creating the search parameters.
|
||||
Always generate English search terms and respond in English regardless of the input language."""
|
||||
|
||||
# Categories selection prompt
|
||||
ARXIV_CATEGORIES_PROMPT = """Select the most relevant arXiv categories for the research query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Choose 2-4 most relevant categories that best match the research topic.
|
||||
|
||||
Available Categories:
|
||||
|
||||
Computer Science (cs):
|
||||
- cs.AI: Artificial Intelligence (neural networks, machine learning, NLP)
|
||||
- cs.CL: Computation and Language (NLP, machine translation)
|
||||
- cs.CV: Computer Vision and Pattern Recognition
|
||||
- cs.LG: Machine Learning (deep learning, reinforcement learning)
|
||||
- cs.NE: Neural and Evolutionary Computing
|
||||
- cs.RO: Robotics
|
||||
- cs.IR: Information Retrieval
|
||||
- cs.SE: Software Engineering
|
||||
- cs.DB: Databases
|
||||
- cs.DC: Distributed Computing
|
||||
- cs.CY: Computers and Society
|
||||
- cs.HC: Human-Computer Interaction
|
||||
|
||||
Mathematics (math):
|
||||
- math.OC: Optimization and Control
|
||||
- math.PR: Probability
|
||||
- math.ST: Statistics
|
||||
- math.NA: Numerical Analysis
|
||||
- math.DS: Dynamical Systems
|
||||
|
||||
Statistics (stat):
|
||||
- stat.ML: Machine Learning
|
||||
- stat.ME: Methodology
|
||||
- stat.TH: Theory
|
||||
- stat.AP: Applications
|
||||
|
||||
Physics (physics):
|
||||
- physics.comp-ph: Computational Physics
|
||||
- physics.data-an: Data Analysis
|
||||
- physics.soc-ph: Physics and Society
|
||||
|
||||
Electrical Engineering (eess):
|
||||
- eess.SP: Signal Processing
|
||||
- eess.AS: Audio and Speech Processing
|
||||
- eess.IV: Image and Video Processing
|
||||
- eess.SY: Systems and Control
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Deep learning for computer vision"
|
||||
<categories>cs.CV, cs.LG, stat.ML</categories>
|
||||
|
||||
2. Query: "Natural language processing with transformers"
|
||||
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
|
||||
3. Query: "Reinforcement learning for robotics"
|
||||
<categories>cs.RO, cs.AI, cs.LG</categories>
|
||||
|
||||
4. Query: "Statistical methods in machine learning"
|
||||
<categories>stat.ML, cs.LG, math.ST</categories>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<categories>List 2-4 most relevant categories, comma-separated</categories>
|
||||
|
||||
Note:
|
||||
- Choose primary categories first, then add related ones
|
||||
- Limit to 2-4 most relevant categories
|
||||
- Order by relevance (most relevant first)
|
||||
- Use comma and space between categories (e.g., "cs.AI, cs.LG")"""
|
||||
|
||||
# 在文件末尾添加新的 prompt
|
||||
ARXIV_LATEST_PROMPT = """Determine if the query is requesting latest papers from arXiv.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Analyze if the query is specifically asking for recent/latest papers from arXiv.
|
||||
|
||||
IMPORTANT RULE:
|
||||
- The query MUST explicitly mention "arXiv" or "arxiv" to be considered a latest arXiv papers request
|
||||
- Queries only asking for recent/latest papers WITHOUT mentioning arXiv should return false
|
||||
|
||||
Indicators for latest papers request:
|
||||
1. MUST HAVE keywords about arXiv:
|
||||
- "arxiv"
|
||||
- "arXiv"
|
||||
AND
|
||||
|
||||
2. Keywords about recency:
|
||||
- "latest"
|
||||
- "recent"
|
||||
- "new"
|
||||
- "newest"
|
||||
- "just published"
|
||||
- "this week/month"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Latest papers request (Valid):
|
||||
Query: "Show me the latest AI papers on arXiv"
|
||||
<is_latest_request>true</is_latest_request>
|
||||
|
||||
2. Latest papers request (Valid):
|
||||
Query: "What are the recent papers about transformers on arxiv"
|
||||
<is_latest_request>true</is_latest_request>
|
||||
|
||||
3. Not a latest papers request (Invalid - no mention of arXiv):
|
||||
Query: "Show me the latest papers about BERT"
|
||||
<is_latest_request>false</is_latest_request>
|
||||
|
||||
4. Not a latest papers request (Invalid - no recency):
|
||||
Query: "Find papers on arxiv about transformers"
|
||||
<is_latest_request>false</is_latest_request>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<is_latest_request>true/false</is_latest_request>
|
||||
|
||||
Note: The response should be true ONLY if both conditions are met:
|
||||
1. Query explicitly mentions arXiv/arxiv
|
||||
2. Query asks for recent/latest papers"""
|
||||
|
||||
ARXIV_LATEST_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine if the query is specifically requesting latest/recent papers from arXiv.
|
||||
Remember: The query MUST explicitly mention arXiv to be considered valid, even if it asks for recent papers.
|
||||
Always respond in English regardless of the input language."""
|
||||
55
crazy_functions/review_fns/prompts/crossref_prompts.py
Normal file
55
crazy_functions/review_fns/prompts/crossref_prompts.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Crossref query optimization prompt
|
||||
CROSSREF_QUERY_PROMPT = """Analyze and optimize the query for Crossref search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized Crossref search query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Available search fields and filters:
|
||||
1. Basic fields:
|
||||
- title: Search in title
|
||||
- abstract: Search in abstract
|
||||
- author: Search for author names
|
||||
- container-title: Search in journal/conference name
|
||||
- publisher: Search by publisher name
|
||||
- type: Filter by work type (journal-article, book-chapter, etc.)
|
||||
- year: Filter by publication year
|
||||
|
||||
2. Boolean operators:
|
||||
- AND: Both terms must appear
|
||||
- OR: Either term can appear
|
||||
- NOT: Exclude terms
|
||||
- "": Exact phrase match
|
||||
|
||||
3. Special filters:
|
||||
- is-referenced-by-count: Filter by citation count
|
||||
- from-pub-date: Filter by publication date
|
||||
- has-abstract: Filter papers with abstracts
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Machine learning in healthcare after 2020"
|
||||
<query>title:"machine learning" AND title:healthcare AND from-pub-date:2020</query>
|
||||
|
||||
2. Query: "Papers by Geoffrey Hinton about deep learning"
|
||||
<query>author:"Hinton, Geoffrey" AND (title:"deep learning" OR abstract:"deep learning")</query>
|
||||
|
||||
3. Query: "Most cited papers about transformers in Nature"
|
||||
<query>title:transformer AND container-title:Nature AND is-referenced-by-count:[100 TO *]</query>
|
||||
|
||||
4. Query: "Recent BERT applications in medical domain"
|
||||
<query>title:BERT AND abstract:medical AND from-pub-date:2020 AND type:journal-article</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized Crossref search query using appropriate fields and operators</query>"""
|
||||
|
||||
# System prompt
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Crossref search queries.
|
||||
Your task is to optimize natural language queries for Crossref's API.
|
||||
Focus on creating precise queries that will return relevant results.
|
||||
Always generate English search terms regardless of the input language.
|
||||
Consider using field-specific search terms and appropriate filters to improve search accuracy."""
|
||||
47
crazy_functions/review_fns/prompts/paper_prompts.py
Normal file
47
crazy_functions/review_fns/prompts/paper_prompts.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# 新建文件,添加论文识别提示
|
||||
PAPER_IDENTIFY_PROMPT = """Analyze the query to identify paper details.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Extract paper identification information from the query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on identifying paper details.
|
||||
|
||||
Possible paper identifiers:
|
||||
1. arXiv ID (e.g., 2103.14030, arXiv:2103.14030)
|
||||
2. DOI (e.g., 10.1234/xxx.xxx)
|
||||
3. Paper title (e.g., "Attention is All You Need")
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query with arXiv ID:
|
||||
Query: "Analyze paper 2103.14030"
|
||||
<paper_info>
|
||||
<paper_source>arxiv</paper_source>
|
||||
<paper_id>2103.14030</paper_id>
|
||||
<paper_title></paper_title>
|
||||
</paper_info>
|
||||
|
||||
2. Query with DOI:
|
||||
Query: "Review the paper with DOI 10.1234/xxx.xxx"
|
||||
<paper_info>
|
||||
<paper_source>doi</paper_source>
|
||||
<paper_id>10.1234/xxx.xxx</paper_id>
|
||||
<paper_title></paper_title>
|
||||
</paper_info>
|
||||
|
||||
3. Query with paper title:
|
||||
Query: "Analyze 'Attention is All You Need' paper"
|
||||
<paper_info>
|
||||
<paper_source>title</paper_source>
|
||||
<paper_id></paper_id>
|
||||
<paper_title>Attention is All You Need</paper_title>
|
||||
</paper_info>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags containing paper information."""
|
||||
|
||||
PAPER_IDENTIFY_SYSTEM_PROMPT = """You are an expert at identifying academic paper references.
|
||||
Your task is to extract paper identification information from queries.
|
||||
Look for arXiv IDs, DOIs, and paper titles."""
|
||||
108
crazy_functions/review_fns/prompts/pubmed_prompts.py
Normal file
108
crazy_functions/review_fns/prompts/pubmed_prompts.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# PubMed search type prompt
|
||||
PUBMED_TYPE_PROMPT = """Analyze the research query and determine the appropriate PubMed search type.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Available search types:
|
||||
1. basic: General keyword search for medical/biomedical topics
|
||||
2. author: Search by author name
|
||||
3. journal: Search within specific journals
|
||||
4. none: Query not related to medical/biomedical research
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "COVID-19 treatment outcomes"
|
||||
<search_type>basic</search_type>
|
||||
|
||||
2. Query: "Papers by Anthony Fauci"
|
||||
<search_type>author</search_type>
|
||||
|
||||
3. Query: "Recent papers in Nature about CRISPR"
|
||||
<search_type>journal</search_type>
|
||||
|
||||
4. Query: "Deep learning for computer vision"
|
||||
<search_type>none</search_type>
|
||||
|
||||
5. Query: "Transformer architecture for NLP"
|
||||
<search_type>none</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose: basic, author, journal, or none</search_type>"""
|
||||
|
||||
# PubMed query optimization prompt
|
||||
PUBMED_QUERY_PROMPT = """Optimize the following query for PubMed search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized PubMed search query.
|
||||
Requirements:
|
||||
- Always generate English search terms regardless of input language
|
||||
- Translate any non-English terms to English before creating the query
|
||||
- Never include non-English characters in the final query
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core medical/biomedical topic for the search query.
|
||||
|
||||
Available field tags:
|
||||
- [Title] - Search in title
|
||||
- [Author] - Search for author
|
||||
- [Journal] - Search in journal name
|
||||
- [MeSH Terms] - Search using MeSH terms
|
||||
|
||||
Boolean operators:
|
||||
- AND
|
||||
- OR
|
||||
- NOT
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "COVID-19 treatment in elderly patients"
|
||||
<query>COVID-19[Title] AND treatment[Title/Abstract] AND elderly[Title/Abstract]</query>
|
||||
|
||||
2. Query: "Cancer immunotherapy review articles"
|
||||
<query>cancer immunotherapy[Title/Abstract] AND review[Publication Type]</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized PubMed search query</query>"""
|
||||
|
||||
# PubMed sort parameters prompt
|
||||
PUBMED_SORT_PROMPT = """Determine optimal sorting parameters for PubMed results.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting method and result limit.
|
||||
|
||||
Available sort options:
|
||||
- relevance: Best match to query
|
||||
- date: Most recent first
|
||||
- journal: Sort by journal name
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in gene therapy"
|
||||
<sort_by>date</sort_by>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Classic papers about DNA structure"
|
||||
<sort_by>relevance</sort_by>
|
||||
<limit>20</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, date, or journal</sort_by>
|
||||
<limit>Suggest number between 10-50</limit>"""
|
||||
|
||||
# System prompts
|
||||
PUBMED_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing medical and scientific queries.
|
||||
Your task is to determine the most appropriate PubMed search type.
|
||||
Consider the query's focus and intended search scope.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
PUBMED_QUERY_SYSTEM_PROMPT = """You are an expert at crafting PubMed search queries.
|
||||
Your task is to optimize natural language queries using PubMed's search syntax.
|
||||
Focus on creating precise, targeted queries that will return relevant medical literature.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
PUBMED_SORT_SYSTEM_PROMPT = """You are an expert at optimizing PubMed search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the balance between relevance and recency.
|
||||
Always respond in English regardless of the input language."""
|
||||
276
crazy_functions/review_fns/prompts/semantic_prompts.py
Normal file
276
crazy_functions/review_fns/prompts/semantic_prompts.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Search type prompt
|
||||
SEMANTIC_TYPE_PROMPT = """Determine the most appropriate search type for Semantic Scholar.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Analyze the research query and select the most appropriate search type for Semantic Scholar API.
|
||||
|
||||
Available search types:
|
||||
|
||||
1. paper: General paper search
|
||||
- Use for broad topic searches
|
||||
- Looking for specific papers
|
||||
- Keyword-based searches
|
||||
Example: "transformer models in NLP"
|
||||
|
||||
2. author: Author-based search
|
||||
- Finding works by specific researchers
|
||||
- Author profile analysis
|
||||
Example: "papers by Yoshua Bengio"
|
||||
|
||||
3. paper_details: Specific paper lookup
|
||||
- Getting details about a known paper
|
||||
- Finding specific versions or citations
|
||||
Example: "Attention is All You Need paper details"
|
||||
|
||||
4. citations: Citation analysis
|
||||
- Finding papers that cite a specific work
|
||||
- Impact analysis
|
||||
Example: "papers citing BERT"
|
||||
|
||||
5. references: Reference analysis
|
||||
- Finding papers cited by a specific work
|
||||
- Background research
|
||||
Example: "references in GPT-3 paper"
|
||||
|
||||
6. recommendations: Paper recommendations
|
||||
- Finding similar papers
|
||||
- Research direction exploration
|
||||
Example: "papers similar to Transformer"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest papers about deep learning"
|
||||
<search_type>paper</search_type>
|
||||
|
||||
2. Query: "Works by Geoffrey Hinton since 2020"
|
||||
<search_type>author</search_type>
|
||||
|
||||
3. Query: "Papers citing the original Transformer paper"
|
||||
<search_type>citations</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose the most appropriate search type from the list above</search_type>"""
|
||||
|
||||
# Query optimization prompt
|
||||
SEMANTIC_QUERY_PROMPT = """Optimize the following query for Semantic Scholar search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized search query for maximum relevance.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Query optimization guidelines:
|
||||
|
||||
1. Use quotes for exact phrases
|
||||
- Ensures exact matching
|
||||
- Reduces irrelevant results
|
||||
Example: "\"attention mechanism\"" vs attention mechanism
|
||||
|
||||
2. Include key technical terms
|
||||
- Use specific technical terminology
|
||||
- Include common variations
|
||||
Example: "transformer architecture" neural networks
|
||||
|
||||
3. Author names (if relevant)
|
||||
- Include full names when known
|
||||
- Consider common name variations
|
||||
Example: "Geoffrey Hinton" OR "G. E. Hinton"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Natural query: "Recent advances in transformer models"
|
||||
<query>"transformer model" "neural architecture" deep learning</query>
|
||||
|
||||
2. Natural query: "BERT applications in text classification"
|
||||
<query>"BERT" "text classification" "language model" application</query>
|
||||
|
||||
3. Natural query: "Deep learning for computer vision by Kaiming He"
|
||||
<query>"deep learning" "computer vision" author:"Kaiming He"</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized search query</query>
|
||||
|
||||
Note:
|
||||
- Balance between specificity and coverage
|
||||
- Include important technical terms
|
||||
- Use quotes for key phrases
|
||||
- Consider synonyms and related terms"""
|
||||
|
||||
# Fields selection prompt
|
||||
SEMANTIC_FIELDS_PROMPT = """Select relevant fields to retrieve from Semantic Scholar.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Determine which paper fields should be retrieved based on the research needs.
|
||||
|
||||
Available fields:
|
||||
|
||||
Core fields:
|
||||
- title: Paper title (always included)
|
||||
- abstract: Full paper abstract
|
||||
- authors: Author information
|
||||
- year: Publication year
|
||||
- venue: Publication venue
|
||||
|
||||
Citation fields:
|
||||
- citations: Papers citing this work
|
||||
- references: Papers cited by this work
|
||||
|
||||
Additional fields:
|
||||
- embedding: Paper vector embedding
|
||||
- tldr: AI-generated summary
|
||||
- venue: Publication venue/journal
|
||||
- url: Paper URL
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in NLP"
|
||||
<fields>title, abstract, authors, year, venue, citations</fields>
|
||||
|
||||
2. Query: "Most influential papers in deep learning"
|
||||
<fields>title, abstract, authors, year, citations, references</fields>
|
||||
|
||||
3. Query: "Survey of transformer architectures"
|
||||
<fields>title, abstract, authors, year, tldr, references</fields>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<fields>List relevant fields, comma-separated</fields>
|
||||
|
||||
Note:
|
||||
- Choose fields based on the query's purpose
|
||||
- Include citation data for impact analysis
|
||||
- Consider tldr for quick paper screening
|
||||
- Balance completeness with API efficiency"""
|
||||
|
||||
# Sort parameters prompt
|
||||
SEMANTIC_SORT_PROMPT = """Determine optimal sorting parameters for the query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting method and result limit for the search.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
Sorting options:
|
||||
|
||||
1. relevance (default)
|
||||
- Best match to query terms
|
||||
- Recommended for specific technical searches
|
||||
Example: "specific algorithm implementations"
|
||||
|
||||
2. citations
|
||||
- Sort by citation count
|
||||
- Best for finding influential papers
|
||||
Example: "most important papers in deep learning"
|
||||
|
||||
3. year
|
||||
- Sort by publication date
|
||||
- Best for following recent developments
|
||||
Example: "latest advances in NLP"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent breakthroughs in AI"
|
||||
<sort_by>year</sort_by>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Most influential papers about GANs"
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>20</limit>
|
||||
|
||||
3. Query: "Specific papers about BERT fine-tuning"
|
||||
<sort_by>relevance</sort_by>
|
||||
<limit>25</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, citations, or year</sort_by>
|
||||
<limit>Suggest number between 10-50</limit>
|
||||
|
||||
Note:
|
||||
- Consider the query's temporal aspects
|
||||
- Balance between comprehensive coverage and information overload
|
||||
- Use citation sorting for impact analysis
|
||||
- Use year sorting for tracking developments"""
|
||||
|
||||
# System prompts for each task
|
||||
SEMANTIC_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine the most appropriate type of search on Semantic Scholar.
|
||||
Consider the query's intent, scope, and specific research needs.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||
Your task is to optimize natural language queries for maximum relevance.
|
||||
Focus on creating precise queries that leverage the platform's search capabilities.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
SEMANTIC_FIELDS_SYSTEM_PROMPT = """You are an expert at Semantic Scholar data fields.
|
||||
Your task is to select the most relevant fields based on the research context.
|
||||
Consider both essential and supplementary information needs.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
SEMANTIC_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the balance between relevance, impact, and recency.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
# 添加新的综合搜索提示词
|
||||
SEMANTIC_SEARCH_PROMPT = """Analyze and optimize the research query for Semantic Scholar search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into optimized search criteria for Semantic Scholar.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements when generating the search terms. These requirements
|
||||
should be considered only for post-search filtering, not as part of the core query.
|
||||
|
||||
Available search options:
|
||||
1. Paper search:
|
||||
- Title and abstract search
|
||||
- Author search
|
||||
- Field-specific search
|
||||
Example: "transformer architecture neural networks"
|
||||
|
||||
2. Field tags:
|
||||
- title: Search in title
|
||||
- abstract: Search in abstract
|
||||
- authors: Search by author names
|
||||
- venue: Search by publication venue
|
||||
Example: "title:transformer authors:\"Vaswani\""
|
||||
|
||||
3. Advanced options:
|
||||
- Year range filtering
|
||||
- Citation count filtering
|
||||
- Venue filtering
|
||||
Example: "deep learning year>2020 venue:\"NeurIPS\""
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent transformer papers by Vaswani with high impact"
|
||||
<search_criteria>
|
||||
<query>title:transformer authors:"Vaswani" year>2017</query>
|
||||
<search_type>paper</search_type>
|
||||
<fields>title,abstract,authors,year,citations,venue</fields>
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>30</limit>
|
||||
</search_criteria>
|
||||
|
||||
2. Query: "Most cited papers about BERT in top conferences"
|
||||
<search_criteria>
|
||||
<query>title:BERT venue:"ACL|EMNLP|NAACL"</query>
|
||||
<search_type>paper</search_type>
|
||||
<fields>title,abstract,authors,year,citations,venue,references</fields>
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>25</limit>
|
||||
</search_criteria>
|
||||
|
||||
Please analyze the query and respond with XML tags containing complete search criteria."""
|
||||
|
||||
SEMANTIC_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||
Your task is to analyze research queries and transform them into optimized search criteria.
|
||||
Consider query intent, field relevance, and citation impact when creating the search parameters.
|
||||
Focus on producing precise and comprehensive search criteria that will yield the most relevant results.
|
||||
Always generate English search terms and respond in English regardless of the input language."""
|
||||
493
crazy_functions/review_fns/query_analyzer.py
Normal file
493
crazy_functions/review_fns/query_analyzer.py
Normal file
@@ -0,0 +1,493 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
from textwrap import dedent
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchCriteria:
|
||||
"""搜索条件"""
|
||||
query_type: str # 查询类型: review/recommend/qa/paper
|
||||
main_topic: str # 主题
|
||||
sub_topics: List[str] # 子主题列表
|
||||
start_year: int # 起始年份
|
||||
end_year: int # 结束年份
|
||||
arxiv_params: Dict # arXiv搜索参数
|
||||
semantic_params: Dict # Semantic Scholar搜索参数
|
||||
pubmed_params: Dict # 新增: PubMed搜索参数
|
||||
crossref_params: Dict # 添加 Crossref 参数
|
||||
adsabs_params: Dict # 添加 ADS 参数
|
||||
paper_id: str = "" # 论文ID (arxiv ID 或 DOI)
|
||||
paper_title: str = "" # 论文标题
|
||||
paper_source: str = "" # 论文来源 (arxiv/doi/title)
|
||||
original_query: str = "" # 新增: 原始查询字符串
|
||||
|
||||
|
||||
class QueryAnalyzer:
|
||||
"""查询分析器"""
|
||||
|
||||
# 响应索引常量
|
||||
BASIC_QUERY_INDEX = 0
|
||||
PAPER_IDENTIFY_INDEX = 1
|
||||
ARXIV_QUERY_INDEX = 2
|
||||
ARXIV_CATEGORIES_INDEX = 3
|
||||
ARXIV_LATEST_INDEX = 4
|
||||
ARXIV_SORT_INDEX = 5
|
||||
SEMANTIC_QUERY_INDEX = 6
|
||||
SEMANTIC_FIELDS_INDEX = 7
|
||||
PUBMED_TYPE_INDEX = 8
|
||||
PUBMED_QUERY_INDEX = 9
|
||||
CROSSREF_QUERY_INDEX = 10
|
||||
ADSABS_QUERY_INDEX = 11
|
||||
|
||||
def __init__(self):
|
||||
self.current_year = datetime.now().year
|
||||
self.valid_types = {
|
||||
"review": ["review", "literature review", "survey"],
|
||||
"recommend": ["recommend", "recommendation", "suggest", "similar"],
|
||||
"qa": ["qa", "question", "answer", "explain", "what", "how", "why"],
|
||||
"paper": ["paper", "analyze", "analysis"]
|
||||
}
|
||||
|
||||
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
|
||||
"""分析查询意图"""
|
||||
from crazy_functions.crazy_utils import \
|
||||
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
from crazy_functions.review_fns.prompts.arxiv_prompts import (
|
||||
ARXIV_QUERY_PROMPT, ARXIV_CATEGORIES_PROMPT, ARXIV_LATEST_PROMPT,
|
||||
ARXIV_SORT_PROMPT, ARXIV_QUERY_SYSTEM_PROMPT, ARXIV_CATEGORIES_SYSTEM_PROMPT, ARXIV_SORT_SYSTEM_PROMPT,
|
||||
ARXIV_LATEST_SYSTEM_PROMPT
|
||||
)
|
||||
from crazy_functions.review_fns.prompts.semantic_prompts import (
|
||||
SEMANTIC_QUERY_PROMPT, SEMANTIC_FIELDS_PROMPT,
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT, SEMANTIC_FIELDS_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.paper_prompts import PAPER_IDENTIFY_PROMPT, PAPER_IDENTIFY_SYSTEM_PROMPT
|
||||
from .prompts.pubmed_prompts import (
|
||||
PUBMED_TYPE_PROMPT, PUBMED_QUERY_PROMPT, PUBMED_SORT_PROMPT,
|
||||
PUBMED_TYPE_SYSTEM_PROMPT, PUBMED_QUERY_SYSTEM_PROMPT, PUBMED_SORT_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.crossref_prompts import (
|
||||
CROSSREF_QUERY_PROMPT,
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.adsabs_prompts import ADSABS_QUERY_PROMPT, ADSABS_QUERY_SYSTEM_PROMPT
|
||||
|
||||
# 1. 基本查询分析
|
||||
type_prompt = dedent(f"""Please analyze this academic query and respond STRICTLY in the following XML format:
|
||||
|
||||
Query: {query}
|
||||
|
||||
Instructions:
|
||||
1. Your response must use XML tags exactly as shown below
|
||||
2. Do not add any text outside the tags
|
||||
3. Choose query type from: review/recommend/qa/paper
|
||||
- review: for literature review or survey requests
|
||||
- recommend: for paper recommendation requests
|
||||
- qa: for general questions about research topics
|
||||
- paper: ONLY for queries about a SPECIFIC paper (with paper ID, DOI, or exact title)
|
||||
4. Identify main topic and subtopics
|
||||
5. Specify year range if mentioned
|
||||
|
||||
Required format:
|
||||
<query_type>ANSWER HERE</query_type>
|
||||
<main_topic>ANSWER HERE</main_topic>
|
||||
<sub_topics>SUBTOPIC1, SUBTOPIC2, ...</sub_topics>
|
||||
<year_range>START_YEAR-END_YEAR</year_range>
|
||||
|
||||
Example responses:
|
||||
|
||||
1. Literature Review Request:
|
||||
Query: "Review recent developments in transformer models for NLP from 2020 to 2023"
|
||||
<query_type>review</query_type>
|
||||
<main_topic>transformer models in natural language processing</main_topic>
|
||||
<sub_topics>architecture improvements, pre-training methods, fine-tuning techniques</sub_topics>
|
||||
<year_range>2020-2023</year_range>
|
||||
|
||||
2. Paper Recommendation Request:
|
||||
Query: "Suggest papers about reinforcement learning in robotics since 2018"
|
||||
<query_type>recommend</query_type>
|
||||
<main_topic>reinforcement learning in robotics</main_topic>
|
||||
<sub_topics>robot control, policy learning, sim-to-real transfer</sub_topics>
|
||||
<year_range>2018-2023</year_range>"""
|
||||
)
|
||||
|
||||
try:
|
||||
# 构建提示数组
|
||||
prompts = [
|
||||
type_prompt,
|
||||
PAPER_IDENTIFY_PROMPT.format(query=query),
|
||||
ARXIV_QUERY_PROMPT.format(query=query),
|
||||
ARXIV_CATEGORIES_PROMPT.format(query=query),
|
||||
ARXIV_LATEST_PROMPT.format(query=query),
|
||||
ARXIV_SORT_PROMPT.format(query=query),
|
||||
SEMANTIC_QUERY_PROMPT.format(query=query),
|
||||
SEMANTIC_FIELDS_PROMPT.format(query=query),
|
||||
PUBMED_TYPE_PROMPT.format(query=query),
|
||||
PUBMED_QUERY_PROMPT.format(query=query),
|
||||
CROSSREF_QUERY_PROMPT.format(query=query),
|
||||
ADSABS_QUERY_PROMPT.format(query=query)
|
||||
]
|
||||
|
||||
show_messages = [
|
||||
"Analyzing query type...",
|
||||
"Identifying paper details...",
|
||||
"Determining arXiv search type...",
|
||||
"Selecting arXiv categories...",
|
||||
"Checking if latest papers requested...",
|
||||
"Determining arXiv sort parameters...",
|
||||
"Optimizing Semantic Scholar query...",
|
||||
"Selecting Semantic Scholar fields...",
|
||||
"Determining PubMed search type...",
|
||||
"Optimizing PubMed query...",
|
||||
"Optimizing Crossref query...",
|
||||
"Optimizing ADS query..."
|
||||
]
|
||||
|
||||
sys_prompts = [
|
||||
"You are an expert at analyzing academic queries.",
|
||||
PAPER_IDENTIFY_SYSTEM_PROMPT,
|
||||
ARXIV_QUERY_SYSTEM_PROMPT,
|
||||
ARXIV_CATEGORIES_SYSTEM_PROMPT,
|
||||
ARXIV_LATEST_SYSTEM_PROMPT,
|
||||
ARXIV_SORT_SYSTEM_PROMPT,
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT,
|
||||
SEMANTIC_FIELDS_SYSTEM_PROMPT,
|
||||
PUBMED_TYPE_SYSTEM_PROMPT,
|
||||
PUBMED_QUERY_SYSTEM_PROMPT,
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT,
|
||||
ADSABS_QUERY_SYSTEM_PROMPT
|
||||
]
|
||||
new_llm_kwargs = llm_kwargs.copy()
|
||||
# new_llm_kwargs['llm_model'] = 'deepseek-chat' # deepseek-ai/DeepSeek-V2.5
|
||||
|
||||
# 使用同步方式调用LLM
|
||||
responses = yield from request_gpt(
|
||||
inputs_array=prompts,
|
||||
inputs_show_user_array=show_messages,
|
||||
llm_kwargs=new_llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[[] for _ in prompts],
|
||||
sys_prompt_array=sys_prompts,
|
||||
max_workers=5
|
||||
)
|
||||
|
||||
# 从收集的响应中提取我们需要的内容
|
||||
extracted_responses = []
|
||||
for i in range(len(prompts)):
|
||||
if (i * 2 + 1) < len(responses):
|
||||
response = responses[i * 2 + 1]
|
||||
if response is None:
|
||||
raise Exception(f"Response {i} is None")
|
||||
if not isinstance(response, str):
|
||||
try:
|
||||
response = str(response)
|
||||
except:
|
||||
raise Exception(f"Cannot convert response {i} to string")
|
||||
extracted_responses.append(response)
|
||||
else:
|
||||
raise Exception(f"未收到第 {i + 1} 个响应")
|
||||
|
||||
# 解析基本信息
|
||||
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
|
||||
if not query_type:
|
||||
print(
|
||||
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
|
||||
raise Exception("无法提取query_type标签内容")
|
||||
query_type = query_type.lower()
|
||||
|
||||
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
|
||||
if not main_topic:
|
||||
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
|
||||
main_topic = query
|
||||
|
||||
query_type = self._normalize_query_type(query_type, query)
|
||||
|
||||
# 解析arXiv参数
|
||||
try:
|
||||
arxiv_params = {
|
||||
"query": self._extract_tag(extracted_responses[self.ARXIV_QUERY_INDEX], "query"),
|
||||
"categories": [cat.strip() for cat in
|
||||
self._extract_tag(extracted_responses[self.ARXIV_CATEGORIES_INDEX],
|
||||
"categories").split(",")],
|
||||
"sort_by": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_by"),
|
||||
"sort_order": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_order"),
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
# 安全地解析limit值
|
||||
limit_str = self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "limit")
|
||||
if limit_str and limit_str.isdigit():
|
||||
arxiv_params["limit"] = int(limit_str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing arXiv parameters: {str(e)}")
|
||||
arxiv_params = {
|
||||
"query": "",
|
||||
"categories": [],
|
||||
"sort_by": "relevance",
|
||||
"sort_order": "descending",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析Semantic Scholar参数
|
||||
try:
|
||||
semantic_params = {
|
||||
"query": self._extract_tag(extracted_responses[self.SEMANTIC_QUERY_INDEX], "query"),
|
||||
"fields": [field.strip() for field in
|
||||
self._extract_tag(extracted_responses[self.SEMANTIC_FIELDS_INDEX], "fields").split(",")],
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing Semantic Scholar parameters: {str(e)}")
|
||||
semantic_params = {
|
||||
"query": query,
|
||||
"fields": ["title", "abstract", "authors", "year"],
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
# 解析PubMed参数
|
||||
try:
|
||||
# 首先检查是否需要PubMed搜索
|
||||
pubmed_search_type = self._extract_tag(extracted_responses[self.PUBMED_TYPE_INDEX], "search_type")
|
||||
|
||||
if pubmed_search_type == "none":
|
||||
# 不需要PubMed搜索,使用空参数
|
||||
pubmed_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
# 需要PubMed搜索,解析完整参数
|
||||
pubmed_params = {
|
||||
"search_type": pubmed_search_type,
|
||||
"query": self._extract_tag(extracted_responses[self.PUBMED_QUERY_INDEX], "query"),
|
||||
"sort_by": "relevance",
|
||||
"limit": 200
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing PubMed parameters: {str(e)}")
|
||||
pubmed_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析Crossref参数
|
||||
try:
|
||||
crossref_query = self._extract_tag(extracted_responses[self.CROSSREF_QUERY_INDEX], "query")
|
||||
|
||||
if not crossref_query:
|
||||
crossref_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
crossref_params = {
|
||||
"search_type": "basic",
|
||||
"query": crossref_query,
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing Crossref parameters: {str(e)}")
|
||||
crossref_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析ADS参数
|
||||
try:
|
||||
adsabs_query = self._extract_tag(extracted_responses[self.ADSABS_QUERY_INDEX], "query")
|
||||
|
||||
if not adsabs_query:
|
||||
adsabs_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
adsabs_params = {
|
||||
"search_type": "basic",
|
||||
"query": adsabs_query,
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing ADS parameters: {str(e)}")
|
||||
adsabs_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
print(f"Debug - Extracted information:")
|
||||
print(f"Query type: {query_type}")
|
||||
print(f"Main topic: {main_topic}")
|
||||
print(f"arXiv params: {arxiv_params}")
|
||||
print(f"Semantic params: {semantic_params}")
|
||||
print(f"PubMed params: {pubmed_params}")
|
||||
print(f"Crossref params: {crossref_params}")
|
||||
print(f"ADS params: {adsabs_params}")
|
||||
|
||||
# 提取子主题
|
||||
sub_topics = []
|
||||
if "sub_topics" in query.lower():
|
||||
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
|
||||
if sub_topics_text:
|
||||
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
|
||||
|
||||
# 提取年份范围
|
||||
start_year = self.current_year - 5 # 默认最近5年
|
||||
end_year = self.current_year
|
||||
year_range = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "year_range")
|
||||
if year_range:
|
||||
try:
|
||||
years = year_range.split("-")
|
||||
if len(years) == 2:
|
||||
start_year = int(years[0].strip())
|
||||
end_year = int(years[1].strip())
|
||||
except:
|
||||
pass
|
||||
|
||||
# 提取 latest request 判断
|
||||
is_latest_request = self._extract_tag(extracted_responses[self.ARXIV_LATEST_INDEX],
|
||||
"is_latest_request").lower() == "true"
|
||||
|
||||
# 如果是最新论文请求,将查询类型改为 "latest"
|
||||
if is_latest_request:
|
||||
query_type = "latest"
|
||||
|
||||
# 提取论文标识信息
|
||||
paper_source = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_source")
|
||||
paper_id = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_id")
|
||||
paper_title = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_title")
|
||||
if start_year > end_year:
|
||||
start_year, end_year = end_year, start_year
|
||||
# 更新返回的 SearchCriteria
|
||||
return SearchCriteria(
|
||||
query_type=query_type,
|
||||
main_topic=main_topic,
|
||||
sub_topics=sub_topics,
|
||||
start_year=start_year,
|
||||
end_year=end_year,
|
||||
arxiv_params=arxiv_params,
|
||||
semantic_params=semantic_params,
|
||||
pubmed_params=pubmed_params,
|
||||
crossref_params=crossref_params,
|
||||
paper_id=paper_id,
|
||||
paper_title=paper_title,
|
||||
paper_source=paper_source,
|
||||
original_query=query,
|
||||
adsabs_params=adsabs_params
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to analyze query: {str(e)}")
|
||||
|
||||
def _normalize_query_type(self, query_type: str, query: str) -> str:
|
||||
"""规范化查询类型"""
|
||||
if query_type in ["review", "recommend", "qa", "paper"]:
|
||||
return query_type
|
||||
|
||||
query_lower = query.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_lower:
|
||||
return type_name
|
||||
|
||||
query_type_lower = query_type.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_type_lower:
|
||||
return type_name
|
||||
|
||||
return "qa" # 默认返回qa类型
|
||||
|
||||
def _extract_tag(self, text: str, tag: str) -> str:
|
||||
"""提取标记内容"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. 标准XML格式(处理多行和特殊字符)
|
||||
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 2. 处理特定标签的复杂内容
|
||||
if tag == "categories":
|
||||
# 处理arXiv类别
|
||||
patterns = [
|
||||
# 标准格式:<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
r"<categories>\s*((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)\s*</categories>",
|
||||
# 简单列表格式:cs.CL, cs.AI, cs.LG
|
||||
r"(?:^|\s)((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)(?:\s|$)",
|
||||
# 单个类别格式:cs.AI
|
||||
r"(?:^|\s)((?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+)(?:\s|$)"
|
||||
]
|
||||
|
||||
elif tag == "query":
|
||||
# 处理搜索查询
|
||||
patterns = [
|
||||
# 完整的查询格式:<query>complex query</query>
|
||||
r"<query>\s*((?:(?:ti|abs|au|cat):[^\n]*?|(?:AND|OR|NOT|\(|\)|\d{4}|year:\d{4}|[\"'][^\"']*[\"']|\s+))+)\s*</query>",
|
||||
# 简单的关键词列表:keyword1, keyword2
|
||||
r"(?:^|\s)((?:\"[^\"]*\"|'[^']*'|[^\s,]+)(?:\s*,\s*(?:\"[^\"]*\"|'[^']*'|[^\s,]+))*)",
|
||||
# 字段搜索格式:field:value
|
||||
r"((?:ti|abs|au|cat):\s*(?:\"[^\"]*\"|'[^']*'|[^\s]+))"
|
||||
]
|
||||
|
||||
elif tag == "fields":
|
||||
# 处理字段列表
|
||||
patterns = [
|
||||
# 标准格式:<fields>field1, field2</fields>
|
||||
r"<fields>\s*([\w\s,]+)\s*</fields>",
|
||||
# 简单列表格式:field1, field2
|
||||
r"(?:^|\s)([\w]+(?:\s*,\s*[\w]+)*)",
|
||||
]
|
||||
|
||||
elif tag == "sort_by":
|
||||
# 处理排序字段
|
||||
patterns = [
|
||||
# 标准格式:<sort_by>value</sort_by>
|
||||
r"<sort_by>\s*(relevance|date|citations|submittedDate|year)\s*</sort_by>",
|
||||
# 简单值格式:relevance
|
||||
r"(?:^|\s)(relevance|date|citations|submittedDate|year)(?:\s|$)"
|
||||
]
|
||||
|
||||
else:
|
||||
# 通用模式
|
||||
patterns = [
|
||||
f"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
|
||||
f"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
|
||||
f"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
|
||||
f"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
|
||||
f"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
|
||||
]
|
||||
|
||||
# 3. 尝试所有模式
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content: # 确保提取的内容不为空
|
||||
return content
|
||||
|
||||
# 4. 如果所有模式都失败,返回空字符串
|
||||
return ""
|
||||
|
||||
64
crazy_functions/review_fns/query_processor.py
Normal file
64
crazy_functions/review_fns/query_processor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import List, Dict, Any
|
||||
from .query_analyzer import QueryAnalyzer, SearchCriteria
|
||||
from .data_sources.arxiv_source import ArxivSource
|
||||
from .data_sources.semantic_source import SemanticScholarSource
|
||||
from .handlers.review_handler import 文献综述功能
|
||||
from .handlers.recommend_handler import 论文推荐功能
|
||||
from .handlers.qa_handler import 学术问答功能
|
||||
from .handlers.paper_handler import 单篇论文分析功能
|
||||
|
||||
class QueryProcessor:
|
||||
"""查询处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.analyzer = QueryAnalyzer()
|
||||
self.arxiv = ArxivSource()
|
||||
self.semantic = SemanticScholarSource()
|
||||
|
||||
# 初始化各种处理器
|
||||
self.handlers = {
|
||||
"review": 文献综述功能(self.arxiv, self.semantic),
|
||||
"recommend": 论文推荐功能(self.arxiv, self.semantic),
|
||||
"qa": 学术问答功能(self.arxiv, self.semantic),
|
||||
"paper": 单篇论文分析功能(self.arxiv, self.semantic)
|
||||
}
|
||||
|
||||
async def process_query(
|
||||
self,
|
||||
query: str,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> List[List[str]]:
|
||||
"""处理用户查询"""
|
||||
|
||||
# 设置默认的插件参数
|
||||
default_plugin_kwargs = {
|
||||
'max_papers': 20, # 最大论文数量
|
||||
'min_year': 2015, # 最早年份
|
||||
'search_multiplier': 3, # 检索倍数
|
||||
}
|
||||
# 更新插件参数
|
||||
plugin_kwargs.update({k: v for k, v in default_plugin_kwargs.items() if k not in plugin_kwargs})
|
||||
|
||||
# 1. 分析查询意图
|
||||
criteria = self.analyzer.analyze_query(query, chatbot, llm_kwargs)
|
||||
|
||||
# 2. 根据查询类型选择处理器
|
||||
handler = self.handlers.get(criteria.query_type)
|
||||
if not handler:
|
||||
handler = self.handlers["qa"] # 默认使用QA处理器
|
||||
|
||||
# 3. 处理查询
|
||||
response = await handler.handle(
|
||||
criteria,
|
||||
chatbot,
|
||||
history,
|
||||
system_prompt,
|
||||
llm_kwargs,
|
||||
plugin_kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user