Compare commits

..

48 Commits

Author SHA1 Message Date
lbykkkk
2f946f3e6c fix 2024-12-02 00:11:41 +08:00
lbykkkk
51ea7f3b5e arxiv rag arXiv directory sharing 2024-12-01 23:40:55 +08:00
lbykkkk
795a6a9333 add tex html formatter 2024-12-01 23:26:02 +08:00
lbykkkk
3beb22a347 up 2024-12-01 22:00:41 +08:00
lbykkkk
b3aef6b393 up 2024-12-01 17:35:57 +08:00
lbykkkk
cf51d4b205 up 2024-12-01 17:28:23 +08:00
lbykkkk
bd9c88e896 up 2024-12-01 13:46:51 +08:00
lbykkkk
27958b9030 up 2024-11-27 22:40:23 +08:00
lbykkkk
9b9d77eded up 2024-11-24 03:18:17 +08:00
lbykkkk
50dbff3a14 available 2024-11-23 19:13:10 +00:00
lbykkkk
e1dc600030 availavle version with async 2024-11-23 17:44:33 +00:00
lbykkkk
6557c3822a up 2024-11-23 20:01:33 +08:00
lbykkkk
81ab9f91a4 up 2024-11-23 19:40:56 +08:00
lbykkkk
241c9641bb up 2024-11-23 11:31:11 +00:00
lbykkkk
b2d6536974 up 2024-11-23 19:11:02 +08:00
lbykkkk
12be7c16e9 up 2024-11-23 19:00:02 +08:00
lbykkkk
724940a9d8 up 2024-11-23 17:59:17 +08:00
lbykkkk
ea4cd95645 Add structured chunking 2024-11-22 02:25:43 +08:00
lbykkkk
f8b60870e9 up 2024-11-17 17:36:01 +00:00
lbykkkk
cbef9a908c up 2024-11-17 23:15:34 +08:00
lbykkkk
21626a44d5 up 2024-11-16 00:35:31 +08:00
lbykkkk
dd902e9519 expect query search 2024-11-11 02:11:42 +08:00
lbykkkk
68aa846a89 up 2024-11-10 15:06:50 +08:00
lbykkkk
b8617921f4 保存完整的section层级路径 2024-11-09 18:19:51 +08:00
lbykkkk
c6687646e4 Merge branch 'refs/heads/boyin_summary' into boyin_rag
# Conflicts:
#	crazy_functions/rag_fns/rag_file_support.py
2024-11-09 15:10:46 +08:00
lbykkkk
bfa72fb4cf up 2024-11-09 14:59:47 +08:00
lbykkkk
61676d0536 up 2024-11-06 00:47:56 +08:00
lbykkkk
df2ef7940c up 2024-11-05 02:08:12 +08:00
lbykkkk
0afd27deca Merge branch 'refs/heads/master' into boyin_rag 2024-11-05 00:09:22 +08:00
lbykkkk
c10f2b45e5 Default prompt word count control 2024-11-03 23:05:02 +08:00
lbykkkk
7e2ede2d12 up 2024-11-03 22:54:19 +08:00
lbykkkk
ec10e2a3ac Merge branch 'refs/heads/batch-file-query' into boyin_summary
# Conflicts:
#	crazy_functional.py
2024-11-03 22:49:29 +08:00
binary-husky
7474d43433 stage connection 2024-11-03 14:19:16 +00:00
binary-husky
83489f9acf Merge remote-tracking branch 'origin/boyin_summary' 2024-11-03 14:12:04 +00:00
lbykkkk
36e50d490d up 2024-11-03 17:57:56 +08:00
lbykkkk
9172337695 Add batch document inquiry function 2024-11-03 17:17:16 +08:00
lbykkkk
5dab7b2290 refine 2024-10-29 23:54:55 +08:00
lbykkkk
89dc6c7265 refine 2024-10-21 22:58:04 +08:00
lbykkkk
21111d3bd0 refine 2024-10-21 00:57:29 +08:00
lbykkkk
be9aead04a Merge remote-tracking branch 'origin/boyin_rag' into boyin_rag 2024-10-21 00:36:58 +08:00
lbykkkk
701018f48c up 2024-10-21 00:30:18 +08:00
lbykkkk
8733c4e1e9 file type support 2024-10-20 01:33:00 +08:00
lbykkkk
8498ddf6bf up 2024-10-19 17:31:30 +00:00
lbykkkk
3c3293818d Change the word document summary function to document summary function 2024-10-20 01:14:42 +08:00
binary-husky
9adc0ade71 change import order 2024-10-14 14:37:44 +00:00
lbykkkk
bbcdd9aa71 Address import LlamaIndexRagWorker problem 2024-10-13 18:01:33 +00:00
lbykkkk
cdfe38d296 new resolve 2024-10-13 17:07:40 +00:00
lbykkkk
159f628dfe Resolve LlamaIndexRagWorker bug 2024-10-13 17:04:34 +00:00
96 changed files with 6471 additions and 6104 deletions

View File

@@ -1,56 +0,0 @@
name: Create Conda Environment Package
on:
workflow_dispatch:
jobs:
build:
runs-on: windows-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-activate-base: true
activate-environment: ""
- name: Create new Conda environment
shell: bash -l {0}
run: |
conda create -n gpt python=3.11 -y
conda activate gpt
- name: Install requirements
shell: bash -l {0}
run: |
conda activate gpt
pip install -r requirements.txt
- name: Install conda-pack
shell: bash -l {0}
run: |
conda activate gpt
conda install conda-pack -y
- name: Pack conda environment
shell: bash -l {0}
run: |
conda activate gpt
conda pack -n gpt -o gpt.tar.gz
- name: Create workspace zip
shell: pwsh
run: |
mkdir workspace
Get-ChildItem -Exclude "workspace" | Copy-Item -Destination workspace -Recurse
Remove-Item -Path workspace/.git* -Recurse -Force -ErrorAction SilentlyContinue
Copy-Item gpt.tar.gz workspace/ -Force
- name: Upload packed files
uses: actions/upload-artifact@v4
with:
name: gpt-academic-package
path: workspace

View File

@@ -7,7 +7,7 @@
name: 'Close stale issues and PRs'
on:
schedule:
- cron: '*/30 * * * *'
- cron: '*/5 * * * *'
jobs:
stale:
@@ -19,6 +19,7 @@ jobs:
steps:
- uses: actions/stale@v8
with:
stale-issue-message: 'This issue is stale because it has been open 100 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
stale-issue-message: 'This issue is stale because it has been open 100 days with no activity. Remove stale label or comment or this will be closed in 1 days.'
days-before-stale: 100
days-before-close: 7
days-before-close: 1
debug-only: true

3
.gitignore vendored
View File

@@ -161,5 +161,4 @@ temp.*
objdump*
*.min.*.js
TODO
experimental_mods
search_results
*.cursorrules

View File

@@ -15,7 +15,6 @@ RUN echo '[global]' > /etc/pip.conf && \
# 语音输出功能以下两行第一行更换阿里源第二行安装ffmpeg都可以删除
RUN UBUNTU_VERSION=$(awk -F= '/^VERSION_CODENAME=/{print $2}' /etc/os-release); echo "deb https://mirrors.aliyun.com/debian/ $UBUNTU_VERSION main non-free contrib" > /etc/apt/sources.list; apt-get update
RUN apt-get install ffmpeg -y
RUN apt-get clean
# 进入工作路径(必要)
@@ -34,7 +33,6 @@ RUN pip3 install -r requirements.txt
# 非必要步骤,用于预热模块(可以删除)
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
RUN python3 -m pip cache purge
# 启动(必要)

View File

@@ -1,10 +1,5 @@
> [!IMPORTANT]
> `master主分支`最新动态(2025.2.4): 增加deepseek-r1支持增加字体自定义功能
> `master主分支`最新动态(2025.2.2): 三分钟快速接入最强qwen2.5-max[视频](https://www.bilibili.com/video/BV1LeFuerEG4)
> `frontier开发分支`最新动态(2024.12.9): 更新对话时间线功能优化xelatex论文翻译
> `wiki文档`最新动态(2024.12.5): 更新ollama接入指南
>
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/drive/folders/14kR-3V-lIbvGxri4AHc8TpiA1fqsw7SK?usp=sharing)的文件服务器
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/file/d/19U_hsLoMrjOlQSzYS3pzWX9fTzyusArP/view?usp=sharing)的文件服务器
> 2024.10.8: 版本3.90加入对llama-index的初步支持版本3.80加入插件二级菜单功能详见wiki
> 2024.5.1: 加入Doc2x翻译PDF论文的功能[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型 SoVits语音克隆模块[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
@@ -175,32 +170,26 @@ flowchart TD
```
<details><summary>如果需要支持清华ChatGLM系列/复旦MOSS/RWKV作为后端请点击展开此处</summary>
<details><summary>如果需要支持清华ChatGLM2/复旦MOSS/RWKV作为后端请点击展开此处</summary>
<p>
【可选步骤】如果需要支持清华ChatGLM系列/复旦MOSS作为后端需要额外安装更多依赖前提条件熟悉Python + 用过Pytorch + 电脑配置够强):
【可选步骤】如果需要支持清华ChatGLM3/复旦MOSS作为后端需要额外安装更多依赖前提条件熟悉Python + 用过Pytorch + 电脑配置够强):
```sh
# 【可选步骤I】支持清华ChatGLM3。清华ChatGLM备注如果遇到"Call ChatGLM fail 不能正常加载ChatGLM的参数" 错误,参考如下: 1以上默认安装的为torch+cpu版使用cuda需要卸载torch重新安装torch+cuda 2如因本机配置不够无法加载模型可以修改request_llm/bridge_chatglm.py中的模型精度, 将 AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 都修改为 AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
python -m pip install -r request_llms/requirements_chatglm.txt
# 【可选步骤II】支持清华ChatGLM4 注意此模型至少需要24G显存
python -m pip install -r request_llms/requirements_chatglm4.txt
# 可使用modelscope下载ChatGLM4模型
# pip install modelscope
# modelscope download --model ZhipuAI/glm-4-9b-chat --local_dir ./THUDM/glm-4-9b-chat
# 【可选步骤III】支持复旦MOSS
# 【可选步骤II】支持复旦MOSS
python -m pip install -r request_llms/requirements_moss.txt
git clone --depth=1 https://github.com/OpenLMLab/MOSS.git request_llms/moss # 注意执行此行代码时,必须处于项目根路径
# 【可选步骤IV】支持RWKV Runner
# 【可选步骤III】支持RWKV Runner
参考wikihttps://github.com/binary-husky/gpt_academic/wiki/%E9%80%82%E9%85%8DRWKV-Runner
# 【可选步骤V】确保config.py配置文件的AVAIL_LLM_MODELS包含了期望的模型目前支持的全部模型如下(jittorllms系列目前仅支持docker方案)
# 【可选步骤IV】确保config.py配置文件的AVAIL_LLM_MODELS包含了期望的模型目前支持的全部模型如下(jittorllms系列目前仅支持docker方案)
AVAIL_LLM_MODELS = ["gpt-3.5-turbo", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss"] # + ["jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
# 【可选步骤VI】支持本地模型INT8,INT4量化这里所指的模型本身不是量化版本目前deepseek-coder支持后面测试后会加入更多模型量化选择
# 【可选步骤V】支持本地模型INT8,INT4量化这里所指的模型本身不是量化版本目前deepseek-coder支持后面测试后会加入更多模型量化选择
pip install bitsandbyte
# windows用户安装bitsandbytes需要使用下面bitsandbytes-windows-webui
python -m pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui

View File

@@ -7,16 +7,11 @@
Configuration reading priority: environment variable > config_private.py > config.py
"""
# [step 1-1]>> ( 接入GPT等模型 ) API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下还需要填写组织格式如org-123456789abcdefghijklmno的请向下翻找 API_ORG 设置项
API_KEY = "此处填APIKEY" # 可同时填写多个API-KEY用英文逗号分割例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 1]>> API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下还需要填写组织格式如org-123456789abcdefghijklmno的请向下翻找 API_ORG 设置项
API_KEY = "此处填API密钥" # 可同时填写多个API-KEY用英文逗号分割例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 1-2]>> ( 接入通义 qwen-max ) 接入通义千问在线大模型api-key获取地址 https://dashscope.console.aliyun.com/
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY
# [step 1-3]>> ( 接入 deepseek-reasoner, 即 deepseek-r1 ) 深度求索(DeepSeek) API KEY默认请求地址为"https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = ""
# [step 2]>> 改为True应用代理。如果使用本地或无地域限制的大模型时此处不修改如果直接在海外服务器部署此处不修改
# [step 2]>> 改为True应用代理如果直接在海外服务器部署此处不修改如果使用本地或无地域限制的大模型时此处也不需要修改
USE_PROXY = False
if USE_PROXY:
"""
@@ -37,13 +32,11 @@ else:
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-12-17", "o1-preview", "o1-preview-2024-09-12",
"gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
"gemini-1.5-pro", "chatglm3", "chatglm4",
"deepseek-chat", "deepseek-coder", "deepseek-reasoner"
"gemini-1.5-pro", "chatglm3"
]
EMBEDDING_MODEL = "text-embedding-3-small"
@@ -54,7 +47,7 @@ EMBEDDING_MODEL = "text-embedding-3-small"
# "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-flash",
# "qianfan", "deepseekcoder",
# "spark", "sparkv2", "sparkv3", "sparkv3.5", "sparkv4",
# "qwen-turbo", "qwen-plus", "qwen-local",
# "qwen-turbo", "qwen-plus", "qwen-max", "qwen-local",
# "moonshot-v1-128k", "moonshot-v1-32k", "moonshot-v1-8k",
# "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0125", "gpt-4o-2024-05-13"
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
@@ -62,7 +55,6 @@ EMBEDDING_MODEL = "text-embedding-3-small"
# "deepseek-chat" ,"deepseek-coder",
# "gemini-1.5-flash",
# "yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview",
# "grok-beta",
# ]
# --- --- --- ---
# 此外您还可以在接入one-api/vllm/ollama/Openroute时
@@ -89,30 +81,6 @@ DEFAULT_WORKER_NUM = 3
THEME = "Default"
AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gstaff/Xkcd", "NoCrypt/Miku"]
FONT = "Theme-Default-Font"
AVAIL_FONTS = [
"默认值(Theme-Default-Font)",
"宋体(SimSun)",
"黑体(SimHei)",
"楷体(KaiTi)",
"仿宋(FangSong)",
"华文细黑(STHeiti Light)",
"华文楷体(STKaiti)",
"华文仿宋(STFangsong)",
"华文宋体(STSong)",
"华文中宋(STZhongsong)",
"华文新魏(STXinwei)",
"华文隶书(STLiti)",
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
"平方萌萌哒(PING FANG MENG MNEG DA@https://chinese-fonts-cdn.deno.dev/packages/pfmmd/dist/平方萌萌哒/result.css)",
"Helvetica",
"ui-sans-serif",
"sans-serif",
"system-ui"
]
# 默认的系统提示词system prompt
INIT_SYS_PROMPT = "Serve me as a writing and programming assistant."
@@ -164,15 +132,16 @@ MULTI_QUERY_LLM_MODELS = "gpt-3.5-turbo&chatglm3"
QWEN_LOCAL_MODEL_SELECTION = "Qwen/Qwen-1_8B-Chat-Int8"
# 接入通义千问在线大模型 https://dashscope.console.aliyun.com/
DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY
# 百度千帆LLM_MODEL="qianfan"
BAIDU_CLOUD_API_KEY = ''
BAIDU_CLOUD_SECRET_KEY = ''
BAIDU_CLOUD_QIANFAN_MODEL = 'ERNIE-Bot' # 可选 "ERNIE-Bot-4"(文心大模型4.0), "ERNIE-Bot"(文心一言), "ERNIE-Bot-turbo", "BLOOMZ-7B", "Llama-2-70B-Chat", "Llama-2-13B-Chat", "Llama-2-7B-Chat", "ERNIE-Speed-128K", "ERNIE-Speed-8K", "ERNIE-Lite-8K"
# 如果使用ChatGLM3或ChatGLM4本地模型请把 LLM_MODEL="chatglm3" 或LLM_MODEL="chatglm4",并在此处指定模型路径
CHATGLM_LOCAL_MODEL_PATH = "THUDM/glm-4-9b-chat" # 例如"/home/hmp/ChatGLM3-6B/"
# 如果使用ChatGLM2微调模型请把 LLM_MODEL="chatglmft",并在此处指定模型路径
CHATGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b-pt-128-1e-2/checkpoint-100"
@@ -266,11 +235,13 @@ MOONSHOT_API_KEY = ""
YIMODEL_API_KEY = ""
# 深度求索(DeepSeek) API KEY默认请求地址为"https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = ""
# 紫东太初大模型 https://ai-maas.wair.ac.cn
TAICHU_API_KEY = ""
# Grok API KEY
GROK_API_KEY = ""
# Mathpix 拥有执行PDF的OCR功能但是需要注册账号
MATHPIX_APPID = ""
@@ -302,8 +273,8 @@ GROBID_URLS = [
]
# Searxng互联网检索服务这是一个huggingface空间请前往huggingface复制该空间然后把自己新的空间地址填在这里
SEARXNG_URLS = [ f"https://kaletianlre-beardvs{i}dd.hf.space/" for i in range(1,5) ]
# Searxng互联网检索服务
SEARXNG_URL = "https://cloud-1.agent-matrix.com/"
# 是否允许通过自然语言描述修改本页的配置,该功能具有一定的危险性,默认关闭
@@ -327,7 +298,7 @@ ARXIV_CACHE_DIR = "gpt_log/arxiv_cache"
# 除了连接OpenAI之外还有哪些场合允许使用代理请尽量不要修改
WHEN_TO_USE_PROXY = ["Connect_OpenAI", "Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]
@@ -339,10 +310,6 @@ PLUGIN_HOT_RELOAD = False
NUM_CUSTOM_BASIC_BTN = 4
# 媒体智能体的服务地址这是一个huggingface空间请前往huggingface复制该空间然后把自己新的空间地址填在这里
DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in range(1,5) ]
"""
--------------- 配置关联关系说明 ---------------
@@ -402,7 +369,6 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
本地大模型示意图
├── "chatglm4"
├── "chatglm3"
├── "chatglm"
├── "chatglm_onnx"
@@ -433,7 +399,7 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
插件在线服务配置依赖关系示意图
├── 互联网检索
│ └── SEARXNG_URLS
│ └── SEARXNG_URL
├── 语音功能
│ ├── ENABLE_AUDIO

View File

@@ -2,6 +2,7 @@ from toolbox import HotReload # HotReload 的意思是热更新,修改函数
from toolbox import trimmed_format_exc
from loguru import logger
def get_crazy_functions():
from crazy_functions.读文章写摘要 import 读文章写摘要
from crazy_functions.生成函数注释 import 批量生成函数注释
@@ -14,13 +15,13 @@ def get_crazy_functions():
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
from crazy_functions.Arxiv_论文对话 import Arxiv论文对话
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
from crazy_functions.高级功能函数模板 import Demo_Wrap
from crazy_functions.Latex_Project_Polish import Latex英文润色
from crazy_functions.Latex全文润色 import Latex英文润色
from crazy_functions.询问多个大语言模型 import 同时问询
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
from crazy_functions.总结word文档 import 总结word文档
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
from crazy_functions.Conversation_To_File import 载入对话历史存档
from crazy_functions.Conversation_To_File import 对话历史存档
@@ -30,10 +31,12 @@ def get_crazy_functions():
from crazy_functions.Markdown_Translate import Markdown英译中
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
from crazy_functions.PDF_Translate import 批量翻译PDF文档
from crazy_functions.批量文件询问 import 批量文件询问
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
from crazy_functions.Latex_Project_Polish import Latex中文润色
from crazy_functions.Latex_Project_Polish import Latex英文纠错
from crazy_functions.Latex全文润色 import Latex中文润色
from crazy_functions.Latex全文润色 import Latex英文纠错
from crazy_functions.Markdown_Translate import Markdown中译英
from crazy_functions.虚空终端 import 虚空终端
from crazy_functions.生成多种Mermaid图表 import Mermaid_Gen
@@ -49,16 +52,8 @@ def get_crazy_functions():
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
from crazy_functions.SourceCode_Comment import 注释Python项目
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
from crazy_functions.VideoResource_GPT import 多媒体任务
function_plugins = {
"多媒体智能体": {
"Group": "智能体",
"Color": "stop",
"AsButton": False,
"Info": "【仅测试】多媒体任务",
"Function": HotReload(多媒体任务),
},
"虚空终端": {
"Group": "对话|编程|学术|智能体",
"Color": "stop",
@@ -66,34 +61,6 @@ def get_crazy_functions():
"Info": "使用自然语言实现您的想法",
"Function": HotReload(虚空终端),
},
"解析整个Python项目": {
"Group": "编程",
"Color": "stop",
"AsButton": True,
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
"Function": HotReload(解析一个Python项目),
},
"注释Python项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False,
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
"Function": HotReload(注释Python项目),
"Class": SourceCodeComment_Wrap,
},
"载入对话历史存档(先上传存档或输入路径)": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info": "载入对话历史存档 | 输入参数为路径",
"Function": HotReload(载入对话历史存档),
},
"删除所有本地对话历史记录(谨慎操作)": {
"Group": "对话",
"AsButton": False,
"Info": "删除所有本地对话历史记录,谨慎操作 | 不需要输入参数",
"Function": HotReload(删除所有本地对话历史记录),
},
"清除所有缓存文件(谨慎操作)": {
"Group": "对话",
"Color": "stop",
@@ -101,14 +68,6 @@ def get_crazy_functions():
"Info": "清除所有缓存文件,谨慎操作 | 不需要输入参数",
"Function": HotReload(清除缓存),
},
"生成多种Mermaid图表(从当前对话或路径(.pdf/.md/.docx)中生产图表)": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info" : "基于当前对话或文件生成多种Mermaid图表,图表类型由模型判断",
"Function": None,
"Class": Mermaid_Gen
},
"Arxiv论文翻译": {
"Group": "学术",
"Color": "stop",
@@ -117,91 +76,25 @@ def get_crazy_functions():
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
"批量总结Word文档": {
"批量文件询问": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"Info": "批量总结word文档 | 输入参数为路径",
"Function": HotReload(总结word文档),
"AdvancedArgs": True,
"Info": "通过在高级参数区写入prompt可自定义询问逻辑默认情况下为总结逻辑 | 输入参数为路径",
"ArgsReminder": r"1、请不要更改上方输入框中以“private_upload/...”开头的路径。 "
r"2、请在下方高级参数区中输入你的prompt文档中的内容将被添加你的prompt后。3、示例“请总结下面的内容此时文档内容将添加在“”后 ",
"Function": HotReload(批量文件询问),
},
"解析整个Matlab项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False,
"Info": "解析一个Matlab项目的所有源文件(.m) | 输入参数为路径",
"Function": HotReload(解析一个Matlab项目),
},
"解析整个C++项目头文件": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个C++项目的所有头文件(.h/.hpp) | 输入参数为路径",
"Function": HotReload(解析一个C项目的头文件),
},
"解析整个C++项目(.cpp/.hpp/.c/.h": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个C++项目的所有源文件(.cpp/.hpp/.c/.h| 输入参数为路径",
"Function": HotReload(解析一个C项目),
},
"解析整个Go项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个Go项目的所有源文件 | 输入参数为路径",
"Function": HotReload(解析一个Golang项目),
},
"解析整个Rust项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个Rust项目的所有源文件 | 输入参数为路径",
"Function": HotReload(解析一个Rust项目),
},
"解析整个Java项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个Java项目的所有源文件 | 输入参数为路径",
"Function": HotReload(解析一个Java项目),
},
"解析整个前端项目js,ts,css等": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个前端项目的所有源文件js,ts,css等 | 输入参数为路径",
"Function": HotReload(解析一个前端项目),
},
"解析整个Lua项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个Lua项目的所有源文件 | 输入参数为路径",
"Function": HotReload(解析一个Lua项目),
},
"解析整个CSharp项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "解析一个CSharp项目的所有源文件 | 输入参数为路径",
"Function": HotReload(解析一个CSharp项目),
},
"解析Jupyter Notebook文件": {
"Group": "编程",
"Color": "stop",
"AsButton": False,
"Info": "解析Jupyter Notebook文件 | 输入参数为路径",
"Function": HotReload(解析ipynb文件),
"AdvancedArgs": True, # 调用时唤起高级参数输入区默认False
"ArgsReminder": "若输入0则不解析notebook中的Markdown块", # 高级参数输入区的显示提示
},
"读Tex论文写摘要": {
"Arxiv论文对话": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"Info": "读取Tex论文并写摘要 | 输入参数为路径",
"Function": HotReload(读文章写摘要),
"AdvancedArgs": True,
"Info": "在输入区中输入论文ID在高级参数区中输入问题",
"ArgsReminder": r"1、请在输入区中输入arxiv ID。 "
r"2、请在下方高级参数区中输入你的问题示例“这篇文章的方法是什么请用中文回答我” ",
"Function": HotReload(Arxiv论文对话),
},
"翻译README或MD": {
"Group": "编程",
@@ -727,6 +620,12 @@ def get_crazy_functions():
logger.error("Load function plugin failed")
# try:
# from crazy_functions.高级功能函数模板 import 测试图表渲染
# function_plugins.update({
@@ -741,6 +640,19 @@ def get_crazy_functions():
# logger.error(trimmed_format_exc())
# print('Load function plugin failed')
# try:
# from crazy_functions.chatglm微调工具 import 微调数据集生成
# function_plugins.update({
# "黑盒模型学习: 微调数据集生成 (先上传数据集)": {
# "Color": "stop",
# "AsButton": False,
# "AdvancedArgs": True,
# "ArgsReminder": "针对数据集输入(如 绿帽子*深蓝色衬衫*黑色运动裤)给出指令,例如您可以将以下命令复制到下方: --llm_to_learn=azure-gpt-3.5 --prompt_prefix='根据下面的服装类型提示想象一个穿着者对这个人外貌、身处的环境、内心世界、过去经历进行描写。要求100字以内用第二人称。' --system_prompt=''",
# "Function": HotReload(微调数据集生成)
# }
# })
# except:
# print('Load function plugin failed')
"""
设置默认值:
@@ -760,23 +672,3 @@ def get_crazy_functions():
function_plugins[name]["Color"] = "secondary"
return function_plugins
def get_multiplex_button_functions():
"""多路复用主提交按钮的功能映射
"""
return {
"常规对话":
"",
"多模型对话":
"询问多个GPT模型", # 映射到上面的 `询问多个GPT模型` 插件
"智能召回 RAG":
"Rag智能召回", # 映射到上面的 `Rag智能召回` 插件
"多媒体查询":
"多媒体智能体", # 映射到上面的 `多媒体智能体` 插件
}

View File

@@ -0,0 +1,573 @@
import asyncio
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock as ThreadLock
from typing import Generator
from typing import List, Dict, Optional
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
# 全局常量配置
MAX_HISTORY_ROUND = 5 # 最大历史对话轮数
MAX_CONTEXT_TOKEN_LIMIT = 4096 # 上下文最大token数
REMEMBER_PREVIEW = 1000 # 记忆预览长度
VECTOR_STORE_TYPE = "Simple" # 向量存储类型Simple或Milvus
MAX_CONCURRENT_PAPERS = 20 # 最大并行处理论文数
MAX_WORKERS = 3 # 最大工作线程数
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
@dataclass
class ProcessingTask:
"""论文处理任务数据类"""
arxiv_id: str
status: str = "pending" # pending, processing, completed, failed
error: Optional[str] = None
fragments: List[Fragment] = None
start_time: float = field(default_factory=time.time)
class ArxivRagWorker:
def __init__(self, user_name: str, llm_kwargs: Dict, arxiv_id: str = None):
"""初始化ArxivRagWorker"""
self.user_name = user_name
self.llm_kwargs = llm_kwargs
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
self.fragments = None
# 初始化基础目录
self.base_dir = Path(get_log_folder( plugin_name='arxiv_rag_cache'))
self._setup_directories()
# 初始化处理状态
# 线程安全的计数器和集合
self._processing_lock = ThreadLock()
self._processed_fragments = set()
self._processed_count = 0
# 优化的线程池配置
cpu_count = os.cpu_count() or 1
self.thread_pool = ThreadPoolExecutor(
max_workers=min(32, cpu_count * 4),
thread_name_prefix="arxiv_worker"
)
# 批处理配置
self._batch_size = min(20, cpu_count * 2) # 动态设置批大小
self.max_concurrent_papers = MAX_CONCURRENT_PAPERS
self._semaphore = None
self._loop = None
# 初始化处理队列
self.processing_queue = {}
# 初始化工作组件
self._init_workers()
def _setup_directories(self):
"""设置工作目录"""
if self.arxiv_id:
self.checkpoint_dir = self.base_dir / self.arxiv_id
self.vector_store_dir = self.checkpoint_dir / "vector_store"
self.fragment_store_dir = self.checkpoint_dir / "fragments"
else:
self.checkpoint_dir = self.base_dir
self.vector_store_dir = self.base_dir / "vector_store"
self.fragment_store_dir = self.base_dir / "fragments"
self.paper_path = self.checkpoint_dir / f"{self.arxiv_id}.processed"
self.loading = self.paper_path.exists()
# 创建必要的目录
for directory in [self.checkpoint_dir, self.vector_store_dir, self.fragment_store_dir]:
directory.mkdir(parents=True, exist_ok=True)
logger.info(f"Created directory: {directory}")
def _init_workers(self):
"""初始化工作组件"""
try:
self.rag_worker = LlamaIndexRagWorker(
user_name=self.user_name,
llm_kwargs=self.llm_kwargs,
checkpoint_dir=str(self.vector_store_dir),
auto_load_checkpoint=True
)
self.arxiv_splitter = ArxivSplitter(
root_dir=str(self.checkpoint_dir / "arxiv_cache")
)
except Exception as e:
logger.error(f"Error initializing workers: {str(e)}")
raise
def _ensure_loop(self):
"""确保存在事件循环"""
if threading.current_thread() is threading.main_thread():
if self._loop is None:
self._loop = asyncio.get_event_loop()
else:
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
return self._loop
@property
def semaphore(self):
"""延迟创建semaphore"""
if self._semaphore is None:
self._semaphore = asyncio.Semaphore(self.max_concurrent_papers)
return self._semaphore
async def _process_fragments(self, fragments: List[Fragment]) -> None:
"""优化的并行处理论文片段"""
if not fragments:
logger.warning("No fragments to process")
return
start_time = time.time()
total_fragments = len(fragments)
try:
# 1. 处理论文概述
overview = self._create_overview(fragments[0])
overview_success = self._safe_add_to_vector_store_sync(overview['text'])
if not overview_success:
raise RuntimeError("Failed to add overview to vector store")
# 2. 并行处理片段
successful_fragments = await self._parallel_process_fragments(fragments)
# 3. 保存处理结果
if successful_fragments > 0:
await self._save_results(fragments, overview['arxiv_id'], successful_fragments)
except Exception as e:
logger.error(f"Error in fragment processing: {str(e)}")
raise
finally:
self._log_processing_stats(start_time, total_fragments)
def _create_overview(self, first_fragment: Fragment) -> Dict:
"""创建论文概述"""
return {
'arxiv_id': first_fragment.arxiv_id,
'text': (
f"Paper Title: {first_fragment.title}\n"
f"ArXiv ID: {first_fragment.arxiv_id}\n"
f"Abstract: {first_fragment.abstract}\n"
f"Table of contents:{first_fragment.catalogs}\n"
f"Type: OVERVIEW"
)
}
async def _parallel_process_fragments(self, fragments: List[Fragment]) -> int:
"""并行处理所有片段"""
successful_count = 0
loop = self._ensure_loop()
for i in range(0, len(fragments), self._batch_size):
batch = fragments[i:i + self._batch_size]
batch_futures = []
for j, fragment in enumerate(batch):
if not self._is_fragment_processed(fragment, i + j):
future = loop.run_in_executor(
self.thread_pool,
self._process_single_fragment_sync,
fragment,
i + j
)
batch_futures.append(future)
if batch_futures:
results = await asyncio.gather(*batch_futures, return_exceptions=True)
successful_count += sum(1 for r in results if isinstance(r, bool) and r)
return successful_count
def _is_fragment_processed(self, fragment: Fragment, index: int) -> bool:
"""检查片段是否已处理"""
fragment_id = f"{fragment.arxiv_id}_{index}"
with self._processing_lock:
return fragment_id in self._processed_fragments
def _safe_add_to_vector_store_sync(self, text: str) -> bool:
"""线程安全的向量存储添加"""
with self._processing_lock:
try:
self.rag_worker.add_text_to_vector_store(text)
return True
except Exception as e:
logger.error(f"Error adding to vector store: {str(e)}")
return False
def _process_single_fragment_sync(self, fragment: Fragment, index: int) -> bool:
"""处理单个片段"""
fragment_id = f"{fragment.arxiv_id}_{index}"
try:
text = self._build_fragment_text(fragment)
if self._safe_add_to_vector_store_sync(text):
with self._processing_lock:
self._processed_fragments.add(fragment_id)
self._processed_count += 1
logger.info(f"Successfully processed fragment {index}")
return True
return False
except Exception as e:
logger.error(f"Error processing fragment {index}: {str(e)}")
return False
def _build_fragment_text(self, fragment: Fragment) -> str:
"""构建片段文本"""
return "".join([
f"Paper Title: {fragment.title}\n",
f"Section: {fragment.current_section}\n",
f"Content: {fragment.content}\n",
f"Bibliography: {fragment.bibliography}\n",
"Type: FRAGMENT"
])
async def _save_results(self, fragments: List[Fragment], arxiv_id: str, successful_count: int) -> None:
"""保存处理结果"""
if successful_count > 0:
loop = self._ensure_loop()
await loop.run_in_executor(
self.thread_pool,
save_fragments_to_file,
fragments,
str(self.fragment_store_dir / f"{arxiv_id}_fragments.json")
)
def _log_processing_stats(self, start_time: float, total_fragments: int) -> None:
"""记录处理统计信息"""
elapsed_time = time.time() - start_time
processing_rate = total_fragments / elapsed_time if elapsed_time > 0 else 0
logger.info(
f"Processed {self._processed_count}/{total_fragments} fragments "
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
)
async def process_paper(self, fragments: List[Fragment]) -> bool:
"""处理论文主函数"""
try:
if self.paper_path.exists():
logger.info(f"Paper {self.arxiv_id} already processed")
return True
task = self._create_processing_task(self.arxiv_id)
try:
async with self.semaphore:
await self._process_fragments(fragments)
self._complete_task(task, fragments, self.paper_path)
return True
except Exception as e:
self._fail_task(task, str(e))
raise
except Exception as e:
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
return False
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
"""创建处理任务"""
task = ProcessingTask(arxiv_id=arxiv_id)
with self._processing_lock:
self.processing_queue[arxiv_id] = task
task.status = "processing"
return task
def _complete_task(self, task: ProcessingTask, fragments: List[Fragment], paper_path: Path) -> None:
"""完成任务处理"""
with self._processing_lock:
task.status = "completed"
task.fragments = fragments
paper_path.touch()
logger.info(f"Paper {task.arxiv_id} processed successfully with {self._processed_count} fragments")
def _fail_task(self, task: ProcessingTask, error: str) -> None:
"""任务失败处理"""
with self._processing_lock:
task.status = "failed"
task.error = error
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if not input_str:
return ""
input_str = input_str.strip().lower()
if 'arxiv.org/' in input_str:
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
async def wait_for_paper(self, arxiv_id: str, timeout: float = 300.0) -> bool:
"""等待论文处理完成"""
start_time = time.time()
try:
while True:
with self._processing_lock:
task = self.processing_queue.get(arxiv_id)
if not task:
return False
if task.status == "completed":
return True
if task.status == "failed":
return False
if time.time() - start_time > timeout:
logger.error(f"Processing paper {arxiv_id} timed out")
return False
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error waiting for paper {arxiv_id}: {str(e)}")
return False
def retrieve_and_generate(self, query: str) -> str:
"""检索相关内容并生成提示词"""
try:
nodes = self.rag_worker.retrieve_from_store_with_query(query)
return self.rag_worker.build_prompt(query=query, nodes=nodes)
except Exception as e:
logger.error(f"Error in retrieve and generate: {str(e)}")
return ""
def remember_qa(self, question: str, answer: str) -> None:
"""记忆问答对"""
try:
self.rag_worker.remember_qa(question, answer)
except Exception as e:
logger.error(f"Error remembering QA: {str(e)}")
async def auto_analyze_paper(self, chatbot: List, history: List, system_prompt: str) -> None:
"""自动分析论文的关键问题"""
key_questions = [
"What is the main research question or problem addressed in this paper?",
"What methods or approaches did the authors use to investigate the problem?",
"What are the key findings or results presented in the paper?",
"How do the findings of this paper contribute to the broader field or topic of study?",
"What are the limitations of this study, and what future research directions do the authors suggest?"
]
results = []
for question in key_questions:
try:
prompt = self.retrieve_and_generate(question)
if prompt:
response = await request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=question,
llm_kwargs=self.llm_kwargs,
chatbot=chatbot,
history=history,
sys_prompt=system_prompt
)
results.append(f"Q: {question}\nA: {response}\n")
self.remember_qa(question, response)
except Exception as e:
logger.error(f"Error in auto analysis: {str(e)}")
# 合并所有结果
summary = "\n\n".join(results)
chatbot[-1] = (chatbot[-1][0], f"论文已成功加载并完成初步分析:\n\n{summary}\n\n您现在可以继续提问更多细节。")
@CatchException
def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, web_port: str) -> Generator:
"""
Arxiv论文对话主函数
Args:
txt: arxiv ID/URL
llm_kwargs: LLM配置参数
plugin_kwargs: 插件配置参数,包含 advanced_arg 字段作为用户询问指令
chatbot: 对话历史
history: 聊天历史
system_prompt: 系统提示词
web_port: Web端口
"""
# 初始化时,提示用户需要 arxiv ID/URL
from toolbox import promote_file_to_downloadzone
if len(history) == 0 and not txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')):
chatbot.append((txt, "请先提供Arxiv论文链接或ID。"))
yield from update_ui(chatbot=chatbot, history=history)
return
user_name = chatbot.get_user()
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
arxiv_id = arxiv_worker.arxiv_id
# 处理新论文的情况
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history)
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
for file in fragment_output_files:
promote_file_to_downloadzone(file, chatbot=chatbot)
chatbot.append(["论文文字内容已保存至下载区,接下来将进行论文编码,请耐心等待三分钟,论文的文字内容为:", formatted_content])
yield from update_ui(chatbot=chatbot, history=history)
try:
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 设置超时时间为5分钟
success = loop.run_until_complete(
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
)
if success:
success = loop.run_until_complete(
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
)
if success:
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
else:
chatbot[-1] = (txt, "论文处理超时,请重试。")
else:
chatbot[-1] = (txt, "论文处理失败请检查论文ID是否正确或稍后重试。")
except asyncio.TimeoutError:
chatbot[-1] = (txt, "论文处理超时,请重试。")
success = False
finally:
loop.close()
if not success:
yield from update_ui(chatbot=chatbot, history=history)
return
except Exception as e:
logger.error(f"Error in main process: {str(e)}")
chatbot[-1] = (txt, f"处理过程中发生错误: {str(e)}")
yield from update_ui(chatbot=chatbot, history=history)
return
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理用户询问的情况
# 获取用户询问指令
user_query = plugin_kwargs.get("advanced_arg",
"What is the main research question or problem addressed in this paper?")
if len(history)<2:
fragments, formatted_content, fragment_output_files = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_id)
for file in fragment_output_files:
promote_file_to_downloadzone(file, chatbot=chatbot)
chatbot.append(["论文文字内容已保存至下载区,论文的文字内容为:", formatted_content])
yield from update_ui(chatbot=chatbot, history=history)
if not user_query:
user_query = "What is the main research question or problem addressed in this paper?"
# chatbot.append((txt, "请提供您的问题。"))
# yield from update_ui(chatbot=chatbot, history=history)
# return
# 处理历史对话长度
if len(history) > MAX_HISTORY_ROUND * 2:
history = history[-(MAX_HISTORY_ROUND * 2):]
# 处理询问指令
query_clip, history, flags = input_clipping(
user_query,
history,
max_token_limit=MAX_CONTEXT_TOKEN_LIMIT,
return_clip_flags=True
)
if flags["original_input_len"] != flags["clipped_input_len"]:
yield from update_ui_lastest_msg('检测到长输入,正在处理...', chatbot, history, delay=0)
if len(user_query) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW // 2
query_to_remember = user_query[
:HALF] + f" ...\n...(省略{len(user_query) - REMEMBER_PREVIEW}字)...\n... " + user_query[
-HALF:]
else:
query_to_remember = query_clip
else:
query_to_remember = query_clip
chatbot.append((user_query, "正在思考中..."))
yield from update_ui(chatbot=chatbot, history=history)
# 生成提示词
prompt = arxiv_worker.retrieve_and_generate(query_clip)
if not prompt:
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
yield from update_ui(chatbot=chatbot, history=history)
return
# 获取回答
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt,
inputs_show_user=query_clip,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=history,
sys_prompt=system_prompt
)
# 记忆问答对
# worker.remember_qa(query_to_remember, response)
history.extend([user_query, response])
yield from update_ui(chatbot=chatbot, history=history)
if __name__ == "__main__":
# 测试代码
llm_kwargs = {
'api_key': os.getenv("one_api_key"),
'client_ip': '127.0.0.1',
'embed_model': 'text-embedding-3-small',
'llm_model': 'one-api-Qwen2.5-72B-Instruct',
'max_length': 4096,
'most_recent_uploaded': None,
'temperature': 1,
'top_p': 1
}
plugin_kwargs = {}
chatbot = []
history = []
system_prompt = "You are a helpful assistant."
web_port = "8080"
# 测试论文导入
arxiv_url = "https://arxiv.org/abs/2312.12345"
for response in Arxiv论文对话(
arxiv_url, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, web_port
):
print(response)
# 测试问答
question = "这篇论文的主要贡献是什么?"
for response in Arxiv论文对话(
question, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, web_port
):
print(response)

View File

@@ -152,8 +152,6 @@ class Conversation_To_File_Wrap(GptAcademicPluginTemplate):
def hide_cwd(str):
import os
current_path = os.getcwd()
@@ -172,7 +170,7 @@ def 载入对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
user_request 当前用户的请求信息IP地址等
"""
from crazy_functions.crazy_utils import get_files_from_everything
success, file_manifest, _ = get_files_from_everything(txt, type='.html',chatbot=chatbot)
success, file_manifest, _ = get_files_from_everything(txt, type='.html')
if not success:
if txt == "": txt = '空空如也的输入栏'

View File

@@ -7,7 +7,7 @@ from bs4 import BeautifulSoup
from functools import lru_cache
from itertools import zip_longest
from check_proxy import check_proxy
from toolbox import CatchException, update_ui, get_conf, update_ui_lastest_msg
from toolbox import CatchException, update_ui, get_conf
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
from request_llms.bridge_all import model_info
from request_llms.bridge_all import predict_no_ui_long_connection
@@ -115,8 +115,7 @@ def get_auth_ip():
def searxng_request(query, proxies, categories='general', searxng_url=None, engines=None):
if searxng_url is None:
urls = get_conf("SEARXNG_URLS")
url = random.choice(urls)
url = get_conf("SEARXNG_URL")
else:
url = searxng_url
@@ -193,38 +192,6 @@ def scrape_text(url, proxies) -> str:
text = "\n".join(chunk for chunk in chunks if chunk)
return text
def internet_search_with_analysis_prompt(prompt, analysis_prompt, llm_kwargs, chatbot):
from toolbox import get_conf
proxies = get_conf('proxies')
categories = 'general'
searxng_url = None # 使用默认的searxng_url
engines = None # 使用默认的搜索引擎
yield from update_ui_lastest_msg(lastmsg=f"检索中: {prompt} ...", chatbot=chatbot, history=[], delay=1)
urls = searxng_request(prompt, proxies, categories, searxng_url, engines=engines)
yield from update_ui_lastest_msg(lastmsg=f"依次访问搜索到的网站 ...", chatbot=chatbot, history=[], delay=1)
if len(urls) == 0:
return None
max_search_result = 5 # 最多收纳多少个网页的结果
history = []
for index, url in enumerate(urls[:max_search_result]):
yield from update_ui_lastest_msg(lastmsg=f"依次访问搜索到的网站: {url['link']} ...", chatbot=chatbot, history=[], delay=1)
res = scrape_text(url['link'], proxies)
prefix = f"{index}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
history.extend([prefix, res])
i_say = f"从以上搜索结果中抽取信息,然后回答问题:{prompt} {analysis_prompt}"
i_say, history = input_clipping( # 裁剪输入从最长的条目开始裁剪防止爆token
inputs=i_say,
history=history,
max_token_limit=8192
)
gpt_say = predict_no_ui_long_connection(
inputs=i_say,
llm_kwargs=llm_kwargs,
history=history,
sys_prompt="请从搜索结果中抽取信息,对最相关的两个搜索结果进行总结,然后回答问题。",
console_slience=False,
)
return gpt_say
@CatchException
def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):

View File

@@ -1,4 +1,4 @@
import random
from toolbox import get_conf
from crazy_functions.Internet_GPT import 连接网络回答问题
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
@@ -20,9 +20,6 @@ class NetworkGPT_Wrap(GptAcademicPluginTemplate):
第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options``default_value`为下拉菜单默认值;
"""
urls = get_conf("SEARXNG_URLS")
url = random.choice(urls)
gui_definition = {
"main_input":
ArgProperty(title="输入问题", description="待通过互联网检索的问题,会自动读取输入框内容", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
@@ -33,7 +30,7 @@ class NetworkGPT_Wrap(GptAcademicPluginTemplate):
"optimizer":
ArgProperty(title="搜索优化", options=["关闭", "开启", "开启(增强)"], default_value="关闭", description="是否使用搜索增强。注意这可能会消耗较多token", type="dropdown").model_dump_json(),
"searxng_url":
ArgProperty(title="Searxng服务地址", description="输入Searxng的地址", default_value=url, type="string").model_dump_json(), # 主输入,自动从输入框同步
ArgProperty(title="Searxng服务地址", description="输入Searxng的地址", default_value=get_conf("SEARXNG_URL"), type="string").model_dump_json(), # 主输入,自动从输入框同步
}
return gui_definition

View File

@@ -559,7 +559,7 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
project_folder = move_project(project_folder)
# <-------------- set a hash tag for repeat-checking ------------->
with open(pj(project_folder, hash_tag + '.tag'), 'w', encoding='utf8') as f:
with open(pj(project_folder, hash_tag + '.tag'), 'w') as f:
f.write(hash_tag)
f.close()

View File

@@ -1,4 +1,3 @@
from shared_utils.fastapi_server import validate_path_safety
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
from loguru import logger
@@ -156,7 +155,6 @@ def Latex英文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
import glob, os
if os.path.exists(txt):
project_folder = txt
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
@@ -195,7 +193,6 @@ def Latex中文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
import glob, os
if os.path.exists(txt):
project_folder = txt
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
@@ -232,7 +229,6 @@ def Latex英文纠错(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
import glob, os
if os.path.exists(txt):
project_folder = txt
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")

View File

@@ -1,6 +1,5 @@
import glob, shutil, os, re
from loguru import logger
from shared_utils.fastapi_server import validate_path_safety
from toolbox import update_ui, trimmed_format_exc, gen_time_str
from toolbox import CatchException, report_exception, get_log_folder
from toolbox import write_history_to_file, promote_file_to_downloadzone
@@ -119,7 +118,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
def get_files_from_everything(txt, preference='', chatbox=None):
def get_files_from_everything(txt, preference=''):
if txt == "": return False, None, None
success = True
if txt.startswith('http'):
@@ -147,11 +146,9 @@ def get_files_from_everything(txt, preference='', chatbox=None):
# 直接给定文件
file_manifest = [txt]
project_folder = os.path.dirname(txt)
validate_path_safety(project_folder, chatbot.get_user())
elif os.path.exists(txt):
# 本地路径,递归搜索
project_folder = txt
validate_path_safety(project_folder, chatbot.get_user())
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.md', recursive=True)]
else:
project_folder = None
@@ -180,7 +177,7 @@ def Markdown英译中(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
return
history = [] # 清空历史,以免输入溢出
success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github", chatbox=chatbot)
success, file_manifest, project_folder = get_files_from_everything(txt, preference="Github")
if not success:
# 什么都没有

View File

@@ -26,7 +26,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
# 清空历史,以免输入溢出
history = []
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot)
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
# 检测输入参数,如没有给定输入参数,直接退出
if (not success) and txt == "": txt = '空空如也的输入栏。提示请先上传文件把PDF文件拖入对话'
@@ -47,7 +47,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
yield from 解析PDF_基于DOC2X(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request)
return
except:
chatbot.append([None, f"DOC2X服务不可用请检查报错详细{trimmed_format_exc_markdown()}"])
chatbot.append([None, f"DOC2X服务不可用现在将执行效果稍差的旧版代码{trimmed_format_exc_markdown()}"])
yield from update_ui(chatbot=chatbot, history=history)
if method == "GROBID":

View File

@@ -5,7 +5,6 @@ from shared_utils.fastapi_server import validate_path_safety
from toolbox import report_exception
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from shared_utils.fastapi_server import validate_path_safety
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
@@ -61,7 +60,6 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
# 1. we retrieve rag worker from global context
user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
@@ -95,6 +93,9 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
return
else:
report_exception(chatbot, history, a=f"上传文件路径错误: {txt}", b="请检查并提供正确路径。")
# 3. Normal Q&A processing
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

View File

@@ -1,204 +0,0 @@
import requests
import random
import time
import re
import json
from bs4 import BeautifulSoup
from functools import lru_cache
from itertools import zip_longest
from check_proxy import check_proxy
from toolbox import CatchException, update_ui, get_conf, promote_file_to_downloadzone, update_ui_lastest_msg, generate_file_link
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
from request_llms.bridge_all import model_info
from request_llms.bridge_all import predict_no_ui_long_connection
from crazy_functions.prompts.internet import SearchOptimizerPrompt, SearchAcademicOptimizerPrompt
from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError
from textwrap import dedent
from loguru import logger
from pydantic import BaseModel, Field
class Query(BaseModel):
search_keyword: str = Field(description="search query for video resource")
class VideoResource(BaseModel):
thought: str = Field(description="analysis of the search results based on the user's query")
title: str = Field(description="title of the video")
author: str = Field(description="author/uploader of the video")
bvid: str = Field(description="unique ID of the video")
another_failsafe_bvid: str = Field(description="provide another bvid, the other one is not working")
def get_video_resource(search_keyword):
from crazy_functions.media_fns.get_media import search_videos
# Search for videos and return the first result
videos = search_videos(
search_keyword
)
# Return the first video if results exist, otherwise return None
return videos
def download_video(bvid, user_name, chatbot, history):
# from experimental_mods.get_bilibili_resource import download_bilibili
from crazy_functions.media_fns.get_media import download_video
# pause a while
tic_time = 8
for i in range(tic_time):
yield from update_ui_lastest_msg(
lastmsg=f"即将下载音频。等待{tic_time-i}秒后自动继续, 点击“停止”键取消此操作。",
chatbot=chatbot, history=[], delay=1)
# download audio
chatbot.append((None, "下载音频, 请稍等...")); yield from update_ui(chatbot=chatbot, history=history)
downloaded_files = yield from download_video(bvid, only_audio=True, user_name=user_name, chatbot=chatbot, history=history)
if len(downloaded_files) == 0:
# failed to download audio
return []
# preview
preview_list = [promote_file_to_downloadzone(fp) for fp in downloaded_files]
file_links = generate_file_link(preview_list)
yield from update_ui_lastest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
chatbot.append((None, f"即将下载视频。"))
# pause a while
tic_time = 16
for i in range(tic_time):
yield from update_ui_lastest_msg(
lastmsg=f"即将下载视频。等待{tic_time-i}秒后自动继续, 点击“停止”键取消此操作。",
chatbot=chatbot, history=[], delay=1)
# download video
chatbot.append((None, "下载视频, 请稍等...")); yield from update_ui(chatbot=chatbot, history=history)
downloaded_files_part2 = yield from download_video(bvid, only_audio=False, user_name=user_name, chatbot=chatbot, history=history)
# preview
preview_list = [promote_file_to_downloadzone(fp) for fp in downloaded_files_part2]
file_links = generate_file_link(preview_list)
yield from update_ui_lastest_msg(f"已完成的文件: <br/>" + file_links, chatbot=chatbot, history=history, delay=0)
# return
return downloaded_files + downloaded_files_part2
class Strategy(BaseModel):
thought: str = Field(description="analysis of the user's wish, for example, can you recall the name of the resource?")
which_methods: str = Field(description="Which method to use to find the necessary information? choose from 'method_1' and 'method_2'.")
method_1_search_keywords: str = Field(description="Generate keywords to search the internet if you choose method 1, otherwise empty.")
method_2_generate_keywords: str = Field(description="Generate keywords for video download engine if you choose method 2, otherwise empty.")
@CatchException
def 多媒体任务(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
user_wish: str = txt
# query demos:
# - "我想找一首歌里面有句歌词是“turn your face towards the sun”"
# - "一首歌,第一句是红豆生南国"
# - "一首音乐,中国航天任务专用的那首"
# - "戴森球计划在熔岩星球的音乐"
# - "hanser的百变什么精"
# - "打大圣残躯时的bgm"
# - "渊下宫战斗音乐"
# 搜索
chatbot.append((txt, "检索中, 请稍等..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
if "跳过联网搜索" not in user_wish:
# 结构化生成
internet_search_keyword = user_wish
yield from update_ui_lastest_msg(lastmsg=f"发起互联网检索: {internet_search_keyword} ...", chatbot=chatbot, history=[], delay=1)
from crazy_functions.Internet_GPT import internet_search_with_analysis_prompt
result = yield from internet_search_with_analysis_prompt(
prompt=internet_search_keyword,
analysis_prompt="请根据搜索结果分析,获取用户需要找的资源的名称、作者、出处等信息。",
llm_kwargs=llm_kwargs,
chatbot=chatbot
)
yield from update_ui_lastest_msg(lastmsg=f"互联网检索结论: {result} \n\n 正在生成进一步检索方案 ...", chatbot=chatbot, history=[], delay=1)
rf_req = dedent(f"""
The user wish to get the following resource:
{user_wish}
Meanwhile, you can access another expert's opinion on the user's wish:
{result}
Generate search keywords (less than 5 keywords) for video download engine accordingly.
""")
else:
user_wish = user_wish.replace("跳过联网搜索", "").strip()
rf_req = dedent(f"""
The user wish to get the following resource:
{user_wish}
Generate reseach keywords (less than 5 keywords) accordingly.
""")
gpt_json_io = GptJsonIO(Query)
inputs = rf_req + gpt_json_io.format_instructions
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
analyze_res = run_gpt_fn(inputs, "")
logger.info(analyze_res)
query: Query = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn)
video_engine_keywords = query.search_keyword
# 关键词展示
chatbot.append((None, f"检索关键词已确认: {video_engine_keywords}。筛选中, 请稍等..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 获取候选资源
candadate_dictionary: dict = get_video_resource(video_engine_keywords)
candadate_dictionary_as_str = json.dumps(candadate_dictionary, ensure_ascii=False, indent=4)
# 展示候选资源
candadate_display = "\n".join([f"{i+1}. {it['title']}" for i, it in enumerate(candadate_dictionary)])
chatbot.append((None, f"候选:\n\n{candadate_display}"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 结构化生成
rf_req_2 = dedent(f"""
The user wish to get the following resource:
{user_wish}
Select the most relevant and suitable video resource from the following search results:
{candadate_dictionary_as_str}
Note:
1. The first several search video results are more likely to satisfy the user's wish.
2. The time duration of the video should be less than 10 minutes.
3. You should analyze the search results first, before giving your answer.
4. Use Chinese if possible.
5. Beside the primary video selection, give a backup video resource `bvid`.
""")
gpt_json_io = GptJsonIO(VideoResource)
inputs = rf_req_2 + gpt_json_io.format_instructions
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
analyze_res = run_gpt_fn(inputs, "")
logger.info(analyze_res)
video_resource: VideoResource = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn)
# Display
chatbot.append(
(None,
f"分析:{video_resource.thought}" "<br/>"
f"选择: `{video_resource.title}`。" "<br/>"
f"作者:{video_resource.author}"
)
)
chatbot.append((None, f"下载中, 请稍等..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
if video_resource and video_resource.bvid:
logger.info(video_resource)
downloaded = yield from download_video(video_resource.bvid, chatbot.get_user(), chatbot, history)
if not downloaded:
chatbot.append((None, f"下载失败, 尝试备选 ..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
downloaded = yield from download_video(video_resource.another_failsafe_bvid, chatbot.get_user(), chatbot, history)
@CatchException
def debug(bvid, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
yield from download_video(bvid, chatbot.get_user(), chatbot, history)

View File

@@ -0,0 +1,141 @@
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
import datetime, json
def fetch_items(list_of_items, batch_size):
for i in range(0, len(list_of_items), batch_size):
yield list_of_items[i:i + batch_size]
def string_to_options(arguments):
import argparse
import shlex
# Create an argparse.ArgumentParser instance
parser = argparse.ArgumentParser()
# Add command-line arguments
parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo")
parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='')
parser.add_argument("--system_prompt", type=str, help="System prompt", default='')
parser.add_argument("--batch", type=int, help="System prompt", default=50)
parser.add_argument("--pre_seq_len", type=int, help="pre_seq_len", default=50)
parser.add_argument("--learning_rate", type=float, help="learning_rate", default=2e-2)
parser.add_argument("--num_gpus", type=int, help="num_gpus", default=1)
parser.add_argument("--json_dataset", type=str, help="json_dataset", default="")
parser.add_argument("--ptuning_directory", type=str, help="ptuning_directory", default="")
# Parse the arguments
args = parser.parse_args(shlex.split(arguments))
return args
@CatchException
def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
llm_kwargs gpt模型参数如温度和top_p等一般原样传递下去就行
plugin_kwargs 插件模型的参数
chatbot 聊天显示框的句柄,用于显示给用户
history 聊天历史,前情提要
system_prompt 给gpt的静默提醒
user_request 当前用户的请求信息IP地址等
"""
history = [] # 清空历史,以免输入溢出
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
args = plugin_kwargs.get("advanced_arg", None)
if args is None:
chatbot.append(("没给定指令", "退出"))
yield from update_ui(chatbot=chatbot, history=history); return
else:
arguments = string_to_options(arguments=args)
dat = []
with open(txt, 'r', encoding='utf8') as f:
for line in f.readlines():
json_dat = json.loads(line)
dat.append(json_dat["content"])
llm_kwargs['llm_model'] = arguments.llm_to_learn
for batch in fetch_items(dat, arguments.batch):
res = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[f"{arguments.prompt_prefix}\n\n{b}" for b in (batch)],
inputs_show_user_array=[f"Show Nothing" for _ in (batch)],
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history_array=[[] for _ in (batch)],
sys_prompt_array=[arguments.system_prompt for _ in (batch)],
max_workers=10 # OpenAI所允许的最大并行过载
)
with open(txt+'.generated.json', 'a+', encoding='utf8') as f:
for b, r in zip(batch, res[1::2]):
f.write(json.dumps({"content":b, "summary":r}, ensure_ascii=False)+'\n')
promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
return
@CatchException
def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
llm_kwargs gpt模型参数如温度和top_p等一般原样传递下去就行
plugin_kwargs 插件模型的参数
chatbot 聊天显示框的句柄,用于显示给用户
history 聊天历史,前情提要
system_prompt 给gpt的静默提醒
user_request 当前用户的请求信息IP地址等
"""
import subprocess
history = [] # 清空历史,以免输入溢出
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
args = plugin_kwargs.get("advanced_arg", None)
if args is None:
chatbot.append(("没给定指令", "退出"))
yield from update_ui(chatbot=chatbot, history=history); return
else:
arguments = string_to_options(arguments=args)
pre_seq_len = arguments.pre_seq_len # 128
learning_rate = arguments.learning_rate # 2e-2
num_gpus = arguments.num_gpus # 1
json_dataset = arguments.json_dataset # 't_code.json'
ptuning_directory = arguments.ptuning_directory # '/home/hmp/ChatGLM2-6B/ptuning'
command = f"torchrun --standalone --nnodes=1 --nproc-per-node={num_gpus} main.py \
--do_train \
--train_file AdvertiseGen/{json_dataset} \
--validation_file AdvertiseGen/{json_dataset} \
--preprocessing_num_workers 20 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path THUDM/chatglm2-6b \
--output_dir output/clothgen-chatglm2-6b-pt-{pre_seq_len}-{learning_rate} \
--overwrite_output_dir \
--max_source_length 256 \
--max_target_length 256 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 100 \
--logging_steps 10 \
--save_steps 20 \
--learning_rate {learning_rate} \
--pre_seq_len {pre_seq_len} \
--quantization_bit 4"
process = subprocess.Popen(command, shell=True, cwd=ptuning_directory)
try:
process.communicate(timeout=3600*24)
except subprocess.TimeoutExpired:
process.kill()
return

View File

@@ -2,7 +2,6 @@ import os
import threading
from loguru import logger
from shared_utils.char_visual_effect import scolling_visual_effect
from shared_utils.fastapi_server import validate_path_safety
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
@@ -170,7 +169,6 @@ def can_multi_process(llm) -> bool:
def default_condition(llm) -> bool:
# legacy condition
if llm.startswith('gpt-'): return True
if llm.startswith('chatgpt-'): return True
if llm.startswith('api2d-'): return True
if llm.startswith('azure-'): return True
if llm.startswith('spark'): return True
@@ -540,7 +538,7 @@ def read_and_clean_pdf_text(fp):
return meta_txt, page_one_meta
def get_files_from_everything(txt, type, chatbot=None): # type='.md'
def get_files_from_everything(txt, type): # type='.md'
"""
这个函数是用来获取指定目录下所有指定类型(如.md的文件并且对于网络上的文件也可以获取它。
下面是对每个参数和返回值的说明:
@@ -552,7 +550,6 @@ def get_files_from_everything(txt, type, chatbot=None): # type='.md'
- file_manifest: 文件路径列表,里面包含以指定类型为后缀名的所有文件的绝对路径。
- project_folder: 字符串,表示文件所在的文件夹路径。如果是网络上的文件,就是临时文件夹的路径。
该函数详细注释已添加,请确认是否满足您的需要。
- chatbot 带Cookies的Chatbot类为实现更多强大的功能做基础
"""
import glob, os
@@ -575,13 +572,9 @@ def get_files_from_everything(txt, type, chatbot=None): # type='.md'
# 直接给定文件
file_manifest = [txt]
project_folder = os.path.dirname(txt)
if chatbot is not None:
validate_path_safety(project_folder, chatbot.get_user())
elif os.path.exists(txt):
# 本地路径,递归搜索
project_folder = txt
if chatbot is not None:
validate_path_safety(project_folder, chatbot.get_user())
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*'+type, recursive=True)]
if len(file_manifest) == 0:
success = False

View File

@@ -0,0 +1,450 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from docx import Document
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.oxml.ns import qn
from docx.shared import Inches, Cm
from docx.shared import Pt, RGBColor, Inches
from typing import Dict, List, Tuple
class DocumentFormatter(ABC):
"""文档格式化基类,定义文档格式化的基本接口"""
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
self.final_summary = final_summary
self.file_summaries_map = file_summaries_map
self.failed_files = failed_files
@abstractmethod
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
pass
@abstractmethod
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
pass
@abstractmethod
def create_document(self) -> str:
"""创建完整文档"""
pass
class WordFormatter(DocumentFormatter):
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.doc = Document()
self._setup_document()
self._create_styles()
# 初始化三级标题编号系统
self.numbers = {
1: 0, # 一级标题编号
2: 0, # 二级标题编号
3: 0 # 三级标题编号
}
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("该文档由GPT-academic生成")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(14)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
style.paragraph_format.first_line_indent = Pt(28)
# 创建各级标题样式
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
"""创建标题样式"""
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
style.font.name = font_name
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
style.font.size = Pt(font_size)
style.font.bold = True
style.paragraph_format.alignment = alignment
style.paragraph_format.space_before = Pt(12)
style.paragraph_format.space_after = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
return style
def _get_heading_number(self, level: int) -> str:
"""
生成标题编号
Args:
level: 标题级别 (0-3)
Returns:
str: 格式化的标题编号
"""
if level == 0: # 主标题不需要编号
return ""
self.numbers[level] += 1 # 增加当前级别的编号
# 重置下级标题编号
for i in range(level + 1, 4):
self.numbers[i] = 0
# 根据级别返回不同格式的编号
if level == 1:
return f"{self.numbers[1]}. "
elif level == 2:
return f"{self.numbers[1]}.{self.numbers[2]} "
elif level == 3:
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
return ""
def _add_heading(self, text: str, level: int):
"""
添加带编号的标题
Args:
text: 标题文本
level: 标题级别 (0-3)
"""
style_map = {
0: 'Title_Custom',
1: 'Heading1_Custom',
2: 'Heading2_Custom',
3: 'Heading3_Custom'
}
number = self._get_heading_number(level)
paragraph = self.doc.add_paragraph(style=style_map[level])
if number:
number_run = paragraph.add_run(number)
font_size = 22 if level == 1 else (18 if level == 2 else 16)
self._get_run_style(number_run, '黑体', font_size, True)
text_run = paragraph.add_run(text)
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
self._get_run_style(text_run, '黑体', font_size, True)
# 主标题添加日期
if level == 0:
date_paragraph = self.doc.add_paragraph()
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d'))
self._get_run_style(date_run, '仿宋', 16, False)
return paragraph
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
"""设置文本运行对象的样式"""
run.font.name = font_name
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
run.font.size = Pt(font_size)
run.font.bold = bold
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
result = []
if not self.failed_files:
return "\n".join(result)
result.append("处理失败文件:")
for fp, reason in self.failed_files:
result.append(f"{os.path.basename(fp)}: {reason}")
self._add_heading("处理失败文件", 1)
for fp, reason in self.failed_files:
self._add_content(f"{os.path.basename(fp)}: {reason}", indent=False)
self.doc.add_paragraph()
return "\n".join(result)
def _add_content(self, text: str, indent: bool = True):
"""添加正文内容"""
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
if not indent:
paragraph.paragraph_format.first_line_indent = Pt(0)
return paragraph
def format_file_summaries(self) -> str:
"""
格式化文件总结内容,确保正确的标题层级
返回:
str: 格式化后的文件总结字符串
标题层级规则:
1. 一级标题为"各文件详细总结"
2. 如果文件有目录路径:
- 目录路径作为二级标题 (2.1, 2.2 等)
- 该目录下所有文件作为三级标题 (2.1.1, 2.1.2 等)
3. 如果文件没有目录路径:
- 文件直接作为二级标题 (2.1, 2.2 等)
"""
result = []
# 首先对文件路径进行分组整理
file_groups = {}
for path in sorted(self.file_summaries_map.keys()):
dir_path = os.path.dirname(path)
if dir_path not in file_groups:
file_groups[dir_path] = []
file_groups[dir_path].append(path)
# 处理没有目录的文件
root_files = file_groups.get("", [])
if root_files:
for path in sorted(root_files):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 无目录的文件作为二级标题
self._add_heading(f"📄 {file_name}", 2)
self._add_content(self.file_summaries_map[path])
self.doc.add_paragraph()
# 处理有目录的文件
for dir_path in sorted(file_groups.keys()):
if dir_path == "": # 跳过已处理的根目录文件
continue
# 添加目录作为二级标题
result.append(f"\n📁 {dir_path}")
self._add_heading(f"📁 {dir_path}", 2)
# 该目录下的所有文件作为三级标题
for path in sorted(file_groups[dir_path]):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 添加文件名作为三级标题
self._add_heading(f"📄 {file_name}", 3)
self._add_content(self.file_summaries_map[path])
self.doc.add_paragraph()
return "\n".join(result)
def create_document(self):
"""创建完整Word文档并返回文档对象"""
# 重置所有编号
for level in self.numbers:
self.numbers[level] = 0
# 添加主标题
self._add_heading("文档总结报告", 0)
self.doc.add_paragraph()
# 添加总体摘要
self._add_heading("总体摘要", 1)
self._add_content(self.final_summary)
self.doc.add_paragraph()
# 添加失败文件列表(如果有)
if self.failed_files:
self.format_failed_files()
# 添加文件详细总结
self._add_heading("各文件详细总结", 1)
self.format_file_summaries()
return self.doc
class MarkdownFormatter(DocumentFormatter):
"""Markdown格式文档生成器"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
formatted_text = ["\n## ⚠️ 处理失败的文件"]
for fp, reason in self.failed_files:
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
formatted_text.append("\n---")
return "\n".join(formatted_text)
def format_file_summaries(self) -> str:
formatted_text = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_text.append(f"\n## 📁 {dir_path}")
current_dir = dir_path
file_name = os.path.basename(path)
formatted_text.append(f"\n### 📄 {file_name}")
formatted_text.append(self.file_summaries_map[path])
formatted_text.append("\n---")
return "\n".join(formatted_text)
def create_document(self) -> str:
document = [
"# 📑 文档总结报告",
"\n## 总体摘要",
self.final_summary
]
if self.failed_files:
document.append(self.format_failed_files())
document.extend([
"\n# 📚 各文件详细总结",
self.format_file_summaries()
])
return "\n".join(document)
class HtmlFormatter(DocumentFormatter):
"""HTML格式文档生成器"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.css_styles = """
body {
font-family: "Microsoft YaHei", Arial, sans-serif;
line-height: 1.6;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
color: #333;
}
h1 {
color: #2c3e50;
border-bottom: 2px solid #eee;
padding-bottom: 10px;
font-size: 24px;
text-align: center;
}
h2 {
color: #34495e;
margin-top: 30px;
font-size: 20px;
border-left: 4px solid #3498db;
padding-left: 10px;
}
h3 {
color: #2c3e50;
font-size: 18px;
margin-top: 20px;
}
.summary {
background-color: #f8f9fa;
padding: 20px;
border-radius: 5px;
margin: 20px 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.details {
margin-top: 40px;
}
.failed-files {
background-color: #fff3f3;
padding: 15px;
border-left: 4px solid #e74c3c;
margin: 20px 0;
}
.file-summary {
background-color: #fff;
padding: 15px;
margin: 15px 0;
border-radius: 4px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
failed_files_html = ['<div class="failed-files">']
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
failed_files_html.append("<ul>")
for fp, reason in self.failed_files:
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
failed_files_html.append("</ul></div>")
return "\n".join(failed_files_html)
def format_file_summaries(self) -> str:
formatted_html = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
current_dir = dir_path
file_name = os.path.basename(path)
formatted_html.append('<div class="file-summary">')
formatted_html.append(f'<h3>📄 {file_name}</h3>')
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
formatted_html.append('</div>')
return "\n".join(formatted_html)
def create_document(self) -> str:
return f"""
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'>
<title>文档总结报告</title>
<style>{self.css_styles}</style>
</head>
<body>
<h1>📑 文档总结报告</h1>
<h2>总体摘要</h2>
<div class="summary">{self.final_summary}</div>
{self.format_failed_files()}
<div class="details">
<h2>📚 各文件详细总结</h2>
{self.format_file_summaries()}
</div>
</body>
</html>
"""

View File

@@ -0,0 +1,387 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
from dataclasses import dataclass
from enum import Enum, auto
import logging
from datetime import datetime
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
# 设置日志
logger = logging.getLogger(__name__)
# 自定义异常类定义
class FoldingError(Exception):
"""折叠相关的自定义异常基类"""
pass
class FormattingError(FoldingError):
"""格式化过程中的错误"""
pass
class MetadataError(FoldingError):
"""元数据相关的错误"""
pass
class ValidationError(FoldingError):
"""验证错误"""
pass
class FoldingStyle(Enum):
"""折叠样式枚举"""
SIMPLE = auto() # 简单折叠
DETAILED = auto() # 详细折叠(带有额外信息)
NESTED = auto() # 嵌套折叠
@dataclass
class FoldingOptions:
"""折叠选项配置"""
style: FoldingStyle = FoldingStyle.DETAILED
code_language: Optional[str] = None # 代码块的语言
show_timestamp: bool = False # 是否显示时间戳
indent_level: int = 0 # 缩进级别
custom_css: Optional[str] = None # 自定义CSS类
T = TypeVar('T') # 用于泛型类型
class BaseMetadata(ABC):
"""元数据基类"""
@abstractmethod
def validate(self) -> bool:
"""验证元数据的有效性"""
pass
def _validate_non_empty_str(self, value: Optional[str]) -> bool:
"""验证字符串非空"""
return bool(value and value.strip())
@dataclass
class FileMetadata(BaseMetadata):
"""文件元数据"""
rel_path: str
size: float
last_modified: Optional[datetime] = None
mime_type: Optional[str] = None
encoding: str = 'utf-8'
def validate(self) -> bool:
"""验证文件元数据的有效性"""
try:
if not self._validate_non_empty_str(self.rel_path):
return False
if self.size < 0:
return False
return True
except Exception as e:
logger.error(f"File metadata validation error: {str(e)}")
return False
class ContentFormatter(ABC, Generic[T]):
"""内容格式化抽象基类
支持泛型类型参数,可以指定具体的元数据类型。
"""
@abstractmethod
def format(self,
content: str,
metadata: T,
options: Optional[FoldingOptions] = None) -> str:
"""格式化内容
Args:
content: 需要格式化的内容
metadata: 类型化的元数据
options: 折叠选项
Returns:
str: 格式化后的内容
Raises:
FormattingError: 格式化过程中的错误
"""
pass
def _create_summary(self, metadata: T) -> str:
"""创建折叠摘要,可被子类重写"""
return str(metadata)
def _format_content_block(self,
content: str,
options: Optional[FoldingOptions]) -> str:
"""格式化内容块,处理代码块等特殊格式"""
if not options:
return content
if options.code_language:
return f"```{options.code_language}\n{content}\n```"
return content
def _add_indent(self, text: str, level: int) -> str:
"""添加缩进"""
if level <= 0:
return text
indent = " " * level
return "\n".join(indent + line for line in text.splitlines())
class FileContentFormatter(ContentFormatter[FileMetadata]):
"""文件内容格式化器"""
def format(self,
content: str,
metadata: FileMetadata,
options: Optional[FoldingOptions] = None) -> str:
"""格式化文件内容"""
if not metadata.validate():
raise MetadataError("Invalid file metadata")
try:
options = options or FoldingOptions()
# 构建摘要信息
summary_parts = [
f"{metadata.rel_path} ({metadata.size:.2f}MB)",
f"Type: {metadata.mime_type}" if metadata.mime_type else None,
(f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
if metadata.last_modified and options.show_timestamp else None)
]
summary = " | ".join(filter(None, summary_parts))
# 构建HTML类
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
# 格式化内容
formatted_content = self._format_content_block(content, options)
# 组装最终结果
result = (
f'<details{css_class}><summary>{summary}</summary>\n\n'
f'{formatted_content}\n\n'
f'</details>\n\n'
)
return self._add_indent(result, options.indent_level)
except Exception as e:
logger.error(f"Error formatting file content: {str(e)}")
raise FormattingError(f"Failed to format file content: {str(e)}")
class ContentFoldingManager:
"""内容折叠管理器"""
def __init__(self):
"""初始化折叠管理器"""
self._formatters: Dict[str, ContentFormatter] = {}
self._register_default_formatters()
def _register_default_formatters(self) -> None:
"""注册默认的格式化器"""
self.register_formatter('file', FileContentFormatter())
def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
"""注册新的格式化器"""
if not isinstance(formatter, ContentFormatter):
raise TypeError("Formatter must implement ContentFormatter interface")
self._formatters[name] = formatter
def _guess_language(self, extension: str) -> Optional[str]:
"""根据文件扩展名猜测编程语言"""
extension = extension.lower().lstrip('.')
language_map = {
'py': 'python',
'js': 'javascript',
'java': 'java',
'cpp': 'cpp',
'cs': 'csharp',
'html': 'html',
'css': 'css',
'md': 'markdown',
'json': 'json',
'xml': 'xml',
'sql': 'sql',
'sh': 'bash',
'yaml': 'yaml',
'yml': 'yaml',
'txt': None # 纯文本不需要语言标识
}
return language_map.get(extension)
def format_content(self,
content: str,
formatter_type: str,
metadata: Union[FileMetadata],
options: Optional[FoldingOptions] = None) -> str:
"""格式化内容"""
formatter = self._formatters.get(formatter_type)
if not formatter:
raise KeyError(f"No formatter registered for type: {formatter_type}")
if not isinstance(metadata, FileMetadata):
raise TypeError("Invalid metadata type")
return formatter.format(content, metadata, options)
@dataclass
class PaperMetadata(BaseMetadata):
"""论文元数据"""
title: str
authors: str
abstract: str
catalogs: str
arxiv_id: str = ""
def validate(self) -> bool:
"""验证论文元数据的有效性"""
try:
if not self._validate_non_empty_str(self.title):
return False
if not self._validate_non_empty_str(self.authors):
return False
if not self._validate_non_empty_str(self.abstract):
return False
if not self._validate_non_empty_str(self.catalogs):
return False
return True
except Exception as e:
logger.error(f"Paper metadata validation error: {str(e)}")
return False
class PaperContentFormatter(ContentFormatter[PaperMetadata]):
"""论文内容格式化器"""
def format(self,
fragments: list[SectionFragment],
metadata: PaperMetadata,
options: Optional[FoldingOptions] = None) -> str:
"""格式化论文内容
Args:
fragments: 论文片段列表
metadata: 论文元数据
options: 折叠选项
Returns:
str: 格式化后的论文内容
"""
if not metadata.validate():
raise MetadataError("Invalid paper metadata")
try:
options = options or FoldingOptions()
# 1. 生成标题部分(不折叠)
result = [f"# {metadata.title}\n"]
# 2. 生成作者信息(折叠)
result.append(self._create_folded_section(
"Authors",
metadata.authors,
options
))
# 3. 生成摘要(折叠)
result.append(self._create_folded_section(
"Abstract",
metadata.abstract,
options
))
# 4. 生成目录树(折叠)
result.append(self._create_folded_section(
"Table of Contents",
f"```\n{metadata.catalogs}\n```",
options
))
# 5. 按章节组织并生成内容
sections = self._organize_sections(fragments)
for section, section_fragments in sections.items():
# 拼接该章节的所有内容
section_content = "\n\n".join(
fragment.content for fragment in section_fragments
)
result.append(self._create_folded_section(
section,
section_content,
options
))
# 6. 生成参考文献(折叠)
# 收集所有非空的参考文献
all_refs = "\n".join(filter(None,
(fragment.bibliography for fragment in fragments)
))
if all_refs:
result.append(self._create_folded_section(
"Bibliography",
f"```bibtex\n{all_refs}\n```",
options
))
return "\n\n".join(result)
except Exception as e:
logger.error(f"Error formatting paper content: {str(e)}")
raise FormattingError(f"Failed to format paper content: {str(e)}")
def _create_folded_section(self,
title: str,
content: str,
options: FoldingOptions) -> str:
"""创建折叠区块
Args:
title: 区块标题
content: 区块内容
options: 折叠选项
Returns:
str: 格式化后的折叠区块
"""
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
result = (
f'<details{css_class}><summary>{title}</summary>\n\n'
f'{content}\n\n'
f'</details>'
)
return self._add_indent(result, options.indent_level)
def _organize_sections(self,
fragments: list[SectionFragment]
) -> Dict[str, list[SectionFragment]]:
"""将片段按章节分组
Args:
fragments: 论文片段列表
Returns:
Dict[str, list[SectionFragment]]: 按章节分组的片段字典
"""
sections: Dict[str, list[SectionFragment]] = {}
for fragment in fragments:
section = fragment.current_section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
return sections

View File

@@ -0,0 +1,354 @@
from pathlib import Path
from typing import List, Dict
from dataclasses import dataclass
from datetime import datetime
import os
import re
@dataclass
class SectionFragment:
"""Arxiv论文片段数据类"""
title: str
authors: str
abstract: str
catalogs: str
arxiv_id: str = ""
current_section: str = "Introduction"
content: str = ''
bibliography: str = ''
class PaperHtmlFormatter:
"""HTML格式论文文档生成器"""
def __init__(self, fragments: List[SectionFragment], output_dir: Path):
self.fragments = fragments
self.output_dir = output_dir
self.css_styles = """
:root {
--primary-color: #1a73e8;
--secondary-color: #34495e;
--background-color: #f8f9fa;
--text-color: #2c3e50;
--border-color: #e0e0e0;
--code-bg-color: #f6f8fa;
}
body {
font-family: "Source Serif Pro", "Times New Roman", serif;
line-height: 1.8;
max-width: 1000px;
margin: 0 auto;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
font-size: 16px;
}
.container {
background: white;
padding: 2rem;
border-radius: 8px;
box-shadow: 0 2px 12px rgba(0,0,0,0.1);
}
h1 {
color: var(--primary-color);
font-size: 2.2em;
text-align: center;
margin: 1.5rem 0;
padding-bottom: 1rem;
border-bottom: 3px solid var(--primary-color);
}
h2 {
color: var(--secondary-color);
font-size: 1.8em;
margin-top: 2rem;
padding-left: 1rem;
border-left: 4px solid var(--primary-color);
}
h3 {
color: var(--text-color);
font-size: 1.5em;
margin-top: 1.5rem;
border-bottom: 2px solid var(--border-color);
padding-bottom: 0.5rem;
}
.authors {
text-align: center;
color: var(--secondary-color);
font-size: 1.1em;
margin: 1rem 0 2rem;
}
.abstract-container {
background: var(--background-color);
padding: 1.5rem;
border-radius: 6px;
margin: 2rem 0;
}
.abstract-title {
font-weight: bold;
color: var(--primary-color);
margin-bottom: 1rem;
}
.abstract-content {
font-style: italic;
line-height: 1.7;
}
.toc {
background: white;
padding: 1.5rem;
border-radius: 6px;
margin: 2rem 0;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
.toc-title {
color: var(--primary-color);
font-size: 1.4em;
margin-bottom: 1rem;
}
.section-content {
background: white;
padding: 1.5rem;
border-radius: 6px;
margin: 1.5rem 0;
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
.fragment {
margin: 2rem 0;
padding-left: 1rem;
border-left: 3px solid var(--border-color);
}
.fragment:hover {
border-left-color: var(--primary-color);
}
.bibliography {
background: var(--code-bg-color);
padding: 1rem;
border-radius: 4px;
font-family: "Source Code Pro", monospace;
font-size: 0.9em;
white-space: pre-wrap;
margin-top: 1rem;
}
pre {
background: var(--code-bg-color);
padding: 1rem;
border-radius: 4px;
overflow-x: auto;
font-family: "Source Code Pro", monospace;
}
.paper-info {
background: white;
padding: 2rem;
border-radius: 8px;
margin: 2rem 0;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.arxiv-id {
text-align: center;
color: #666;
font-size: 0.9em;
margin: 1rem 0;
}
.section-title {
display: flex;
align-items: center;
gap: 0.5rem;
color: var(--secondary-color);
}
.section-icon {
color: var(--primary-color);
}
@media print {
body {
background: white;
}
.container {
box-shadow: none;
}
}
"""
def _sanitize_html(self, text: str) -> str:
"""清理HTML特殊字符"""
if not text:
return ""
replacements = {
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&#39;"
}
for old, new in replacements.items():
text = text.replace(old, new)
return text
def _create_section_id(self, section: str) -> str:
"""创建section的ID"""
section = section.strip() or "uncategorized"
# 移除特殊字符,转换为小写并用连字符替换空格
section_id = re.sub(r'[^\w\s-]', '', section.lower())
return section_id.replace(' ', '-')
def format_paper_info(self) -> str:
"""格式化论文基本信息"""
if not self.fragments:
return ""
first_fragment = self.fragments[0]
paper_info = ['<div class="paper-info">']
# 添加标题
if first_fragment.title:
paper_info.append(f'<h1>{self._sanitize_html(first_fragment.title)}</h1>')
# 添加arXiv ID
if first_fragment.arxiv_id:
paper_info.append(f'<div class="arxiv-id">arXiv: {self._sanitize_html(first_fragment.arxiv_id)}</div>')
# 添加作者
if first_fragment.authors:
paper_info.append(f'<div class="authors">{self._sanitize_html(first_fragment.authors)}</div>')
# 添加摘要
if first_fragment.abstract:
paper_info.append('<div class="abstract-container">')
paper_info.append('<div class="abstract-title">Abstract</div>')
paper_info.append(f'<div class="abstract-content">{self._sanitize_html(first_fragment.abstract)}</div>')
paper_info.append('</div>')
# 添加目录结构
if first_fragment.catalogs:
paper_info.append('<h2>Document Structure</h2>')
paper_info.append('<pre>')
paper_info.append(self._sanitize_html(first_fragment.catalogs))
paper_info.append('</pre>')
paper_info.append('</div>')
return '\n'.join(paper_info)
def format_table_of_contents(self, sections: Dict[str, List[SectionFragment]]) -> str:
"""生成目录"""
toc = ['<div class="toc">']
toc.append('<div class="toc-title">Table of Contents</div>')
toc.append('<nav>')
for section in sections:
section_id = self._create_section_id(section)
clean_section = section.strip() or "Uncategorized"
toc.append(f'<div><a href="#{section_id}">{self._sanitize_html(clean_section)} '
f'</a></div>')
toc.append('</nav>')
toc.append('</div>')
return '\n'.join(toc)
def format_sections(self) -> str:
"""格式化论文各部分内容"""
sections = {}
for fragment in self.fragments:
section = fragment.current_section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
formatted_html = ['<div class="content">']
formatted_html.append(self.format_table_of_contents(sections))
# 生成各部分内容
for section, fragments in sections.items():
section_id = self._create_section_id(section)
formatted_html.append(f'<h2 id="{section_id}">')
formatted_html.append(f'<span class="section-title">')
formatted_html.append(f'<span class="section-icon">§</span>')
formatted_html.append(f'{self._sanitize_html(section)}')
formatted_html.append('</span>')
formatted_html.append('</h2>')
formatted_html.append('<div class="section-content">')
for i, fragment in enumerate(fragments, 1):
formatted_html.append('<div class="fragment">')
# 添加内容
if fragment.content:
formatted_html.append(
f'<div class="fragment-content">{self._sanitize_html(fragment.content)}</div>'
)
# 添加参考文献
if fragment.bibliography:
formatted_html.append('<div class="bibliography">')
formatted_html.append(f'{self._sanitize_html(fragment.bibliography)}')
formatted_html.append('</div>')
formatted_html.append('</div>')
formatted_html.append('</div>')
formatted_html.append('</div>')
return '\n'.join(formatted_html)
def save_html(self) -> Path:
"""保存HTML文档"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"paper_content_{timestamp}.html"
file_path = self.output_dir / filename
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>{self._sanitize_html(self.fragments[0].title if self.fragments else 'Paper Content')}</title>
<style>
{self.css_styles}
</style>
</head>
<body>
<div class="container">
{self.format_paper_info()}
{self.format_sections()}
</div>
</body>
</html>
"""
with open(file_path, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"HTML document saved to: {file_path}")
return file_path
except Exception as e:
print(f"Error saving HTML document: {str(e)}")
raise
# 使用示例:
# formatter = PaperHtmlFormatter(fragments, output_dir)
# output_path = formatter.save_html()

View File

@@ -300,8 +300,7 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
write_html(pfg.sp_file_contents, pfg.sp_file_result, chatbot=chatbot, project_folder=project_folder)
# <-------- 写出文件 ---------->
model_name = llm_kwargs['llm_model'].replace('_', '\\_') # 替换LLM模型名称中的下划线为转义字符
msg = f"当前大语言模型: {model_name},当前语言模型温度设定: {llm_kwargs['temperature']}"
msg = f"当前大语言模型: {llm_kwargs['llm_model']},当前语言模型温度设定: {llm_kwargs['temperature']}"
final_tex = lps.merge_result(pfg.file_result, mode, msg)
objdump((lps, pfg.file_result, mode, msg), file=pj(project_folder,'merge_result.pkl'))
@@ -352,41 +351,6 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
chatbot.append([f"正在编译PDF文档", f'编译已经开始。当前工作路径为{work_folder}如果程序停顿5分钟以上请直接去该路径下取回翻译结果或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history)
chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面
yield from update_ui_lastest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面
# 检查是否需要使用xelatex
def check_if_need_xelatex(tex_path):
try:
with open(tex_path, 'r', encoding='utf-8', errors='replace') as f:
content = f.read(5000)
# 检查是否有使用xelatex的宏包
need_xelatex = any(
pkg in content
for pkg in ['fontspec', 'xeCJK', 'xetex', 'unicode-math', 'xltxtra', 'xunicode']
)
if need_xelatex:
logger.info(f"检测到宏包需要xelatex编译, 切换至xelatex编译")
else:
logger.info(f"未检测到宏包需要xelatex编译, 使用pdflatex编译")
return need_xelatex
except Exception:
return False
# 根据编译器类型返回编译命令
def get_compile_command(compiler, filename):
compile_command = f'{compiler} -interaction=batchmode -file-line-error {filename}.tex'
logger.info('Latex 编译指令: ' + compile_command)
return compile_command
# 确定使用的编译器
compiler = 'pdflatex'
if check_if_need_xelatex(pj(work_folder_modified, f'{main_file_modified}.tex')):
logger.info("检测到宏包需要xelatex编译切换至xelatex编译")
# Check if xelatex is installed
try:
import subprocess
subprocess.run(['xelatex', '--version'], capture_output=True, check=True)
compiler = 'xelatex'
except (subprocess.CalledProcessError, FileNotFoundError):
raise RuntimeError("检测到需要使用xelatex编译但系统中未安装xelatex。请先安装texlive或其他提供xelatex的LaTeX发行版。")
while True:
import os
@@ -397,10 +361,10 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
# https://stackoverflow.com/questions/738755/dont-make-me-manually-abort-a-latex-compile-when-theres-an-error
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译原始PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译转化后的PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
if ok and os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf')):
# 只有第二步成功,才能继续下面的步骤
@@ -411,10 +375,10 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
ok = compile_latex_with_timeout(f'bibtex {main_file_modified}.aux', work_folder_modified)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译文献交叉引用 ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_original), work_folder_original)
ok = compile_latex_with_timeout(get_compile_command(compiler, main_file_modified), work_folder_modified)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
if mode!='translate_zh':
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
@@ -422,10 +386,10 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(get_compile_command(compiler, 'merge_diff'), work_folder)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
ok = compile_latex_with_timeout(f'bibtex merge_diff.aux', work_folder)
ok = compile_latex_with_timeout(get_compile_command(compiler, 'merge_diff'), work_folder)
ok = compile_latex_with_timeout(get_compile_command(compiler, 'merge_diff'), work_folder)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
# <---------- 检查结果 ----------->
results_ = ""

View File

@@ -1,43 +0,0 @@
from toolbox import update_ui, get_conf, promote_file_to_downloadzone, update_ui_lastest_msg, generate_file_link
from shared_utils.docker_as_service_api import stream_daas
from shared_utils.docker_as_service_api import DockerServiceApiComModel
import random
def download_video(video_id, only_audio, user_name, chatbot, history):
from toolbox import get_log_folder
chatbot.append([None, "Processing..."])
yield from update_ui(chatbot, history)
client_command = f'{video_id} --audio-only' if only_audio else video_id
server_urls = get_conf('DAAS_SERVER_URLS')
server_url = random.choice(server_urls)
docker_service_api_com_model = DockerServiceApiComModel(client_command=client_command)
save_file_dir = get_log_folder(user_name, plugin_name='media_downloader')
for output_manifest in stream_daas(docker_service_api_com_model, server_url, save_file_dir):
status_buf = ""
status_buf += "DaaS message: \n\n"
status_buf += output_manifest['server_message'].replace('\n', '<br/>')
status_buf += "\n\n"
status_buf += "DaaS standard error: \n\n"
status_buf += output_manifest['server_std_err'].replace('\n', '<br/>')
status_buf += "\n\n"
status_buf += "DaaS standard output: \n\n"
status_buf += output_manifest['server_std_out'].replace('\n', '<br/>')
status_buf += "\n\n"
status_buf += "DaaS file attach: \n\n"
status_buf += str(output_manifest['server_file_attach'])
yield from update_ui_lastest_msg(status_buf, chatbot, history)
return output_manifest['server_file_attach']
def search_videos(keywords):
from toolbox import get_log_folder
client_command = keywords
server_urls = get_conf('DAAS_SERVER_URLS')
server_url = random.choice(server_urls)
server_url = server_url.replace('stream', 'search')
docker_service_api_com_model = DockerServiceApiComModel(client_command=client_command)
save_file_dir = get_log_folder("default_user", plugin_name='media_downloader')
for output_manifest in stream_daas(docker_service_api_com_model, server_url, save_file_dir):
return output_manifest['server_message']

View File

@@ -6,128 +6,75 @@ from crazy_functions.crazy_utils import get_files_from_everything
from shared_utils.colorful import *
from loguru import logger
import os
import requests
import time
def refresh_key(doc2x_api_key):
import requests, json
url = "https://api.doc2x.noedgeai.com/api/token/refresh"
res = requests.post(
url,
headers={"Authorization": "Bearer " + doc2x_api_key}
)
res_json = []
if res.status_code == 200:
decoded = res.content.decode("utf-8")
res_json = json.loads(decoded)
doc2x_api_key = res_json['data']['token']
else:
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
return doc2x_api_key
def retry_request(max_retries=3, delay=3):
"""
Decorator for retrying HTTP requests
Args:
max_retries: Maximum number of retry attempts
delay: Delay between retries in seconds
"""
def decorator(func):
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt < max_retries - 1:
logger.error(
f"Request failed, retrying... ({attempt + 1}/{max_retries}) Error: {e}"
)
time.sleep(delay)
continue
raise e
return None
return wrapper
return decorator
@retry_request()
def make_request(method, url, **kwargs):
"""
Make HTTP request with retry mechanism
"""
return requests.request(method, url, **kwargs)
def doc2x_api_response_status(response, uid=""):
"""
Check the status of Doc2x API response
Args:
response_data: Response object from Doc2x API
"""
response_json = response.json()
response_data = response_json.get("data", {})
code = response_json.get("code", "Unknown")
meg = response_data.get("message", response_json)
trace_id = response.headers.get("trace-id", "Failed to get trace-id")
if response.status_code != 200:
raise RuntimeError(
f"Doc2x return an error:\nTrace ID: {trace_id} {uid}\n{response.status_code} - {response_json}"
)
if code in ["parse_page_limit_exceeded", "parse_concurrency_limit"]:
raise RuntimeError(
f"Reached the limit of Doc2x:\nTrace ID: {trace_id} {uid}\n{code} - {meg}"
)
if code not in ["ok", "success"]:
raise RuntimeError(
f"Doc2x return an error:\nTrace ID: {trace_id} {uid}\n{code} - {meg}"
)
return response_data
def 解析PDF_DOC2X_转Latex(pdf_file_path):
zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format="tex")
zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format='tex')
return unzipped_folder
def 解析PDF_DOC2X(pdf_file_path, format="tex"):
def 解析PDF_DOC2X(pdf_file_path, format='tex'):
"""
format: 'tex', 'md', 'docx'
format: 'tex', 'md', 'docx'
"""
DOC2X_API_KEY = get_conf("DOC2X_API_KEY")
import requests, json, os
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
doc2x_api_key = DOC2X_API_KEY
# < ------ 第1步预上传获取URL然后上传文件 ------ >
logger.info("Doc2x 上传文件预上传获取URL")
res = make_request(
"POST",
"https://v2.doc2x.noedgeai.com/api/v2/parse/preupload",
headers={"Authorization": "Bearer " + doc2x_api_key},
timeout=15,
)
res_data = doc2x_api_response_status(res)
upload_url = res_data["url"]
uuid = res_data["uid"]
logger.info("Doc2x 上传文件:上传文件")
with open(pdf_file_path, "rb") as file:
res = make_request("PUT", upload_url, data=file, timeout=60)
res.raise_for_status()
# < ------ 第1步上传 ------ >
logger.info("Doc2x 第1步上传")
with open(pdf_file_path, 'rb') as file:
res = requests.post(
"https://v2.doc2x.noedgeai.com/api/v2/parse/pdf",
headers={"Authorization": "Bearer " + doc2x_api_key},
data=file
)
# res_json = []
if res.status_code == 200:
res_json = res.json()
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
uuid = res_json['data']['uid']
# < ------ 第2步轮询等待 ------ >
logger.info("Doc2x 处理文件中:轮询等待")
params = {"uid": uuid}
max_attempts = 60
attempt = 0
while attempt < max_attempts:
res = make_request(
"GET",
"https://v2.doc2x.noedgeai.com/api/v2/parse/status",
logger.info("Doc2x 第2步:轮询等待")
params = {'uid': uuid}
while True:
res = requests.get(
'https://v2.doc2x.noedgeai.com/api/v2/parse/status',
headers={"Authorization": "Bearer " + doc2x_api_key},
params=params,
timeout=15,
params=params
)
res_data = doc2x_api_response_status(res)
if res_data["status"] == "success":
res_json = res.json()
if res_json['data']['status'] == "success":
break
elif res_data["status"] == "processing":
time.sleep(5)
logger.info(f"Doc2x is processing at {res_data['progress']}%")
attempt += 1
else:
raise RuntimeError(f"Doc2x return an error: {res_data}")
if attempt >= max_attempts:
raise RuntimeError("Doc2x processing timeout after maximum attempts")
elif res_json['data']['status'] == "processing":
time.sleep(3)
logger.info(f"Doc2x is processing at {res_json['data']['progress']}%")
elif res_json['data']['status'] == "failed":
raise RuntimeError(f"Doc2x return an error: {res_json}")
# < ------ 第3步提交转化 ------ >
logger.info("Doc2x 第3步提交转化")
@@ -137,44 +84,42 @@ def 解析PDF_DOC2X(pdf_file_path, format="tex"):
"formula_mode": "dollar",
"filename": "output"
}
res = make_request(
"POST",
"https://v2.doc2x.noedgeai.com/api/v2/convert/parse",
res = requests.post(
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse',
headers={"Authorization": "Bearer " + doc2x_api_key},
json=data,
timeout=15,
json=data
)
doc2x_api_response_status(res, uid=f"uid: {uuid}")
if res.status_code == 200:
res_json = res.json()
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
# < ------ 第4步等待结果 ------ >
logger.info("Doc2x 第4步等待结果")
params = {"uid": uuid}
max_attempts = 36
attempt = 0
while attempt < max_attempts:
res = make_request(
"GET",
"https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result",
params = {'uid': uuid}
while True:
res = requests.get(
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result',
headers={"Authorization": "Bearer " + doc2x_api_key},
params=params,
timeout=15,
params=params
)
res_data = doc2x_api_response_status(res, uid=f"uid: {uuid}")
if res_data["status"] == "success":
res_json = res.json()
if res_json['data']['status'] == "success":
break
elif res_data["status"] == "processing":
elif res_json['data']['status'] == "processing":
time.sleep(3)
logger.info("Doc2x still processing to convert file")
attempt += 1
if attempt >= max_attempts:
raise RuntimeError("Doc2x conversion timeout after maximum attempts")
logger.info(f"Doc2x still processing")
elif res_json['data']['status'] == "failed":
raise RuntimeError(f"Doc2x return an error: {res_json}")
# < ------ 第5步最后的处理 ------ >
logger.info("Doc2x 第5步下载转换后的文件")
logger.info("Doc2x 第5步最后的处理")
if format == "tex":
if format=='tex':
target_path = latex_dir
if format == "md":
if format=='md':
target_path = markdown_dir
os.makedirs(target_path, exist_ok=True)
@@ -182,18 +127,17 @@ def 解析PDF_DOC2X(pdf_file_path, format="tex"):
# < ------ 下载 ------ >
for attempt in range(max_attempt):
try:
result_url = res_data["url"]
res = make_request("GET", result_url, timeout=60)
zip_path = os.path.join(target_path, gen_time_str() + ".zip")
result_url = res_json['data']['url']
res = requests.get(result_url)
zip_path = os.path.join(target_path, gen_time_str() + '.zip')
unzip_path = os.path.join(target_path, gen_time_str())
if res.status_code == 200:
with open(zip_path, "wb") as f:
f.write(res.content)
with open(zip_path, "wb") as f: f.write(res.content)
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
except Exception as e:
if attempt < max_attempt - 1:
logger.error(f"Failed to download uid = {uuid} file, retrying... {e}")
logger.error(f"Failed to download latex file, retrying... {e}")
time.sleep(3)
continue
else:
@@ -201,31 +145,22 @@ def 解析PDF_DOC2X(pdf_file_path, format="tex"):
# < ------ 解压 ------ >
import zipfile
with zipfile.ZipFile(zip_path, "r") as zip_ref:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(unzip_path)
return zip_path, unzip_path
def 解析PDF_DOC2X_单文件(
fp,
project_folder,
llm_kwargs,
plugin_kwargs,
chatbot,
history,
system_prompt,
DOC2X_API_KEY,
user_request,
):
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
def pdf2markdown(filepath):
chatbot.append((None, f"Doc2x 解析中"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format="md")
md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format='md')
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return md_zip_path
def deliver_to_markdown_plugin(md_zip_path, user_request):
@@ -239,95 +174,77 @@ def 解析PDF_DOC2X_单文件(
os.makedirs(target_path_base, exist_ok=True)
shutil.copyfile(md_zip_path, this_file_path)
ex_folder = this_file_path + ".extract"
extract_archive(file_path=this_file_path, dest_dir=ex_folder)
extract_archive(
file_path=this_file_path, dest_dir=ex_folder
)
# edit markdown files
success, file_manifest, project_folder = get_files_from_everything(ex_folder, type='.md', chatbot=chatbot)
success, file_manifest, project_folder = get_files_from_everything(ex_folder, type='.md')
for generated_fp in file_manifest:
# 修正一些公式问题
with open(generated_fp, "r", encoding="utf8") as f:
with open(generated_fp, 'r', encoding='utf8') as f:
content = f.read()
# 将公式中的\[ \]替换成$$
content = content.replace(r"\[", r"$$").replace(r"\]", r"$$")
content = content.replace(r'\[', r'$$').replace(r'\]', r'$$')
# 将公式中的\( \)替换成$
content = content.replace(r"\(", r"$").replace(r"\)", r"$")
content = content.replace("```markdown", "\n").replace("```", "\n")
with open(generated_fp, "w", encoding="utf8") as f:
content = content.replace(r'\(', r'$').replace(r'\)', r'$')
content = content.replace('```markdown', '\n').replace('```', '\n')
with open(generated_fp, 'w', encoding='utf8') as f:
f.write(content)
promote_file_to_downloadzone(generated_fp, chatbot=chatbot)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 生成在线预览html
file_name = "在线预览翻译(原文)" + gen_time_str() + ".html"
file_name = '在线预览翻译(原文)' + gen_time_str() + '.html'
preview_fp = os.path.join(ex_folder, file_name)
from shared_utils.advanced_markdown_format import (
markdown_convertion_for_file,
)
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
with open(generated_fp, "r", encoding="utf-8") as f:
md = f.read()
# # Markdown中使用不标准的表格需要在表格前加上一个emoji以便公式渲染
# md = re.sub(r'^<table>', r'.<table>', md, flags=re.MULTILINE)
html = markdown_convertion_for_file(md)
with open(preview_fp, "w", encoding="utf-8") as f:
f.write(html)
with open(preview_fp, "w", encoding="utf-8") as f: f.write(html)
chatbot.append([None, f"生成在线预览:{generate_file_link([preview_fp])}"])
promote_file_to_downloadzone(preview_fp, chatbot=chatbot)
chatbot.append((None, f"调用Markdown插件 {ex_folder} ..."))
plugin_kwargs["markdown_expected_output_dir"] = ex_folder
translated_f_name = "translated_markdown.md"
generated_fp = plugin_kwargs["markdown_expected_output_path"] = os.path.join(
ex_folder, translated_f_name
)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from Markdown英译中(
ex_folder,
llm_kwargs,
plugin_kwargs,
chatbot,
history,
system_prompt,
user_request,
)
chatbot.append((None, f"调用Markdown插件 {ex_folder} ..."))
plugin_kwargs['markdown_expected_output_dir'] = ex_folder
translated_f_name = 'translated_markdown.md'
generated_fp = plugin_kwargs['markdown_expected_output_path'] = os.path.join(ex_folder, translated_f_name)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from Markdown英译中(ex_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
if os.path.exists(generated_fp):
# 修正一些公式问题
with open(generated_fp, "r", encoding="utf8") as f:
content = f.read()
content = content.replace("```markdown", "\n").replace("```", "\n")
with open(generated_fp, 'r', encoding='utf8') as f: content = f.read()
content = content.replace('```markdown', '\n').replace('```', '\n')
# Markdown中使用不标准的表格需要在表格前加上一个emoji以便公式渲染
# content = re.sub(r'^<table>', r'.<table>', content, flags=re.MULTILINE)
with open(generated_fp, "w", encoding="utf8") as f:
f.write(content)
with open(generated_fp, 'w', encoding='utf8') as f: f.write(content)
# 生成在线预览html
file_name = "在线预览翻译" + gen_time_str() + ".html"
file_name = '在线预览翻译' + gen_time_str() + '.html'
preview_fp = os.path.join(ex_folder, file_name)
from shared_utils.advanced_markdown_format import (
markdown_convertion_for_file,
)
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
with open(generated_fp, "r", encoding="utf-8") as f:
md = f.read()
html = markdown_convertion_for_file(md)
with open(preview_fp, "w", encoding="utf-8") as f:
f.write(html)
with open(preview_fp, "w", encoding="utf-8") as f: f.write(html)
promote_file_to_downloadzone(preview_fp, chatbot=chatbot)
# 生成包含图片的压缩包
dest_folder = get_log_folder(chatbot.get_user())
zip_name = "翻译后的带图文档.zip"
zip_folder(
source_folder=ex_folder, dest_folder=dest_folder, zip_name=zip_name
)
zip_name = '翻译后的带图文档.zip'
zip_folder(source_folder=ex_folder, dest_folder=dest_folder, zip_name=zip_name)
zip_fp = os.path.join(dest_folder, zip_name)
promote_file_to_downloadzone(zip_fp, chatbot=chatbot)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
md_zip_path = yield from pdf2markdown(fp)
yield from deliver_to_markdown_plugin(md_zip_path, user_request)
def 解析PDF_基于DOC2X(file_manifest, *args):
for index, fp in enumerate(file_manifest):
yield from 解析PDF_DOC2X_单文件(fp, *args)
return

View File

@@ -27,10 +27,10 @@ def extract_text_from_files(txt, chatbot, history):
return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容
#查找输入区内容中的文件
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf', chatbot=chatbot)
file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md', chatbot=chatbot)
file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx', chatbot=chatbot)
file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc', chatbot=chatbot)
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf')
file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md')
file_word,word_manifest,folder_word = get_files_from_everything(txt, '.docx')
file_doc,doc_manifest,folder_doc = get_files_from_everything(txt, '.doc')
if file_doc:
excption = "word"

View File

@@ -0,0 +1,115 @@
import logging
import tarfile
from pathlib import Path
from typing import Optional, Dict
import requests
class ArxivDownloader:
"""用于下载arXiv论文源码的下载器"""
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
"""
初始化下载器
Args:
root_dir: 保存下载文件的根目录
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(exist_ok=True)
self.proxies = proxies
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def _download_and_extract(self, arxiv_id: str) -> str:
"""
下载并解压arxiv论文源码
Args:
arxiv_id: arXiv论文ID例如"2103.00020"
Returns:
str: 解压后的文件目录路径
Raises:
RuntimeError: 当下载失败时抛出
"""
paper_dir = self.root_dir / arxiv_id
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 检查缓存
if paper_dir.exists() and any(paper_dir.iterdir()):
logging.info(f"Using cached version for {arxiv_id}")
return str(paper_dir)
paper_dir.mkdir(exist_ok=True)
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
for url in urls:
try:
logging.info(f"Downloading from {url}")
response = requests.get(url, proxies=self.proxies)
if response.status_code == 200:
tar_path.write_bytes(response.content)
with tarfile.open(tar_path, 'r:gz') as tar:
tar.extractall(path=paper_dir)
return str(paper_dir)
except Exception as e:
logging.warning(f"Download failed for {url}: {e}")
continue
raise RuntimeError(f"Failed to download paper {arxiv_id}")
def download_paper(self, arxiv_id: str) -> str:
"""
下载指定的arXiv论文
Args:
arxiv_id: arXiv论文ID
Returns:
str: 论文文件所在的目录路径
"""
return self._download_and_extract(arxiv_id)
def main():
"""测试下载功能"""
# 配置代理(如果需要)
proxies = {
"http": "http://your-proxy:port",
"https": "https://your-proxy:port"
}
# 创建下载器实例如果不需要代理可以不传入proxies参数
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
# 测试下载一篇论文这里使用一个示例ID
try:
paper_id = "2103.00020" # 这是一个示例ID
paper_dir = downloader.download_paper(paper_id)
print(f"Successfully downloaded paper to: {paper_dir}")
# 检查下载的文件
paper_path = Path(paper_dir)
if paper_path.exists():
print("Downloaded files:")
for file in paper_path.rglob("*"):
if file.is_file():
print(f"- {file.relative_to(paper_path)}")
except Exception as e:
print(f"Error downloading paper: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,836 @@
import asyncio
import logging
import re
import tarfile
import time
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Dict, Set
import aiohttp
from crazy_functions.rag_fns.arxiv_fns.author_extractor import LatexAuthorExtractor
from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructureParser, DocumentStructure, read_tex_file
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
"""
Save all fragments to a single structured markdown file.
Args:
fragments: List of SectionFragment objects
output_dir: Output directory path
Returns:
Path: Path to the generated markdown file
"""
from datetime import datetime
from pathlib import Path
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Generate filename
filename = f"paper_latex_content_{timestamp}.md"
file_path = output_path/ filename
# Group fragments by section
sections = {}
for fragment in fragments:
section = fragment.current_section or "Uncategorized"
if section not in sections:
sections[section] = []
sections[section].append(fragment)
with open(file_path, "w", encoding="utf-8") as f:
# Write document header
f.write("# Document Fragments Analysis\n\n")
f.write("## Overview\n")
f.write(f"- Total Fragments: {len(fragments)}\n")
f.write(f"- Generated Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
# Add paper information if available
if fragments and (fragments[0].title or fragments[0].abstract):
f.write("\n## Paper Information\n")
if fragments[0].title:
f.write(f"### Title\n{fragments[0].title}\n")
if fragments[0].authors:
f.write(f"\n### Authors\n{fragments[0].authors}\n")
if fragments[0].abstract:
f.write(f"\n### Abstract\n{fragments[0].abstract}\n")
# Write section tree if available
if fragments and fragments[0].catalogs:
f.write("\n## Section Tree\n")
f.write("```\n") # 添加代码块开始标记
f.write(fragments[0].catalogs)
f.write("\n```") # 添加代码块结束标记
# Generate table of contents
f.write("\n## Table of Contents\n")
for section in sections:
clean_section = section.strip() or "Uncategorized"
fragment_count = len(sections[section])
f.write(f"- [{clean_section}](#{clean_section.lower().replace(' ', '-')}) "
f"({fragment_count} fragments)\n")
# Write content sections
f.write("\n## Content\n")
for section, section_fragments in sections.items():
clean_section = section.strip() or "Uncategorized"
f.write(f"\n### {clean_section}\n")
# Write each fragment
for i, fragment in enumerate(section_fragments, 1):
f.write(f"\n#### Fragment {i}\n")
# Metadata
f.write("**Metadata:**\n")
metadata = [
f"- Section: {fragment.current_section}",
f"- Length: {len(fragment.content)} chars",
f"- ArXiv ID: {fragment.arxiv_id}" if fragment.arxiv_id else None
]
f.write("\n".join(filter(None, metadata)) + "\n")
# Content
f.write("\n**Content:**\n")
f.write("\n")
f.write(fragment.content)
f.write("\n")
# Bibliography if exists
if fragment.bibliography:
f.write("\n**Bibliography:**\n")
f.write("```bibtex\n")
f.write(fragment.bibliography)
f.write("\n```\n")
# Add separator
if i < len(section_fragments):
f.write("\n---\n")
# Add statistics
f.write("\n## Statistics\n")
# Length distribution
lengths = [len(f.content) for f in fragments]
f.write("\n### Length Distribution\n")
f.write(f"- Minimum: {min(lengths)} chars\n")
f.write(f"- Maximum: {max(lengths)} chars\n")
f.write(f"- Average: {sum(lengths) / len(lengths):.1f} chars\n")
# Section distribution
f.write("\n### Section Distribution\n")
for section, section_fragments in sections.items():
percentage = (len(section_fragments) / len(fragments)) * 100
f.write(f"- {section}: {len(section_fragments)} ({percentage:.1f}%)\n")
print(f"Fragments saved to: {file_path}")
return file_path
# 定义各种引用命令的模式
CITATION_PATTERNS = [
# 基本的 \cite{} 格式
r'\\cite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# natbib 格式
r'\\citep(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citet(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeauthor(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citeyear(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealt(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\citealp(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# biblatex 格式
r'\\textcite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\parencite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
r'\\autocite(?:\*)?(?:\[[^\]]*\])?{([^}]+)}',
# 自定义 [cite:...] 格式
r'\[cite:([^\]]+)\]',
]
# 编译所有模式
COMPILED_PATTERNS = [re.compile(pattern) for pattern in CITATION_PATTERNS]
class ArxivSplitter:
"""Arxiv论文智能分割器"""
def __init__(self,
root_dir: str = "gpt_log/arxiv_cache",
proxies: Optional[Dict[str, str]] = None,
cache_ttl: int = 7 * 24 * 60 * 60):
"""
初始化分割器
Args:
char_range: 字符数范围(最小值, 最大值)
root_dir: 缓存根目录
proxies: 代理设置
cache_ttl: 缓存过期时间(秒)
"""
self.root_dir = Path(root_dir)
self.root_dir.mkdir(parents=True, exist_ok=True)
self.proxies = proxies or {}
self.cache_ttl = cache_ttl
# 动态计算最优线程数
import multiprocessing
cpu_count = multiprocessing.cpu_count()
# 根据CPU核心数动态设置但设置上限防止过度并发
self.document_structure = DocumentStructure()
self.document_parser = EssayStructureParser()
self.max_workers = min(32, cpu_count * 2)
# 初始化TeX处理器
self.tex_processor = TexUtils()
# 配置日志
self._setup_logging()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _normalize_arxiv_id(self, input_str: str) -> str:
"""规范化ArXiv ID"""
if 'arxiv.org/' in input_str.lower():
# 处理URL格式
if '/pdf/' in input_str:
arxiv_id = input_str.split('/pdf/')[-1]
else:
arxiv_id = input_str.split('/abs/')[-1]
# 移除版本号和其他后缀
return arxiv_id.split('v')[0].strip()
return input_str.split('v')[0].strip()
def _check_cache(self, paper_dir: Path) -> bool:
"""
检查缓存是否有效,包括文件完整性检查
Args:
paper_dir: 论文目录路径
Returns:
bool: 如果缓存有效返回True否则返回False
"""
if not paper_dir.exists():
return False
# 检查目录中是否存在必要文件
has_tex_files = False
has_main_tex = False
for file_path in paper_dir.rglob("*"):
if file_path.suffix == '.tex':
has_tex_files = True
content = self.tex_processor.read_file(str(file_path))
if content and r'\documentclass' in content:
has_main_tex = True
break
if not (has_tex_files and has_main_tex):
return False
# 检查缓存时间
cache_time = paper_dir.stat().st_mtime
if (time.time() - cache_time) < self.cache_ttl:
self.logger.info(f"Using valid cache for {paper_dir.name}")
return True
return False
async def download_paper(self, arxiv_id: str, paper_dir: Path) -> bool:
"""
异步下载论文,包含重试机制和临时文件处理
Args:
arxiv_id: ArXiv论文ID
paper_dir: 目标目录路径
Returns:
bool: 下载成功返回True否则返回False
"""
from crazy_functions.rag_fns.arxiv_fns.arxiv_downloader import ArxivDownloader
temp_tar_path = paper_dir / f"{arxiv_id}_temp.tar.gz"
final_tar_path = paper_dir / f"{arxiv_id}.tar.gz"
# 确保目录存在
paper_dir.mkdir(parents=True, exist_ok=True)
# 尝试使用 ArxivDownloader 下载
try:
downloader = ArxivDownloader(root_dir=str(paper_dir), proxies=self.proxies)
downloaded_dir = downloader.download_paper(arxiv_id)
if downloaded_dir:
self.logger.info(f"Successfully downloaded using ArxivDownloader to {downloaded_dir}")
return True
except Exception as e:
self.logger.warning(f"ArxivDownloader failed: {str(e)}. Falling back to direct download.")
# 如果 ArxivDownloader 失败,使用原有的下载方式作为备选
urls = [
f"https://arxiv.org/src/{arxiv_id}",
f"https://arxiv.org/e-print/{arxiv_id}"
]
max_retries = 3
retry_delay = 1 # 初始重试延迟(秒)
for url in urls:
for attempt in range(max_retries):
try:
self.logger.info(f"Downloading from {url} (attempt {attempt + 1}/{max_retries})")
async with aiohttp.ClientSession() as session:
async with session.get(url, proxy=self.proxies.get('http')) as response:
if response.status == 200:
content = await response.read()
# 写入临时文件
temp_tar_path.write_bytes(content)
try:
# 验证tar文件完整性并解压
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._process_tar_file, temp_tar_path, paper_dir)
# 下载成功后移动临时文件到最终位置
temp_tar_path.rename(final_tar_path)
return True
except Exception as e:
self.logger.warning(f"Invalid tar file: {str(e)}")
if temp_tar_path.exists():
temp_tar_path.unlink()
except Exception as e:
self.logger.warning(f"Download attempt {attempt + 1} failed from {url}: {str(e)}")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
continue
return False
def _process_tar_file(self, tar_path: Path, extract_path: Path):
"""处理tar文件的同步操作"""
with tarfile.open(tar_path, 'r:gz') as tar:
tar.testall() # 验证文件完整性
tar.extractall(path=extract_path) # 解压文件
def process_references(self, doc_structure: DocumentStructure, ref_bib: str) -> DocumentStructure:
"""
Process citations in document structure and add referenced literature for each section
Args:
doc_structure: DocumentStructure object
ref_bib: String containing references separated by newlines
Returns:
Updated DocumentStructure object
"""
try:
# Create a copy to avoid modifying the original
doc = deepcopy(doc_structure)
# Parse references into a mapping
ref_map = self._parse_references(ref_bib)
if not ref_map:
self.logger.warning("No valid references found in ref_bib")
return doc
# Process all sections recursively
self._process_section_references(doc.toc, ref_map)
return doc
except Exception as e:
self.logger.error(f"Error processing references: {str(e)}")
return doc_structure # Return original if processing fails
def _process_section_references(self, sections: List[Section], ref_map: Dict[str, str]) -> None:
"""
Recursively process sections to add references
Args:
sections: List of Section objects
ref_map: Mapping of citation keys to full references
"""
for section in sections:
if section.content:
# Find citations in current section
cited_refs = self.find_citations(section.content)
if cited_refs:
# Get full references for citations
full_refs = []
for ref_key in cited_refs:
ref_text = ref_map.get(ref_key)
if ref_text:
full_refs.append(ref_text)
else:
self.logger.warning(f"Reference not found for citation key: {ref_key}")
# Add references to section content
if full_refs:
section.bibliography = "\n\n".join(full_refs)
# Process subsections recursively
if section.subsections:
self._process_section_references(section.subsections, ref_map)
def _parse_references(self, ref_bib: str) -> Dict[str, str]:
"""
Parse reference string into a mapping of citation keys to full references
Args:
ref_bib: Reference string with references separated by newlines
Returns:
Dict mapping citation keys to full reference text
"""
ref_map = {}
current_ref = []
current_key = None
try:
for line in ref_bib.split('\n'):
line = line.strip()
if not line:
continue
# New reference entry
if line.startswith('@'):
# Save previous reference if exists
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
current_ref = []
# Extract key from new reference
key_match = re.search(r'{(.*?),', line)
if key_match:
current_key = key_match.group(1)
current_ref.append(line)
else:
if current_ref is not None:
current_ref.append(line)
# Save last reference
if current_key and current_ref:
ref_map[current_key] = '\n'.join(current_ref)
except Exception as e:
self.logger.error(f"Error parsing references: {str(e)}")
return ref_map
# 编译一次正则表达式以提高效率
@staticmethod
def _clean_citation_key(key: str) -> str:
"""Clean individual citation key."""
return key.strip().strip(',').strip()
def _extract_keys_from_group(self, keys_str: str) -> Set[str]:
"""Extract and clean individual citation keys from a group."""
try:
# 分割多个引用键(支持逗号和分号分隔)
separators = '[,;]'
keys = re.split(separators, keys_str)
# 清理并过滤空键
return {self._clean_citation_key(k) for k in keys if self._clean_citation_key(k)}
except Exception as e:
self.logger.warning(f"Error processing citation group '{keys_str}': {e}")
return set()
def find_citations(self, content: str) -> Set[str]:
"""
Find citation keys in text content in various formats.
Args:
content: Text content to search for citations
Returns:
Set of unique citation keys
Examples:
Supported formats include:
- \cite{key1,key2}
- \cite[p. 1]{key}
- \citep{key}
- \citet{key}
- [cite:key1, key2]
- And many other variants
"""
citations = set()
if not content:
return citations
try:
# 对每个编译好的模式进行搜索
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取捕获组中的引用键
keys_str = match.group(1)
if keys_str:
# 提取并添加所有引用键
new_keys = self._extract_keys_from_group(keys_str)
citations.update(new_keys)
except Exception as e:
self.logger.error(f"Error finding citations: {str(e)}")
# 移除明显无效的键
citations = {key for key in citations
if key and not key.startswith(('\\', '{', '}', '[', ']'))}
return citations
def get_citation_contexts(self, content: str, context_chars: int = 100) -> dict:
"""
Find citations and their surrounding context.
Args:
content: Text content to search for citations
context_chars: Number of characters of context to include before/after
Returns:
Dict mapping citation keys to lists of context strings
"""
contexts = {}
if not content:
return contexts
try:
for pattern in COMPILED_PATTERNS:
matches = pattern.finditer(content)
for match in matches:
# 获取匹配的位置
start = max(0, match.start() - context_chars)
end = min(len(content), match.end() + context_chars)
# 获取上下文
context = content[start:end]
# 获取并处理引用键
keys_str = match.group(1)
keys = self._extract_keys_from_group(keys_str)
# 为每个键添加上下文
for key in keys:
if key not in contexts:
contexts[key] = []
contexts[key].append(context)
except Exception as e:
self.logger.error(f"Error finding citation contexts: {str(e)}")
return contexts
async def process(self, arxiv_id_or_url: str) -> List[SectionFragment]:
"""
Process ArXiv paper and convert to list of SectionFragments.
Each fragment represents the smallest section unit.
Args:
arxiv_id_or_url: ArXiv paper ID or URL
Returns:
List[SectionFragment]: List of processed paper fragments
"""
try:
arxiv_id = self._normalize_arxiv_id(arxiv_id_or_url)
paper_dir = self.root_dir / arxiv_id
# Check if paper directory exists, if not, try to download
if not paper_dir.exists():
self.logger.info(f"Downloading paper {arxiv_id}")
await self.download_paper(arxiv_id, paper_dir)
# Find main TeX file
main_tex = self.tex_processor.find_main_tex_file(str(paper_dir))
if not main_tex:
raise RuntimeError(f"No main TeX file found in {paper_dir}")
# 读取主 TeX 文件内容
main_tex_content = read_tex_file(main_tex)
# Get all related TeX files and references
tex_files = self.tex_processor.resolve_includes(main_tex)
ref_bib = self.tex_processor.resolve_references(main_tex, paper_dir)
if not tex_files:
raise RuntimeError(f"No valid TeX files found for {arxiv_id}")
# Reset document structure for new processing
self.document_structure = DocumentStructure()
# 提取作者信息
author_extractor = LatexAuthorExtractor()
authors = author_extractor.extract_authors(main_tex_content)
self.document_structure.authors = authors # 保存到文档结构中
# Process each TeX file
for file_path in tex_files:
self.logger.info(f"Processing TeX file: {file_path}")
tex_content = read_tex_file(file_path)
if tex_content:
additional_doc = self.document_parser.parse(tex_content)
self.document_structure = self.document_structure.merge(additional_doc)
# Process references if available
if ref_bib:
self.document_structure = self.process_references(self.document_structure, ref_bib)
self.logger.info("Successfully processed references")
else:
self.logger.info("No references found to process")
# Generate table of contents once
section_tree = self.document_structure.generate_toc_tree()
# Convert DocumentStructure to SectionFragments
fragments = self._convert_to_fragments(
doc_structure=self.document_structure,
arxiv_id=arxiv_id,
section_tree=section_tree
)
return fragments
except Exception as e:
self.logger.error(f"Failed to process {arxiv_id_or_url}: {str(e)}")
raise
def _convert_to_fragments(self,
doc_structure: DocumentStructure,
arxiv_id: str,
section_tree: str) -> List[SectionFragment]:
"""
Convert DocumentStructure to list of SectionFragments.
Creates a fragment for each leaf section in the document hierarchy.
Args:
doc_structure: Source DocumentStructure
arxiv_id: ArXiv paper ID
section_tree: Pre-generated table of contents tree
Returns:
List[SectionFragment]: List of paper fragments
"""
fragments = []
# Create a base template for all fragments to avoid repetitive assignments
base_fragment_template = {
'title': doc_structure.title,
'authors': doc_structure.authors,
'abstract': doc_structure.abstract,
'catalogs': section_tree,
'arxiv_id': arxiv_id
}
def get_leaf_sections(section: Section, path: List[str] = None) -> None:
"""
Recursively find all leaf sections and create fragments.
A leaf section is one that has content but no subsections, or has neither.
Args:
section: Current section being processed
path: List of section titles forming the path to current section
"""
if path is None:
path = []
current_path = path + [section.title]
if not section.subsections:
# This is a leaf section, create a fragment if it has content
if section.content or section.bibliography:
fragment = SectionFragment(
**base_fragment_template,
current_section="/".join(current_path),
content=self._clean_content(section.content),
bibliography=section.bibliography
)
if self._validate_fragment(fragment):
fragments.append(fragment)
else:
# Process each subsection
for subsection in section.subsections:
get_leaf_sections(subsection, current_path)
# Process all top-level sections
for section in doc_structure.toc:
get_leaf_sections(section)
# Add a fragment for the abstract if it exists
if doc_structure.abstract:
abstract_fragment = SectionFragment(
**base_fragment_template,
current_section="Abstract",
content=self._clean_content(doc_structure.abstract)
)
if self._validate_fragment(abstract_fragment):
fragments.insert(0, abstract_fragment)
self.logger.info(f"Created {len(fragments)} fragments")
return fragments
def _validate_fragment(self, fragment: SectionFragment) -> bool:
"""
Validate if the fragment has all required fields with meaningful content.
Args:
fragment: SectionFragment to validate
Returns:
bool: True if fragment is valid, False otherwise
"""
try:
return all([
fragment.title.strip(),
fragment.catalogs.strip(),
fragment.current_section.strip(),
fragment.content.strip() or fragment.bibliography.strip()
])
except AttributeError:
return False
def _clean_content(self, content: str) -> str:
"""
Clean and normalize content text.
Args:
content: Raw content text
Returns:
str: Cleaned content text
"""
if not content:
return ""
# Remove excessive whitespace
content = re.sub(r'\s+', ' ', content)
# Remove remaining LaTeX artifacts
content = re.sub(r'\\item\s*', '', content) # Convert \item to bullet points
content = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', content) # Remove simple LaTeX commands
# Clean special characters
content = content.replace('\\\\', '\n') # Convert LaTeX newlines to actual newlines
content = re.sub(r'\s*\n\s*', '\n', content) # Clean up newlines
return content.strip()
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, List[Path]]:
"""
同步处理 ArXiv 文档并返回分割后的片段
Args:
splitter: ArxivSplitter 实例
arxiv_id: ArXiv 文档ID
Returns:
list: 分割后的文档片段列表
"""
try:
from crazy_functions.doc_fns.tex_html_formatter import PaperHtmlFormatter
# 创建一个异步函数来执行异步操作
async def _process():
return await splitter.process(arxiv_id)
# 使用 asyncio.run() 运行异步函数
output_files=[]
fragments = asyncio.run(_process())
file_save_path = splitter.root_dir / "arxiv_fragments"
# 保存片段到文件
try:
md_output_dir = save_fragments_to_file(
fragments,
output_dir = file_save_path
)
output_files.append(md_output_dir)
except:
pass
# 创建论文格式化器
formatter = PaperContentFormatter()
# 准备元数据
# 创建格式化选项
metadata = PaperMetadata(
title=fragments[0].title,
authors=fragments[0].authors,
abstract=fragments[0].abstract,
catalogs=fragments[0].catalogs,
arxiv_id=fragments[0].arxiv_id
)
# 格式化内容
formatted_content = formatter.format(fragments, metadata)
try:
html_formatter = PaperHtmlFormatter(fragments, file_save_path)
html_output_dir = html_formatter.save_html()
output_files.append(html_output_dir)
except:
pass
return fragments, formatted_content, output_files
except Exception as e:
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
raise
def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
test_cases = [
{
"arxiv_id": "2411.03663",
"expected_title": "Large Language Models and Simple Scripts",
"min_fragments": 10,
},
# {
# "arxiv_id": "1805.10988",
# "expected_title": "RAG vs Fine-tuning",
# "min_fragments": 15,
# }
]
# 创建分割器实例
splitter = ArxivSplitter(
root_dir="private_upload/default_user"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
# fragments = await splitter.process(case['arxiv_id'])
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
# 保存fragments
for fragment in fragments:
# 长度检查
print((fragment.content))
print(len(fragment.content))
# 类型检查
print(output_dir)
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
raise
if __name__ == "__main__":
test_arxiv_splitter()

View File

@@ -0,0 +1,177 @@
import re
from typing import Optional
class LatexAuthorExtractor:
def __init__(self):
# Patterns for matching author blocks with balanced braces
self.author_block_patterns = [
# Standard LaTeX patterns with optional arguments
r'\\author(?:\s*\[[^\]]*\])?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\(?:title)?author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\name[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\Author[s]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\AUTHOR[S]?\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
# Conference and journal specific patterns
r'\\addauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\IEEEauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\speaker\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\authorrunning\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
# Academic publisher specific patterns
r'\\alignauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\spauthor\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
r'\\authors\s*\{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*)\}',
]
# Cleaning patterns for LaTeX commands and formatting
self.cleaning_patterns = [
# Text formatting commands - preserve content
(r'\\textbf\{([^}]+)\}', r'\1'),
(r'\\textit\{([^}]+)\}', r'\1'),
(r'\\emph\{([^}]+)\}', r'\1'),
(r'\\texttt\{([^}]+)\}', r'\1'),
(r'\\textrm\{([^}]+)\}', r'\1'),
(r'\\text\{([^}]+)\}', r'\1'),
# Affiliation and footnote markers
(r'\$\^{[^}]+}\$', ''),
(r'\^{[^}]+}', ''),
(r'\\thanks\{[^}]+\}', ''),
(r'\\footnote\{[^}]+\}', ''),
# Email and contact formatting
(r'\\email\{([^}]+)\}', r'\1'),
(r'\\href\{[^}]+\}\{([^}]+)\}', r'\1'),
# Institution formatting
(r'\\inst\{[^}]+\}', ''),
(r'\\affil\{[^}]+\}', ''),
# Special characters and symbols
(r'\\&', '&'),
(r'\\\\\s*', ' '),
(r'\\,', ' '),
(r'\\;', ' '),
(r'\\quad', ' '),
(r'\\qquad', ' '),
# Math mode content
(r'\$[^$]+\$', ''),
# Common symbols
(r'\\dagger', ''),
(r'\\ddagger', ''),
(r'\\ast', '*'),
(r'\\star', ''),
# Remove remaining LaTeX commands
(r'\\[a-zA-Z]+', ''),
# Clean up remaining special characters
(r'[\\{}]', '')
]
def extract_author_block(self, text: str) -> Optional[str]:
"""
Extract the complete author block from LaTeX text.
Args:
text (str): Input LaTeX text
Returns:
Optional[str]: Extracted author block or None if not found
"""
try:
if not text:
return None
for pattern in self.author_block_patterns:
match = re.search(pattern, text, re.DOTALL | re.MULTILINE)
if match:
return match.group(1).strip()
return None
except (AttributeError, IndexError) as e:
print(f"Error extracting author block: {e}")
return None
def clean_tex_commands(self, text: str) -> str:
"""
Remove LaTeX commands and formatting from text while preserving content.
Args:
text (str): Text containing LaTeX commands
Returns:
str: Cleaned text with commands removed
"""
if not text:
return ""
cleaned_text = text
# Apply cleaning patterns
for pattern, replacement in self.cleaning_patterns:
cleaned_text = re.sub(pattern, replacement, cleaned_text)
# Clean up whitespace
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
cleaned_text = cleaned_text.strip()
return cleaned_text
def extract_authors(self, text: str) -> Optional[str]:
"""
Extract and clean author information from LaTeX text.
Args:
text (str): Input LaTeX text
Returns:
Optional[str]: Cleaned author information or None if extraction fails
"""
try:
if not text:
return None
# Extract author block
author_block = self.extract_author_block(text)
if not author_block:
return None
# Clean LaTeX commands
cleaned_authors = self.clean_tex_commands(author_block)
return cleaned_authors or None
except Exception as e:
print(f"Error processing text: {e}")
return None
def test_author_extractor():
"""Test the LatexAuthorExtractor with sample inputs."""
test_cases = [
# Basic test case
(r"\author{John Doe}", "John Doe"),
# Test with multiple authors
(r"\author{Alice Smith \and Bob Jones}", "Alice Smith and Bob Jones"),
# Test with affiliations
(r"\author[1]{John Smith}\affil[1]{University}", "John Smith"),
]
extractor = LatexAuthorExtractor()
for i, (input_tex, expected) in enumerate(test_cases, 1):
result = extractor.extract_authors(input_tex)
print(f"\nTest case {i}:")
print(f"Input: {input_tex[:50]}...")
print(f"Expected: {expected[:50]}...")
print(f"Got: {result[:50]}...")
print(f"Pass: {bool(result and result.strip() == expected.strip())}")
if __name__ == "__main__":
test_author_extractor()

View File

@@ -0,0 +1,290 @@
"""
LaTeX Document Parser
This module provides functionality for parsing and extracting structured information from LaTeX documents,
including metadata, document structure, and content. It uses modular design and clean architecture principles.
"""
import logging
import re
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict
from crazy_functions.rag_fns.arxiv_fns.latex_cleaner import clean_latex_commands
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section, EnhancedSectionExtractor
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def read_tex_file(file_path):
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
@dataclass
class DocumentStructure:
title: str = ''
authors: str = ''
abstract: str = ''
toc: List[Section] = field(default_factory=list)
metadata: Dict[str, str] = field(default_factory=dict)
def merge(self, other: 'DocumentStructure', strategy: str = 'smart') -> 'DocumentStructure':
"""
Merge this document structure with another one.
Args:
other: Another DocumentStructure to merge with
strategy: Merge strategy - 'smart' (default) or 'append'
'smart' - Intelligently merge sections with same titles
'append' - Simply append sections from other document
"""
merged = deepcopy(self)
# Merge title if needed
if not merged.title and other.title:
merged.title = other.title
# Merge abstract
merged.abstract = self._merge_abstract(merged.abstract, other.abstract)
# Merge metadata
merged.metadata.update(other.metadata)
if strategy == 'append':
merged.toc.extend(deepcopy(other.toc))
else: # smart merge
# Create sections lookup for efficient merging
sections_map = {s.title: s for s in merged.toc}
for other_section in other.toc:
if other_section.title in sections_map:
# Merge existing section
idx = next(i for i, s in enumerate(merged.toc)
if s.title == other_section.title)
merged.toc[idx] = merged.toc[idx].merge(other_section)
else:
# Add new section
merged.toc.append(deepcopy(other_section))
return merged
@staticmethod
def _merge_abstract(abstract1: str, abstract2: str) -> str:
"""Merge abstracts intelligently."""
if not abstract1:
return abstract2
if not abstract2:
return abstract1
# Combine non-empty abstracts with a separator
return f"{abstract1}\n\n{abstract2}"
def generate_toc_tree(self, indent_char: str = " ", abstract_preview_length: int = 0) -> str:
"""
Generate a tree-like string representation of the table of contents including abstract.
Args:
indent_char: Character(s) used for indentation. Default is two spaces.
abstract_preview_length: Maximum length of abstract preview. Default is 200 characters.
Returns:
str: A formatted string showing the hierarchical document structure with abstract
"""
def _format_section(section: Section, level: int = 0) -> str:
# Create the current section line with proper indentation
current_line = f"{indent_char * level}{'' if level > 0 else ''} {section.title}\n"
# Recursively process subsections
subsections = ""
if section.subsections:
subsections = "".join(_format_section(subsec, level + 1)
for subsec in section.subsections)
return current_line + subsections
result = []
# Add document title if it exists
if self.title:
result.append(f"{self.title}\n")
# Add abstract if it exists
if self.abstract:
result.append("\n□ Abstract:")
# Format abstract content with word wrap
abstract_preview = self.abstract[:abstract_preview_length]
if len(self.abstract) > abstract_preview_length:
abstract_preview += "..."
# Split abstract into lines and indent them
wrapped_lines = []
current_line = ""
for word in abstract_preview.split():
if len(current_line) + len(word) + 1 <= 80: # 80 characters per line
current_line = (current_line + " " + word).strip()
else:
wrapped_lines.append(current_line)
current_line = word
if current_line:
wrapped_lines.append(current_line)
# Add formatted abstract lines
for line in wrapped_lines:
result.append(f"\n{indent_char}{line}")
result.append("\n") # Add extra newline after abstract
# Add table of contents header if there are sections
if self.toc:
result.append("\n◈ Table of Contents:\n")
# Add all top-level sections and their subsections
result.extend(_format_section(section, 0) for section in self.toc)
return "".join(result)
class BaseExtractor(ABC):
"""Base class for LaTeX content extractors."""
@abstractmethod
def extract(self, content: str) -> str:
"""Extract specific content from LaTeX document."""
pass
class TitleExtractor(BaseExtractor):
"""Extracts title from LaTeX document."""
PATTERNS = [
r'\\title{(.+?)}',
r'\\title\[.*?\]{(.+?)}',
r'\\Title{(.+?)}',
r'\\TITLE{(.+?)}',
r'\\begin{document}\s*\\section[*]?{(.+?)}',
r'\\maketitle\s*\\section[*]?{(.+?)}',
r'\\chapter[*]?{(.+?)}'
]
def extract(self, content: str) -> str:
"""Extract title using defined patterns."""
for pattern in self.PATTERNS:
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
for match in matches:
title = match.group(1).strip()
if title:
return clean_latex_commands(title)
return ''
class AbstractExtractor(BaseExtractor):
"""Extracts abstract from LaTeX document."""
PATTERNS = [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract{(.*?)}',
r'\\ABSTRACT{(.*?)}',
r'\\Abstract{(.*?)}',
r'\\begin{Abstract}(.*?)\\end{Abstract}',
r'\\section[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\section|\Z)',
r'\\chapter[*]?{(?:Abstract|ABSTRACT)}\s*(.*?)(?:\\chapter|\Z)'
]
def extract(self, content: str) -> str:
"""Extract abstract using defined patterns."""
for pattern in self.PATTERNS:
matches = list(re.finditer(pattern, content, re.IGNORECASE | re.DOTALL))
for match in matches:
abstract = match.group(1).strip()
if abstract:
return clean_latex_commands(abstract)
return ''
class EssayStructureParser:
"""Main class for parsing LaTeX documents."""
def __init__(self):
self.title_extractor = TitleExtractor()
self.abstract_extractor = AbstractExtractor()
self.section_extractor = EnhancedSectionExtractor() # Using the enhanced extractor
def parse(self, content: str) -> DocumentStructure:
"""Parse LaTeX document and extract structured information."""
try:
content = self._preprocess_content(content)
return DocumentStructure(
title=self.title_extractor.extract(content),
abstract=self.abstract_extractor.extract(content),
toc=self.section_extractor.extract(content)
)
except Exception as e:
logger.error(f"Error parsing LaTeX document: {str(e)}")
raise
def _preprocess_content(self, content: str) -> str:
"""Preprocess LaTeX content for parsing."""
# Remove comments
content = re.sub(r'(?<!\\)%.*$', '', content, flags=re.MULTILINE)
return content
def pretty_print_structure(doc: DocumentStructure, max_content_length: int = 100):
"""Print document structure in a readable format."""
print(f"Title: {doc.title}\n")
print(f"Abstract: {doc.abstract}\n")
print("Table of Contents:")
def print_section(section: Section, indent: int = 0):
print(" " * indent + f"- {section.title}")
if section.content:
preview = section.content[:max_content_length]
if len(section.content) > max_content_length:
preview += "..."
print(" " * (indent + 1) + f"Content: {preview}")
for subsection in section.subsections:
print_section(subsection, indent + 1)
for section in doc.toc:
print_section(section)
# Example usage:
if __name__ == "__main__":
# Test with a file
file_path = 'test_cache/2411.03663/neurips_2024.tex'
main_tex = read_tex_file(file_path)
# Parse main file
parser = EssayStructureParser()
main_doc = parser.parse(main_tex)
# Merge other documents
file_path_list = [
"test_cache/2411.03663/1_intro.tex",
"test_cache/2411.03663/0_abstract.tex",
"test_cache/2411.03663/2_pre.tex",
"test_cache/2411.03663/3_method.tex",
"test_cache/2411.03663/4_experiment.tex",
"test_cache/2411.03663/5_related_work.tex",
"test_cache/2411.03663/6_conclu.tex",
"test_cache/2411.03663/reference.bib"
]
for file_path in file_path_list:
tex_content = read_tex_file(file_path)
additional_doc = parser.parse(tex_content)
main_doc = main_doc.merge(additional_doc)
tree = main_doc.generate_toc_tree()
pretty_print_structure(main_doc)

View File

@@ -0,0 +1,329 @@
import logging
import re
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from typing import Set, Dict, Pattern, Optional, List, Tuple
class EnvType(Enum):
"""Environment classification types."""
PRESERVE = "preserve" # Preserve complete environment including commands
REMOVE = "remove" # Remove environment completely
EXTRACT = "extract" # Extract and clean content
@dataclass
class LatexConfig:
"""Configuration for LaTeX processing."""
preserve_envs: Set[str] = field(default_factory=lambda: {
# Math environments - preserve complete content
'equation', 'equation*', 'align', 'align*', 'displaymath',
'math', 'eqnarray', 'eqnarray*', 'gather', 'gather*',
'multline', 'multline*', 'flalign', 'flalign*',
'alignat', 'alignat*', 'cases', 'split', 'aligned',
# Tables and figures - preserve structure and content
'table', 'table*', 'tabular', 'tabularx', 'array', 'matrix',
'figure', 'figure*', 'subfigure', 'wrapfigure',
'minipage', 'tabbing', 'verbatim', 'longtable',
'sidewaystable', 'sidewaysfigure', 'floatrow',
# Arrays and matrices
'pmatrix', 'bmatrix', 'Bmatrix', 'vmatrix', 'Vmatrix',
'smallmatrix', 'array', 'matrix*', 'pmatrix*', 'bmatrix*',
# Algorithms and code
'algorithm', 'algorithmic', 'lstlisting', 'verbatim',
'minted', 'listing', 'algorithmic*', 'algorithm2e',
# Theorems and proofs
'theorem', 'proof', 'definition', 'lemma', 'corollary',
'proposition', 'example', 'remark', 'note', 'claim',
'axiom', 'property', 'assumption', 'conjecture', 'observation',
# Bibliography
'thebibliography', 'bibliography', 'references'
})
# 引用类命令的特殊处理配置
citation_commands: Set[str] = field(default_factory=lambda: {
# Basic citations
'cite', 'citep', 'citet', 'citeyear', 'citeauthor',
'citeyearpar', 'citetext', 'citenum',
# Natbib citations
'citefullauthor', 'citealp', 'citealt', 'citename',
'citepalias', 'citetalias', 'citetext',
# Cross-references
'ref', 'eqref', 'pageref', 'autoref', 'nameref', 'cref',
'Cref', 'vref', 'Vref', 'fref', 'pref',
# Hyperref
'hyperref', 'href', 'url',
# Labels
'label', 'tag'
})
preserve_commands: Set[str] = field(default_factory=lambda: {
# Text formatting
'emph', 'textbf', 'textit', 'underline', 'texttt', 'footnote',
'section', 'subsection', 'subsubsection', 'paragraph', 'part',
'chapter', 'title', 'author', 'date', 'thanks',
# Math operators and symbols
'frac', 'sum', 'int', 'prod', 'lim', 'sup', 'inf',
'partial', 'nabla', 'implies', 'iff', 'therefore',
'exists', 'forall', 'in', 'subset', 'subseteq',
# Greek letters and math symbols
'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu',
'nu', 'xi', 'pi', 'rho', 'sigma', 'tau',
'upsilon', 'phi', 'chi', 'psi', 'omega',
'Gamma', 'Delta', 'Theta', 'Lambda', 'Xi', 'Pi',
'Sigma', 'Upsilon', 'Phi', 'Psi', 'Omega',
# Math commands
'left', 'right', 'big', 'Big', 'bigg', 'Bigg',
'mathbf', 'mathit', 'mathsf', 'mathtt', 'mathbb',
'mathcal', 'mathfrak', 'mathscr', 'mathrm', 'mathop',
'operatorname', 'overline', 'underline', 'overbrace',
'underbrace', 'overset', 'underset', 'stackrel',
# Spacing and alignment
'quad', 'qquad', 'hspace', 'vspace', 'medskip',
'bigskip', 'smallskip', 'hfill', 'vfill', 'centering',
'raggedright', 'raggedleft'
})
remove_commands: Set[str] = field(default_factory=lambda: {
# Document setup
'documentclass', 'usepackage', 'input', 'include', 'includeonly',
'bibliographystyle', 'frontmatter', 'mainmatter',
'newtheorem', 'theoremstyle', 'proofname',
'newcommand', 'renewcommand', 'providecommand', 'DeclareMathOperator',
'newenvironment',
# Layout and spacing
'pagestyle', 'thispagestyle', 'newpage', 'clearpage',
'pagebreak', 'linebreak', 'newline', 'setlength',
'setcounter', 'addtocounter', 'makeatletter',
'makeatother', 'pagenumbering'
})
latex_chars: Dict[str, str] = field(default_factory=lambda: {
'~': ' ', '\\&': '&', '\\%': '%', '\\_': '_', '\\$': '$',
'\\#': '#', '\\{': '{', '\\}': '}', '``': '"', "''": '"',
'\\textbackslash': '\\', '\\ldots': '...', '\\dots': '...',
'\\textasciitilde': '~', '\\textasciicircum': '^'
})
# 保留原始格式的特殊命令模式
special_command_patterns: List[Tuple[str, str]] = field(default_factory=lambda: [
(r'\\cite\*?(?:\[[^\]]*\])?{([^}]+)}', r'\\cite{\1}'),
(r'\\ref\*?{([^}]+)}', r'\\ref{\1}'),
(r'\\label{([^}]+)}', r'\\label{\1}'),
(r'\\eqref{([^}]+)}', r'\\eqref{\1}'),
(r'\\autoref{([^}]+)}', r'\\autoref{\1}'),
(r'\\url{([^}]+)}', r'\\url{\1}'),
(r'\\href{([^}]+)}{([^}]+)}', r'\\href{\1}{\2}')
])
class LatexCleaner:
"""Enhanced LaTeX text cleaner that preserves mathematical content and citations."""
def __init__(self, config: Optional[LatexConfig] = None):
self.config = config or LatexConfig()
self.logger = logging.getLogger(__name__)
# 初始化正则表达式缓存
self._regex_cache = {}
@lru_cache(maxsize=128)
def _get_env_pattern(self, env_name: str) -> Pattern:
"""Get cached regex pattern for environment matching."""
return re.compile(fr'\\begin{{{env_name}}}(.*?)\\end{{{env_name}}}', re.DOTALL)
def _get_env_type(self, env_name: str) -> EnvType:
"""Determine environment processing type."""
if env_name.rstrip('*') in {name.rstrip('*') for name in self.config.preserve_envs}:
return EnvType.PRESERVE
elif env_name in {'comment'}:
return EnvType.REMOVE
return EnvType.EXTRACT
def _preserve_special_commands(self, text: str) -> str:
"""Preserve special commands like citations and references with their complete structure."""
for pattern, replacement in self.config.special_command_patterns:
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern)
def replace_func(match):
# 保持原始命令格式
return match.group(0)
text = self._regex_cache[pattern].sub(replace_func, text)
return text
def _process_environment(self, match: re.Match) -> str:
"""Process LaTeX environments while preserving complete content for special environments."""
try:
env_name = match.group(1)
content = match.group(2)
env_type = self._get_env_type(env_name)
if env_type == EnvType.PRESERVE:
# 完整保留环境内容
complete_env = match.group(0)
return f"\n[BEGIN_{env_name}]\n{complete_env}\n[END_{env_name}]\n"
elif env_type == EnvType.REMOVE:
return ' '
else:
# 处理嵌套环境
return self._clean_nested_environments(content)
except Exception as e:
self.logger.error(f"Error processing environment {match.group(1) if match else 'unknown'}: {e}")
return match.group(0)
def _preserve_inline_math(self, text: str) -> str:
"""Preserve complete inline math content."""
def preserve_math(match):
return f" {match.group(0)} "
patterns = [
(r'\$[^$]+\$', preserve_math),
(r'\\[\(\[].*?\\[\)\]]', preserve_math),
(r'\\begin{math}.*?\\end{math}', preserve_math)
]
for pattern, handler in patterns:
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
text = self._regex_cache[pattern].sub(handler, text)
return text
def _clean_nested_environments(self, text: str) -> str:
"""Process nested environments recursively."""
pattern = r'\\begin{(\w+)}(.*?)\\end{\1}'
if pattern not in self._regex_cache:
self._regex_cache[pattern] = re.compile(pattern, re.DOTALL)
return self._regex_cache[pattern].sub(self._process_environment, text)
def _clean_commands(self, text: str) -> str:
"""Clean LaTeX commands while preserving important content."""
# 首先处理特殊命令
text = self._preserve_special_commands(text)
# 保留内联数学
text = self._preserve_inline_math(text)
# 移除指定的命令
for cmd in self.config.remove_commands:
if cmd not in self._regex_cache:
self._regex_cache[cmd] = re.compile(
fr'\\{cmd}\*?(?:\[.*?\])?(?:{{.*?}})*'
)
text = self._regex_cache[cmd].sub('', text)
# 处理带内容的命令
def handle_command(match: re.Match) -> str:
cmd = match.group(1).rstrip('*')
if cmd in self.config.preserve_commands or cmd in self.config.citation_commands:
return match.group(0) # 完整保留命令和内容
return ' '
if 'command_pattern' not in self._regex_cache:
self._regex_cache['command_pattern'] = re.compile(
r'\\(\w+)\*?(?:\[.*?\])?{(.*?)}'
)
text = self._regex_cache['command_pattern'].sub(handle_command, text)
return text
def _normalize_text(self, text: str) -> str:
"""Normalize text while preserving special content markers."""
# 替换特殊字符
for char, replacement in self.config.latex_chars.items():
text = text.replace(char, replacement)
# 清理空白字符,同时保留环境标记
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'\s*\[BEGIN_(\w+)\]\s*', r'\n[BEGIN_\1]\n', text)
text = re.sub(r'\s*\[END_(\w+)\]\s*', r'\n[END_\1]\n', text)
# 保持块级环境之间的分隔
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def clean_text(self, text: str) -> str:
"""Clean LaTeX text while preserving mathematical content, citations, and special environments."""
if not text:
return ""
try:
# 移除注释
text = re.sub(r'(?<!\\)%.*?(?=\n|$)', '', text, flags=re.MULTILINE)
# 处理环境
text = self._clean_nested_environments(text)
# 清理命令并规范化
text = self._clean_commands(text)
text = self._normalize_text(text)
return text
except Exception as e:
self.logger.error(f"Error cleaning text: {e}")
return text # 发生错误时返回原始文本
def clean_latex_commands(text: str) -> str:
"""Convenience function for quick text cleaning with default config."""
cleaner = LatexCleaner()
return cleaner.clean_text(text)
# Example usage:
if __name__ == "__main__":
text = r"""
\documentclass{article}
\begin{document}
\section{Introduction}
This is a reference to \cite{smith2020} and equation \eqref{eq:main}.
\begin{equation}\label{eq:main}
E = mc^2 \times \sum_{i=1}^{n} x_i
\end{equation}
See Figure \ref{fig:example} for details.
\begin{figure}
\includegraphics{image.png}
\caption{Example figure\label
\textbf{Important} result: $E=mc^2$ and
\begin{equation}
F = ma
\end{equation}
\label{sec:intro}
"""
# Custom configuration
config = LatexConfig(
preserve_envs={},
preserve_commands={'textbf', 'emph'},
latex_chars={'~': ' ', '\\&': '&'}
)
def read_tex_file(file_path):
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
return content
except FileNotFoundError:
return "文件未找到,请检查路径是否正确。"
except Exception as e:
return f"读取文件时发生错误: {e}"
# 使用函数
file_path = 'test_cache/2411.03663/neurips_2024.tex'
content = read_tex_file(file_path)
cleaner = LatexCleaner(config)
text = cleaner.clean_text(text)
print(text)

View File

@@ -0,0 +1,396 @@
from dataclasses import dataclass
@dataclass
class LaTeXPatterns:
"""LaTeX模式存储类用于集中管理所有LaTeX相关的正则表达式模式"""
special_envs = {
'math': [
# 基础数学环境
r'\\begin{(equation|align|gather|eqnarray|multline|flalign|alignat)\*?}.*?\\end{\1\*?}',
r'\$\$.*?\$\$',
r'\$[^$]+\$',
# 矩阵环境
r'\\begin{(matrix|pmatrix|bmatrix|Bmatrix|vmatrix|Vmatrix|smallmatrix)\*?}.*?\\end{\1\*?}',
# 数组环境
r'\\begin{(array|cases|aligned|gathered|split)\*?}.*?\\end{\1\*?}',
# 其他数学环境
r'\\begin{(subequations|math|displaymath)\*?}.*?\\end{\1\*?}'
],
'table': [
# 基础表格环境
r'\\begin{(table|tabular|tabularx|tabulary|longtable)\*?}.*?\\end{\1\*?}',
# 复杂表格环境
r'\\begin{(tabu|supertabular|xtabular|mpsupertabular)\*?}.*?\\end{\1\*?}',
# 自定义表格环境
r'\\begin{(threeparttable|tablefootnote)\*?}.*?\\end{\1\*?}',
# 表格注释环境
r'\\begin{(tablenotes)\*?}.*?\\end{\1\*?}'
],
'figure': [
# 图片环境
r'\\begin{figure\*?}.*?\\end{figure\*?}',
r'\\begin{(subfigure|wrapfigure)\*?}.*?\\end{\1\*?}',
# 图片插入命令
r'\\includegraphics(\[.*?\])?\{.*?\}',
# tikz 图形环境
r'\\begin{(tikzpicture|pgfpicture)\*?}.*?\\end{\1\*?}',
# 其他图形环境
r'\\begin{(picture|pspicture)\*?}.*?\\end{\1\*?}'
],
'algorithm': [
# 算法环境
r'\\begin{(algorithm|algorithmic|algorithm2e|algorithmicx)\*?}.*?\\end{\1\*?}',
r'\\begin{(lstlisting|verbatim|minted|listing)\*?}.*?\\end{\1\*?}',
# 代码块环境
r'\\begin{(code|verbatimtab|verbatimwrite)\*?}.*?\\end{\1\*?}',
# 伪代码环境
r'\\begin{(pseudocode|procedure)\*?}.*?\\end{\1\*?}'
],
'list': [
# 列表环境
r'\\begin{(itemize|enumerate|description)\*?}.*?\\end{\1\*?}',
r'\\begin{(list|compactlist|bulletlist)\*?}.*?\\end{\1\*?}',
# 自定义列表环境
r'\\begin{(tasks|todolist)\*?}.*?\\end{\1\*?}'
],
'theorem': [
# 定理类环境
r'\\begin{(theorem|lemma|proposition|corollary)\*?}.*?\\end{\1\*?}',
r'\\begin{(definition|example|proof|remark)\*?}.*?\\end{\1\*?}',
# 其他证明环境
r'\\begin{(axiom|property|assumption|conjecture)\*?}.*?\\end{\1\*?}'
],
'box': [
# 文本框环境
r'\\begin{(tcolorbox|mdframed|framed|shaded)\*?}.*?\\end{\1\*?}',
r'\\begin{(boxedminipage|shadowbox)\*?}.*?\\end{\1\*?}',
# 强调环境
r'\\begin{(important|warning|info|note)\*?}.*?\\end{\1\*?}'
],
'quote': [
# 引用环境
r'\\begin{(quote|quotation|verse|abstract)\*?}.*?\\end{\1\*?}',
r'\\begin{(excerpt|epigraph)\*?}.*?\\end{\1\*?}'
],
'bibliography': [
# 参考文献环境
r'\\begin{(thebibliography|bibliography)\*?}.*?\\end{\1\*?}',
r'\\begin{(biblist|citelist)\*?}.*?\\end{\1\*?}'
],
'index': [
# 索引环境
r'\\begin{(theindex|printindex)\*?}.*?\\end{\1\*?}',
r'\\begin{(glossary|acronym)\*?}.*?\\end{\1\*?}'
]
}
# 章节模式
section_patterns = [
# 基础章节命令
r'\\chapter\{([^}]+)\}',
r'\\section\{([^}]+)\}',
r'\\subsection\{([^}]+)\}',
r'\\subsubsection\{([^}]+)\}',
r'\\paragraph\{([^}]+)\}',
r'\\subparagraph\{([^}]+)\}',
# 带星号的变体(不编号)
r'\\chapter\*\{([^}]+)\}',
r'\\section\*\{([^}]+)\}',
r'\\subsection\*\{([^}]+)\}',
r'\\subsubsection\*\{([^}]+)\}',
r'\\paragraph\*\{([^}]+)\}',
r'\\subparagraph\*\{([^}]+)\}',
# 特殊章节
r'\\part\{([^}]+)\}',
r'\\part\*\{([^}]+)\}',
r'\\appendix\{([^}]+)\}',
# 前言部分
r'\\frontmatter\{([^}]+)\}',
r'\\mainmatter\{([^}]+)\}',
r'\\backmatter\{([^}]+)\}',
# 目录相关
r'\\tableofcontents',
r'\\listoffigures',
r'\\listoftables',
# 自定义章节命令
r'\\addchap\{([^}]+)\}', # KOMA-Script类
r'\\addsec\{([^}]+)\}', # KOMA-Script类
r'\\minisec\{([^}]+)\}', # KOMA-Script类
# 带可选参数的章节命令
r'\\chapter\[([^]]+)\]\{([^}]+)\}',
r'\\section\[([^]]+)\]\{([^}]+)\}',
r'\\subsection\[([^]]+)\]\{([^}]+)\}'
]
# 包含模式
include_patterns = [
r'\\(input|include|subfile)\{([^}]+)\}'
]
metadata_patterns = {
# 标题相关
'title': [
r'\\title\{([^}]+)\}',
r'\\Title\{([^}]+)\}',
r'\\doctitle\{([^}]+)\}',
r'\\subtitle\{([^}]+)\}',
r'\\chapter\*?\{([^}]+)\}', # 第一章可能作为标题
r'\\maketitle\s*\\section\*?\{([^}]+)\}' # 第一节可能作为标题
],
# 摘要相关
'abstract': [
r'\\begin{abstract}(.*?)\\end{abstract}',
r'\\abstract\{([^}]+)\}',
r'\\begin{摘要}(.*?)\\end{摘要}',
r'\\begin{Summary}(.*?)\\end{Summary}',
r'\\begin{synopsis}(.*?)\\end{synopsis}',
r'\\begin{abstracten}(.*?)\\end{abstracten}' # 英文摘要
],
# 作者信息
'author': [
r'\\author\{([^}]+)\}',
r'\\Author\{([^}]+)\}',
r'\\authorinfo\{([^}]+)\}',
r'\\authors\{([^}]+)\}',
r'\\author\[([^]]+)\]\{([^}]+)\}', # 带附加信息的作者
r'\\begin{authors}(.*?)\\end{authors}'
],
# 日期相关
'date': [
r'\\date\{([^}]+)\}',
r'\\Date\{([^}]+)\}',
r'\\submitdate\{([^}]+)\}',
r'\\publishdate\{([^}]+)\}',
r'\\revisiondate\{([^}]+)\}'
],
# 关键词
'keywords': [
r'\\keywords\{([^}]+)\}',
r'\\Keywords\{([^}]+)\}',
r'\\begin{keywords}(.*?)\\end{keywords}',
r'\\key\{([^}]+)\}',
r'\\begin{关键词}(.*?)\\end{关键词}'
],
# 机构/单位
'institution': [
r'\\institute\{([^}]+)\}',
r'\\institution\{([^}]+)\}',
r'\\affiliation\{([^}]+)\}',
r'\\organization\{([^}]+)\}',
r'\\department\{([^}]+)\}'
],
# 学科/主题
'subject': [
r'\\subject\{([^}]+)\}',
r'\\Subject\{([^}]+)\}',
r'\\field\{([^}]+)\}',
r'\\discipline\{([^}]+)\}'
],
# 版本信息
'version': [
r'\\version\{([^}]+)\}',
r'\\revision\{([^}]+)\}',
r'\\release\{([^}]+)\}'
],
# 许可证/版权
'license': [
r'\\license\{([^}]+)\}',
r'\\copyright\{([^}]+)\}',
r'\\begin{license}(.*?)\\end{license}'
],
# 联系方式
'contact': [
r'\\email\{([^}]+)\}',
r'\\phone\{([^}]+)\}',
r'\\address\{([^}]+)\}',
r'\\contact\{([^}]+)\}'
],
# 致谢
'acknowledgments': [
r'\\begin{acknowledgments}(.*?)\\end{acknowledgments}',
r'\\acknowledgments\{([^}]+)\}',
r'\\thanks\{([^}]+)\}',
r'\\begin{致谢}(.*?)\\end{致谢}'
],
# 项目/基金
'funding': [
r'\\funding\{([^}]+)\}',
r'\\grant\{([^}]+)\}',
r'\\project\{([^}]+)\}',
r'\\support\{([^}]+)\}'
],
# 分类号/编号
'classification': [
r'\\classification\{([^}]+)\}',
r'\\serialnumber\{([^}]+)\}',
r'\\id\{([^}]+)\}',
r'\\doi\{([^}]+)\}'
],
# 语言
'language': [
r'\\documentlanguage\{([^}]+)\}',
r'\\lang\{([^}]+)\}',
r'\\language\{([^}]+)\}'
]
}
latex_only_patterns = {
# 文档类和包引入
r'\\documentclass(\[.*?\])?\{.*?\}',
r'\\usepackage(\[.*?\])?\{.*?\}',
# 常见的文档设置命令
r'\\setlength\{.*?\}\{.*?\}',
r'\\newcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\renewcommand\{.*?\}(\[.*?\])?\{.*?\}',
r'\\definecolor\{.*?\}\{.*?\}\{.*?\}',
# 页面设置相关
r'\\pagestyle\{.*?\}',
r'\\thispagestyle\{.*?\}',
# 其他常见的设置命令
r'\\bibliographystyle\{.*?\}',
r'\\bibliography\{.*?\}',
r'\\setcounter\{.*?\}\{.*?\}',
# 字体和文本设置命令
r'\\makeFNbottom',
r'\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}', # 匹配字体大小设置
r'\\renewcommand\\[A-Z]+\{\\@setfontsize\\[A-Z]+\{.*?\}\{.*?\}\}',
r'\\renewcommand\{?\\thefootnote\}?\{\\fnsymbol\{footnote\}\}',
r'\\renewcommand\\footnoterule\{.*?\}',
r'\\color\{.*?\}',
# 页面和节标题设置
r'\\setcounter\{secnumdepth\}\{.*?\}',
r'\\renewcommand\\@biblabel\[.*?\]\{.*?\}',
r'\\renewcommand\\@makefntext\[.*?\](\{.*?\})*',
r'\\renewcommand\{?\\figurename\}?\{.*?\}',
# 字体样式设置
r'\\sectionfont\{.*?\}',
r'\\subsectionfont\{.*?\}',
r'\\subsubsectionfont\{.*?\}',
# 间距和布局设置
r'\\setstretch\{.*?\}',
r'\\setlength\{\\skip\\footins\}\{.*?\}',
r'\\setlength\{\\footnotesep\}\{.*?\}',
r'\\setlength\{\\jot\}\{.*?\}',
r'\\hrule\s+width\s+.*?\s+height\s+.*?',
# makeatletter 和 makeatother
r'\\makeatletter\s*',
r'\\makeatother\s*',
r'\\footnotetext\{[^}]*\$\^{[^}]*}\$[^}]*\}', # 带有上标的脚注
# r'\\footnotetext\{[^}]*\}', # 普通脚注
# r'\\footnotetext\{.*?(?:\$\^{.*?}\$)?.*?(?:email\s*:\s*[^}]*)?.*?\}', # 带有邮箱的脚注
# r'\\footnotetext\{.*?(?:ESI|DOI).*?\}', # 带有 DOI 或 ESI 引用的脚注
# 文档结构命令
r'\\begin\{document\}',
r'\\end\{document\}',
r'\\maketitle',
r'\\printbibliography',
r'\\newpage',
# 输入文件命令
r'\\input\{[^}]*\}',
r'\\input\{.*?\.tex\}', # 特别匹配 .tex 后缀的输入
# 脚注相关
# r'\\footnotetext\[\d+\]\{[^}]*\}', # 带编号的脚注
# 致谢环境
r'\\begin\{ack\}',
r'\\end\{ack\}',
r'\\begin\{ack\}[^\n]*(?:\n.*?)*?\\end\{ack\}', # 匹配整个致谢环境及其内容
# 其他文档控制命令
r'\\renewcommand\{\\thefootnote\}\{\\fnsymbol\{footnote\}\}',
}
math_envs = [
# 基础数学环境
(r'\\begin{equation\*?}.*?\\end{equation\*?}', 'equation'), # 单行公式
(r'\\begin{align\*?}.*?\\end{align\*?}', 'align'), # 多行对齐公式
(r'\\begin{gather\*?}.*?\\end{gather\*?}', 'gather'), # 多行居中公式
(r'\$\$.*?\$\$', 'display'), # 行间公式
(r'\$.*?\$', 'inline'), # 行内公式
# 矩阵环境
(r'\\begin{matrix}.*?\\end{matrix}', 'matrix'), # 基础矩阵
(r'\\begin{pmatrix}.*?\\end{pmatrix}', 'pmatrix'), # 圆括号矩阵
(r'\\begin{bmatrix}.*?\\end{bmatrix}', 'bmatrix'), # 方括号矩阵
(r'\\begin{vmatrix}.*?\\end{vmatrix}', 'vmatrix'), # 竖线矩阵
(r'\\begin{Vmatrix}.*?\\end{Vmatrix}', 'Vmatrix'), # 双竖线矩阵
(r'\\begin{smallmatrix}.*?\\end{smallmatrix}', 'smallmatrix'), # 小号矩阵
# 数组环境
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
(r'\\begin{cases}.*?\\end{cases}', 'cases'), # 分段函数
# 多行公式环境
(r'\\begin{multline\*?}.*?\\end{multline\*?}', 'multline'), # 多行单个公式
(r'\\begin{split}.*?\\end{split}', 'split'), # 拆分长公式
(r'\\begin{alignat\*?}.*?\\end{alignat\*?}', 'alignat'), # 对齐环境带间距控制
(r'\\begin{flalign\*?}.*?\\end{flalign\*?}', 'flalign'), # 完全左对齐
# 特殊数学环境
(r'\\begin{subequations}.*?\\end{subequations}', 'subequations'), # 子公式编号
(r'\\begin{gathered}.*?\\end{gathered}', 'gathered'), # 居中对齐组
(r'\\begin{aligned}.*?\\end{aligned}', 'aligned'), # 内部对齐组
# 定理类环境
(r'\\begin{theorem}.*?\\end{theorem}', 'theorem'), # 定理
(r'\\begin{lemma}.*?\\end{lemma}', 'lemma'), # 引理
(r'\\begin{proof}.*?\\end{proof}', 'proof'), # 证明
# 数学模式中的表格环境
(r'\\begin{tabular}.*?\\end{tabular}', 'tabular'), # 表格
(r'\\begin{array}.*?\\end{array}', 'array'), # 数组
# 其他专业数学环境
(r'\\begin{CD}.*?\\end{CD}', 'CD'), # 交换图
(r'\\begin{boxed}.*?\\end{boxed}', 'boxed'), # 带框公式
(r'\\begin{empheq}.*?\\end{empheq}', 'empheq'), # 强调公式
# 化学方程式环境 (需要加载 mhchem 包)
(r'\\begin{reaction}.*?\\end{reaction}', 'reaction'), # 化学反应式
(r'\\ce\{.*?\}', 'chemequation'), # 化学方程式
# 物理单位环境 (需要加载 siunitx 包)
(r'\\SI\{.*?\}\{.*?\}', 'SI'), # 物理单位
(r'\\si\{.*?\}', 'si'), # 单位
# 补充环境
(r'\\begin{equation\+}.*?\\end{equation\+}', 'equation+'), # breqn包的自动换行公式
(r'\\begin{dmath\*?}.*?\\end{dmath\*?}', 'dmath'), # breqn包的显示数学模式
(r'\\begin{dgroup\*?}.*?\\end{dgroup\*?}', 'dgroup'), # breqn包的公式组
]
# 示例使用函数
# 使用示例

View File

@@ -0,0 +1,416 @@
import logging
import re
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Tuple
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SectionLevel(Enum):
CHAPTER = 0
SECTION = 1
SUBSECTION = 2
SUBSUBSECTION = 3
PARAGRAPH = 4
SUBPARAGRAPH = 5
def __lt__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value < other.value
def __le__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value <= other.value
def __gt__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value > other.value
def __ge__(self, other):
if not isinstance(other, SectionLevel):
return NotImplemented
return self.value >= other.value
@dataclass
class Section:
level: SectionLevel
title: str
content: str = ''
bibliography: str = ''
subsections: List['Section'] = field(default_factory=list)
def merge(self, other: 'Section') -> 'Section':
"""Merge this section with another section."""
if self.title != other.title or self.level != other.level:
raise ValueError("Can only merge sections with same title and level")
merged = deepcopy(self)
merged.content = self._merge_content(self.content, other.content)
# Create subsections lookup for efficient merging
subsections_map = {s.title: s for s in merged.subsections}
for other_subsection in other.subsections:
if other_subsection.title in subsections_map:
# Merge existing subsection
idx = next(i for i, s in enumerate(merged.subsections)
if s.title == other_subsection.title)
merged.subsections[idx] = merged.subsections[idx].merge(other_subsection)
else:
# Add new subsection
merged.subsections.append(deepcopy(other_subsection))
return merged
@staticmethod
def _merge_content(content1: str, content2: str) -> str:
"""Merge content strings intelligently."""
if not content1:
return content2
if not content2:
return content1
# Combine non-empty contents with a separator
return f"{content1}\n\n{content2}"
@dataclass
class LatexEnvironment:
"""表示LaTeX环境的数据类"""
name: str
start: int
end: int
content: str
raw: str
class EnhancedSectionExtractor:
"""Enhanced section extractor with comprehensive content handling and hierarchy management."""
def __init__(self, preserve_environments: bool = True):
"""
初始化Section提取器
Args:
preserve_environments: 是否保留特定环境如equation, figure等的原始LaTeX代码
"""
self.preserve_environments = preserve_environments
# Section级别定义
self.section_levels = {
'chapter': SectionLevel.CHAPTER,
'section': SectionLevel.SECTION,
'subsection': SectionLevel.SUBSECTION,
'subsubsection': SectionLevel.SUBSUBSECTION,
'paragraph': SectionLevel.PARAGRAPH,
'subparagraph': SectionLevel.SUBPARAGRAPH
}
# 需要保留的环境类型
self.important_environments = {
'equation', 'equation*', 'align', 'align*',
'figure', 'table', 'algorithm', 'algorithmic',
'definition', 'theorem', 'lemma', 'proof',
'itemize', 'enumerate', 'description'
}
# 改进的section pattern
self.section_pattern = (
r'\\(?P<type>chapter|section|subsection|subsubsection|paragraph|subparagraph)'
r'\*?' # Optional star
r'(?:\[(?P<short>.*?)\])?' # Optional short title
r'{(?P<title>(?:[^{}]|\{[^{}]*\})*?)}' # Main title with nested braces support
)
# 环境匹配模式
self.environment_pattern = (
r'\\begin{(?P<env_name>[^}]+)}'
r'(?P<env_content>.*?)'
r'\\end{(?P=env_name)}'
)
def _find_environments(self, content: str) -> List[LatexEnvironment]:
"""
查找文档中的所有LaTeX环境。
支持嵌套环境的处理。
"""
environments = []
stack = []
# 使用正则表达式查找所有begin和end标记
begin_pattern = r'\\begin{([^}]+)}'
end_pattern = r'\\end{([^}]+)}'
# 组合模式来同时匹配begin和end
tokens = []
for match in re.finditer(fr'({begin_pattern})|({end_pattern})', content):
if match.group(1): # begin标记
tokens.append(('begin', match.group(1), match.start()))
else: # end标记
tokens.append(('end', match.group(2), match.start()))
# 处理环境嵌套
for token_type, env_name, pos in tokens:
if token_type == 'begin':
stack.append((env_name, pos))
elif token_type == 'end' and stack:
if stack[-1][0] == env_name:
start_env_name, start_pos = stack.pop()
env_content = content[start_pos:pos]
raw_content = content[start_pos:pos + len('\\end{' + env_name + '}')]
if start_env_name in self.important_environments:
environments.append(LatexEnvironment(
name=start_env_name,
start=start_pos,
end=pos + len('\\end{' + env_name + '}'),
content=env_content,
raw=raw_content
))
return sorted(environments, key=lambda x: x.start)
def _protect_environments(self, content: str) -> Tuple[str, Dict[str, str]]:
"""
保护重要的LaTeX环境用占位符替换它们。
返回处理后的内容和恢复映射。
"""
environments = self._find_environments(content)
replacements = {}
# 从后向前替换,避免位置改变的问题
for env in reversed(environments):
if env.name in self.important_environments:
placeholder = f'__ENV_{len(replacements)}__'
replacements[placeholder] = env.raw
content = content[:env.start] + placeholder + content[env.end:]
return content, replacements
def _restore_environments(self, content: str, replacements: Dict[str, str]) -> str:
"""
恢复之前保护的环境。
"""
for placeholder, original in replacements.items():
content = content.replace(placeholder, original)
return content
def extract(self, content: str) -> List[Section]:
"""
从LaTeX文档中提取sections及其内容。
Args:
content: LaTeX文档内容
Returns:
List[Section]: 提取的section列表包含层次结构
"""
try:
# 预处理:保护重要环境
if self.preserve_environments:
content, env_replacements = self._protect_environments(content)
# 查找所有sections
sections = self._find_all_sections(content)
if not sections:
return []
# 处理sections
root_sections = self._process_sections(content, sections)
# 如果需要,恢复环境
if self.preserve_environments:
for section in self._traverse_sections(root_sections):
section.content = self._restore_environments(section.content, env_replacements)
return root_sections
except Exception as e:
logger.error(f"Error extracting sections: {str(e)}")
raise
def _find_all_sections(self, content: str) -> List[dict]:
"""查找所有section命令及其位置。"""
sections = []
for match in re.finditer(self.section_pattern, content, re.DOTALL | re.MULTILINE):
section_type = match.group('type').lower()
if section_type not in self.section_levels:
continue
section = {
'type': section_type,
'level': self.section_levels[section_type],
'title': self._clean_title(match.group('title')),
'start': match.start(),
'command_end': match.end(),
}
sections.append(section)
return sorted(sections, key=lambda x: x['start'])
def _process_sections(self, content: str, sections: List[dict]) -> List[Section]:
"""处理sections以构建层次结构和提取内容。"""
# 计算content范围
self._calculate_content_ranges(content, sections)
# 构建层次结构
root_sections = []
section_stack = []
for section_info in sections:
new_section = Section(
level=section_info['level'],
title=section_info['title'],
content=self._extract_clean_content(content, section_info),
subsections=[]
)
# 调整堆栈以找到正确的父section
while section_stack and section_stack[-1].level.value >= new_section.level.value:
section_stack.pop()
if section_stack:
section_stack[-1].subsections.append(new_section)
else:
root_sections.append(new_section)
section_stack.append(new_section)
return root_sections
def _calculate_content_ranges(self, content: str, sections: List[dict]):
for i, current in enumerate(sections):
content_start = current['command_end']
# 找到下一个section无论什么级别
content_end = len(content)
for next_section in sections[i + 1:]:
content_end = next_section['start']
break
current['content_range'] = (content_start, content_end)
def _calculate_content_ranges_with_subsection_content(self, content: str, sections: List[dict]):
"""为每个section计算内容范围。"""
for i, current in enumerate(sections):
content_start = current['command_end']
# 找到下一个同级或更高级的section
content_end = len(content)
for next_section in sections[i + 1:]:
if next_section['level'] <= current['level']:
content_end = next_section['start']
break
current['content_range'] = (content_start, content_end)
def _extract_clean_content(self, content: str, section_info: dict) -> str:
"""提取并清理section内容。"""
start, end = section_info['content_range']
raw_content = content[start:end]
# 清理内容
clean_content = self._clean_content(raw_content)
return clean_content
def _clean_content(self, content: str) -> str:
"""清理LaTeX内容同时保留重要信息。"""
# 移除注释
content = re.sub(r'(?<!\\)%.*?\n', '\n', content)
# LaTeX命令处理规则
replacements = [
# 保留引用
(r'\\cite(?:\[.*?\])?{(.*?)}', r'[cite:\1]'),
# 保留脚注
(r'\\footnote{(.*?)}', r'[footnote:\1]'),
# 处理引用
(r'\\ref{(.*?)}', r'[ref:\1]'),
# 保留URL
(r'\\url{(.*?)}', r'[url:\1]'),
# 保留超链接
(r'\\href{(.*?)}{(.*?)}', r'[\2](\1)'),
# 处理文本格式命令
(r'\\(?:textbf|textit|emph){(.*?)}', r'\1'),
# 保留特殊字符
(r'\\([&%$#_{}])', r'\1'),
]
# 应用所有替换规则
for pattern, replacement in replacements:
content = re.sub(pattern, replacement, content, flags=re.DOTALL)
# 清理多余的空白
content = re.sub(r'\n\s*\n', '\n\n', content)
return content.strip()
def _clean_title(self, title: str) -> str:
"""清理section标题。"""
# 处理嵌套的花括号
while '{' in title:
title = re.sub(r'{([^{}]*)}', r'\1', title)
# 处理LaTeX命令
title = re.sub(r'\\[a-zA-Z]+(?:\[.*?\])?{(.*?)}', r'\1', title)
title = re.sub(r'\\([&%$#_{}])', r'\1', title)
return title.strip()
def _traverse_sections(self, sections: List[Section]) -> List[Section]:
"""遍历所有sections包括子sections"""
result = []
for section in sections:
result.append(section)
result.extend(self._traverse_sections(section.subsections))
return result
def test_enhanced_extractor():
"""使用复杂的测试用例测试提取器。"""
test_content = r"""
\section{Complex Examples}
Here's a complex section with various environments.
\begin{equation}
E = mc^2
\end{equation}
\subsection{Nested Environments}
This subsection has nested environments.
\begin{figure}
\begin{equation*}
f(x) = \int_0^x g(t) dt
\end{equation*}
\caption{A nested equation in a figure}
\end{figure}
"""
extractor = EnhancedSectionExtractor()
sections = extractor.extract(test_content)
def print_section(section, level=0):
print("\n" + " " * level + f"[{section.level.name}] {section.title}")
if section.content:
content_preview = section.content[:150] + "..." if len(section.content) > 150 else section.content
print(" " * (level + 1) + f"Content: {content_preview}")
for subsection in section.subsections:
print_section(subsection, level + 1)
print("\nExtracted Section Structure:")
for section in sections:
print_section(section)
if __name__ == "__main__":
test_enhanced_extractor()

View File

@@ -0,0 +1,14 @@
from dataclasses import dataclass
@dataclass
class SectionFragment:
"""Arxiv论文片段数据类"""
title: str # 论文标题
authors: str
abstract: str # 论文摘要
catalogs: str # 文章各章节的目录结构
arxiv_id: str = "" # 添加 arxiv_id 属性
current_section: str = "Introduction" # 当前片段所属的section或者subsection或者孙subsubsection名字
content: str = '' # 当前片段的内容
bibliography: str = '' # 当前片段的参考文献

View File

@@ -0,0 +1,266 @@
import logging
import os
import re
from pathlib import Path
from typing import List, Set, Optional
from crazy_functions.rag_fns.arxiv_fns.latex_patterns import LaTeXPatterns
class TexUtils:
"""TeX文档处理器类"""
def __init__(self, ):
"""
初始化TeX处理器
Args:
char_range: 字符数范围(最小值, 最大值)
"""
self.logger = logging.getLogger(__name__)
# 初始化LaTeX环境和命令模式
self._init_patterns()
self.latex_only_patterns = LaTeXPatterns.latex_only_patterns
def _init_patterns(self):
"""初始化LaTeX模式匹配规则"""
# 特殊环境模式
self.special_envs = LaTeXPatterns.special_envs
# 章节模式
self.section_patterns = LaTeXPatterns.section_patterns
# 包含模式
self.include_patterns = LaTeXPatterns.include_patterns
# 元数据模式
self.metadata_patterns = LaTeXPatterns.metadata_patterns
def read_file(self, file_path: str) -> Optional[str]:
"""
读取TeX文件内容支持多种编码
Args:
file_path: 文件路径
Returns:
Optional[str]: 文件内容或None
"""
encodings = ['utf-8', 'latin1', 'gbk', 'gb2312', 'ascii']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
self.logger.warning(f"Failed to read {file_path} with all encodings")
return None
def find_main_tex_file(self, directory: str) -> Optional[str]:
"""
查找主TeX文件
Args:
directory: 目录路径
Returns:
Optional[str]: 主文件路径或None
"""
tex_files = list(Path(directory).rglob("*.tex"))
if not tex_files:
return None
# 按优先级查找
for tex_file in tex_files:
content = self.read_file(str(tex_file))
if content:
if r'\documentclass' in content:
return str(tex_file)
if tex_file.name.lower() == 'main.tex':
return str(tex_file)
# 返回最大的tex文件
return str(max(tex_files, key=lambda x: x.stat().st_size))
def resolve_includes(self, tex_file: str, processed: Set[str] = None) -> List[str]:
"""
解析TeX文件中的include引用
Args:
tex_file: TeX文件路径
processed: 已处理的文件集合
Returns:
List[str]: 相关文件路径列表
"""
if processed is None:
processed = set()
if tex_file in processed:
return []
processed.add(tex_file)
result = [tex_file]
content = self.read_file(tex_file)
if not content:
return result
base_dir = Path(tex_file).parent
for pattern in self.include_patterns:
for match in re.finditer(pattern, content):
included_file = match.group(2)
if not included_file.endswith('.tex'):
included_file += '.tex'
full_path = str(base_dir / included_file)
if os.path.exists(full_path) and full_path not in processed:
result.extend(self.resolve_includes(full_path, processed))
return result
def resolve_references(self, tex_file: str, path_dir: str = None) -> str:
"""
解析TeX文件中的参考文献引用返回所有引用文献的内容只保留title、author和journal字段。
如果在tex_file目录下没找到bib文件会在path_dir中查找。
Args:
tex_file: TeX文件路径
path_dir: 额外的参考文献搜索路径
Returns:
str: 所有参考文献内容的字符串,只包含特定字段,不同参考文献之间用空行分隔
"""
all_references = [] # 存储所有参考文献内容
content = self.read_file(tex_file)
if not content:
return ""
# 扩展参考文献引用的模式
bib_patterns = [
r'\\bibliography\{([^}]+)\}',
r'\\addbibresource\{([^}]+)\}',
r'\\bibliographyfile\{([^}]+)\}',
r'\\begin\{thebibliography\}',
r'\\bibinput\{([^}]+)\}',
r'\\newrefsection\{([^}]+)\}'
]
base_dir = Path(tex_file).parent
found_in_tex_dir = False
# 首先在tex文件目录下查找显式引用的bib文件
for pattern in bib_patterns:
for match in re.finditer(pattern, content):
if not match.groups():
continue
bib_files = match.group(1).split(',')
for bib_file in bib_files:
bib_file = bib_file.strip()
if not bib_file.endswith('.bib'):
bib_file += '.bib'
full_path = str(base_dir / bib_file)
if os.path.exists(full_path):
found_in_tex_dir = True
bib_content = self.read_file(full_path)
if bib_content:
processed_refs = self._process_bib_content(bib_content)
all_references.extend(processed_refs)
# 如果在tex文件目录下没找到bib文件且提供了额外搜索路径
if not found_in_tex_dir and path_dir:
search_dir = Path(path_dir)
try:
for bib_path in search_dir.glob('**/*.bib'):
bib_content = self.read_file(str(bib_path))
if bib_content:
processed_refs = self._process_bib_content(bib_content)
all_references.extend(processed_refs)
except Exception as e:
print(f"Error searching in path_dir: {e}")
# 合并所有参考文献内容,用空行分隔
return "\n\n".join(all_references)
def _process_bib_content(self, content: str) -> List[str]:
"""
处理bib文件内容提取每个参考文献的特定字段
Args:
content: bib文件内容
Returns:
List[str]: 处理后的参考文献列表
"""
processed_refs = []
# 匹配完整的参考文献条目
ref_pattern = r'@\w+\{[^@]*\}'
# 匹配参考文献类型和键值
entry_start_pattern = r'@(\w+)\{([^,]*?),'
# 匹配字段
field_pattern = r'(\w+)\s*=\s*\{([^}]*)\}'
# 查找所有参考文献条目
for ref_match in re.finditer(ref_pattern, content, re.DOTALL):
ref_content = ref_match.group(0)
# 获取参考文献类型和键值
entry_match = re.match(entry_start_pattern, ref_content)
if not entry_match:
continue
entry_type, cite_key = entry_match.groups()
# 提取需要的字段
needed_fields = {'title': None, 'author': None, 'journal': None}
for field_match in re.finditer(field_pattern, ref_content):
field_name, field_value = field_match.groups()
field_name = field_name.lower()
if field_name in needed_fields:
needed_fields[field_name] = field_value.strip()
# 构建新的参考文献条目
if any(needed_fields.values()): # 如果至少有一个需要的字段
ref_lines = [f"@{entry_type}{{{cite_key},"]
for field_name, field_value in needed_fields.items():
if field_value:
ref_lines.append(f" {field_name}={{{field_value}}},")
ref_lines[-1] = ref_lines[-1][:-1] # 移除最后一个逗号
ref_lines.append("}")
processed_refs.append("\n".join(ref_lines))
return processed_refs
def _extract_inline_references(self, content: str) -> str:
"""
从tex文件内容中提取直接写在文件中的参考文献
Args:
content: tex文件内容
Returns:
str: 提取的参考文献内容,如果没有找到则返回空字符串
"""
# 查找参考文献环境
bib_start = r'\\begin\{thebibliography\}'
bib_end = r'\\end\{thebibliography\}'
start_match = re.search(bib_start, content)
end_match = re.search(bib_end, content)
if start_match and end_match:
return content[start_match.start():end_match.end()]
return ""
def _preprocess_content(self, content: str) -> str:
"""预处理TeX内容"""
# 移除注释
content = re.sub(r'(?m)%.*$', '', content)
# 规范化空白字符
# content = re.sub(r'\s+', ' ', content)
content = re.sub(r'\n\s*\n', '\n\n', content)
return content.strip()

View File

@@ -1,10 +1,10 @@
import atexit
from loguru import logger
from typing import List
import os
from llama_index.core import Document
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import TextNode
from loguru import logger
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
@@ -59,7 +59,7 @@ class SaveLoad():
def purge(self):
import shutil
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
self.vs_index = self.create_new_vs(self.checkpoint_dir)
self.vs_index = self.create_new_vs()
class LlamaIndexRagWorker(SaveLoad):
@@ -68,11 +68,60 @@ class LlamaIndexRagWorker(SaveLoad):
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
# 确保checkpoint_dir存在
if checkpoint_dir:
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"Initializing LlamaIndexRagWorker with checkpoint_dir: {checkpoint_dir}")
# 初始化向量存储
if auto_load_checkpoint and self.does_checkpoint_exist():
logger.info("Loading existing vector store from checkpoint")
self.vs_index = self.load_from_checkpoint()
else:
logger.info("Creating new vector store")
self.vs_index = self.create_new_vs()
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
# 注册退出时保存
atexit.register(self.save_to_checkpoint)
def add_text_to_vector_store(self, text: str) -> None:
"""添加文本到向量存储"""
try:
logger.info(f"Adding text to vector store (first 100 chars): {text[:100]}...")
node = TextNode(text=text)
nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(nodes)
# 立即保存
self.save_to_checkpoint()
if self.debug_mode:
self.inspect_vector_store()
except Exception as e:
logger.error(f"Error adding text to vector store: {str(e)}")
raise
def save_to_checkpoint(self, checkpoint_dir=None):
"""保存向量存储到检查点"""
try:
if checkpoint_dir is None:
checkpoint_dir = self.checkpoint_dir
logger.info(f'Saving vector store to: {checkpoint_dir}')
if checkpoint_dir:
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
logger.info('Vector store saved successfully')
else:
logger.warning('No checkpoint directory specified, skipping save')
except Exception as e:
logger.error(f"Error saving checkpoint: {str(e)}")
raise
def assign_embedding_model(self):
pass
@@ -81,44 +130,28 @@ class LlamaIndexRagWorker(SaveLoad):
# This function is for debugging
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
logger.info('\n++ --------inspect_vector_store begin--------')
logger.info(vector_store_preview)
logger.info('oo --------inspect_vector_store end--------')
return vector_store_preview
def add_documents_to_vector_store(self, document_list: List[Document]):
"""
Adds a list of Document objects to the vector store after processing.
"""
documents = document_list
def add_documents_to_vector_store(self, document_list):
documents = [Document(text=t) for t in document_list]
documents_nodes = run_transformations(
documents, # type: ignore
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
def add_text_to_vector_store(self, text: str):
node = TextNode(text=text)
documents_nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
if self.debug_mode: self.inspect_vector_store()
def remember_qa(self, question, answer):
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
self.add_text_to_vector_store(formatted_str)
def retrieve_from_store_with_query(self, query):
if self.debug_mode:
self.inspect_vector_store()
if self.debug_mode: self.inspect_vector_store()
retriever = self.vs_index.as_retriever()
return retriever.retrieve(query)
@@ -127,12 +160,6 @@ class LlamaIndexRagWorker(SaveLoad):
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
def generate_node_array_preview(self, nodes):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
buf = "\n".join(([f"(No.{i + 1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: logger.info(buf)
return buf
def purge_vector_store(self):
"""
Purges the current vector store and creates a new one.
"""
self.purge()

View File

@@ -1,20 +1,14 @@
import llama_index
import os
import atexit
import os
from typing import List
from loguru import logger
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from loguru import logger
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
@@ -65,17 +59,19 @@ class MilvusSaveLoad():
def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
embed_model=self.embed_model)
return index
def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
@@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(

View File

@@ -1,22 +1,47 @@
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx']
supports_format = ['.csv', '.docx', '.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx', '.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml', '.m']
def read_docx_doc(file_path):
if file_path.split(".")[-1] == "docx":
from docx import Document
doc = Document(file_path)
file_content = "\n".join([para.text for para in doc.paragraphs])
else:
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.visible = False
# 打开文件
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
# file_content = doc.Content.Text
doc = word.ActiveDocument
file_content = doc.Range().Text
doc.Close()
word.Quit()
except:
raise RuntimeError('请先将.doc文档转换为.docx文档。')
return file_content
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
import os
def extract_text(file_path):
_, ext = os.path.splitext(file_path.lower())
# 使用 SimpleDirectoryReader 处理它支持的文件格式
if ext in supports_format:
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
if ext in ['.docx', '.doc']:
return read_docx_doc(file_path)
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
return None

View File

@@ -1,6 +1,6 @@
from llama_index.core import VectorStoreIndex
from typing import Any, List, Optional
from typing import Any, List, Optional
from llama_index.core import VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.schema import TransformComponent
from llama_index.core.service_context import ServiceContext
@@ -13,18 +13,18 @@ from llama_index.core.storage.storage_context import StorageContext
class GptacVectorStoreIndex(VectorStoreIndex):
@classmethod
def default_vector_store(
cls,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
embed_model = None,
**kwargs: Any,
cls,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
embed_model=None,
**kwargs: Any,
):
"""Create index from documents.
@@ -36,15 +36,14 @@ class GptacVectorStoreIndex(VectorStoreIndex):
storage_context = storage_context or StorageContext.from_defaults()
docstore = storage_context.docstore
callback_manager = (
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
)
transformations = transformations or transformations_from_settings_or_context(
Settings, service_context
)
with callback_manager.as_trace("index_construction"):
return cls(
nodes=[],
storage_context=storage_context,
@@ -55,4 +54,3 @@ class GptacVectorStoreIndex(VectorStoreIndex):
embed_model=embed_model,
**kwargs,
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,129 +0,0 @@
from toolbox import update_ui
from toolbox import CatchException, report_exception
from toolbox import write_history_to_file, promote_file_to_downloadzone
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
fast_debug = False
def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
import time, os
# pip install python-docx 用于docx格式跨平台
# pip install pywin32 用于doc格式仅支持Win平台
for index, fp in enumerate(file_manifest):
if fp.split(".")[-1] == "docx":
from docx import Document
doc = Document(fp)
file_content = "\n".join([para.text for para in doc.paragraphs])
else:
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.visible = False
# 打开文件
doc = word.Documents.Open(os.getcwd() + '/' + fp)
# file_content = doc.Content.Text
doc = word.ActiveDocument
file_content = doc.Range().Text
doc.Close()
word.Quit()
except:
raise RuntimeError('请先将.doc文档转换为.docx文档。')
# private_upload里面的文件名在解压zip后容易出现乱码rar和7z格式正常故可以只分析文章内容不输入文件名
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from request_llms.bridge_all import model_info
max_token = model_info[llm_kwargs['llm_model']]['max_token']
TOKEN_LIMIT_PER_FRAGMENT = max_token * 3 // 4
paper_fragments = breakdown_text_to_satisfy_token_limit(txt=file_content, limit=TOKEN_LIMIT_PER_FRAGMENT, llm_model=llm_kwargs['llm_model'])
this_paper_history = []
for i, paper_frag in enumerate(paper_fragments):
i_say = f'请对下面的文章片段用中文做概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{paper_frag}```'
i_say_show_user = f'请对下面的文章片段做概述: {os.path.abspath(fp)}的第{i+1}/{len(paper_fragments)}个片段。'
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say,
inputs_show_user=i_say_show_user,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=[],
sys_prompt="总结文章。"
)
chatbot[-1] = (i_say_show_user, gpt_say)
history.extend([i_say_show_user,gpt_say])
this_paper_history.extend([i_say_show_user,gpt_say])
# 已经对该文章的所有片段总结完毕,如果文章被切分了,
if len(paper_fragments) > 1:
i_say = f"根据以上的对话,总结文章{os.path.abspath(fp)}的主要内容。"
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say,
inputs_show_user=i_say,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=this_paper_history,
sys_prompt="总结文章。"
)
history.extend([i_say,gpt_say])
this_paper_history.extend([i_say,gpt_say])
res = write_history_to_file(history)
promote_file_to_downloadzone(res, chatbot=chatbot)
chatbot.append(("完成了吗?", res))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
res = write_history_to_file(history)
promote_file_to_downloadzone(res, chatbot=chatbot)
chatbot.append(("所有文件都总结完成了吗?", res))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
@CatchException
def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
import glob, os
# 基本信息:功能、贡献者
chatbot.append([
"函数插件功能?",
"批量总结Word文档。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 尝试导入依赖,如果缺少依赖,则给出安装建议
try:
from docx import Document
except:
report_exception(chatbot, history,
a=f"解析项目: {txt}",
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade python-docx pywin32```。")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 清空历史,以免输入溢出
history = []
# 检测输入参数,如没有给定输入参数,直接退出
if os.path.exists(txt):
project_folder = txt
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 搜索需要处理的文件清单
if txt.endswith('.docx') or txt.endswith('.doc'):
file_manifest = [txt]
else:
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.docx', recursive=True)] + \
[f for f in glob.glob(f'{project_folder}/**/*.doc', recursive=True)]
# 如果没找到任何文件
if len(file_manifest) == 0:
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.docx或doc文件: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 开始正式执行任务
yield from 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)

View File

@@ -0,0 +1,496 @@
import os
import threading
import time
from dataclasses import dataclass
from typing import List, Tuple, Dict, Generator
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.rag_fns.rag_file_support import extract_text
from request_llms.bridge_all import model_info
from toolbox import update_ui, CatchException, report_exception
@dataclass
class FileFragment:
"""文件片段数据类,用于组织处理单元"""
file_path: str
content: str
rel_path: str
fragment_index: int
total_fragments: int
class BatchDocumentSummarizer:
"""优化的文档总结器 - 批处理版本"""
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
"""初始化总结器"""
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.chatbot = chatbot
self.history = history
self.system_prompt = system_prompt
self.failed_files = []
self.file_summaries_map = {}
def _get_token_limit(self) -> int:
"""获取模型token限制"""
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
return max_token * 3 // 4
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
"""创建批处理输入"""
inputs_array = []
inputs_show_user_array = []
history_array = []
for frag in fragments:
if self.plugin_kwargs.get("advanced_arg"):
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)}'
f'用户要求为:{self.plugin_kwargs["advanced_arg"]}'
f'文件内容是 ```{frag.content}```')
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
else:
i_say = (f'请对下面的内容用中文做总结不超过500字文件名是{os.path.basename(frag.file_path)}'
f'内容是 ```{frag.content}```')
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
return inputs_array, inputs_show_user_array, history_array
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
"""包装了超时控制的文件处理函数"""
def timeout_handler():
thread = threading.current_thread()
if hasattr(thread, '_timeout_occurred'):
thread._timeout_occurred = True
# 设置超时标记
thread = threading.current_thread()
thread._timeout_occurred = False
# 设置超时定时器
timer = threading.Timer(self.watch_dog_patience, timeout_handler)
timer.start()
try:
fp, project_folder = file_info
fragments = []
# 定期检查是否超时
def check_timeout():
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
raise TimeoutError("处理超时")
# 更新状态
mutable_status[0] = "检查文件大小"
mutable_status[1] = time.time()
check_timeout()
# 文件大小检查
if os.path.getsize(fp) > self.max_file_size:
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
mutable_status[2] = "文件过大"
return fragments
check_timeout()
# 更新状态
mutable_status[0] = "提取文件内容"
mutable_status[1] = time.time()
# 提取内容
content = extract_text(fp)
if content is None:
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
mutable_status[2] = "格式不支持"
return fragments
elif not content.strip():
self.failed_files.append((fp, "文件内容为空"))
mutable_status[2] = "内容为空"
return fragments
check_timeout()
# 更新状态
mutable_status[0] = "分割文本"
mutable_status[1] = time.time()
# 分割文本
try:
paper_fragments = breakdown_text_to_satisfy_token_limit(
txt=content,
limit=self._get_token_limit(),
llm_model=self.llm_kwargs['llm_model']
)
except Exception as e:
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
mutable_status[2] = "分割失败"
return fragments
check_timeout()
# 处理片段
rel_path = os.path.relpath(fp, project_folder)
for i, frag in enumerate(paper_fragments):
if frag.strip():
fragments.append(FileFragment(
file_path=fp,
content=frag,
rel_path=rel_path,
fragment_index=i,
total_fragments=len(paper_fragments)
))
mutable_status[2] = "处理完成"
return fragments
except TimeoutError as e:
self.failed_files.append((fp, "处理超时"))
mutable_status[2] = "处理超时"
return []
except Exception as e:
self.failed_files.append((fp, f"处理失败:{str(e)}"))
mutable_status[2] = "处理异常"
return []
finally:
timer.cancel()
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Generator, List
"""并行准备所有文件的处理片段"""
all_fragments = []
total_files = len(file_paths)
# 配置参数
self.refresh_interval = 0.2 # UI刷新间隔
self.watch_dog_patience = 5 # 看门狗超时时间
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
self.max_workers = min(32, len(file_paths)) # 最多32个线程
# 创建有超时控制的线程池
executor = ThreadPoolExecutor(max_workers=self.max_workers)
# 用于跨线程状态传递的可变列表 - 增加文件名信息
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
# 创建文件处理任务
file_infos = [(fp, project_folder) for fp in file_paths]
# 提交所有任务,使用带超时控制的处理函数
futures = [
executor.submit(
self._process_single_file_with_timeout,
file_info,
mutable_status_array[i]
) for i, file_info in enumerate(file_infos)
]
# 更新UI的计数器
cnt = 0
try:
# 监控任务执行
while True:
time.sleep(self.refresh_interval)
cnt += 1
# 检查任务完成状态
worker_done = [f.done() for f in futures]
# 更新状态显示
status_str = ""
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
# 获取文件名(去掉路径)
file_name = os.path.basename(file_path)
if worker_done[i]:
status_str += f"文件 {file_name}: {desc}\n"
else:
status_str += f"文件 {file_name}: {status} {desc}\n"
# 更新UI
self.chatbot[-1] = [
"处理进度",
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
]
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 检查是否所有任务完成
if all(worker_done):
break
finally:
# 确保线程池正确关闭
executor.shutdown(wait=False)
# 收集结果
processed_files = 0
for future in futures:
try:
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
all_fragments.extend(fragments)
processed_files += 1
except concurrent.futures.TimeoutError:
# 处理获取结果超时
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], "结果获取超时"))
continue
except Exception as e:
# 处理其他异常
file_index = futures.index(future)
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
continue
# 最终进度更新
self.chatbot.append([
"文件处理完成",
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
])
yield from update_ui(chatbot=self.chatbot, history=self.history)
return all_fragments
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
"""批量处理文件片段"""
from collections import defaultdict
batch_size = 64 # 每批处理的片段数
max_retries = 3 # 最大重试次数
retry_delay = 5 # 重试延迟(秒)
results = defaultdict(list)
# 按批次处理
for i in range(0, len(fragments), batch_size):
batch = fragments[i:i + batch_size]
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
sys_prompt_array = ["请总结以下内容:"] * len(batch)
# 添加重试机制
for retry in range(max_retries):
try:
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=history_array,
sys_prompt_array=sys_prompt_array,
)
# 处理响应
for j, frag in enumerate(batch):
summary = response_collection[j * 2 + 1]
if summary and summary.strip():
results[frag.rel_path].append({
'index': frag.fragment_index,
'summary': summary,
'total': frag.total_fragments
})
break # 成功处理,跳出重试循环
except Exception as e:
if retry == max_retries - 1: # 最后一次重试失败
for frag in batch:
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
else:
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
time.sleep(retry_delay)
return results
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
"""准备最终总结请求"""
if not self.file_summaries_map:
return (["无可用的文件总结"], ["生成最终总结"], [[]])
summaries = list(self.file_summaries_map.values())
if all(not summary for summary in summaries):
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
if self.plugin_kwargs.get("advanced_arg"):
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
else:
i_say = "请根据以上所有文件的处理结果生成最终的总结不超过1000字。"
return ([i_say], [i_say], [summaries])
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
"""处理所有文件"""
total_files = len(file_paths)
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 1. 准备所有文件片段
# 在 process_files 函数中:
fragments = yield from self.prepare_fragments(project_folder, file_paths)
if not fragments:
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
return "没有可处理的文件内容"
# 2. 批量处理所有文件片段
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
try:
file_summaries = yield from self._process_fragments_batch(fragments)
except Exception as e:
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
return "处理过程发生错误"
# 3. 为每个文件生成整体总结
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
yield from update_ui(chatbot=self.chatbot, history=self.history)
# 处理每个文件的总结
for rel_path, summaries in file_summaries.items():
if len(summaries) > 1: # 多片段文件需要生成整体总结
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
if self.plugin_kwargs.get("advanced_arg"):
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}'
else:
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容不超过500字。"
try:
summary_texts = [s['summary'] for s in sorted_summaries]
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[i_say],
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[summary_texts],
sys_prompt_array=["你是一个优秀的助手,"],
)
self.file_summaries_map[rel_path] = response_collection[1]
except Exception as e:
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
self.file_summaries_map[rel_path] = "总结生成失败"
else: # 单片段文件直接使用其唯一的总结
self.file_summaries_map[rel_path] = summaries[0]['summary']
# 4. 生成最终总结
if total_files ==1:
return "文件数为1此时不调用总结模块"
else:
try:
# 收集所有文件的总结用于生成最终总结
file_summaries_for_final = []
for rel_path, summary in self.file_summaries_map.items():
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
if self.plugin_kwargs.get("advanced_arg"):
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
self.plugin_kwargs['advanced_arg'])
else:
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=[final_summary_prompt],
inputs_show_user_array=["生成最终总结报告"],
llm_kwargs=self.llm_kwargs,
chatbot=self.chatbot,
history_array=[file_summaries_for_final],
sys_prompt_array=["总结所有文件内容。"],
max_workers=1
)
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
except Exception as e:
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
return "生成总结失败"
def save_results(self, final_summary: str):
"""保存结果到文件"""
from toolbox import promote_file_to_downloadzone, write_history_to_file
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
import os
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 创建各种格式化器
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
result_files = []
# 保存 Markdown
md_content = md_formatter.create_document()
result_file_md = write_history_to_file(
history=[md_content], # 直接传入内容列表
file_basename=f"文档总结_{timestamp}.md"
)
result_files.append(result_file_md)
# 保存 HTML
html_content = html_formatter.create_document()
result_file_html = write_history_to_file(
history=[html_content],
file_basename=f"文档总结_{timestamp}.html"
)
result_files.append(result_file_html)
# 保存 Word
doc = word_formatter.create_document()
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
result_file_docx = os.path.join(
os.path.dirname(result_file_md),
f"文档总结_{timestamp}.docx"
)
doc.save(result_file_docx)
result_files.append(result_file_docx)
# 添加到下载区
for file in result_files:
promote_file_to_downloadzone(file, chatbot=self.chatbot)
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
@CatchException
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数 - 优化版本"""
# 初始化
import glob
import re
from crazy_functions.rag_fns.rag_file_support import supports_format
from toolbox import report_exception
summarizer = BatchDocumentSummarizer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
chatbot.append(["函数插件功能", f"作者lbykkkk批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件如果长时间卡在文件处理过程请查看处理进度然后删除所有处于“pending”状态的文件然后重新上传处理。"])
yield from update_ui(chatbot=chatbot, history=history)
# 验证输入路径
if not os.path.exists(txt):
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history)
return
# 获取文件列表
project_folder = txt
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
if not file_manifest:
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
yield from update_ui(chatbot=chatbot, history=history)
return
# 处理所有文件并生成总结
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
yield from update_ui(chatbot=chatbot, history=history)
# 保存结果
summarizer.save_results(final_summary)
yield from update_ui(chatbot=chatbot, history=history)

View File

@@ -61,7 +61,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
history = []
from crazy_functions.crazy_utils import get_files_from_everything
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf', chatbot=chatbot)
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
if len(file_manifest) > 0:
# 尝试导入依赖,如果缺少依赖,则给出安装建议
try:
@@ -73,7 +73,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade nougat-ocr tiktoken```。")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd', chatbot=chatbot)
success_mmd, file_manifest_mmd, _ = get_files_from_everything(txt, type='.mmd')
success = success or success_mmd
file_manifest += file_manifest_mmd
chatbot.append(["文件列表:", ", ".join([e.split('/')[-1] for e in file_manifest])]);

View File

@@ -87,8 +87,6 @@ def 理解PDF文档内容标准文件输入(txt, llm_kwargs, plugin_kwargs, chat
# 检测输入参数,如没有给定输入参数,直接退出
if os.path.exists(txt):
project_folder = txt
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "":
txt = '空空如也的输入栏'

View File

@@ -39,8 +39,6 @@ def 批量生成函数注释(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
import glob, os
if os.path.exists(txt):
project_folder = txt
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")

View File

@@ -49,7 +49,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
file_manifest = []
spl = ["txt", "doc", "docx", "email", "epub", "html", "json", "md", "msg", "pdf", "ppt", "pptx", "rtf"]
for sp in spl:
_, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}', chatbot=chatbot)
_, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}')
file_manifest += file_manifest_tmp
if len(file_manifest) == 0:

View File

@@ -126,8 +126,6 @@ def 解析ipynb文件(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_p
import os
if os.path.exists(txt):
project_folder = txt
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "":
txt = '空空如也的输入栏'

View File

@@ -48,8 +48,6 @@ def 读文章写摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
import glob, os
if os.path.exists(txt):
project_folder = txt
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")

View File

@@ -5,10 +5,6 @@ FROM fuqingxu/11.3.1-runtime-ubuntu20.04-with-texlive:latest
# edge-tts需要的依赖某些pip包所需的依赖
RUN apt update && apt install ffmpeg build-essential -y
RUN apt-get install -y fontconfig
RUN ln -s /usr/local/texlive/2023/texmf-dist/fonts/truetype /usr/share/fonts/truetype/texlive
RUN fc-cache -fv
RUN apt-get clean
# use python3 as the system default python
WORKDIR /gpt
@@ -34,7 +30,7 @@ RUN python3 -m pip install -r request_llms/requirements_qwen.txt
RUN python3 -m pip install -r request_llms/requirements_chatglm.txt
RUN python3 -m pip install -r request_llms/requirements_newbing.txt
RUN python3 -m pip install nougat-ocr
RUN python3 -m pip cache purge
# 预热Tiktoken模块
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'

View File

@@ -7,7 +7,6 @@ RUN apt-get install -y git python python3 python-dev python3-dev --fix-missing
# edge-tts需要的依赖某些pip包所需的依赖
RUN apt update && apt install ffmpeg build-essential -y
RUN apt-get clean
# use python3 as the system default python
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.8
@@ -23,7 +22,6 @@ RUN python3 -m pip install -r request_llms/requirements_moss.txt
RUN python3 -m pip install -r request_llms/requirements_qwen.txt
RUN python3 -m pip install -r request_llms/requirements_chatglm.txt
RUN python3 -m pip install -r request_llms/requirements_newbing.txt
RUN python3 -m pip cache purge
# 预热Tiktoken模块

View File

@@ -18,7 +18,5 @@ RUN apt update && apt install ffmpeg -y
# 可选步骤,用于预热模块
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
RUN python3 -m pip cache purge && apt-get clean
# 启动
CMD ["python3", "-u", "main.py"]

View File

@@ -30,7 +30,5 @@ COPY --chown=gptuser:gptuser . .
# 可选步骤,用于预热模块
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
RUN python3 -m pip cache purge
# 启动
CMD ["python3", "-u", "main.py"]

View File

@@ -24,8 +24,6 @@ RUN apt update && apt install ffmpeg -y
# 可选步骤,用于预热模块
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
RUN python3 -m pip cache purge && apt-get clean
# 启动
CMD ["python3", "-u", "main.py"]

View File

@@ -1,26 +0,0 @@
@echo off
setlocal
:: 设置环境变量
set ENV_NAME=gpt
set ENV_PATH=%~dp0%ENV_NAME%
set SCRIPT_PATH=%~dp0main.py
:: 判断环境是否已解压
if not exist "%ENV_PATH%" (
echo Extracting environment...
mkdir "%ENV_PATH%"
tar -xzf gpt.tar.gz -C "%ENV_PATH%"
:: 运行conda环境激活脚本
call "%ENV_PATH%\Scripts\activate.bat"
) else (
:: 如果环境已存在,直接激活
call "%ENV_PATH%\Scripts\activate.bat"
)
echo Start to run program:
:: 运行Python脚本
python "%SCRIPT_PATH%"
endlocal
pause

216
instruction.txt Normal file
View File

@@ -0,0 +1,216 @@
1、GPT Academic 项目结构
.
├── Dockerfile
├── LICENSE
├── README.md
├── check_proxy.py
├── config.py
├── config_private.py
├── core_functional.py
├── crazy_functional.py
├── crazy_functions
│ ├── Arxiv_论文对话.py
│ ├── Conversation_To_File.py
│ ├── Image_Generate.py
│ ├── Image_Generate_Wrap.py
│ ├── Internet_GPT.py
│ ├── Internet_GPT_Wrap.py
│ ├── Latex_Function.py
│ ├── Latex_Function_Wrap.py
│ ├── Latex全文润色.py
│ ├── Latex全文翻译.py
│ ├── Markdown_Translate.py
│ ├── PDF_Translate.py
│ ├── PDF_Translate_Wrap.py
│ ├── Rag_Interface.py
│ ├── Social_Helper.py
│ ├── SourceCode_Analyse.py
│ ├── SourceCode_Comment.py
│ ├── SourceCode_Comment_Wrap.py
│ ├── __init__.py
│ │ ├── auto_agent.py
│ │ ├── echo_agent.py
│ │ ├── general.py
│ │ ├── persistent.py
│ │ ├── pipe.py
│ │ ├── python_comment_agent.py
│ │ ├── python_comment_compare.html
│ │ └── watchdog.py
│ ├── ast_fns
│ │ └── comment_remove.py
│ ├── chatglm微调工具.py
│ ├── crazy_utils.py
│ ├── diagram_fns
│ │ └── file_tree.py
│ ├── game_fns
│ │ ├── game_ascii_art.py
│ │ ├── game_interactive_story.py
│ │ └── game_utils.py
│ ├── gen_fns
│ │ └── gen_fns_shared.py
│ ├── ipc_fns
│ │ └── mp.py
│ ├── json_fns
│ │ ├── pydantic_io.py
│ │ └── select_tool.py
│ ├── latex_fns
│ │ ├── latex_actions.py
│ │ ├── latex_pickle_io.py
│ │ └── latex_toolbox.py
│ ├── live_audio
│ │ ├── aliyunASR.py
│ │ └── audio_io.py
│ ├── multi_stage
│ │ └── multi_stage_utils.py
│ ├── rag_essay_fns
│ │ └── multi_stage_utils.py
│ ├── pdf_fns
│ │ ├── breakdown_txt.py
│ │ ├── parse_pdf.py
│ │ ├── parse_pdf_grobid.py
│ │ ├── parse_pdf_legacy.py
│ │ ├── parse_pdf_via_doc2x.py
│ │ ├── parse_word.py
│ │ ├── report_gen_html.py
│ │ ├── report_template.html
│ │ └── report_template_v2.html
│ ├── plugin_template
│ │ └── plugin_class_template.py
│ ├── prompts
│ │ └── internet.py
│ ├── rag_fns
│ │ ├── llama_index_worker.py
│ │ ├── milvus_worker.py
│ │ ├── rag_file_support.py
│ │ └── vector_store_index.py
│ ├── vector_fns
│ │ ├── __init__.py
│ │ ├── general_file_loader.py
│ │ └── vector_database.py
│ ├── vt_fns
│ │ ├── vt_call_plugin.py
│ │ ├── vt_modify_config.py
│ │ └── vt_state.py
│ ├── 下载arxiv论文翻译摘要.py
│ ├── 互动小游戏.py
│ ├── 交互功能函数模板.py
│ ├── 函数动态生成.py
│ ├── 命令行助手.py
│ ├── 多智能体.py
│ ├── 总结word文档.py
│ ├── 总结音视频.py
│ ├── 批量总结PDF文档.py
│ ├── 批量总结PDF文档pdfminer.py
│ ├── 批量文件询问.py
│ ├── 批量翻译PDF文档_NOUGAT.py
│ ├── 数学动画生成manim.py
│ ├── 理解PDF文档内容.py
│ ├── 生成函数注释.py
│ ├── 生成多种Mermaid图表.py
│ ├── 知识库问答.py
│ ├── 联网的ChatGPT.py
│ ├── 联网的ChatGPT_bing版.py
│ ├── 虚空终端.py
│ ├── 解析JupyterNotebook.py
│ ├── 询问多个大语言模型.py
│ ├── 语音助手.py
│ ├── 读文章写摘要.py
│ ├── 谷歌检索小助手.py
│ ├── 辅助功能.py
│ └── 高级功能函数模板.py
├── docker-compose.yml
├── instruction.txt
├── main.py
├── multi_language.py
├── requirements.txt
├── shared_utils
│ ├── advanced_markdown_format.py
│ ├── char_visual_effect.py
│ ├── colorful.py
│ ├── config_loader.py
│ ├── connect_void_terminal.py
│ ├── cookie_manager.py
│ ├── fastapi_server.py
│ ├── handle_upload.py
│ ├── key_pattern_manager.py
│ ├── logging.py
│ ├── map_names.py
│ └── text_mask.py
├── toolbox.py
└── version
2、light_rag的实现方案路径为crazy_functions/rag_fns/LightRAG主要功能实现文件为operate.pyrag使用到的其他文件为prompt.py、base.py、storage.py、utils.py请参考实现方案实现插件功能。light_rag的使用案例可以参考crazy_functions/rag_fns/LightRAG/examples路径下的lightrag_hf_demo.py、lightrag_lmdeploy_demo.py
路径目录结构为
├── README.md
├── examples
│   ├── batch_eval.py
│   ├── generate_query.py
│   ├── graph_visual_with_html.py
│   ├── graph_visual_with_neo4j.py
│   ├── lightrag_azure_openai_demo.py
│   ├── lightrag_bedrock_demo.py
│   ├── lightrag_hf_demo.py
│   ├── lightrag_ollama_demo.py
│   ├── lightrag_openai_compatible_demo.py
│   ├── lightrag_openai_demo.py
│   └── vram_management_demo.py
├── lightrag
│   ├── __init__.py
│   ├── base.py
│   ├── lightrag.py
│   ├── llm.py
│   ├── operate.py
│   ├── prompt.py
│   ├── storage.py
│   └── utils.py
├── reproduce
│   ├── Step_0.py
│   ├── Step_1.py
│   ├── Step_1_openai_compatible.py
│   ├── Step_2.py
│   ├── Step_3.py
│   └── Step_3_openai_compatible.py
├── requirements.txt
└── setup.py
3、我需要开发一个rag插件请帮我实现一个插件插件的名称是rag论文总结插件主入口在crazy_functions/Arxiv_论文对话.py中的Rag论文对话函数插件的功能步骤分为文件处理和RAG两个步骤,以下是具体的一些要求:
I. 函数头如下:
@CatchException
def rag论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
II. 函数返回可参考crazy_functions/批量文件询问.py中的“批量文件询问”函数主要采用yield方式
3、对于RAG我希望采用light_rag的方案参考已有方案其主要的功能实现是
主要功能包括:
a. 分别为project和arxiv创建rag_handlerproject类的fragment类内容为
@dataclass
class DocFragment:
"""文本片段数据类"""
file_path: str # 原始文件路径
content: str # 片段内容
segment_index: int # 片段序号
total_segments: int # 总片段数
rel_path: str # 相对路径
arxiv的fragment内容为
@dataclass
class ArxivFragment:
"""Arxiv论文片段数据类"""
file_path: str
content: str
segment_index: int
total_segments: int
rel_path: str
segment_type: str
title: str
abstract: str
section: str
is_appendix: bool
b 如果目录下不存在抽取好的实体或关系的摘要,利用`_handle_entity_relation_summary`函数对d步骤生成的文本块进行实体或关系的摘要并将其存储在project或者arxiv的路径下,路径为获取fragment.file_path的前三级目录按照“/”区分每一级),如果原目录存在抽取好的,请直接使用,不再重复抽取。
f 利用`_handle_single_entity_extraction` 和 `_handle_single_relationship_extraction`:从记录中提取单个实体或关系信息。
g `_merge_nodes_then_upsert` 和 `_merge_edges_then_upsert`:合并并插入节点或边。
h `extract_entities`:处理多个文本块,提取实体和关系,并存储在知识图谱和向量数据库中。
i `local_query`:根据查询提取关键词并生成响应。

38
main.py
View File

@@ -1,4 +1,4 @@
import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
import os, json; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
help_menu_description = \
"""Github源代码开源和更新[地址🚀](https://github.com/binary-husky/gpt_academic),
@@ -34,9 +34,9 @@ def encode_plugin_info(k, plugin)->str:
def main():
import gradio as gr
if gr.__version__ not in ['3.32.12']:
if gr.__version__ not in ['3.32.9', '3.32.10', '3.32.11']:
raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.")
# 一些基础工具
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
@@ -49,7 +49,7 @@ def main():
# 读取配置
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = get_conf('CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
ENABLE_AUDIO, AUTO_CLEAR_TXT, AVAIL_FONTS, AVAIL_THEMES, THEME, ADD_WAIFU = get_conf('ENABLE_AUDIO', 'AUTO_CLEAR_TXT', 'AVAIL_FONTS', 'AVAIL_THEMES', 'THEME', 'ADD_WAIFU')
ENABLE_AUDIO, AUTO_CLEAR_TXT, PATH_LOGGING, AVAIL_THEMES, THEME, ADD_WAIFU = get_conf('ENABLE_AUDIO', 'AUTO_CLEAR_TXT', 'PATH_LOGGING', 'AVAIL_THEMES', 'THEME', 'ADD_WAIFU')
NUM_CUSTOM_BASIC_BTN, SSL_KEYFILE, SSL_CERTFILE = get_conf('NUM_CUSTOM_BASIC_BTN', 'SSL_KEYFILE', 'SSL_CERTFILE')
DARK_MODE, INIT_SYS_PROMPT, ADD_WAIFU, TTS_TYPE = get_conf('DARK_MODE', 'INIT_SYS_PROMPT', 'ADD_WAIFU', 'TTS_TYPE')
if LLM_MODEL not in AVAIL_LLM_MODELS: AVAIL_LLM_MODELS += [LLM_MODEL]
@@ -57,8 +57,8 @@ def main():
# 如果WEB_PORT是-1, 则随机选取WEB端口
PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
from check_proxy import get_current_version
from themes.theme import adjust_theme, advanced_css, theme_declaration, js_code_clear, js_code_show_or_hide, js_code_show_or_hide_group2
from themes.theme import js_code_for_toggle_darkmode
from themes.theme import adjust_theme, advanced_css, theme_declaration, js_code_clear, js_code_reset, js_code_show_or_hide, js_code_show_or_hide_group2
from themes.theme import js_code_for_toggle_darkmode, js_code_for_persistent_cookie_init
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
@@ -68,7 +68,7 @@ def main():
functional = get_core_functions()
# 高级函数插件
from crazy_functional import get_crazy_functions, get_multiplex_button_functions
from crazy_functional import get_crazy_functions
DEFAULT_FN_GROUPS = get_conf('DEFAULT_FN_GROUPS')
plugins = get_crazy_functions()
all_plugin_groups = list(set([g for _, plugin in plugins.items() for g in plugin['Group'].split('|')]))
@@ -106,7 +106,7 @@ def main():
with gr_L2(scale=2, elem_id="gpt-chat"):
chatbot = gr.Chatbot(label=f"当前模型:{LLM_MODEL}", elem_id="gpt-chatbot")
if LAYOUT == "TOP-DOWN": chatbot.style(height=CHATBOT_HEIGHT)
history, _, _ = make_history_cache() # 定义 后端statehistory、前端history_cache、后端setterhistory_cache_update三兄弟
history, history_cache, history_cache_update = make_history_cache() # 定义 后端statehistory、前端history_cache、后端setterhistory_cache_update三兄弟
with gr_L2(scale=1, elem_id="gpt-panel"):
with gr.Accordion("输入区", open=True, elem_id="input-panel") as area_input_primary:
with gr.Row():
@@ -114,7 +114,12 @@ def main():
with gr.Row(elem_id="gpt-submit-row"):
multiplex_submit_btn = gr.Button("提交", elem_id="elem_submit_visible", variant="primary")
multiplex_sel = gr.Dropdown(
choices=get_multiplex_button_functions().keys(), value="常规对话",
choices=[
"常规对话",
"多模型对话",
"智能召回 RAG",
# "智能上下文",
], value="常规对话",
interactive=True, label='', show_label=False,
elem_classes='normal_mut_select', elem_id="gpt-submit-dropdown").style(container=False)
submit_btn = gr.Button("提交", elem_id="elem_submit", variant="primary", visible=False)
@@ -174,19 +179,15 @@ def main():
with gr.Accordion("点击展开“文件下载区”。", open=False) as area_file_up:
file_upload = gr.Files(label="任何文件, 推荐上传压缩文件(zip, tar)", file_count="multiple", elem_id="elem_upload")
# 左上角工具栏定义
from themes.gui_toolbar import define_gui_toolbar
checkboxes, checkboxes_2, max_length_sl, theme_dropdown, system_prompt, file_upload_2, md_dropdown, top_p, temperature = \
define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAIL_THEMES, AVAIL_FONTS, ADD_WAIFU, help_menu_description, js_code_for_toggle_darkmode)
define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAIL_THEMES, ADD_WAIFU, help_menu_description, js_code_for_toggle_darkmode)
# 浮动菜单定义
from themes.gui_floating_menu import define_gui_floating_menu
area_input_secondary, txt2, area_customize, _, resetBtn2, clearBtn2, stopBtn2 = \
define_gui_floating_menu(customize_btns, functional, predefined_btns, cookies, web_cookie_cache)
# 浮动时间线定义
gr.Spark()
# 插件二级菜单的实现
from themes.gui_advanced_plugin_class import define_gui_advanced_plugin_class
@@ -226,8 +227,11 @@ def main():
multiplex_sel.select(
None, [multiplex_sel], None, _js=f"""(multiplex_sel)=>run_multiplex_shift(multiplex_sel)""")
cancel_handles.append(submit_btn.click(**predict_args))
resetBtn.click(None, None, [chatbot, history, status], _js= """clear_conversation""") # 先在前端快速清除chatbot&status
resetBtn2.click(None, None, [chatbot, history, status], _js="""clear_conversation""") # 先在前端快速清除chatbot&status
resetBtn.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
resetBtn2.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
reset_server_side_args = (lambda history: ([], [], "已重置", json.dumps(history)), [history], [chatbot, history, status, history_cache])
resetBtn.click(*reset_server_side_args) # 再在后端清除history把history转存history_cache备用
resetBtn2.click(*reset_server_side_args) # 再在后端清除history把history转存history_cache备用
clearBtn.click(None, None, [txt, txt2], _js=js_code_clear)
clearBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
if AUTO_CLEAR_TXT:
@@ -327,7 +331,7 @@ def main():
from shared_utils.cookie_manager import load_web_cookie_cache__fn_builder
load_web_cookie_cache = load_web_cookie_cache__fn_builder(customize_btns, cookies, predefined_btns)
app_block.load(load_web_cookie_cache, inputs = [web_cookie_cache, cookies],
outputs = [web_cookie_cache, cookies, *customize_btns.values(), *predefined_btns.values()], _js="""persistent_cookie_init""")
outputs = [web_cookie_cache, cookies, *customize_btns.values(), *predefined_btns.values()], _js=js_code_for_persistent_cookie_init)
app_block.load(None, inputs=[], outputs=None, _js=f"""()=>GptAcademicJavaScriptInit("{DARK_MODE}","{INIT_SYS_PROMPT}","{ADD_WAIFU}","{LAYOUT}","{TTS_TYPE}")""") # 配置暗色主题或亮色主题
app_block.load(None, inputs=[], outputs=None, _js="""()=>{REP}""".replace("REP", register_advanced_plugin_init_arr))

View File

@@ -26,9 +26,6 @@ from .bridge_chatglm import predict as chatglm_ui
from .bridge_chatglm3 import predict_no_ui_long_connection as chatglm3_noui
from .bridge_chatglm3 import predict as chatglm3_ui
from .bridge_chatglm4 import predict_no_ui_long_connection as chatglm4_noui
from .bridge_chatglm4 import predict as chatglm4_ui
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
from .bridge_qianfan import predict as qianfan_ui
@@ -79,7 +76,6 @@ cohere_endpoint = "https://api.cohere.ai/v1/chat"
ollama_endpoint = "http://localhost:11434/api/chat"
yimodel_endpoint = "https://api.lingyiwanwu.com/v1/chat/completions"
deepseekapi_endpoint = "https://api.deepseek.com/v1/chat/completions"
grok_model_endpoint = "https://api.x.ai/v1/chat/completions"
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
@@ -101,7 +97,6 @@ if cohere_endpoint in API_URL_REDIRECT: cohere_endpoint = API_URL_REDIRECT[coher
if ollama_endpoint in API_URL_REDIRECT: ollama_endpoint = API_URL_REDIRECT[ollama_endpoint]
if yimodel_endpoint in API_URL_REDIRECT: yimodel_endpoint = API_URL_REDIRECT[yimodel_endpoint]
if deepseekapi_endpoint in API_URL_REDIRECT: deepseekapi_endpoint = API_URL_REDIRECT[deepseekapi_endpoint]
if grok_model_endpoint in API_URL_REDIRECT: grok_model_endpoint = API_URL_REDIRECT[grok_model_endpoint]
# 获取tokenizer
tokenizer_gpt35 = LazyloadTiktoken("gpt-3.5-turbo")
@@ -217,16 +212,6 @@ model_info = {
"token_cnt": get_token_num_gpt4,
},
"chatgpt-4o-latest": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": openai_endpoint,
"has_multimodal_capacity": True,
"max_token": 128000,
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
},
"gpt-4o-2024-05-13": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
@@ -273,9 +258,7 @@ model_info = {
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
"openai_force_temperature_one": True,
},
"o1-mini": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
@@ -285,31 +268,6 @@ model_info = {
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
"openai_force_temperature_one": True,
},
"o1-2024-12-17": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": openai_endpoint,
"max_token": 200000,
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
"openai_force_temperature_one": True,
},
"o1": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": openai_endpoint,
"max_token": 200000,
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
"openai_disable_system_prompt": True,
"openai_disable_stream": True,
"openai_force_temperature_one": True,
},
"gpt-4-turbo": {
@@ -446,7 +404,6 @@ model_info = {
"token_cnt": get_token_num_gpt4,
},
# ChatGLM本地模型
# 将 chatglm 直接对齐到 chatglm2
"chatglm": {
"fn_with_ui": chatglm_ui,
@@ -472,14 +429,6 @@ model_info = {
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"chatglm4": {
"fn_with_ui": chatglm4_ui,
"fn_without_ui": chatglm4_noui,
"endpoint": None,
"max_token": 8192,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"qianfan": {
"fn_with_ui": qianfan_ui,
"fn_without_ui": qianfan_noui,
@@ -812,8 +761,7 @@ if "qwen-local" in AVAIL_LLM_MODELS:
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 通义-在线模型 -=-=-=-=-=-=-
qwen_models = ["qwen-max-latest", "qwen-max-2025-01-25","qwen-max","qwen-turbo","qwen-plus"]
if any(item in qwen_models for item in AVAIL_LLM_MODELS):
if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-max" in AVAIL_LLM_MODELS: # zhipuai
try:
from .bridge_qwen import predict_no_ui_long_connection as qwen_noui
from .bridge_qwen import predict as qwen_ui
@@ -823,7 +771,7 @@ if any(item in qwen_models for item in AVAIL_LLM_MODELS):
"fn_without_ui": qwen_noui,
"can_multi_thread": True,
"endpoint": None,
"max_token": 100000,
"max_token": 6144,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -832,7 +780,7 @@ if any(item in qwen_models for item in AVAIL_LLM_MODELS):
"fn_without_ui": qwen_noui,
"can_multi_thread": True,
"endpoint": None,
"max_token": 129024,
"max_token": 30720,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -841,25 +789,7 @@ if any(item in qwen_models for item in AVAIL_LLM_MODELS):
"fn_without_ui": qwen_noui,
"can_multi_thread": True,
"endpoint": None,
"max_token": 30720,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"qwen-max-latest": {
"fn_with_ui": qwen_ui,
"fn_without_ui": qwen_noui,
"can_multi_thread": True,
"endpoint": None,
"max_token": 30720,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"qwen-max-2025-01-25": {
"fn_with_ui": qwen_ui,
"fn_without_ui": qwen_noui,
"can_multi_thread": True,
"endpoint": None,
"max_token": 30720,
"max_token": 28672,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
}
@@ -946,31 +876,6 @@ if any(item in yi_models for item in AVAIL_LLM_MODELS):
})
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- Grok model from x.ai -=-=-=-=-=-=-
grok_models = ["grok-beta"]
if any(item in grok_models for item in AVAIL_LLM_MODELS):
try:
grok_beta_128k_noui, grok_beta_128k_ui = get_predict_function(
api_key_conf_name="GROK_API_KEY", max_output_token=8192, disable_proxy=False
)
model_info.update({
"grok-beta": {
"fn_with_ui": grok_beta_128k_ui,
"fn_without_ui": grok_beta_128k_noui,
"can_multi_thread": True,
"endpoint": grok_model_endpoint,
"max_token": 128000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
})
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 讯飞星火认知大模型 -=-=-=-=-=-=-
if "spark" in AVAIL_LLM_MODELS:
try:
@@ -1090,18 +995,18 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
try:
deepseekapi_noui, deepseekapi_ui = get_predict_function(
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
)
)
model_info.update({
"deepseek-chat":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 64000,
"max_token": 32000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -1114,16 +1019,6 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS o
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"deepseek-reasoner":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
"enable_reasoning": True
},
})
except:
logger.error(trimmed_format_exc())
@@ -1391,11 +1286,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot,
inputs = apply_gpt_academic_string_mask(inputs, mode="show_llm")
if llm_kwargs['llm_model'] not in model_info:
from toolbox import update_ui
chatbot.append([inputs, f"很抱歉,模型 '{llm_kwargs['llm_model']}' 暂不支持<br/>(1) 检查config中的AVAIL_LLM_MODELS选项<br/>(2) 检查request_llms/bridge_all.py中的模型路由"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错检查config中的AVAIL_LLM_MODELS选项
if additional_fn: # 根据基础功能区 ModelOverride 参数调整模型类型

View File

@@ -23,33 +23,39 @@ class GetGLM3Handle(LocalLLMHandle):
import os
import platform
LOCAL_MODEL_PATH, LOCAL_MODEL_QUANT, device = get_conf("CHATGLM_LOCAL_MODEL_PATH", "LOCAL_MODEL_QUANT", "LOCAL_MODEL_DEVICE")
model_path = LOCAL_MODEL_PATH
LOCAL_MODEL_QUANT, device = get_conf("LOCAL_MODEL_QUANT", "LOCAL_MODEL_DEVICE")
_model_name_ = "THUDM/chatglm3-6b"
# if LOCAL_MODEL_QUANT == "INT4": # INT4
# _model_name_ = "THUDM/chatglm3-6b-int4"
# elif LOCAL_MODEL_QUANT == "INT8": # INT8
# _model_name_ = "THUDM/chatglm3-6b-int8"
# else:
# _model_name_ = "THUDM/chatglm3-6b" # FP16
with ProxyNetworkActivate("Download_LLM"):
chatglm_tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
_model_name_, trust_remote_code=True
)
if device == "cpu":
chatglm_model = AutoModel.from_pretrained(
model_path,
_model_name_,
trust_remote_code=True,
device="cpu",
).float()
elif LOCAL_MODEL_QUANT == "INT4": # INT4
chatglm_model = AutoModel.from_pretrained(
pretrained_model_name_or_path=model_path,
pretrained_model_name_or_path=_model_name_,
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
elif LOCAL_MODEL_QUANT == "INT8": # INT8
chatglm_model = AutoModel.from_pretrained(
pretrained_model_name_or_path=model_path,
pretrained_model_name_or_path=_model_name_,
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
else:
chatglm_model = AutoModel.from_pretrained(
pretrained_model_name_or_path=model_path,
pretrained_model_name_or_path=_model_name_,
trust_remote_code=True,
device="cuda",
)

View File

@@ -1,81 +0,0 @@
model_name = "ChatGLM4"
cmd_to_install = """
`pip install -r request_llms/requirements_chatglm4.txt`
`pip install modelscope`
`modelscope download --model ZhipuAI/glm-4-9b-chat --local_dir ./THUDM/glm-4-9b-chat`
"""
from toolbox import get_conf, ProxyNetworkActivate
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
# ------------------------------------------------------------------------------------------------------------------------
# 🔌💻 Local Model
# ------------------------------------------------------------------------------------------------------------------------
class GetGLM4Handle(LocalLLMHandle):
def load_model_info(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
self.model_name = model_name
self.cmd_to_install = cmd_to_install
def load_model_and_tokenizer(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
import os
LOCAL_MODEL_PATH, device = get_conf("CHATGLM_LOCAL_MODEL_PATH", "LOCAL_MODEL_DEVICE")
model_path = LOCAL_MODEL_PATH
chatglm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
chatglm_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device=device
).eval().to(device)
self._model = chatglm_model
self._tokenizer = chatglm_tokenizer
return self._model, self._tokenizer
def llm_stream_generator(self, **kwargs):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
def adaptor(kwargs):
query = kwargs["query"]
max_length = kwargs["max_length"]
top_p = kwargs["top_p"]
temperature = kwargs["temperature"]
history = kwargs["history"]
return query, max_length, top_p, temperature, history
query, max_length, top_p, temperature, history = adaptor(kwargs)
inputs = self._tokenizer.apply_chat_template([{"role": "user", "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(self._model.device)
gen_kwargs = {"max_length": max_length, "do_sample": True, "top_k": top_p}
outputs = self._model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
yield response
def try_to_import_special_deps(self, **kwargs):
# import something that will raise error if the user does not install requirement_*.txt
# 🏃‍♂️🏃‍♂️🏃‍♂️ 主进程执行
import importlib
# importlib.import_module('modelscope')
# ------------------------------------------------------------------------------------------------------------------------
# 🔌💻 GPT-Academic Interface
# ------------------------------------------------------------------------------------------------------------------------
predict_no_ui_long_connection, predict = get_local_llm_predict_fns(
GetGLM4Handle, model_name, history_format="chatglm3"
)

View File

@@ -23,13 +23,8 @@ from loguru import logger
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
from toolbox import trimmed_format_exc, is_the_upload_folder, read_one_api_model_name, log_chat
from toolbox import ChatBotWithCookies, have_any_recent_upload_image_files, encode_image
proxies, WHEN_TO_USE_PROXY, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
get_conf('proxies', 'WHEN_TO_USE_PROXY', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
if "Connect_OpenAI" not in WHEN_TO_USE_PROXY:
if proxies is not None:
logger.error("虽然您配置了代理设置但不会在连接OpenAI的过程中起作用请检查WHEN_TO_USE_PROXY配置。")
proxies = None
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
@@ -185,20 +180,14 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
raise ConnectionAbortedError("正常结束但显示Token不足导致输出不完整请削减单次输入的文本量。")
else:
raise RuntimeError("OpenAI拒绝了请求" + error_msg)
if ('data: [DONE]' in chunk_decoded): break # api2d & one-api 正常完成
if ('data: [DONE]' in chunk_decoded): break # api2d 正常完成
# 提前读取一些信息 (用于判断异常)
if has_choices and not choice_valid:
# 一些垃圾第三方接口的出现这样的错误
continue
json_data = chunkjson['choices'][0]
delta = json_data["delta"]
if len(delta) == 0:
is_termination_certain = False
if (has_choices) and (chunkjson['choices'][0].get('finish_reason', 'null') == 'stop'): is_termination_certain = True
if is_termination_certain: break
else: continue # 对于不符合规范的狗屎接口,这里需要继续
if len(delta) == 0: break
if (not has_content) and has_role: continue
if (not has_content) and (not has_role): continue # raise RuntimeError("发现不标准的第三方接口:"+delta)
if has_content: # has_role = True/False
@@ -296,8 +285,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
history.extend([inputs, ""])
retry = 0
previous_ui_reflesh_time = 0
ui_reflesh_min_interval = 0.0
while True:
try:
# make a POST request to the API endpoint, stream=True
@@ -310,13 +297,13 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
if retry > MAX_RETRY: raise TimeoutError
if not stream:
# 该分支仅适用于不支持stream的o1模型其他情形一律不适用
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
return
if stream:
reach_termination = False # 处理一些 new-api 的奇葩异常
gpt_replying_buffer = ""
is_head_of_the_stream = True
stream_response = response.iter_lines()
@@ -329,14 +316,11 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
error_msg = chunk_decoded
# 首先排除一个one-api没有done数据包的第三方Bug情形
if len(gpt_replying_buffer.strip()) > 0 and len(error_msg) == 0:
yield from update_ui(chatbot=chatbot, history=history, msg="检测到有缺陷的接口,建议选择更稳定的接口。")
if not reach_termination:
reach_termination = True
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
yield from update_ui(chatbot=chatbot, history=history, msg="检测到有缺陷的非OpenAI官方接口,建议选择更稳定的接口。")
break
# 其他情况,直接返回报错
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
yield from update_ui(chatbot=chatbot, history=history, msg="接口返回了错误:" + chunk.decode()) # 刷新界面
yield from update_ui(chatbot=chatbot, history=history, msg="非OpenAI官方接口返回了错误:" + chunk.decode()) # 刷新界面
return
# 提前读取一些信息 (用于判断异常)
@@ -346,8 +330,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
# 数据流的第一帧不携带content
is_head_of_the_stream = False; continue
if "error" in chunk_decoded: logger.error(f"接口返回了未知错误: {chunk_decoded}")
if chunk:
try:
if has_choices and not choice_valid:
@@ -356,25 +338,14 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
if ('data: [DONE]' not in chunk_decoded) and len(chunk_decoded) > 0 and (chunkjson is None):
# 传递进来一些奇怪的东西
raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}')
# 前者是API2D & One-API的结束条件后者是OPENAI的结束条件
one_api_terminate = ('data: [DONE]' in chunk_decoded)
openai_terminate = (has_choices) and (len(chunkjson['choices'][0]["delta"]) == 0)
if one_api_terminate or openai_terminate:
is_termination_certain = False
if one_api_terminate: is_termination_certain = True # 抓取符合规范的结束条件
elif (has_choices) and (chunkjson['choices'][0].get('finish_reason', 'null') == 'stop'): is_termination_certain = True # 抓取符合规范的结束条件
if is_termination_certain:
reach_termination = True
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
break # 对于符合规范的接口这里可以break
else:
continue # 对于不符合规范的狗屎接口,这里需要继续
# 到这里我们已经可以假定必须包含choice了
try:
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
except:
logger.error(f"一些垃圾第三方接口出现这样的错误,兼容一下吧: {chunk_decoded}")
# 前者是API2D的结束条件后者是OPENAI的结束条件
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
# 判定为数据流的结束gpt_replying_buffer也写完了
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
break
# 处理数据流的主体
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
if has_content:
# 正常情况
gpt_replying_buffer = gpt_replying_buffer + chunkjson['choices'][0]["delta"]["content"]
@@ -383,26 +354,21 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
continue
else:
# 至此已经超出了正常接口应该进入的范围,一些垃圾第三方接口会出现这样的错误
if chunkjson['choices'][0]["delta"].get("content", None) is None:
logger.error(f"一些垃圾第三方接口出现这样的错误,兼容一下吧: {chunk_decoded}")
continue
if chunkjson['choices'][0]["delta"]["content"] is None: continue # 一些垃圾第三方接口出现这样的错误,兼容一下吧
gpt_replying_buffer = gpt_replying_buffer + chunkjson['choices'][0]["delta"]["content"]
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
if time.time() - previous_ui_reflesh_time > ui_reflesh_min_interval:
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
previous_ui_reflesh_time = time.time()
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
except Exception as e:
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
chunk = get_full_error(chunk, stream_response)
chunk_decoded = chunk.decode()
error_msg = chunk_decoded
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
logger.error(error_msg)
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
logger.error(error_msg)
return
yield from update_ui(chatbot=chatbot, history=history, msg="完成") # 刷新界面
return # return from stream-branch
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
@@ -570,8 +536,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
"n": 1,
"stream": stream,
}
openai_force_temperature_one = model_info[llm_kwargs['llm_model']].get('openai_force_temperature_one', False)
if openai_force_temperature_one:
payload.pop('temperature')
return headers,payload

View File

@@ -26,7 +26,7 @@ class GetLlamaHandle(LocalLLMHandle):
import platform
huggingface_token, device = get_conf('HUGGINGFACE_ACCESS_TOKEN', 'LOCAL_MODEL_DEVICE')
assert len(huggingface_token) != 0, "没有填写 HUGGINGFACE_ACCESS_TOKEN"
with open(os.path.expanduser('~/.cache/huggingface/token'), 'w', encoding='utf8') as f:
with open(os.path.expanduser('~/.cache/huggingface/token'), 'w') as f:
f.write(huggingface_token)
model_id = 'meta-llama/Llama-2-7b-chat-hf'
with ProxyNetworkActivate('Download_LLM'):

View File

@@ -31,7 +31,7 @@ class MoonShotInit:
files.append(f)
for file in files:
if file.split('.')[-1] in ['pdf']:
with open(file, 'r', encoding='utf8') as fp:
with open(file, 'r') as fp:
from crazy_functions.crazy_utils import read_and_clean_pdf_text
file_content, _ = read_and_clean_pdf_text(fp)
what_ask.append({"role": "system", "content": file_content})

View File

@@ -75,7 +75,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
# make a POST request to the API endpoint, stream=False
from .bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
response = requests.post(endpoint, headers=headers, proxies=None,
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
retry += 1
@@ -152,12 +152,10 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
history.append(inputs); history.append("")
retry = 0
if proxies is not None:
logger.error("Ollama不会使用代理服务器, 忽略了proxies的设置。")
while True:
try:
# make a POST request to the API endpoint, stream=True
response = requests.post(endpoint, headers=headers, proxies=None,
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
except:
retry += 1

View File

@@ -170,7 +170,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
chunk_decoded, chunkjson, has_choices, choice_valid, has_content, has_role = decode_chunk(chunk)
if len(chunk_decoded)==0 or chunk_decoded.startswith(':'): continue
if len(chunk_decoded)==0: continue
if not chunk_decoded.startswith('data:'):
error_msg = get_full_error(chunk, stream_response).decode()
if "reduce the length" in error_msg:
@@ -181,6 +181,9 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
raise RuntimeError("OpenAI拒绝了请求" + error_msg)
if ('data: [DONE]' in chunk_decoded): break # api2d 正常完成
# 提前读取一些信息 (用于判断异常)
if (has_choices and not choice_valid) or ('OPENROUTER PROCESSING' in chunk_decoded):
# 一些垃圾第三方接口的出现这样的错误openrouter的特殊处理
continue
json_data = chunkjson['choices'][0]
delta = json_data["delta"]
if len(delta) == 0: break
@@ -325,7 +328,8 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
if chunk:
try:
if (has_choices and not choice_valid) or chunk_decoded.startswith(':'):
if (has_choices and not choice_valid) or ('OPENROUTER PROCESSING' in chunk_decoded):
# 一些垃圾第三方接口的出现这样的错误, 或者OPENROUTER的特殊处理,因为OPENROUTER的数据流未连接到模型时会出现OPENROUTER PROCESSING
continue
if ('data: [DONE]' not in chunk_decoded) and len(chunk_decoded) > 0 and (chunkjson is None):
# 传递进来一些奇怪的东西

View File

@@ -202,29 +202,16 @@ class GoogleChatInit:
) # 处理 history
messages.append(self.__conversation_user(inputs, llm_kwargs, enable_multimodal_capacity)) # 处理用户对话
stop_sequences = str(llm_kwargs.get("stop", "")).split(" ")
# 过滤空字符串并确保至少有一个停止序列
stop_sequences = [s for s in stop_sequences if s]
if not stop_sequences:
payload = {
"contents": messages,
"generationConfig": {
"temperature": llm_kwargs.get("temperature", 1),
"topP": llm_kwargs.get("top_p", 0.8),
"topK": 10,
},
}
else:
payload = {
"contents": messages,
"generationConfig": {
# "maxOutputTokens": llm_kwargs.get("max_token", 1024),
"stopSequences": stop_sequences,
"temperature": llm_kwargs.get("temperature", 1),
"topP": llm_kwargs.get("top_p", 0.8),
"topK": 10,
},
}
payload = {
"contents": messages,
"generationConfig": {
# "maxOutputTokens": llm_kwargs.get("max_token", 1024),
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
"temperature": llm_kwargs.get("temperature", 1),
"topP": llm_kwargs.get("top_p", 0.8),
"topK": 10,
},
}
return header, payload

View File

@@ -24,13 +24,18 @@ class QwenRequestInstance():
def generate(self, inputs, llm_kwargs, history, system_prompt):
# import _thread as thread
from dashscope import Generation
QWEN_MODEL = {
'qwen-turbo': Generation.Models.qwen_turbo,
'qwen-plus': Generation.Models.qwen_plus,
'qwen-max': Generation.Models.qwen_max,
}[llm_kwargs['llm_model']]
top_p = llm_kwargs.get('top_p', 0.8)
if top_p == 0: top_p += 1e-5
if top_p == 1: top_p -= 1e-5
self.result_buf = ""
responses = Generation.call(
model=llm_kwargs['llm_model'],
model=QWEN_MODEL,
messages=generate_message_payload(inputs, llm_kwargs, history, system_prompt),
top_p=top_p,
temperature=llm_kwargs.get('temperature', 1.0),

View File

@@ -1,4 +1,3 @@
from llama_index.embeddings.openai import OpenAIEmbedding
from openai import OpenAI
from toolbox import get_conf
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate

View File

@@ -36,11 +36,10 @@ def get_full_error(chunk, stream_response):
def decode_chunk(chunk):
"""
用于解读"content""finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
用于解读"content""finish_reason"的内容
"""
chunk = chunk.decode()
respose = ""
reasoning_content = ""
finish_reason = "False"
try:
chunk = json.loads(chunk[6:])
@@ -58,20 +57,14 @@ def decode_chunk(chunk):
return respose, finish_reason
try:
if chunk["choices"][0]["delta"]["content"] is not None:
respose = chunk["choices"][0]["delta"]["content"]
except:
pass
try:
if chunk["choices"][0]["delta"]["reasoning_content"] is not None:
reasoning_content = chunk["choices"][0]["delta"]["reasoning_content"]
respose = chunk["choices"][0]["delta"]["content"]
except:
pass
try:
finish_reason = chunk["choices"][0]["finish_reason"]
except:
pass
return respose, reasoning_content, finish_reason
return respose, finish_reason
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
@@ -156,7 +149,6 @@ def get_predict_function(
observe_window = None
用于负责跨越线程传递已经输出的部分大部分时候仅仅为了fancy的视觉效果留空即可。observe_window[0]观测窗。observe_window[1]:看门狗
"""
from .bridge_all import model_info
watch_dog_patience = 5 # 看门狗的耐心设置5秒不准咬人(咬的也不是人
if len(APIKEY) == 0:
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
@@ -171,21 +163,29 @@ def get_predict_function(
system_prompt=sys_prompt,
temperature=llm_kwargs["temperature"],
)
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
retry = 0
while True:
try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
@@ -194,13 +194,10 @@ def get_predict_function(
raise TimeoutError
if MAX_RETRY != 0:
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
stream_response = response.iter_lines()
result = ""
finish_reason = ""
if reasoning:
resoning_buffer = ""
stream_response = response.iter_lines()
while True:
try:
chunk = next(stream_response)
@@ -210,9 +207,9 @@ def get_predict_function(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
response_text, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
if response_text == "" and finish_reason != "False":
continue
if response_text == "API_ERROR" and (
finish_reason != "False" or finish_reason != "stop"
@@ -230,8 +227,6 @@ def get_predict_function(
print(f"[response] {result}")
break
result += response_text
if reasoning:
resoning_buffer += reasoning_content
if observe_window is not None:
# 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1:
@@ -246,10 +241,6 @@ def get_predict_function(
error_msg = chunk_decoded
logger.error(error_msg)
raise RuntimeError("Json解析不合常规")
if reasoning:
# reasoning 的部分加上框 (>)
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + \
'\n\n' + result
return result
def predict(
@@ -271,7 +262,6 @@ def get_predict_function(
chatbot 为WebUI中显示的对话列表修改它然后yeild出去可以直接修改对话界面内容
additional_fn代表点击的哪个按钮按钮见functional.py
"""
from .bridge_all import model_info
if len(APIKEY) == 0:
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
if inputs == "":
@@ -308,23 +298,32 @@ def get_predict_function(
system_prompt=system_prompt,
temperature=llm_kwargs["temperature"],
)
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
history.append(inputs)
history.append("")
retry = 0
while True:
try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
@@ -339,8 +338,6 @@ def get_predict_function(
raise TimeoutError
gpt_replying_buffer = ""
if reasoning:
gpt_reasoning_buffer = ""
stream_response = response.iter_lines()
while True:
@@ -350,9 +347,9 @@ def get_predict_function(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
response_text, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
if response_text == "" and finish_reason != "False":
status_text = f"finish_reason: {finish_reason}"
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text
@@ -382,14 +379,9 @@ def get_predict_function(
logger.info(f"[response] {gpt_replying_buffer}")
break
status_text = f"finish_reason: {finish_reason}"
if reasoning:
gpt_replying_buffer += response_text
gpt_reasoning_buffer += reasoning_content
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
else:
gpt_replying_buffer += response_text
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
history[-1] = gpt_replying_buffer
gpt_replying_buffer += response_text
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text

View File

@@ -1,7 +0,0 @@
protobuf
cpm_kernels
torch>=1.10
transformers>=4.44
mdtex2html
sentencepiece
accelerate

View File

@@ -1,4 +1,4 @@
https://public.agent-matrix.com/publish/gradio-3.32.12-py3-none-any.whl
https://public.agent-matrix.com/publish/gradio-3.32.10-py3-none-any.whl
fastapi==0.110
gradio-client==0.8
pypdf2==2.12.1
@@ -12,12 +12,14 @@ transformers>=4.27.1,<4.42
scipdf_parser>=0.52
spacy==3.7.4
anthropic>=0.18.1
sentence-transformers
python-markdown-math
pymdown-extensions>=10.14
pymdown-extensions
websocket-client
beautifulsoup4
prompt_toolkit
latex2mathml
scikit-learn
python-docx
mdtex2html
dashscope
@@ -25,7 +27,7 @@ pyautogen
colorama
Markdown
pygments
edge-tts>=7.0.0
edge-tts
pymupdf
openai
rjsmin
@@ -43,4 +45,4 @@ llama-index-embeddings-azure-openai==0.1.10
llama-index-embeddings-openai==0.1.10
llama-parse==0.4.9
mdit-py-plugins>=0.3.3
linkify-it-py==2.0.3
linkify-it-py==2.0.3

View File

@@ -2,7 +2,6 @@ import markdown
import re
import os
import math
import html
from loguru import logger
from textwrap import dedent
@@ -385,24 +384,6 @@ def markdown_convertion(txt):
)
def code_block_title_replace_format(match):
lang = match.group(1)
filename = match.group(2)
return f"```{lang} {{title=\"{filename}\"}}\n"
def get_last_backticks_indent(text):
# 从后向前查找最后一个 ```
lines = text.splitlines()
for line in reversed(lines):
if '```' in line:
# 计算前面的空格数量
indent = len(line) - len(line.lstrip())
return indent
return 0 # 如果没找到返回0
@lru_cache(maxsize=16) # 使用lru缓存
def close_up_code_segment_during_stream(gpt_reply):
"""
在gpt输出代码的中途输出了前面的```,但还没输出完后面的```),补上后面的```
@@ -416,12 +397,6 @@ def close_up_code_segment_during_stream(gpt_reply):
"""
if "```" not in gpt_reply:
return gpt_reply
# replace [```python:warp.py] to [```python {title="warp.py"}]
pattern = re.compile(r"```([a-z]{1,12}):([^:\n]{1,35}\.([a-zA-Z^:\n]{1,3}))\n")
if pattern.search(gpt_reply):
gpt_reply = pattern.sub(code_block_title_replace_format, gpt_reply)
if gpt_reply.endswith("```"):
return gpt_reply
@@ -429,11 +404,7 @@ def close_up_code_segment_during_stream(gpt_reply):
segments = gpt_reply.split("```")
n_mark = len(segments) - 1
if n_mark % 2 == 1:
try:
num_padding = get_last_backticks_indent(gpt_reply)
except:
num_padding = 0
return gpt_reply + "\n" + " "*num_padding + "```" # 输出代码片段中!
return gpt_reply + "\n```" # 输出代码片段中!
else:
return gpt_reply
@@ -450,19 +421,6 @@ def special_render_issues_for_mermaid(text):
return text
def contain_html_tag(text):
"""
判断文本中是否包含HTML标签。
"""
pattern = r'</?([a-zA-Z0-9_]{3,16})>|<script\s+[^>]*src=["\']([^"\']+)["\'][^>]*>'
return re.search(pattern, text) is not None
def contain_image(text):
pattern = r'<br/><br/><div align="center"><img src="file=(.*?)" base64="(.*?)"></div>'
return re.search(pattern, text) is not None
def compat_non_markdown_input(text):
"""
改善非markdown输入的显示效果例如将空格转换为&nbsp;,将换行符转换为</br>等。
@@ -471,13 +429,9 @@ def compat_non_markdown_input(text):
# careful inputmarkdown输入
text = special_render_issues_for_mermaid(text) # 处理特殊的渲染问题
return text
elif ("<" in text) and (">" in text) and contain_html_tag(text):
elif "</div>" in text:
# careful inputhtml输入
if contain_image(text):
return text
else:
escaped_text = html.escape(text)
return escaped_text
return text
else:
# whatever input非markdown输入
lines = text.split("\n")

View File

@@ -77,28 +77,16 @@ def make_history_cache():
# 定义 后端statehistory、前端history_cache、后端setterhistory_cache_update三兄弟
import gradio as gr
# 定义history的后端state
# history = gr.State([])
history = gr.Textbox(visible=False, elem_id="history-ng")
# # 定义history的一个孪生的前端存储区隐藏
# history_cache = gr.Textbox(visible=False, elem_id="history_cache")
# # 定义history_cache->history的更新方法隐藏。在触发这个按钮时会先执行js代码更新history_cache然后再执行python代码更新history
# def process_history_cache(history_cache):
# return json.loads(history_cache)
# # 另一种更简单的setter方法
# history_cache_update = gr.Button("", elem_id="elem_update_history", visible=False).click(
# process_history_cache, inputs=[history_cache], outputs=[history])
# # save history to history_cache
# def process_history_cache(history_cache):
# return json.dumps(history_cache)
# # 定义history->history_cache的更新方法隐藏
# def sync_history_cache(history):
# print("sync_history_cache", history)
# return json.dumps(history)
# # history.change(sync_history_cache, inputs=[history], outputs=[history_cache])
# # history_cache_sync = gr.Button("", elem_id="elem_sync_history", visible=False).click(
# # lambda history: (json.dumps(history)), inputs=[history_cache], outputs=[history])
return history, None, None
history = gr.State([])
# 定义history的一个孪生的前端存储区隐藏
history_cache = gr.Textbox(visible=False, elem_id="history_cache")
# 定义history_cache->history的更新方法隐藏。在触发这个按钮时会先执行js代码更新history_cache然后再执行python代码更新history
def process_history_cache(history_cache):
return json.loads(history_cache)
# 另一种更简单的setter方法
history_cache_update = gr.Button("", elem_id="elem_update_history", visible=False).click(
process_history_cache, inputs=[history_cache], outputs=[history])
return history, history_cache, history_cache_update

View File

@@ -1,83 +0,0 @@
import requests
import pickle
import io
import os
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
from loguru import logger
class DockerServiceApiComModel(BaseModel):
client_command: Optional[str] = Field(default=None, title="Client command", description="The command to be executed on the client side")
client_file_attach: Optional[dict] = Field(default=None, title="Client file attach", description="The file to be attached to the client side")
server_message: Optional[Any] = Field(default=None, title="Server standard error", description="The standard error from the server side")
server_std_err: Optional[str] = Field(default=None, title="Server standard error", description="The standard error from the server side")
server_std_out: Optional[str] = Field(default=None, title="Server standard output", description="The standard output from the server side")
server_file_attach: Optional[dict] = Field(default=None, title="Server file attach", description="The file to be attached to the server side")
def process_received(received: DockerServiceApiComModel, save_file_dir="./daas_output", output_manifest=None):
# Process the received data
if received.server_message:
try:
output_manifest['server_message'] += received.server_message
except:
output_manifest['server_message'] = received.server_message
if received.server_std_err:
output_manifest['server_std_err'] += received.server_std_err
if received.server_std_out:
output_manifest['server_std_out'] += received.server_std_out
if received.server_file_attach:
# print(f"Recv file attach: {received.server_file_attach}")
for file_name, file_content in received.server_file_attach.items():
new_fp = os.path.join(save_file_dir, file_name)
new_fp_dir = os.path.dirname(new_fp)
if not os.path.exists(new_fp_dir):
os.makedirs(new_fp_dir, exist_ok=True)
with open(new_fp, 'wb') as f:
f.write(file_content)
output_manifest['server_file_attach'].append(new_fp)
return output_manifest
def stream_daas(docker_service_api_com_model, server_url, save_file_dir):
# Prepare the file
# Pickle the object
pickled_data = pickle.dumps(docker_service_api_com_model)
# Create a file-like object from the pickled data
file_obj = io.BytesIO(pickled_data)
# Prepare the file for sending
files = {'file': ('docker_service_api_com_model.pkl', file_obj, 'application/octet-stream')}
# Send the POST request
response = requests.post(server_url, files=files, stream=True)
max_full_package_size = 1024 * 1024 * 1024 * 1 # 1 GB
received_output_manifest = {}
received_output_manifest['server_message'] = ""
received_output_manifest['server_std_err'] = ""
received_output_manifest['server_std_out'] = ""
received_output_manifest['server_file_attach'] = []
# Check if the request was successful
if response.status_code == 200:
# Process the streaming response
chunk_buf = None
for chunk in response.iter_content(max_full_package_size):
if chunk:
if chunk_buf is None: chunk_buf = chunk
else: chunk_buf += chunk
try:
received = pickle.loads(chunk_buf)
chunk_buf = None
received_output_manifest = process_received(received, save_file_dir, output_manifest = received_output_manifest)
yield received_output_manifest
except Exception as e:
# logger.error(f"pickle data was truncated, but don't worry, we will continue to receive the rest of the data.")
continue
else:
logger.error(f"Error: Received status code {response.status_code}, response.text: {response.text}")
return received_output_manifest

View File

@@ -51,7 +51,7 @@ def validate_path_safety(path_or_url, user):
from toolbox import get_conf, default_user_name
from toolbox import FriendlyException
PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING')
sensitive_path = None # 必须不能包含 '/',即不能是多级路径
sensitive_path = None
path_or_url = os.path.relpath(path_or_url)
if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分)
sensitive_path = PATH_LOGGING

View File

@@ -104,27 +104,17 @@ def extract_archive(file_path, dest_dir):
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
elif file_extension in [".tar", ".gz", ".bz2"]:
try:
with tarfile.open(file_path, "r:*") as tarobj:
# 清理提取路径,移除任何不安全的元素
for member in tarobj.getmembers():
member_path = os.path.normpath(member.name)
full_path = os.path.join(dest_dir, member_path)
full_path = os.path.abspath(full_path)
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
raise Exception(f"Attempted Path Traversal in {member.name}")
with tarfile.open(file_path, "r:*") as tarobj:
# 清理提取路径,移除任何不安全的元素
for member in tarobj.getmembers():
member_path = os.path.normpath(member.name)
full_path = os.path.join(dest_dir, member_path)
full_path = os.path.abspath(full_path)
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
raise Exception(f"Attempted Path Traversal in {member.name}")
tarobj.extractall(path=dest_dir)
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
except tarfile.ReadError as e:
if file_extension == ".gz":
# 一些特别奇葩的项目是一个gz文件里面不是tar只有一个tex文件
import gzip
with gzip.open(file_path, 'rb') as f_in:
with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
f_out.write(f_in.read())
else:
raise e
tarobj.extractall(path=dest_dir)
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
# 第三方库需要预先pip install rarfile
# 此外Windows上还需要安装winrar软件配置其Path环境变量如"C:\Program Files\WinRAR"才可以

View File

@@ -4,6 +4,7 @@ from functools import wraps, lru_cache
from shared_utils.advanced_markdown_format import format_io
from shared_utils.config_loader import get_conf as get_conf
pj = os.path.join
default_user_name = 'default_user'
@@ -11,13 +12,10 @@ default_user_name = 'default_user'
openai_regex = re.compile(
r"sk-[a-zA-Z0-9_-]{48}$|" +
r"sk-[a-zA-Z0-9_-]{92}$|" +
r"sk-proj-[a-zA-Z0-9_-]{48}$|" +
r"sk-proj-[a-zA-Z0-9_-]{124}$|" +
r"sk-proj-[a-zA-Z0-9_-]{156}$|" + #新版apikey位数不匹配故修改此正则表达式
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
r"sess-[a-zA-Z0-9]{40}$"
)
def is_openai_api_key(key):
CUSTOM_API_KEY_PATTERN = get_conf('CUSTOM_API_KEY_PATTERN')
if len(CUSTOM_API_KEY_PATTERN) != 0:
@@ -28,7 +26,7 @@ def is_openai_api_key(key):
def is_azure_api_key(key):
API_MATCH_AZURE = re.match(r"^[a-zA-Z0-9]{32}$|^[a-zA-Z0-9]{84}", key)
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_AZURE)
@@ -36,25 +34,16 @@ def is_api2d_key(key):
API_MATCH_API2D = re.match(r"fk[a-zA-Z0-9]{6}-[a-zA-Z0-9]{32}$", key)
return bool(API_MATCH_API2D)
def is_openroute_api_key(key):
API_MATCH_OPENROUTE = re.match(r"sk-or-v1-[a-zA-Z0-9]{64}$", key)
return bool(API_MATCH_OPENROUTE)
def is_cohere_api_key(key):
API_MATCH_AZURE = re.match(r"[a-zA-Z0-9]{40}$", key)
return bool(API_MATCH_AZURE)
def is_any_api_key(key):
# key 一般只包含字母、数字、下划线、逗号、中划线
if not re.match(r"^[a-zA-Z0-9_\-,]+$", key):
# 如果配置了 CUSTOM_API_KEY_PATTERN再检查以下以免误杀
if CUSTOM_API_KEY_PATTERN := get_conf('CUSTOM_API_KEY_PATTERN'):
return bool(re.match(CUSTOM_API_KEY_PATTERN, key))
return False
if ',' in key:
keys = key.split(',')
for k in keys:
@@ -88,8 +77,7 @@ def select_api_key(keys, llm_model):
avail_key_list = []
key_list = keys.split(',')
if llm_model.startswith('gpt-') or llm_model.startswith('chatgpt-') or \
llm_model.startswith('one-api-') or llm_model == 'o1' or llm_model.startswith('o1-'):
if llm_model.startswith('gpt-') or llm_model.startswith('one-api-') or llm_model.startswith('o1-'):
for k in key_list:
if is_openai_api_key(k): avail_key_list.append(k)
@@ -104,7 +92,7 @@ def select_api_key(keys, llm_model):
if llm_model.startswith('cohere-'):
for k in key_list:
if is_cohere_api_key(k): avail_key_list.append(k)
if llm_model.startswith('openrouter-'):
for k in key_list:
if is_openroute_api_key(k): avail_key_list.append(k)
@@ -112,7 +100,7 @@ def select_api_key(keys, llm_model):
if len(avail_key_list) == 0:
raise RuntimeError(f"您提供的api-key不满足要求不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源左上角更换模型菜单中可切换openai,azure,claude,cohere等请求源")
api_key = random.choice(avail_key_list) # 随机负载均衡
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key
@@ -128,5 +116,5 @@ def select_api_key_for_embed_models(keys, llm_model):
if len(avail_key_list) == 0:
raise RuntimeError(f"您提供的api-key不满足要求不包含任何可用于{llm_model}的api-key。您可能选择了错误的模型或请求源。")
api_key = random.choice(avail_key_list) # 随机负载均衡
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key

View File

@@ -1,15 +0,0 @@
"""
对项目中的各个插件进行测试。运行方法:直接运行 python tests/test_plugins.py
"""
import init_test
import os, sys
if __name__ == "__main__":
from experimental_mods.get_bilibili_resource import download_bilibili
download_bilibili("BV1LSSHYXEtv", only_audio=True, user_name="test")
# if __name__ == "__main__":
# from test_utils import plugin_test
# plugin_test(plugin='crazy_functions.VideoResource_GPT->视频任务', main_input="帮我找到《天文馆的猫》,歌手泠鸢")

View File

@@ -19,8 +19,4 @@ if __name__ == "__main__":
plugin_test = importlib.import_module('test_utils').plugin_test
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2203.01927")
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="gpt_log/arxiv_cache/2203.01927/workfolder")
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2410.05779")
plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="gpt_log/default_user/workfolder")
plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2203.01927")

View File

@@ -20,7 +20,7 @@ Replace 'Tex/' with the actual directory path where your files are located befor
md = """
Following code including wrapper
```python:wrapper.py
```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{decide}
@@ -29,45 +29,8 @@ graph TD
E --> B
D --> F[Save Image and Code]
F --> B
```
<details>
<summary><b>My section header in bold</b></summary>
Any folded content here. It requires an empty line just above it.
</details>
"""
md ="""
在这种场景中,您希望机器 B 能够通过轮询机制来间接地“请求”机器 A而实际上机器 A 只能主动向机器 B 发出请求。这是一种典型的客户端-服务器轮询模式。下面是如何实现这种机制的详细步骤:
### 机器 B 的实现
1. **安装 FastAPI 和必要的依赖库**
```bash
pip install fastapi uvicorn
```
2. **创建 FastAPI 服务**
```python
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from uuid import uuid4
from threading import Lock
import time
app = FastAPI()
# 字典用于存储请求和状态
requests = {}
process_lock = Lock()
"""
def validate_path():
import os, sys
@@ -80,12 +43,10 @@ def validate_path():
validate_path() # validate path so you can run from base directory
from toolbox import markdown_convertion
# from shared_utils.advanced_markdown_format import markdown_convertion_for_file
from shared_utils.advanced_markdown_format import close_up_code_segment_during_stream
# with open("gpt_log/default_user/shared/2024-04-22-01-27-43.zip.extract/translated_markdown.md", "r", encoding="utf-8") as f:
# md = f.read()
md = close_up_code_segment_during_stream(md)
html = markdown_convertion(md)
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
with open("gpt_log/default_user/shared/2024-04-22-01-27-43.zip.extract/translated_markdown.md", "r", encoding="utf-8") as f:
md = f.read()
html = markdown_convertion_for_file(md)
# print(html)
with open("test.html", "w", encoding="utf-8") as f:
f.write(html)

View File

@@ -1,67 +0,0 @@
"""
对项目中的各个插件进行测试。运行方法:直接运行 python tests/test_plugins.py
"""
import init_test
import os, sys
if __name__ == "__main__":
from test_utils import plugin_test
plugin_test(plugin='crazy_functions.VideoResource_GPT->多媒体任务', main_input="我想找一首歌里面有句歌词是“turn your face towards the sun”")
# plugin_test(plugin='crazy_functions.Internet_GPT->连接网络回答问题', main_input="谁是应急食品?")
# plugin_test(plugin='crazy_functions.函数动态生成->函数动态生成', main_input='交换图像的蓝色通道和红色通道', advanced_arg={"file_path_arg": "./build/ants.jpg"})
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2307.07522")
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='build/pdf/t1.pdf')
# plugin_test(
# plugin="crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF",
# main_input="G:/SEAFILE_LOCAL/50503047/我的资料库/学位/paperlatex/aaai/Fu_8368_with_appendix",
# )
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='修改api-key为sk-jhoejriotherjep')
# plugin_test(plugin='crazy_functions.批量翻译PDF文档_NOUGAT->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='调用插件对C:/Users/fuqingxu/Desktop/旧文件/gpt/chatgpt_academic/crazy_functions/latex_fns中的python文件进行解析')
# plugin_test(plugin='crazy_functions.命令行助手->命令行助手', main_input='查看当前的docker容器列表')
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个Python项目', main_input="crazy_functions/test_project/python/dqn")
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个C项目', main_input="crazy_functions/test_project/cpp/cppipc")
# plugin_test(plugin='crazy_functions.Latex_Project_Polish->Latex英文润色', main_input="crazy_functions/test_project/latex/attention")
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown中译英', main_input="README.md")
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
# plugin_test(plugin='crazy_functions.谷歌检索小助手->谷歌检索小助手', main_input="https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=auto+reinforcement+learning&btnG=")
# plugin_test(plugin='crazy_functions.总结word文档->总结word文档', main_input="crazy_functions/test_project/pdf_and_word")
# plugin_test(plugin='crazy_functions.下载arxiv论文翻译摘要->下载arxiv论文并翻译摘要', main_input="1812.10695")
# plugin_test(plugin='crazy_functions.联网的ChatGPT->连接网络回答问题', main_input="谁是应急食品?")
# plugin_test(plugin='crazy_functions.解析JupyterNotebook->解析ipynb文件', main_input="crazy_functions/test_samples")
# plugin_test(plugin='crazy_functions.数学动画生成manim->动画生成', main_input="A ball split into 2, and then split into 4, and finally split into 8.")
# for lang in ["English", "French", "Japanese", "Korean", "Russian", "Italian", "German", "Portuguese", "Arabic"]:
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown翻译指定语言', main_input="README.md", advanced_arg={"advanced_arg": lang})
# plugin_test(plugin='crazy_functions.知识库文件注入->知识库文件注入', main_input="./")
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="What is the installation method")
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="远程云服务器部署?")
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2210.03629")

View File

@@ -36,7 +36,7 @@ if __name__ == "__main__":
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个C项目', main_input="crazy_functions/test_project/cpp/cppipc")
# plugin_test(plugin='crazy_functions.Latex_Project_Polish->Latex英文润色', main_input="crazy_functions/test_project/latex/attention")
# plugin_test(plugin='crazy_functions.Latex全文润色->Latex英文润色', main_input="crazy_functions/test_project/latex/attention")
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown中译英', main_input="README.md")
@@ -65,3 +65,8 @@ if __name__ == "__main__":
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2210.03629")
# advanced_arg = {"advanced_arg":"--llm_to_learn=gpt-3.5-turbo --prompt_prefix='根据下面的服装类型提示想象一个穿着者对这个人外貌、身处的环境、内心世界、人设进行描写。要求100字以内用第二人称。' --system_prompt=''" }
# plugin_test(plugin='crazy_functions.chatglm微调工具->微调数据集生成', main_input='build/dev.json', advanced_arg=advanced_arg)
# advanced_arg = {"advanced_arg":"--pre_seq_len=128 --learning_rate=2e-2 --num_gpus=1 --json_dataset='t_code.json' --ptuning_directory='/home/hmp/ChatGLM2-6B/ptuning' " }
# plugin_test(plugin='crazy_functions.chatglm微调工具->启动微调', main_input='build/dev.json', advanced_arg=advanced_arg)

View File

@@ -1,33 +0,0 @@
import edge_tts
import os
import httpx
from toolbox import get_conf
async def test_tts():
async with httpx.AsyncClient() as client:
try:
# Forward the request to the target service
import tempfile
import edge_tts
import wave
import uuid
from pydub import AudioSegment
voice = get_conf("EDGE_TTS_VOICE")
tts = edge_tts.Communicate(text="测试", voice=voice)
temp_folder = tempfile.gettempdir()
temp_file_name = str(uuid.uuid4().hex)
temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3')
await tts.save(temp_file)
try:
mp3_audio = AudioSegment.from_file(temp_file, format="mp3")
mp3_audio.export(temp_file, format="wav")
with open(temp_file, 'rb') as wav_file: t = wav_file.read()
except:
raise RuntimeError("ffmpeg未安装无法处理EdgeTTS音频。安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`")
except httpx.RequestError as e:
raise RuntimeError(f"请求失败: {e}")
if __name__ == "__main__":
import asyncio
asyncio.run(test_tts())

View File

@@ -1,16 +1,9 @@
:root {
--gpt-academic-message-font-size: 15px;
}
.message {
font-size: var(--gpt-academic-message-font-size) !important;
}
#plugin_arg_menu {
transform: translate(-50%, -50%);
border: dashed;
}
/* hide remove all button */
.remove-all.svelte-aqlk7e.svelte-aqlk7e.svelte-aqlk7e {
visibility: hidden;
@@ -32,6 +25,7 @@
visibility: hidden;
}
/* height of the upload box */
.wrap.svelte-xwlu1w {
min-height: var(--size-32);
@@ -103,9 +97,13 @@
min-width: min(80px, 100%);
}
#cbs,
#cbs {
background-color: var(--block-background-fill) !important;
}
#cbsc {
background-color: rgba(var(--block-background-fill), 0.5) !important;
background-color: var(--block-background-fill) !important;
}
#interact-panel .form {
@@ -157,7 +155,7 @@
transform: translate(-50%, -50%);
flex-wrap: wrap;
justify-content: center;
transition: opacity 0.6s ease-in-out;
transition: opacity 1s ease-in-out;
opacity: 0;
}
.welcome-card-container.show {
@@ -209,7 +207,6 @@
.welcome-content {
text-wrap: balance;
height: 55px;
font-size: 13px;
display: flex;
align-items: center;
}
@@ -273,41 +270,4 @@
}
#gpt-submit-row #gpt-submit-dropdown > *:hover {
cursor: context-menu;
}
.tooltip.svelte-p2nen8.svelte-p2nen8 {
box-shadow: 10px 10px 15px rgba(0, 0, 0, 0.5);
left: 10px;
}
#tooltip .hidden {
/* display: none; */
opacity: 0;
transition: opacity 0.5s ease;
}
#tooltip .visible {
/* display: block; */
opacity: 1;
transition: opacity 0.5s ease;
}
#elem_fontsize,
#elem_top_p,
#elem_temperature,
#elem_max_length_sl,
#elem_prompt {
/* 左右为0顶部为0底部为2px */
padding: 0 0 4px 0;
backdrop-filter: blur(10px);
background-color: rgba(var(--block-background-fill), 0.5);
}
#tooltip #cbs,
#tooltip #cbsc,
#tooltip .svelte-b6y5bg,
#tooltip .tabitem {
backdrop-filter: blur(10px);
background-color: rgba(var(--block-background-fill), 0.5);
}
}

View File

@@ -318,7 +318,7 @@ function addCopyButton(botElement, index, is_last_in_arr) {
}
});
if (enable_tts) {
if (enable_tts){
var audioButton = document.createElement('button');
audioButton.classList.add('audio-toggle-btn');
audioButton.innerHTML = audioIcon;
@@ -346,7 +346,7 @@ function addCopyButton(botElement, index, is_last_in_arr) {
var messageBtnColumn = document.createElement('div');
messageBtnColumn.classList.add('message-btn-row');
messageBtnColumn.appendChild(copyButton);
if (enable_tts) {
if (enable_tts){
messageBtnColumn.appendChild(audioButton);
}
botElement.appendChild(messageBtnColumn);
@@ -391,9 +391,6 @@ function chatbotContentChanged(attempt = 1, force = false) {
// Now pass both the message element and the is_last_in_arr boolean to addCopyButton
addCopyButton(message, index, is_last_in_arr);
// save_conversation_history
save_conversation_history_slow_down();
});
// gradioApp().querySelectorAll('#gpt-chatbot .message-wrap .message.bot').forEach(addCopyButton);
}, i === 0 ? 0 : 200);
@@ -750,24 +747,10 @@ function minor_ui_adjustment() {
var bar_btn_width = [];
// 自动隐藏超出范围的toolbar按钮
function auto_hide_toolbar() {
// if chatbot hit upper page boarder, hide all
const elem_chatbot = document.getElementById('gpt-chatbot');
const chatbot_top = elem_chatbot.getBoundingClientRect().top;
var tooltip = document.getElementById('tooltip');
var tab_nav = tooltip.getElementsByClassName('tab-nav')[0];
// 20 px 大概是一个字的高度
if (chatbot_top < 20) {
// tab_nav.style.display = 'none';
if (tab_nav.classList.contains('visible')) {tab_nav.classList.remove('visible');}
if (!tab_nav.classList.contains('hidden')) {tab_nav.classList.add('hidden');}
return;
}
if (tab_nav.classList.contains('hidden')) {tab_nav.classList.remove('hidden');}
if (!tab_nav.classList.contains('visible')) {tab_nav.classList.add('visible');}
// tab_nav.style.display = '';
var qq = document.getElementById('tooltip');
var tab_nav = qq.getElementsByClassName('tab-nav');
if (tab_nav.length == 0) { return; }
var btn_list = tab_nav.getElementsByTagName('button')
var btn_list = tab_nav[0].getElementsByTagName('button')
if (btn_list.length == 0) { return; }
// 获取页面宽度
var page_width = document.documentElement.clientWidth;
@@ -871,7 +854,8 @@ function limit_scroll_position() {
// -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
function loadLive2D() {
if (document.querySelector(".waifu")) {
if (document.querySelector(".waifu") )
{
$('.waifu').show();
} else {
try {
@@ -938,12 +922,12 @@ function gpt_academic_gradio_saveload(
if (save_or_load === "load") {
let value = getCookie(cookie_key);
if (value) {
// console.log('加载cookie', elem_id, value)
console.log('加载cookie', elem_id, value)
push_data_to_gradio_component(value, elem_id, load_type);
}
else {
if (load_default) {
// console.log('加载cookie的默认值', elem_id, load_default_value)
console.log('加载cookie的默认值', elem_id, load_default_value)
push_data_to_gradio_component(load_default_value, elem_id, load_type);
}
}
@@ -953,153 +937,113 @@ function gpt_academic_gradio_saveload(
}
}
function generateUUID() {
// Generate a random number and convert it to a hexadecimal string
function randomHexDigit() {
return Math.floor((1 + Math.random()) * 0x10000).toString(16).slice(1);
}
// Construct the UUID using the randomHexDigit function
return (
randomHexDigit() + randomHexDigit() + '-' +
randomHexDigit() + '-' +
'4' + randomHexDigit().slice(0, 3) + '-' + // Version 4 UUID
((Math.floor(Math.random() * 4) + 8).toString(16)) + randomHexDigit().slice(0, 3) + '-' +
randomHexDigit() + randomHexDigit() + randomHexDigit()
);
}
function update_conversation_metadata() {
// Create a conversation UUID and timestamp
try {
const conversationId = generateUUID();
console.log('Create conversation ID:', conversationId);
const timestamp = new Date().toISOString();
const conversationMetaData = {
id: conversationId,
timestamp: timestamp
};
localStorage.setItem("conversation_metadata", JSON.stringify(conversationMetaData));
} catch (e) {
console.error('Error in updating conversation metadata:', e);
}
}
// Helper function to generate conversation preview
function generatePreview(conversation, timestamp, maxLength = 100) {
if (!conversation || conversation.length === 0) return "";
// Join all messages with dash separator
let preview = conversation.join("\n");
const readableDate = new Date(timestamp).toLocaleString();
preview = readableDate + "\n" + preview;
if (preview.length <= maxLength) return preview;
return preview.substring(0, maxLength) + "...";
}
async function save_conversation_history() {
let chatbot = await get_data_from_gradio_component('gpt-chatbot');
let history = await get_data_from_gradio_component('history-ng');
let conversation = {};
let conversation_metadata = localStorage.getItem("conversation_metadata");
try {
conversation_metadata = JSON.parse(conversation_metadata);
conversation = {
timestamp: conversation_metadata.timestamp,
id: conversation_metadata.id,
metadata: conversation_metadata,
conversation: chatbot,
history: history,
preview: generatePreview(JSON.parse(history), conversation_metadata.timestamp)
};
} catch (e) {
// console.error('Conversation metadata parse error, recreate conversation metadata');
update_conversation_metadata();
return;
}
// Get existing conversation history from local storage
let conversation_history = [];
try {
const stored = localStorage.getItem('conversation_history');
if (stored) {
conversation_history = JSON.parse(stored);
}
} catch (e) {
// console.error('Error reading conversation history from localStorage:', e);
}
// Find existing conversation with same ID
const existingIndex = conversation_history.findIndex(c => c.id === conversation.id);
if (existingIndex >= 0) {
// Update existing conversation
conversation_history[existingIndex] = conversation;
} else {
// Add new conversation
conversation_history.push(conversation);
}
// Sort conversations by timestamp, newest first
conversation_history.sort((a, b) => {
const timeA = new Date(a.timestamp).getTime();
const timeB = new Date(b.timestamp).getTime();
return timeB - timeA;
});
const max_chat_preserve = 10;
if (conversation_history.length >= max_chat_preserve + 1) {
toast_push('对话时间线记录已满,正在移除最早的对话记录。您也可以点击左侧的记录点进行手动清理。', 3000);
conversation_history = conversation_history.slice(0, max_chat_preserve);
}
// Save back to local storage
try {
localStorage.setItem('conversation_history', JSON.stringify(conversation_history));
const LOCAL_STORAGE_UPDATED = "gptac_conversation_history_updated";
window.dispatchEvent(
new CustomEvent(LOCAL_STORAGE_UPDATED, {
detail: conversation_history
})
);
} catch (e) {
console.error('Error saving conversation history to localStorage:', e);
}
}
save_conversation_history_slow_down = do_something_but_not_too_frequently(300, save_conversation_history);
function restore_chat_from_local_storage(event) {
let conversation = event.detail;
push_data_to_gradio_component(conversation.conversation, "gpt-chatbot", "obj");
push_data_to_gradio_component(conversation.history, "history-ng", "obj");
const conversationId = conversation.id;
const timestamp = conversation.timestamp;
const conversationData = {
id: conversationId,
timestamp: timestamp
};
localStorage.setItem("conversation_metadata", JSON.stringify(conversationData));
}
async function clear_conversation(a, b, c) {
await save_conversation_history();
update_conversation_metadata();
let stopButton = document.getElementById("elem_stop");
stopButton.click();
return reset_conversation(a, b);
}
function reset_conversation(a, b) {
// console.log("js_code_reset");
a = btoa(unescape(encodeURIComponent(JSON.stringify(a))));
setCookie("js_previous_chat_cookie", a, 1);
gen_restore_btn();
return [[], [], "已重置"];
}
// clear -> 将 history 缓存至 history_cache -> 点击复原 -> restore_previous_chat() -> 触发elem_update_history -> 读取 history_cache
function restore_previous_chat() {
console.log("restore_previous_chat");
let chat = getCookie("js_previous_chat_cookie");
chat = JSON.parse(decodeURIComponent(escape(atob(chat))));
push_data_to_gradio_component(chat, "gpt-chatbot", "obj");
document.querySelector("#elem_update_history").click(); // in order to call set_history_gr_state, and send history state to server
}
function gen_restore_btn() {
// 创建按钮元素
const button = document.createElement('div');
// const recvIcon = '<span><svg stroke="currentColor" fill="none" stroke-width="2" viewBox="0 0 24 24" stroke-linecap="round" stroke-linejoin="round" height=".8em" width=".8em" xmlns="http://www.w3.org/2000/svg"><polyline points="20 6 9 17 4 12"></polyline></svg></span>';
const rec_svg = '<svg t="1714361184567" style="transform:translate(1px, 2.5px)" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4389" width="35" height="35"><path d="M320 512h384v64H320zM320 384h384v64H320zM320 640h192v64H320z" p-id="4390" fill="#ffffff"></path><path d="M863.7 544c-1.9 44-11.4 86.8-28.5 127.2-18.5 43.8-45.1 83.2-78.9 117-33.8 33.8-73.2 60.4-117 78.9C593.9 886.3 545.7 896 496 896s-97.9-9.7-143.2-28.9c-43.8-18.5-83.2-45.1-117-78.9-33.8-33.8-60.4-73.2-78.9-117C137.7 625.9 128 577.7 128 528s9.7-97.9 28.9-143.2c18.5-43.8 45.1-83.2 78.9-117s73.2-60.4 117-78.9C398.1 169.7 446.3 160 496 160s97.9 9.7 143.2 28.9c23.5 9.9 45.8 22.2 66.5 36.7l-119.7 20 9.9 59.4 161.6-27 59.4-9.9-9.9-59.4-27-161.5-59.4 9.9 19 114.2C670.3 123.8 586.4 96 496 96 257.4 96 64 289.4 64 528s193.4 432 432 432c233.2 0 423.3-184.8 431.7-416h-64z" p-id="4391" fill="#ffffff"></path></svg>'
const recvIcon = '<span>' + rec_svg + '</span>';
// 设置按钮的样式和属性
button.id = 'floatingButton';
button.className = 'glow';
button.style.textAlign = 'center';
button.style.position = 'fixed';
button.style.bottom = '10px';
button.style.left = '10px';
button.style.width = '50px';
button.style.height = '50px';
button.style.borderRadius = '50%';
button.style.backgroundColor = '#007bff';
button.style.color = 'white';
button.style.display = 'flex';
button.style.alignItems = 'center';
button.style.justifyContent = 'center';
button.style.cursor = 'pointer';
button.style.transition = 'all 0.3s ease';
button.style.boxShadow = '0 0 10px rgba(0,0,0,0.2)';
button.innerHTML = recvIcon;
// 添加发光动画的关键帧
const styleSheet = document.createElement('style');
styleSheet.id = 'floatingButtonStyle';
styleSheet.innerText = `
@keyframes glow {
from {
box-shadow: 0 0 10px rgba(0,0,0,0.2);
}
to {
box-shadow: 0 0 13px rgba(0,0,0,0.5);
}
}
#floatingButton.glow {
animation: glow 1s infinite alternate;
}
#floatingButton:hover {
transform: scale(1.2);
box-shadow: 0 0 20px rgba(0,0,0,0.4);
}
#floatingButton.disappearing {
animation: shrinkAndDisappear 0.5s forwards;
}
`;
// only add when not exist
if (!document.getElementById('recvButtonStyle'))
{
document.head.appendChild(styleSheet);
}
// 鼠标悬停和移开的事件监听器
button.addEventListener('mouseover', function () {
this.textContent = "还原\n对话";
});
button.addEventListener('mouseout', function () {
this.innerHTML = recvIcon;
});
// 点击事件监听器
button.addEventListener('click', function () {
// 添加一个类来触发缩小和消失的动画
restore_previous_chat();
this.classList.add('disappearing');
// 在动画结束后移除按钮
document.body.removeChild(this);
});
// only add when not exist
if (!document.getElementById('recvButton'))
{
document.body.appendChild(button);
}
// 将按钮添加到页面中
}
async function on_plugin_exe_complete(fn_name) {
// console.log(fn_name);
console.log(fn_name);
if (fn_name === "保存当前的对话") {
// get chat profile path
let chatbot = await get_data_from_gradio_component('gpt-chatbot');
@@ -1118,15 +1062,15 @@ async function on_plugin_exe_complete(fn_name) {
}
let href = get_href(may_have_chat_profile_info);
if (href) {
const cleanedHref = href.replace('file=', ''); // gpt_log/default_user/chat_history/GPT-Academic对话存档2024-04-12-00-35-06.html
// console.log(cleanedHref);
const cleanedHref = href.replace('file=', ''); // /home/fuqingxu/chatgpt_academic/gpt_log/default_user/chat_history/GPT-Academic对话存档2024-04-12-00-35-06.html
console.log(cleanedHref);
}
}
}
async function generate_menu(guiBase64String, btnName) {
async function generate_menu(guiBase64String, btnName){
// assign the button and menu data
push_data_to_gradio_component(guiBase64String, "invisible_current_pop_up_plugin_arg", "string");
push_data_to_gradio_component(btnName, "invisible_callback_btn_for_plugin_exe", "string");
@@ -1160,22 +1104,22 @@ async function generate_menu(guiBase64String, btnName) {
///////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////// Textbox ////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////
if (gui_args[key].type == 'string') { // PLUGIN_ARG_MENU
if (gui_args[key].type=='string'){ // PLUGIN_ARG_MENU
const component_name = "plugin_arg_txt_" + text_cnt;
push_data_to_gradio_component({
visible: true,
label: gui_args[key].title + "(" + gui_args[key].description + ")",
label: gui_args[key].title + "(" + gui_args[key].description + ")",
// label: gui_args[key].title,
placeholder: gui_args[key].description,
__type__: 'update'
}, component_name, "obj");
if (key === "main_input") {
if (key === "main_input"){
// 为了与旧插件兼容,生成菜单时,自动加载输入栏的值
let current_main_input = await get_data_from_gradio_component('user_input_main');
let current_main_input_2 = await get_data_from_gradio_component('user_input_float');
push_data_to_gradio_component(current_main_input + current_main_input_2, component_name, "obj");
}
else if (key === "advanced_arg") {
else if (key === "advanced_arg"){
// 为了与旧插件兼容,生成菜单时,自动加载旧高级参数输入区的值
let advance_arg_input_legacy = await get_data_from_gradio_component('advance_arg_input_legacy');
push_data_to_gradio_component(advance_arg_input_legacy, component_name, "obj");
@@ -1190,12 +1134,12 @@ async function generate_menu(guiBase64String, btnName) {
///////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////// Dropdown ////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////
if (gui_args[key].type == 'dropdown') { // PLUGIN_ARG_MENU
if (gui_args[key].type=='dropdown'){ // PLUGIN_ARG_MENU
const component_name = "plugin_arg_drop_" + dropdown_cnt;
push_data_to_gradio_component({
visible: true,
choices: gui_args[key].options,
label: gui_args[key].title + "(" + gui_args[key].description + ")",
label: gui_args[key].title + "(" + gui_args[key].description + ")",
// label: gui_args[key].title,
placeholder: gui_args[key].description,
__type__: 'update'
@@ -1210,7 +1154,7 @@ async function generate_menu(guiBase64String, btnName) {
}
}
async function execute_current_pop_up_plugin() {
async function execute_current_pop_up_plugin(){
let guiBase64String = await get_data_from_gradio_component('invisible_current_pop_up_plugin_arg');
const stringData = atob(guiBase64String);
let guiJsonData = JSON.parse(stringData);
@@ -1226,8 +1170,8 @@ async function execute_current_pop_up_plugin() {
let text_cnt = 0;
for (const key in gui_args) {
if (gui_args.hasOwnProperty(key)) {
if (gui_args[key].type == 'string') { // PLUGIN_ARG_MENU
corrisponding_elem_id = "plugin_arg_txt_" + text_cnt
if (gui_args[key].type=='string'){ // PLUGIN_ARG_MENU
corrisponding_elem_id = "plugin_arg_txt_"+text_cnt
gui_args[key].user_confirmed_value = await get_data_from_gradio_component(corrisponding_elem_id);
text_cnt += 1;
}
@@ -1236,8 +1180,8 @@ async function execute_current_pop_up_plugin() {
let dropdown_cnt = 0;
for (const key in gui_args) {
if (gui_args.hasOwnProperty(key)) {
if (gui_args[key].type == 'dropdown') { // PLUGIN_ARG_MENU
corrisponding_elem_id = "plugin_arg_drop_" + dropdown_cnt
if (gui_args[key].type=='dropdown'){ // PLUGIN_ARG_MENU
corrisponding_elem_id = "plugin_arg_drop_"+dropdown_cnt
gui_args[key].user_confirmed_value = await get_data_from_gradio_component(corrisponding_elem_id);
dropdown_cnt += 1;
}
@@ -1256,29 +1200,29 @@ async function execute_current_pop_up_plugin() {
}
function hide_all_elem() {
// PLUGIN_ARG_MENU
for (text_cnt = 0; text_cnt < 8; text_cnt++) {
function hide_all_elem(){
// PLUGIN_ARG_MENU
for (text_cnt = 0; text_cnt < 8; text_cnt++){
push_data_to_gradio_component({
visible: false,
label: "",
__type__: 'update'
}, "plugin_arg_txt_" + text_cnt, "obj");
document.getElementById("plugin_arg_txt_" + text_cnt).parentNode.parentNode.style.display = 'none';
}, "plugin_arg_txt_"+text_cnt, "obj");
document.getElementById("plugin_arg_txt_"+text_cnt).parentNode.parentNode.style.display = 'none';
}
for (dropdown_cnt = 0; dropdown_cnt < 8; dropdown_cnt++) {
for (dropdown_cnt = 0; dropdown_cnt < 8; dropdown_cnt++){
push_data_to_gradio_component({
visible: false,
choices: [],
label: "",
__type__: 'update'
}, "plugin_arg_drop_" + dropdown_cnt, "obj");
document.getElementById("plugin_arg_drop_" + dropdown_cnt).parentNode.style.display = 'none';
}, "plugin_arg_drop_"+dropdown_cnt, "obj");
document.getElementById("plugin_arg_drop_"+dropdown_cnt).parentNode.style.display = 'none';
}
}
function close_current_pop_up_plugin() {
// PLUGIN_ARG_MENU
function close_current_pop_up_plugin(){
// PLUGIN_ARG_MENU
push_data_to_gradio_component({
visible: false,
__type__: 'update'
@@ -1289,13 +1233,15 @@ function close_current_pop_up_plugin() {
// 生成高级插件的选择菜单
plugin_init_info_lib = {}
function register_plugin_init(key, base64String) {
function register_plugin_init(key, base64String){
// console.log('x')
const stringData = atob(base64String);
let guiJsonData = JSON.parse(stringData);
if (key in plugin_init_info_lib) {
if (key in plugin_init_info_lib)
{
}
else {
else
{
plugin_init_info_lib[key] = {};
}
plugin_init_info_lib[key].info = guiJsonData.Info;
@@ -1305,26 +1251,28 @@ function register_plugin_init(key, base64String) {
plugin_init_info_lib[key].enable_advanced_arg = guiJsonData.AdvancedArgs;
plugin_init_info_lib[key].arg_reminder = guiJsonData.ArgsReminder;
}
function register_advanced_plugin_init_code(key, code) {
if (key in plugin_init_info_lib) {
function register_advanced_plugin_init_code(key, code){
if (key in plugin_init_info_lib)
{
}
else {
else
{
plugin_init_info_lib[key] = {};
}
plugin_init_info_lib[key].secondary_menu_code = code;
}
function run_advanced_plugin_launch_code(key) {
function run_advanced_plugin_launch_code(key){
// convert js code string to function
generate_menu(plugin_init_info_lib[key].secondary_menu_code, key);
}
function on_flex_button_click(key) {
if (plugin_init_info_lib.hasOwnProperty(key) && plugin_init_info_lib[key].hasOwnProperty('secondary_menu_code')) {
function on_flex_button_click(key){
if (plugin_init_info_lib.hasOwnProperty(key) && plugin_init_info_lib[key].hasOwnProperty('secondary_menu_code')){
run_advanced_plugin_launch_code(key);
} else {
}else{
document.getElementById("old_callback_btn_for_plugin_exe").click();
}
}
async function run_dropdown_shift(dropdown) {
async function run_dropdown_shift(dropdown){
let key = dropdown;
push_data_to_gradio_component({
value: key,
@@ -1333,7 +1281,7 @@ async function run_dropdown_shift(dropdown) {
__type__: 'update'
}, "elem_switchy_bt", "obj");
if (plugin_init_info_lib[key].enable_advanced_arg) {
if (plugin_init_info_lib[key].enable_advanced_arg){
push_data_to_gradio_component({
visible: true,
label: plugin_init_info_lib[key].label,
@@ -1355,9 +1303,9 @@ async function duplicate_in_new_window() {
window.open(url, '_blank');
}
async function run_classic_plugin_via_id(plugin_elem_id) {
for (key in plugin_init_info_lib) {
if (plugin_init_info_lib[key].elem_id == plugin_elem_id) {
async function run_classic_plugin_via_id(plugin_elem_id){
for (key in plugin_init_info_lib){
if (plugin_init_info_lib[key].elem_id == plugin_elem_id){
// 获取按钮名称
let current_btn_name = await get_data_from_gradio_component(plugin_elem_id);
// 执行
@@ -1378,7 +1326,7 @@ async function call_plugin_via_name(current_btn_name) {
hide_all_elem();
// 为了与旧插件兼容,生成菜单时,自动加载旧高级参数输入区的值
let advance_arg_input_legacy = await get_data_from_gradio_component('advance_arg_input_legacy');
if (advance_arg_input_legacy.length != 0) {
if (advance_arg_input_legacy.length != 0){
gui_args["advanced_arg"] = {};
gui_args["advanced_arg"].user_confirmed_value = advance_arg_input_legacy;
}
@@ -1401,11 +1349,18 @@ async function multiplex_function_begin(multiplex_sel) {
click_real_submit_btn();
return;
}
// do not delete `REPLACE_EXTENDED_MULTIPLEX_FUNCTIONS_HERE`! It will be read and replaced by Python code.
// REPLACE_EXTENDED_MULTIPLEX_FUNCTIONS_HERE
if (multiplex_sel === "多模型对话") {
let _align_name_in_crazy_function_py = "询问多个GPT模型";
call_plugin_via_name(_align_name_in_crazy_function_py);
return;
}
if (multiplex_sel === "智能召回 RAG") {
let _align_name_in_crazy_function_py = "Rag智能召回";
call_plugin_via_name(_align_name_in_crazy_function_py);
return;
}
}
async function run_multiplex_shift(multiplex_sel) {
async function run_multiplex_shift(multiplex_sel){
let key = multiplex_sel;
if (multiplex_sel === "常规对话") {
key = "提交";
@@ -1417,8 +1372,3 @@ async function run_multiplex_shift(multiplex_sel) {
__type__: 'update'
}, "elem_submit_visible", "obj");
}
async function persistent_cookie_init(web_cookie_cache, cookie) {
return [localStorage.getItem('web_cookie_cache'), cookie];
}

View File

@@ -2,25 +2,6 @@ from functools import lru_cache
from toolbox import get_conf
CODE_HIGHLIGHT, ADD_WAIFU, LAYOUT = get_conf("CODE_HIGHLIGHT", "ADD_WAIFU", "LAYOUT")
def inject_mutex_button_code(js_content):
from crazy_functional import get_multiplex_button_functions
fns = get_multiplex_button_functions()
template = """
if (multiplex_sel === "{x}") {
let _align_name_in_crazy_function_py = "{y}";
call_plugin_via_name(_align_name_in_crazy_function_py);
return;
}
"""
replacement = ""
for fn in fns.keys():
if fn == "常规对话": continue
replacement += template.replace("{x}", fn).replace("{y}", fns[fn])
js_content = js_content.replace("// REPLACE_EXTENDED_MULTIPLEX_FUNCTIONS_HERE", replacement)
return js_content
def minimize_js(common_js_path):
try:
import rjsmin, hashlib, glob, os
@@ -29,16 +10,14 @@ def minimize_js(common_js_path):
os.remove(old_min_js)
# use rjsmin to minimize `common_js_path`
c_jsmin = rjsmin.jsmin
with open(common_js_path, "r", encoding='utf-8') as f:
with open(common_js_path, "r") as f:
js_content = f.read()
if common_js_path == "themes/common.js":
js_content = inject_mutex_button_code(js_content)
minimized_js_content = c_jsmin(js_content)
# compute sha256 hash of minimized js content
sha_hash = hashlib.sha256(minimized_js_content.encode()).hexdigest()[:8]
minimized_js_path = common_js_path + '.min.' + sha_hash + '.js'
# save to minimized js file
with open(minimized_js_path, "w", encoding='utf-8') as f:
with open(minimized_js_path, "w") as f:
f.write(minimized_js_content)
# return minimized js file path
return minimized_js_path

File diff suppressed because it is too large Load Diff

View File

@@ -10,14 +10,6 @@ theme_dir = os.path.dirname(__file__)
def adjust_theme():
try:
set_theme = gr.themes.Soft(
font=[
"Helvetica",
"Microsoft YaHei",
"ui-sans-serif",
"sans-serif",
"system-ui",
],
font_mono=["ui-monospace", "Consolas", "monospace"],
primary_hue=gr.themes.Color(
c50="#EBFAF2",
c100="#CFF3E1",

View File

@@ -1,7 +1,7 @@
import gradio as gr
def define_gui_floating_menu(customize_btns, functional, predefined_btns, cookies, web_cookie_cache):
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top", elem_id="f_area_input_secondary") as area_input_secondary:
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top") as area_input_secondary:
with gr.Accordion("浮动输入区", open=True, elem_id="input-panel2"):
with gr.Row() as row:
row.style(equal_height=True)
@@ -17,7 +17,7 @@ def define_gui_floating_menu(customize_btns, functional, predefined_btns, cookie
clearBtn2 = gr.Button("清除", elem_id="elem_clear2", variant="secondary", visible=False); clearBtn2.style(size="sm")
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top", elem_id="f_area_customize") as area_customize:
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top") as area_customize:
with gr.Accordion("自定义菜单", open=True, elem_id="edit-panel"):
with gr.Row() as row:
with gr.Column(scale=10):
@@ -35,9 +35,9 @@ def define_gui_floating_menu(customize_btns, functional, predefined_btns, cookie
# update btn
h = basic_fn_confirm.click(assign_btn, [web_cookie_cache, cookies, basic_btn_dropdown, basic_fn_title, basic_fn_prefix, basic_fn_suffix],
[web_cookie_cache, cookies, *customize_btns.values(), *predefined_btns.values()])
h.then(None, [web_cookie_cache], None, _js="""(web_cookie_cache)=>{localStorage.setItem("web_cookie_cache", web_cookie_cache);}""")
h.then(None, [web_cookie_cache], None, _js="""(web_cookie_cache)=>{setCookie("web_cookie_cache", web_cookie_cache, 365);}""")
# clean up btn
h2 = basic_fn_clean.click(assign_btn, [web_cookie_cache, cookies, basic_btn_dropdown, basic_fn_title, basic_fn_prefix, basic_fn_suffix, gr.State(True)],
[web_cookie_cache, cookies, *customize_btns.values(), *predefined_btns.values()])
h2.then(None, [web_cookie_cache], None, _js="""(web_cookie_cache)=>{localStorage.setItem("web_cookie_cache", web_cookie_cache);}""")
h2.then(None, [web_cookie_cache], None, _js="""(web_cookie_cache)=>{setCookie("web_cookie_cache", web_cookie_cache, 365);}""")
return area_input_secondary, txt2, area_customize, submitBtn2, resetBtn2, clearBtn2, stopBtn2

View File

@@ -1,7 +1,6 @@
import gradio as gr
from toolbox import get_conf
def define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAIL_THEMES, AVAIL_FONTS, ADD_WAIFU, help_menu_description, js_code_for_toggle_darkmode):
def define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAIL_THEMES, ADD_WAIFU, help_menu_description, js_code_for_toggle_darkmode):
with gr.Floating(init_x="0%", init_y="0%", visible=True, width=None, drag="forbidden", elem_id="tooltip"):
with gr.Row():
with gr.Tab("上传文件", elem_id="interact-panel"):
@@ -10,12 +9,12 @@ def define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAI
with gr.Tab("更换模型", elem_id="interact-panel"):
md_dropdown = gr.Dropdown(AVAIL_LLM_MODELS, value=LLM_MODEL, elem_id="elem_model_sel", label="更换LLM模型/请求源").style(container=False)
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)", elem_id="elem_top_p")
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature", elem_id="elem_temperature")
max_length_sl = gr.Slider(minimum=256, maximum=1024*32, value=4096, step=128, interactive=True, label="Local LLM MaxLength", elem_id="elem_max_length_sl")
max_length_sl = gr.Slider(minimum=256, maximum=1024*32, value=4096, step=128, interactive=True, label="Local LLM MaxLength",)
system_prompt = gr.Textbox(show_label=True, lines=2, placeholder=f"System Prompt", label="System prompt", value=INIT_SYS_PROMPT, elem_id="elem_prompt")
temperature.change(None, inputs=[temperature], outputs=None,
_js="""(temperature)=>gpt_academic_gradio_saveload("save", "elem_temperature", "js_temperature_cookie", temperature)""")
_js="""(temperature)=>gpt_academic_gradio_saveload("save", "elem_prompt", "js_temperature_cookie", temperature)""")
system_prompt.change(None, inputs=[system_prompt], outputs=None,
_js="""(system_prompt)=>gpt_academic_gradio_saveload("save", "elem_prompt", "js_system_prompt_cookie", system_prompt)""")
md_dropdown.change(None, inputs=[md_dropdown], outputs=None,
@@ -23,8 +22,6 @@ def define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAI
with gr.Tab("界面外观", elem_id="interact-panel"):
theme_dropdown = gr.Dropdown(AVAIL_THEMES, value=THEME, label="更换UI主题").style(container=False)
fontfamily_dropdown = gr.Dropdown(AVAIL_FONTS, value=get_conf("FONT"), elem_id="elem_fontfamily", label="更换字体类型").style(container=False)
fontsize_slider = gr.Slider(minimum=5, maximum=25, value=15, step=1, interactive=True, label="字体大小(默认15)", elem_id="elem_fontsize")
checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "浮动输入区", "输入清除键", "插件参数区"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区", elem_id='cbs').style(container=False)
opt = ["自定义菜单"]
value=[]
@@ -34,10 +31,7 @@ def define_gui_toolbar(AVAIL_LLM_MODELS, LLM_MODEL, INIT_SYS_PROMPT, THEME, AVAI
dark_mode_btn.click(None, None, None, _js=js_code_for_toggle_darkmode)
open_new_tab = gr.Button("打开新对话", variant="secondary").style(size="sm")
open_new_tab.click(None, None, None, _js=f"""()=>duplicate_in_new_window()""")
fontfamily_dropdown.select(None, inputs=[fontfamily_dropdown], outputs=None,
_js="""(fontfamily)=>{gpt_academic_gradio_saveload("save", "elem_fontfamily", "js_fontfamily", fontfamily); gpt_academic_change_chatbot_font(fontfamily, null, null);}""")
fontsize_slider.change(None, inputs=[fontsize_slider], outputs=None,
_js="""(fontsize)=>{gpt_academic_gradio_saveload("save", "elem_fontsize", "js_fontsize", fontsize); gpt_academic_change_chatbot_font(null, fontsize, null);}""")
with gr.Tab("帮助", elem_id="interact-panel"):
gr.Markdown(help_menu_description)

View File

@@ -1,142 +1,9 @@
function remove_legacy_cookie() {
setCookie("web_cookie_cache", "", -1);
setCookie("js_previous_chat_cookie", "", -1);
setCookie("js_previous_history_cookie", "", -1);
}
function processFontFamily(fontfamily) {
// 检查是否包含括号
if (fontfamily.includes('(')) {
// 分割字符串
const parts = fontfamily.split('(');
const fontNamePart = parts[1].split(')')[0].trim(); // 获取括号内的部分
// 检查是否包含 @
if (fontNamePart.includes('@')) {
const [fontName, fontUrl] = fontNamePart.split('@').map(part => part.trim());
return { fontName, fontUrl };
} else {
return { fontName: fontNamePart, fontUrl: null };
}
} else {
return { fontName: fontfamily, fontUrl: null };
}
}
// 检查字体是否存在
function checkFontAvailability(fontfamily) {
return new Promise((resolve) => {
const canvas = document.createElement('canvas');
const context = canvas.getContext('2d');
// 设置两个不同的字体进行比较
const testText = 'abcdefghijklmnopqrstuvwxyz0123456789';
context.font = `16px ${fontfamily}, sans-serif`;
const widthWithFont = context.measureText(testText).width;
context.font = '16px sans-serif';
const widthWithFallback = context.measureText(testText).width;
// 如果宽度相同,说明字体不存在
resolve(widthWithFont !== widthWithFallback);
});
}
async function checkFontAvailabilityV2(fontfamily) {
fontName = fontfamily;
console.log('Checking font availability:', fontName);
if ('queryLocalFonts' in window) {
try {
const fonts = await window.queryLocalFonts();
const fontExists = fonts.some(font => font.family === fontName);
console.log(`Local Font "${fontName}" exists:`, fontExists);
return fontExists;
} catch (error) {
console.error('Error querying local fonts:', error);
return false;
}
} else {
console.error('queryLocalFonts is not supported in this browser.');
return false;
}
}
// 动态加载字体
function loadFont(fontfamily, fontUrl) {
return new Promise((resolve, reject) => {
// 使用 Google Fonts 或其他字体来源
const link = document.createElement('link');
link.rel = 'stylesheet';
link.href = fontUrl;
link.onload = () => {
toast_push(`字体 "${fontfamily}" 已成功加载`, 3000);
resolve();
};
link.onerror = (error) => {
reject(error);
};
document.head.appendChild(link);
});
}
function gpt_academic_change_chatbot_font(fontfamily, fontsize, fontcolor) {
const chatbot = document.querySelector('#gpt-chatbot');
// 检查元素是否存在
if (chatbot) {
if (fontfamily != null) {
// 更改字体
const result = processFontFamily(fontfamily);
if (result.fontName == "Theme-Default-Font") {
chatbot.style.fontFamily = result.fontName;
return;
}
// 检查字体是否存在
checkFontAvailability(result.fontName).then((isAvailable) => {
if (isAvailable) {
// 如果字体存在,直接应用
chatbot.style.fontFamily = result.fontName;
} else {
if (result.fontUrl == null) {
// toast_push('无法加载字体本地字体不存在且URL未提供', 3000);
// 直接把失效的字体放上去让系统自动fallback
chatbot.style.fontFamily = result.fontName;
return;
} else {
toast_push('正在下载字体', 3000);
// 如果字体不存在,尝试加载字体
loadFont(result.fontName, result.fontUrl).then(() => {
chatbot.style.fontFamily = result.fontName;
}).catch((error) => {
console.error(`无法加载字体 "${result.fontName}":`, error);
});
}
}
});
}
if (fontsize != null) {
// 修改字体大小
document.documentElement.style.setProperty(
'--gpt-academic-message-font-size',
`${fontsize}px`
);
}
if (fontcolor != null) {
// 更改字体颜色
chatbot.style.color = fontcolor;
}
} else {
console.error('#gpt-chatbot is missing');
}
}
async function GptAcademicJavaScriptInit(dark, prompt, live2d, layout, tts) {
// 第一部分,布局初始化
remove_legacy_cookie();
audio_fn_init();
minor_ui_adjustment();
ButtonWithDropdown_init();
update_conversation_metadata();
window.addEventListener("gptac_restore_chat_from_local_storage", restore_chat_from_local_storage);
// 加载欢迎页面
const welcomeMessage = new WelcomeMessage();
welcomeMessage.begin_render();
@@ -146,7 +13,7 @@ async function GptAcademicJavaScriptInit(dark, prompt, live2d, layout, tts) {
welcomeMessage.update();
});
chatbotObserver.observe(chatbotIndicator, { attributes: true, childList: true, subtree: true });
if (layout === "LEFT-RIGHT") { chatbotAutoHeight(); }
if (layout === "LEFT-RIGHT") { limit_scroll_position(); }
@@ -169,7 +36,7 @@ async function GptAcademicJavaScriptInit(dark, prompt, live2d, layout, tts) {
}
// 自动朗读
if (tts != "DISABLE") {
if (tts != "DISABLE"){
enable_tts = true;
if (getCookie("js_auto_read_cookie")) {
auto_read_tts = getCookie("js_auto_read_cookie")
@@ -179,11 +46,7 @@ async function GptAcademicJavaScriptInit(dark, prompt, live2d, layout, tts) {
}
}
}
// 字体
gpt_academic_gradio_saveload("load", "elem_fontfamily", "js_fontfamily", null, "str");
gpt_academic_change_chatbot_font(getCookie("js_fontfamily"), null, null);
gpt_academic_gradio_saveload("load", "elem_fontsize", "js_fontsize", null, "str");
gpt_academic_change_chatbot_font(null, getCookie("js_fontsize"), null);
// SysPrompt 系统静默提示词
gpt_academic_gradio_saveload("load", "elem_prompt", "js_system_prompt_cookie", null, "str");
// Temperature 大模型温度参数
@@ -193,7 +56,7 @@ async function GptAcademicJavaScriptInit(dark, prompt, live2d, layout, tts) {
const cached_model = getCookie("js_md_dropdown_cookie");
var model_sel = await get_gradio_component("elem_model_sel");
// determine whether the cached model is in the choices
if (model_sel.props.choices.includes(cached_model)) {
if (model_sel.props.choices.includes(cached_model)){
// change dropdown
gpt_academic_gradio_saveload("load", "elem_model_sel", "js_md_dropdown_cookie", null, "str");
// 连锁修改chatbot的label

View File

@@ -87,6 +87,21 @@ js_code_for_toggle_darkmode = """() => {
}"""
js_code_for_persistent_cookie_init = """(web_cookie_cache, cookie) => {
return [getCookie("web_cookie_cache"), cookie];
}
"""
# 详见 themes/common.js
js_code_reset = """
(a,b,c)=>{
let stopButton = document.getElementById("elem_stop");
stopButton.click();
return reset_conversation(a,b);
}
"""
js_code_clear = """
(a,b)=>{
return ["", ""];

View File

@@ -84,9 +84,8 @@ class WelcomeMessage {
this.max_welcome_card_num = 6;
this.card_array = [];
this.static_welcome_message_previous = [];
this.reflesh_time_interval = 15 * 1000;
this.update_time_interval = 2 * 1000;
this.major_title = "欢迎使用GPT-Academic";
this.reflesh_time_interval = 15*1000;
const reflesh_render_status = () => {
for (let index = 0; index < this.card_array.length; index++) {
@@ -97,31 +96,16 @@ class WelcomeMessage {
};
const pageFocusHandler = new PageFocusHandler();
pageFocusHandler.addFocusCallback(reflesh_render_status);
// call update when page size change, call this.update when page size change
window.addEventListener('resize', this.update.bind(this));
// add a loop to reflesh cards
this.startRefleshCards();
this.startAutoUpdate();
}
begin_render() {
this.update();
}
async startAutoUpdate() {
// sleep certain time
await new Promise(r => setTimeout(r, this.update_time_interval));
this.update();
}
async startRefleshCards() {
// sleep certain time
await new Promise(r => setTimeout(r, this.reflesh_time_interval));
// checkout visible status
if (this.visible) {
// if visible, then reflesh cards
await this.reflesh_cards();
await this.reflesh_cards();
if (this.visible){
setTimeout(() => {
this.startRefleshCards.call(this);
}, 1);
@@ -129,7 +113,7 @@ class WelcomeMessage {
}
async reflesh_cards() {
if (!this.visible) {
if (!this.visible){
return;
}
@@ -142,7 +126,6 @@ class WelcomeMessage {
// combine two lists
this.static_welcome_message_previous = not_shown_previously.concat(already_shown_previously);
this.static_welcome_message_previous = this.static_welcome_message_previous.slice(0, this.max_welcome_card_num);
(async () => {
// 使用 for...of 循环来处理异步操作
@@ -159,10 +142,8 @@ class WelcomeMessage {
continue;
}
card.classList.add('hide');
const timeout = 100; // 与CSS中transition的时间保持一致(0.1s)
setTimeout(() => {
// 等待动画结束
card.addEventListener('transitionend', () => {
// 更新卡片信息
const message = this.static_welcome_message_previous[index];
const title = card.getElementsByClassName('welcome-card-title')[0];
@@ -174,14 +155,16 @@ class WelcomeMessage {
text.href = message.url;
content.textContent = message.content;
card.classList.remove('hide');
// 等待动画结束
card.classList.add('show');
const timeout = 100; // 与CSS中transition的时间保持一致(0.1s)
setTimeout(() => {
card.classList.remove('show');
}, timeout);
}, timeout);
// 等待动画结束
card.addEventListener('transitionend', () => {
card.classList.remove('show');
}, { once: true });
card.classList.add('show');
}, { once: true });
card.classList.add('hide');
// 等待 250 毫秒
await new Promise(r => setTimeout(r, 200));
@@ -190,55 +173,43 @@ class WelcomeMessage {
}
shuffle(array) {
var currentIndex = array.length, randomIndex;
var currentIndex = array.length, randomIndex;
// While there remain elements to shuffle...
while (currentIndex != 0) {
// Pick a remaining element...
randomIndex = Math.floor(Math.random() * currentIndex);
currentIndex--;
// Pick a remaining element...
randomIndex = Math.floor(Math.random() * currentIndex);
currentIndex--;
// And swap it with the current element.
[array[currentIndex], array[randomIndex]] = [
array[randomIndex], array[currentIndex]];
// And swap it with the current element.
[array[currentIndex], array[randomIndex]] = [
array[randomIndex], array[currentIndex]];
}
return array;
}
async can_display() {
// update the card visibility
const elem_chatbot = document.getElementById('gpt-chatbot');
const chatbot_top = elem_chatbot.getBoundingClientRect().top;
const welcome_card_container = document.getElementsByClassName('welcome-card-container')[0];
// detect if welcome card overflow
let welcome_card_overflow = false;
if (welcome_card_container) {
const welcome_card_top = welcome_card_container.getBoundingClientRect().top;
if (welcome_card_top < chatbot_top) {
welcome_card_overflow = true;
}
}
async update() {
// console.log('update')
var page_width = document.documentElement.clientWidth;
const width_to_hide_welcome = 1200;
if (!await this.isChatbotEmpty() || page_width < width_to_hide_welcome || welcome_card_overflow) {
// cannot display
return false;
}
return true;
}
async update() {
const can_display = await this.can_display();
if (can_display && !this.visible) {
this.showWelcome();
if (!await this.isChatbotEmpty() || page_width < width_to_hide_welcome) {
if (this.visible) {
this.removeWelcome();
this.visible = false;
this.card_array = [];
this.static_welcome_message_previous = [];
}
return;
}
if (!can_display && this.visible) {
this.removeWelcome();
if (this.visible){
return;
}
// console.log("welcome");
this.showWelcome();
this.visible = true;
this.startRefleshCards();
}
showCard(message) {
@@ -249,28 +220,28 @@ class WelcomeMessage {
const title = document.createElement('div');
title.classList.add('welcome-card-title');
// 创建图标
const svg = document.createElement('img');
svg.classList.add('welcome-svg');
svg.src = message.svg;
svg.style.height = '30px';
title.appendChild(svg);
// 创建图标
const svg = document.createElement('img');
svg.classList.add('welcome-svg');
svg.src = message.svg;
svg.style.height = '30px';
title.appendChild(svg);
// 创建标题
const text = document.createElement('a');
text.textContent = message.title;
text.classList.add('welcome-title-text');
text.href = message.url;
text.target = "_blank";
title.appendChild(text)
// 创建标题
const text = document.createElement('a');
text.textContent = message.title;
text.classList.add('welcome-title-text');
text.href = message.url;
text.target = "_blank";
title.appendChild(text)
// 创建内容
const content = document.createElement('div');
content.classList.add('welcome-content');
const content_c = document.createElement('div');
content_c.classList.add('welcome-content-c');
content_c.textContent = message.content;
content.appendChild(content_c);
const content_c = document.createElement('div');
content_c.classList.add('welcome-content-c');
content_c.textContent = message.content;
content.appendChild(content_c);
// 将标题和内容添加到卡片 div 中
card.appendChild(title);
@@ -279,7 +250,7 @@ class WelcomeMessage {
}
async showWelcome() {
this.visible = true;
// 首先,找到想要添加子元素的父元素
const elem_chatbot = document.getElementById('gpt-chatbot');
@@ -290,7 +261,7 @@ class WelcomeMessage {
// 创建主标题
const major_title = document.createElement('div');
major_title.classList.add('welcome-title');
major_title.textContent = this.major_title;
major_title.textContent = "欢迎使用GPT-Academic";
welcome_card_container.appendChild(major_title)
// 创建卡片
@@ -305,16 +276,6 @@ class WelcomeMessage {
});
elem_chatbot.appendChild(welcome_card_container);
const can_display = await this.can_display();
if (!can_display) {
// undo
this.visible = false;
this.card_array = [];
this.static_welcome_message_previous = [];
elem_chatbot.removeChild(welcome_card_container);
await new Promise(r => setTimeout(r, this.update_time_interval / 2));
return;
}
// 添加显示动画
requestAnimationFrame(() => {
@@ -323,24 +284,15 @@ class WelcomeMessage {
}
async removeWelcome() {
this.visible = false;
// remove welcome-card-container
const elem_chatbot = document.getElementById('gpt-chatbot');
const welcome_card_container = document.getElementsByClassName('welcome-card-container')[0];
// begin hide animation
// 添加隐藏动画
welcome_card_container.classList.add('hide');
// 等待动画结束后再移除元素
welcome_card_container.addEventListener('transitionend', () => {
elem_chatbot.removeChild(welcome_card_container);
this.card_array = [];
this.static_welcome_message_previous = [];
}, { once: true });
// add a fail safe timeout
const timeout = 600; // 与 CSS 中 transition 的时间保持一致(1s)
setTimeout(() => {
if (welcome_card_container.parentNode) {
elem_chatbot.removeChild(welcome_card_container);
}
}, timeout);
}
async isChatbotEmpty() {
@@ -355,28 +307,28 @@ class WelcomeMessage {
class PageFocusHandler {
constructor() {
this.hasReturned = false;
this.focusCallbacks = [];
this.hasReturned = false;
this.focusCallbacks = [];
// Bind the focus and blur event handlers
window.addEventListener('visibilitychange', this.handleFocus.bind(this));
// Bind the focus and blur event handlers
window.addEventListener('visibilitychange', this.handleFocus.bind(this));
}
// Method to handle the focus event
handleFocus() {
if (this.hasReturned) {
this.focusCallbacks.forEach(callback => callback());
}
this.hasReturned = true;
if (this.hasReturned) {
this.focusCallbacks.forEach(callback => callback());
}
this.hasReturned = true;
}
// Method to add a custom callback function
addFocusCallback(callback) {
if (typeof callback === 'function') {
this.focusCallbacks.push(callback);
} else {
throw new Error('Callback must be a function');
}
if (typeof callback === 'function') {
this.focusCallbacks.push(callback);
} else {
throw new Error('Callback must be a function');
}
}
}

View File

@@ -8,7 +8,6 @@ import base64
import gradio
import shutil
import glob
import json
import uuid
from loguru import logger
from functools import wraps
@@ -93,9 +92,8 @@ def ArgsGeneralWrapper(f):
"""
def decorated(request: gradio.Request, cookies:dict, max_length:int, llm_model:str,
txt:str, txt2:str, top_p:float, temperature:float, chatbot:list,
json_history:str, system_prompt:str, plugin_advanced_arg:dict, *args):
history:list, system_prompt:str, plugin_advanced_arg:dict, *args):
txt_passon = txt
history = json.loads(json_history) if json_history else []
if txt == "" and txt2 != "": txt_passon = txt2
# 引入一个有cookie的chatbot
if request.username is not None:
@@ -150,11 +148,10 @@ def ArgsGeneralWrapper(f):
return decorated
def update_ui(chatbot:ChatBotWithCookies, history:list, msg:str="正常", **kwargs): # 刷新界面
def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): # 刷新界面
"""
刷新用户界面
"""
assert isinstance(history, list), "history必须是一个list"
assert isinstance(
chatbot, ChatBotWithCookies
), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
@@ -178,12 +175,10 @@ def update_ui(chatbot:ChatBotWithCookies, history:list, msg:str="正常", **kwar
else:
chatbot_gr = chatbot
history = [str(history_item) for history_item in history] # ensure all items are string
json_history = json.dumps(history, ensure_ascii=False)
yield cookies, chatbot_gr, json_history, msg
yield cookies, chatbot_gr, history, msg
def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay:float=1, msg:str="正常"): # 刷新界面
def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面
"""
刷新用户界面
"""
@@ -499,22 +494,6 @@ def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False, om
return tabs_list
def validate_file_size(files, max_size_mb=500):
"""
验证文件大小是否在允许范围内。
:param files: 文件的完整路径的列表
:param max_size_mb: 最大文件大小单位为MB默认500MB
:return: True 如果文件大小有效,否则抛出异常
"""
# 获取文件大小(字节)
total_size = 0
max_size_bytes = max_size_mb * 1024 * 1024
for file in files:
total_size += os.path.getsize(file.name)
if total_size > max_size_bytes:
raise ValueError(f"File size exceeds the allowed limit of {max_size_mb} MB. "
f"Current size: {total_size / (1024 * 1024):.2f} MB")
return True
def on_file_uploaded(
request: gradio.Request, files:List[str], chatbot:ChatBotWithCookies,
@@ -526,7 +505,6 @@ def on_file_uploaded(
if len(files) == 0:
return chatbot, txt
validate_file_size(files, max_size_mb=500)
# 创建工作路径
user_name = default_user_name if not request.username else request.username
time_tag = gen_time_str()

View File

@@ -1,5 +1,5 @@
{
"version": 3.93,
"version": 3.90,
"show_feature": true,
"new_feature": "支持deepseek-reason(r1) <-> 字体和字体大小自定义 <-> 优化前端并修复TTS的BUG <-> 添加时间线回溯功能 <-> 支持chatgpt-4o-latest <-> 增加RAG组件 <-> 升级多合一主提交键"
"new_feature": "增加RAG组件 <-> 升级多合一主提交键"
}