Compare commits
102 Commits
bold_front
...
boyin_essa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c83bf214d0 | ||
|
|
e34c49dce5 | ||
|
|
3890467c84 | ||
|
|
074b3c9828 | ||
|
|
b8e8457a01 | ||
|
|
2c93a24d7e | ||
|
|
e9af6ef3a0 | ||
|
|
5ae8981dbb | ||
|
|
adbed044e4 | ||
|
|
2fe5febaf0 | ||
|
|
f54d8e559a | ||
|
|
e68fc2bc69 | ||
|
|
f695d7f1da | ||
|
|
679352d896 | ||
|
|
12c9ab1e33 | ||
|
|
da4a5efc49 | ||
|
|
9ac450cfb6 | ||
|
|
172f9e220b | ||
|
|
a28b7d8475 | ||
|
|
7d3ed36899 | ||
|
|
a7bc5fa357 | ||
|
|
4f5dd9ebcf | ||
|
|
427feb99d8 | ||
|
|
a01ca93362 | ||
|
|
597c320808 | ||
|
|
18290fd138 | ||
|
|
0d0575a639 | ||
|
|
4e041e1d4e | ||
|
|
7ef39770c7 | ||
|
|
8222f638cf | ||
|
|
ab32c314ab | ||
|
|
dcfed97054 | ||
|
|
dd66ca26f7 | ||
|
|
8b91d2ac0a | ||
|
|
e4e00b713f | ||
|
|
710a65522c | ||
|
|
34784c1d40 | ||
|
|
80b1a6f99b | ||
|
|
08c3c56f53 | ||
|
|
294716c832 | ||
|
|
16f4fd636e | ||
|
|
e07caf7a69 | ||
|
|
a95b3daab9 | ||
|
|
4873e9dfdc | ||
|
|
a119ab36fe | ||
|
|
f9384e4e5f | ||
|
|
6fe5f6ee6e | ||
|
|
068d753426 | ||
|
|
5010537f3c | ||
|
|
f35f6633e0 | ||
|
|
573dc4d184 | ||
|
|
da8b2d69ce | ||
|
|
58e732c26f | ||
|
|
ca238daa8c | ||
|
|
60b3491513 | ||
|
|
c1175bfb7d | ||
|
|
b705afd5ff | ||
|
|
dfcd28abce | ||
|
|
1edaa9e234 | ||
|
|
f0cd617ec2 | ||
|
|
0b08bb2cea | ||
|
|
d1f8607ac8 | ||
|
|
7eb68a2086 | ||
|
|
ee9e99036a | ||
|
|
55e255220b | ||
|
|
019cd26ae8 | ||
|
|
a5b21d5cc0 | ||
|
|
ce940ff70f | ||
|
|
fc6a83c29f | ||
|
|
1d3212e367 | ||
|
|
8a835352a3 | ||
|
|
5456c9fa43 | ||
|
|
ea67054c30 | ||
|
|
1084108df6 | ||
|
|
40c9700a8d | ||
|
|
6da5623813 | ||
|
|
778c9cd9ec | ||
|
|
e290317146 | ||
|
|
85b92b7f07 | ||
|
|
ff899777ce | ||
|
|
c1b8c773c3 | ||
|
|
8747c48175 | ||
|
|
c0010c88bc | ||
|
|
68838da8ad | ||
|
|
ca7de8fcdd | ||
|
|
7ebc2d00e7 | ||
|
|
47fb81cfde | ||
|
|
83961c1002 | ||
|
|
a8621333af | ||
|
|
f402ef8134 | ||
|
|
65d0f486f1 | ||
|
|
41f25a6a9b | ||
|
|
4a6a032334 | ||
|
|
114192e025 | ||
|
|
9d11b17f25 | ||
|
|
1d9e9fa6a1 | ||
|
|
6babcb4a9c | ||
|
|
b7b4e201cb | ||
|
|
26e7677dc3 | ||
|
|
5e64a50898 | ||
|
|
60a42fb070 | ||
|
|
c94d5054a2 |
44
.github/workflows/build-with-jittorllms.yml
vendored
44
.github/workflows/build-with-jittorllms.yml
vendored
@@ -1,44 +0,0 @@
|
|||||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
|
||||||
name: build-with-jittorllms
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- 'master'
|
|
||||||
|
|
||||||
env:
|
|
||||||
REGISTRY: ghcr.io
|
|
||||||
IMAGE_NAME: ${{ github.repository }}_jittorllms
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build-and-push-image:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
packages: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ${{ env.REGISTRY }}
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v4
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
|
|
||||||
- name: Build and push Docker image
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
push: true
|
|
||||||
file: docs/GithubAction+JittorLLMs
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
|
||||||
name: build-with-all-capacity-beta
|
name: build-with-latex-arm
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'master'
|
- "master"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
REGISTRY: ghcr.io
|
REGISTRY: ghcr.io
|
||||||
IMAGE_NAME: ${{ github.repository }}_with_all_capacity_beta
|
IMAGE_NAME: ${{ github.repository }}_with_latex_arm
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push-image:
|
build-and-push-image:
|
||||||
@@ -18,11 +18,17 @@ jobs:
|
|||||||
packages: write
|
packages: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Log in to the Container registry
|
- name: Log in to the Container registry
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ${{ env.REGISTRY }}
|
registry: ${{ env.REGISTRY }}
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -35,10 +41,11 @@ jobs:
|
|||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
file: docs/GithubAction+AllCapacityBeta
|
platforms: linux/arm64
|
||||||
|
file: docs/GithubAction+NoLocal+Latex+Arm
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -131,6 +131,9 @@ dmypy.json
|
|||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
# macOS files
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
.vscode
|
.vscode
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
@@ -153,6 +156,8 @@ media
|
|||||||
flagged
|
flagged
|
||||||
request_llms/ChatGLM-6b-onnx-u8s8
|
request_llms/ChatGLM-6b-onnx-u8s8
|
||||||
.pre-commit-config.yaml
|
.pre-commit-config.yaml
|
||||||
test.html
|
test.*
|
||||||
|
temp.*
|
||||||
objdump*
|
objdump*
|
||||||
*.min.*.js
|
*.min.*.js
|
||||||
|
TODO
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> 2024.6.1: 版本3.80加入插件二级菜单功能(详见wiki)
|
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/file/d/19U_hsLoMrjOlQSzYS3pzWX9fTzyusArP/view?usp=sharing)的文件服务器
|
||||||
|
> 2024.10.8: 版本3.90加入对llama-index的初步支持,版本3.80加入插件二级菜单功能(详见wiki)
|
||||||
> 2024.5.1: 加入Doc2x翻译PDF论文的功能,[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
|
> 2024.5.1: 加入Doc2x翻译PDF论文的功能,[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
|
||||||
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型! SoVits语音克隆模块,[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
|
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型! SoVits语音克隆模块,[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
|
||||||
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
|
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from loguru import logger
|
||||||
|
|
||||||
def check_proxy(proxies, return_ip=False):
|
def check_proxy(proxies, return_ip=False):
|
||||||
import requests
|
import requests
|
||||||
@@ -19,14 +20,14 @@ def check_proxy(proxies, return_ip=False):
|
|||||||
else:
|
else:
|
||||||
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
|
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
|
||||||
if not return_ip:
|
if not return_ip:
|
||||||
print(result)
|
logger.warning(result)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return ip
|
return ip
|
||||||
except:
|
except:
|
||||||
result = f"代理配置 {proxies_https}, 代理所在地查询超时,代理可能无效"
|
result = f"代理配置 {proxies_https}, 代理所在地查询超时,代理可能无效"
|
||||||
if not return_ip:
|
if not return_ip:
|
||||||
print(result)
|
logger.warning(result)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return ip
|
return ip
|
||||||
@@ -82,25 +83,25 @@ def patch_and_restart(path):
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import glob
|
import glob
|
||||||
from shared_utils.colorful import print亮黄, print亮绿, print亮红
|
from shared_utils.colorful import log亮黄, log亮绿, log亮红
|
||||||
# if not using config_private, move origin config.py as config_private.py
|
# if not using config_private, move origin config.py as config_private.py
|
||||||
if not os.path.exists('config_private.py'):
|
if not os.path.exists('config_private.py'):
|
||||||
print亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
|
||||||
'另外您可以随时在history子文件夹下找回旧版的程序。')
|
'另外您可以随时在history子文件夹下找回旧版的程序。')
|
||||||
shutil.copyfile('config.py', 'config_private.py')
|
shutil.copyfile('config.py', 'config_private.py')
|
||||||
path_new_version = glob.glob(path + '/*-master')[0]
|
path_new_version = glob.glob(path + '/*-master')[0]
|
||||||
dir_util.copy_tree(path_new_version, './')
|
dir_util.copy_tree(path_new_version, './')
|
||||||
print亮绿('代码已经更新,即将更新pip包依赖……')
|
log亮绿('代码已经更新,即将更新pip包依赖……')
|
||||||
for i in reversed(range(5)): time.sleep(1); print(i)
|
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
|
||||||
try:
|
try:
|
||||||
import subprocess
|
import subprocess
|
||||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
|
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
|
||||||
except:
|
except:
|
||||||
print亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||||
print亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
|
||||||
print亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
|
||||||
print(' ------------------------------ -----------------------------------')
|
log亮绿(' ------------------------------ -----------------------------------')
|
||||||
for i in reversed(range(8)): time.sleep(1); print(i)
|
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
|
||||||
os.execl(sys.executable, sys.executable, *sys.argv)
|
os.execl(sys.executable, sys.executable, *sys.argv)
|
||||||
|
|
||||||
|
|
||||||
@@ -135,9 +136,9 @@ def auto_update(raise_error=False):
|
|||||||
current_version = f.read()
|
current_version = f.read()
|
||||||
current_version = json.loads(current_version)['version']
|
current_version = json.loads(current_version)['version']
|
||||||
if (remote_version - current_version) >= 0.01-1e-5:
|
if (remote_version - current_version) >= 0.01-1e-5:
|
||||||
from shared_utils.colorful import print亮黄
|
from shared_utils.colorful import log亮黄
|
||||||
print亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
|
log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
|
||||||
print('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
logger.info('(1)Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
|
||||||
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
|
user_instruction = input('(2)是否一键更新代码(Y+回车=确认,输入其他/无输入+回车=不更新)?')
|
||||||
if user_instruction in ['Y', 'y']:
|
if user_instruction in ['Y', 'y']:
|
||||||
path = backup_and_download(current_version, remote_version)
|
path = backup_and_download(current_version, remote_version)
|
||||||
@@ -148,9 +149,9 @@ def auto_update(raise_error=False):
|
|||||||
if raise_error:
|
if raise_error:
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
msg += trimmed_format_exc()
|
msg += trimmed_format_exc()
|
||||||
print(msg)
|
logger.warning(msg)
|
||||||
else:
|
else:
|
||||||
print('自动更新程序:已禁用')
|
logger.info('自动更新程序:已禁用')
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
@@ -159,10 +160,10 @@ def auto_update(raise_error=False):
|
|||||||
if raise_error:
|
if raise_error:
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
msg += trimmed_format_exc()
|
msg += trimmed_format_exc()
|
||||||
print(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
def warm_up_modules():
|
def warm_up_modules():
|
||||||
print('正在执行一些模块的预热 ...')
|
logger.info('正在执行一些模块的预热 ...')
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
with ProxyNetworkActivate("Warmup_Modules"):
|
with ProxyNetworkActivate("Warmup_Modules"):
|
||||||
@@ -172,7 +173,7 @@ def warm_up_modules():
|
|||||||
enc.encode("模块预热", disallowed_special=())
|
enc.encode("模块预热", disallowed_special=())
|
||||||
|
|
||||||
def warm_up_vectordb():
|
def warm_up_vectordb():
|
||||||
print('正在执行一些模块的预热 ...')
|
logger.info('正在执行一些模块的预热 ...')
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
with ProxyNetworkActivate("Warmup_Modules"):
|
with ProxyNetworkActivate("Warmup_Modules"):
|
||||||
import nltk
|
import nltk
|
||||||
|
|||||||
16
config.py
16
config.py
@@ -33,11 +33,14 @@ else:
|
|||||||
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||||
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
|
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
|
||||||
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
|
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
|
||||||
"gpt-4o", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
||||||
"gemini-pro", "chatglm3"
|
"gemini-1.5-pro", "chatglm3"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|
||||||
# --- --- --- ---
|
# --- --- --- ---
|
||||||
# P.S. 其他可用的模型还包括
|
# P.S. 其他可用的模型还包括
|
||||||
# AVAIL_LLM_MODELS = [
|
# AVAIL_LLM_MODELS = [
|
||||||
@@ -50,12 +53,13 @@ AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-p
|
|||||||
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
|
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
|
||||||
# "moss", "llama2", "chatglm_onnx", "internlm", "jittorllms_pangualpha", "jittorllms_llama",
|
# "moss", "llama2", "chatglm_onnx", "internlm", "jittorllms_pangualpha", "jittorllms_llama",
|
||||||
# "deepseek-chat" ,"deepseek-coder",
|
# "deepseek-chat" ,"deepseek-coder",
|
||||||
|
# "gemini-1.5-flash",
|
||||||
# "yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview",
|
# "yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview",
|
||||||
# ]
|
# ]
|
||||||
# --- --- --- ---
|
# --- --- --- ---
|
||||||
# 此外,您还可以在接入one-api/vllm/ollama时,
|
# 此外,您还可以在接入one-api/vllm/ollama/Openroute时,
|
||||||
# 使用"one-api-*","vllm-*","ollama-*"前缀直接使用非标准方式接入的模型,例如
|
# 使用"one-api-*","vllm-*","ollama-*","openrouter-*"前缀直接使用非标准方式接入的模型,例如
|
||||||
# AVAIL_LLM_MODELS = ["one-api-claude-3-sonnet-20240229(max_token=100000)", "ollama-phi3(max_token=4096)"]
|
# AVAIL_LLM_MODELS = ["one-api-claude-3-sonnet-20240229(max_token=100000)", "ollama-phi3(max_token=4096)","openrouter-openai/gpt-4o-mini","openrouter-openai/chatgpt-4o-latest"]
|
||||||
# --- --- --- ---
|
# --- --- --- ---
|
||||||
|
|
||||||
|
|
||||||
@@ -295,7 +299,7 @@ ARXIV_CACHE_DIR = "gpt_log/arxiv_cache"
|
|||||||
|
|
||||||
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
|
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
|
||||||
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
|
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
|
||||||
"Warmup_Modules", "Nougat_Download", "AutoGen"]
|
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]
|
||||||
|
|
||||||
|
|
||||||
# 启用插件热加载
|
# 启用插件热加载
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def get_core_functions():
|
|||||||
text_show_english=
|
text_show_english=
|
||||||
r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, "
|
r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, "
|
||||||
r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. "
|
r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. "
|
||||||
r"Firstly, you should provide the polished paragraph. "
|
r"Firstly, you should provide the polished paragraph (in English). "
|
||||||
r"Secondly, you should list all your modification and explain the reasons to do so in markdown table.",
|
r"Secondly, you should list all your modification and explain the reasons to do so in markdown table.",
|
||||||
text_show_chinese=
|
text_show_chinese=
|
||||||
r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,"
|
r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,"
|
||||||
|
|||||||
@@ -1,25 +1,26 @@
|
|||||||
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
|
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
def get_crazy_functions():
|
def get_crazy_functions():
|
||||||
from crazy_functions.读文章写摘要 import 读文章写摘要
|
from crazy_functions.读文章写摘要 import 读文章写摘要
|
||||||
from crazy_functions.生成函数注释 import 批量生成函数注释
|
from crazy_functions.生成函数注释 import 批量生成函数注释
|
||||||
from crazy_functions.解析项目源代码 import 解析项目本身
|
from crazy_functions.SourceCode_Analyse import 解析项目本身
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Python项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Python项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Matlab项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Matlab项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个C项目的头文件
|
from crazy_functions.SourceCode_Analyse import 解析一个C项目的头文件
|
||||||
from crazy_functions.解析项目源代码 import 解析一个C项目
|
from crazy_functions.SourceCode_Analyse import 解析一个C项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Golang项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Golang项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Rust项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Java项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个前端项目
|
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
|
||||||
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
|
||||||
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
from crazy_functions.高级功能函数模板 import Demo_Wrap
|
||||||
from crazy_functions.Latex全文润色 import Latex英文润色
|
from crazy_functions.Latex全文润色 import Latex英文润色
|
||||||
from crazy_functions.询问多个大语言模型 import 同时问询
|
from crazy_functions.询问多个大语言模型 import 同时问询
|
||||||
from crazy_functions.解析项目源代码 import 解析一个Lua项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
||||||
from crazy_functions.解析项目源代码 import 解析一个CSharp项目
|
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
||||||
from crazy_functions.总结word文档 import 总结word文档
|
from crazy_functions.总结word文档 import 总结word文档
|
||||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||||
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
||||||
@@ -45,6 +46,9 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.Latex_Function_Wrap import PDF_Localize
|
from crazy_functions.Latex_Function_Wrap import PDF_Localize
|
||||||
from crazy_functions.Internet_GPT import 连接网络回答问题
|
from crazy_functions.Internet_GPT import 连接网络回答问题
|
||||||
from crazy_functions.Internet_GPT_Wrap import NetworkGPT_Wrap
|
from crazy_functions.Internet_GPT_Wrap import NetworkGPT_Wrap
|
||||||
|
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
||||||
|
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
|
||||||
|
from crazy_functions.SourceCode_Comment import 注释Python项目
|
||||||
|
|
||||||
function_plugins = {
|
function_plugins = {
|
||||||
"虚空终端": {
|
"虚空终端": {
|
||||||
@@ -61,6 +65,13 @@ def get_crazy_functions():
|
|||||||
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
|
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
|
||||||
"Function": HotReload(解析一个Python项目),
|
"Function": HotReload(解析一个Python项目),
|
||||||
},
|
},
|
||||||
|
"注释Python项目": {
|
||||||
|
"Group": "编程",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
|
||||||
|
"Function": HotReload(注释Python项目),
|
||||||
|
},
|
||||||
"载入对话历史存档(先上传存档或输入路径)": {
|
"载入对话历史存档(先上传存档或输入路径)": {
|
||||||
"Group": "对话",
|
"Group": "对话",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
@@ -324,7 +335,7 @@ def get_crazy_functions():
|
|||||||
"ArgsReminder": "如果有必要, 请在此处追加更细致的矫错指令(使用英文)。",
|
"ArgsReminder": "如果有必要, 请在此处追加更细致的矫错指令(使用英文)。",
|
||||||
"Function": HotReload(Latex英文纠错加PDF对比),
|
"Function": HotReload(Latex英文纠错加PDF对比),
|
||||||
},
|
},
|
||||||
"Arxiv论文精细翻译(输入arxivID)[需Latex]": {
|
"📚Arxiv论文精细翻译(输入arxivID)[需Latex]": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
@@ -336,7 +347,7 @@ def get_crazy_functions():
|
|||||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||||
},
|
},
|
||||||
"本地Latex论文精细翻译(上传Latex项目)[需Latex]": {
|
"📚本地Latex论文精细翻译(上传Latex项目)[需Latex]": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
@@ -361,6 +372,39 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function_plugins.update(
|
||||||
|
{
|
||||||
|
"🎨图片生成(DALLE2/DALLE3, 使用前切换到GPT系列模型)": {
|
||||||
|
"Group": "对话",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"Info": "使用 DALLE2/DALLE3 生成图片 | 输入参数字符串,提供图像的内容",
|
||||||
|
"Function": HotReload(图片生成_DALLE2), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||||
|
"Class": ImageGen_Wrap # 新一代插件需要注册Class
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
function_plugins.update(
|
||||||
|
{
|
||||||
|
"🎨图片修改_DALLE2 (使用前请切换模型到GPT系列)": {
|
||||||
|
"Group": "对话",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"AdvancedArgs": False, # 调用时,唤起高级参数输入区(默认False)
|
||||||
|
# "Info": "使用DALLE2修改图片 | 输入参数字符串,提供图像的内容",
|
||||||
|
"Function": HotReload(图片修改_DALLE2),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=-
|
# -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=-
|
||||||
try:
|
try:
|
||||||
@@ -378,8 +422,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# from crazy_functions.联网的ChatGPT import 连接网络回答问题
|
# from crazy_functions.联网的ChatGPT import 连接网络回答问题
|
||||||
@@ -409,11 +453,11 @@ def get_crazy_functions():
|
|||||||
# }
|
# }
|
||||||
# )
|
# )
|
||||||
# except:
|
# except:
|
||||||
# print(trimmed_format_exc())
|
# logger.error(trimmed_format_exc())
|
||||||
# print("Load function plugin failed")
|
# logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.解析项目源代码 import 解析任意code项目
|
from crazy_functions.SourceCode_Analyse import 解析任意code项目
|
||||||
|
|
||||||
function_plugins.update(
|
function_plugins.update(
|
||||||
{
|
{
|
||||||
@@ -428,8 +472,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.询问多个大语言模型 import 同时问询_指定模型
|
from crazy_functions.询问多个大语言模型 import 同时问询_指定模型
|
||||||
@@ -439,7 +483,7 @@ def get_crazy_functions():
|
|||||||
"询问多个GPT模型(手动指定询问哪些模型)": {
|
"询问多个GPT模型(手动指定询问哪些模型)": {
|
||||||
"Group": "对话",
|
"Group": "对话",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": True,
|
"AsButton": False,
|
||||||
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
||||||
"ArgsReminder": "支持任意数量的llm接口,用&符号分隔。例如chatglm&gpt-3.5-turbo&gpt-4", # 高级参数输入区的显示提示
|
"ArgsReminder": "支持任意数量的llm接口,用&符号分隔。例如chatglm&gpt-3.5-turbo&gpt-4", # 高级参数输入区的显示提示
|
||||||
"Function": HotReload(同时问询_指定模型),
|
"Function": HotReload(同时问询_指定模型),
|
||||||
@@ -447,53 +491,10 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
|
||||||
from crazy_functions.图片生成 import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
|
||||||
|
|
||||||
function_plugins.update(
|
|
||||||
{
|
|
||||||
"图片生成_DALLE2 (先切换模型到gpt-*)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
|
||||||
"ArgsReminder": "在这里输入分辨率, 如1024x1024(默认),支持 256x256, 512x512, 1024x1024", # 高级参数输入区的显示提示
|
|
||||||
"Info": "使用DALLE2生成图片 | 输入参数字符串,提供图像的内容",
|
|
||||||
"Function": HotReload(图片生成_DALLE2),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
function_plugins.update(
|
|
||||||
{
|
|
||||||
"图片生成_DALLE3 (先切换模型到gpt-*)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
|
|
||||||
"ArgsReminder": "在这里输入自定义参数「分辨率-质量(可选)-风格(可选)」, 参数示例「1024x1024-hd-vivid」 || 分辨率支持 「1024x1024」(默认) /「1792x1024」/「1024x1792」 || 质量支持 「-standard」(默认) /「-hd」 || 风格支持 「-vivid」(默认) /「-natural」", # 高级参数输入区的显示提示
|
|
||||||
"Info": "使用DALLE3生成图片 | 输入参数字符串,提供图像的内容",
|
|
||||||
"Function": HotReload(图片生成_DALLE3),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
function_plugins.update(
|
|
||||||
{
|
|
||||||
"图片修改_DALLE2 (先切换模型到gpt-*)": {
|
|
||||||
"Group": "对话",
|
|
||||||
"Color": "stop",
|
|
||||||
"AsButton": False,
|
|
||||||
"AdvancedArgs": False, # 调用时,唤起高级参数输入区(默认False)
|
|
||||||
# "Info": "使用DALLE2修改图片 | 输入参数字符串,提供图像的内容",
|
|
||||||
"Function": HotReload(图片修改_DALLE2),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print(trimmed_format_exc())
|
|
||||||
print("Load function plugin failed")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.总结音视频 import 总结音视频
|
from crazy_functions.总结音视频 import 总结音视频
|
||||||
@@ -512,8 +513,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.数学动画生成manim import 动画生成
|
from crazy_functions.数学动画生成manim import 动画生成
|
||||||
@@ -530,8 +531,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.Markdown_Translate import Markdown翻译指定语言
|
from crazy_functions.Markdown_Translate import Markdown翻译指定语言
|
||||||
@@ -549,8 +550,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.知识库问答 import 知识库文件注入
|
from crazy_functions.知识库问答 import 知识库文件注入
|
||||||
@@ -568,8 +569,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.知识库问答 import 读取知识库作答
|
from crazy_functions.知识库问答 import 读取知识库作答
|
||||||
@@ -587,8 +588,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.交互功能函数模板 import 交互功能模板函数
|
from crazy_functions.交互功能函数模板 import 交互功能模板函数
|
||||||
@@ -604,8 +605,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -627,8 +628,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.批量翻译PDF文档_NOUGAT import 批量翻译PDF文档
|
from crazy_functions.批量翻译PDF文档_NOUGAT import 批量翻译PDF文档
|
||||||
@@ -644,8 +645,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.函数动态生成 import 函数动态生成
|
from crazy_functions.函数动态生成 import 函数动态生成
|
||||||
@@ -661,8 +662,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.多智能体 import 多智能体终端
|
from crazy_functions.多智能体 import 多智能体终端
|
||||||
@@ -678,8 +679,8 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crazy_functions.互动小游戏 import 随机小游戏
|
from crazy_functions.互动小游戏 import 随机小游戏
|
||||||
@@ -695,8 +696,33 @@ def get_crazy_functions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
print("Load function plugin failed")
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from crazy_functions.Rag_Interface import Rag问答
|
||||||
|
|
||||||
|
function_plugins.update(
|
||||||
|
{
|
||||||
|
"Rag智能召回": {
|
||||||
|
"Group": "对话",
|
||||||
|
"Color": "stop",
|
||||||
|
"AsButton": False,
|
||||||
|
"Info": "将问答数据记录到向量库中,作为长期参考。",
|
||||||
|
"Function": HotReload(Rag问答),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.error(trimmed_format_exc())
|
||||||
|
logger.error("Load function plugin failed")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# from crazy_functions.高级功能函数模板 import 测试图表渲染
|
# from crazy_functions.高级功能函数模板 import 测试图表渲染
|
||||||
@@ -709,7 +735,7 @@ def get_crazy_functions():
|
|||||||
# }
|
# }
|
||||||
# })
|
# })
|
||||||
# except:
|
# except:
|
||||||
# print(trimmed_format_exc())
|
# logger.error(trimmed_format_exc())
|
||||||
# print('Load function plugin failed')
|
# print('Load function plugin failed')
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ def 载入对话历史存档(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
|||||||
system_prompt 给gpt的静默提醒
|
system_prompt 给gpt的静默提醒
|
||||||
user_request 当前用户的请求信息(IP地址等)
|
user_request 当前用户的请求信息(IP地址等)
|
||||||
"""
|
"""
|
||||||
from .crazy_utils import get_files_from_everything
|
from crazy_functions.crazy_utils import get_files_from_everything
|
||||||
success, file_manifest, _ = get_files_from_everything(txt, type='.html')
|
success, file_manifest, _ = get_files_from_everything(txt, type='.html')
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def gen_image(llm_kwargs, prompt, resolution="1024x1024", model="dall-e-2", qual
|
|||||||
if style is not None:
|
if style is not None:
|
||||||
data['style'] = style
|
data['style'] = style
|
||||||
response = requests.post(url, headers=headers, json=data, proxies=proxies)
|
response = requests.post(url, headers=headers, json=data, proxies=proxies)
|
||||||
print(response.content)
|
# logger.info(response.content)
|
||||||
try:
|
try:
|
||||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||||
except:
|
except:
|
||||||
@@ -76,7 +76,7 @@ def edit_image(llm_kwargs, prompt, image_path, resolution="1024x1024", model="da
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, files=files, proxies=proxies)
|
response = requests.post(url, headers=headers, files=files, proxies=proxies)
|
||||||
print(response.content)
|
# logger.info(response.content)
|
||||||
try:
|
try:
|
||||||
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
image_url = json.loads(response.content.decode('utf8'))['data'][0]['url']
|
||||||
except:
|
except:
|
||||||
@@ -108,7 +108,7 @@ def 图片生成_DALLE2(prompt, llm_kwargs, plugin_kwargs, chatbot, history, sys
|
|||||||
chatbot.append((prompt, "[Local Message] 图像生成提示为空白,请在“输入区”输入图像生成提示。"))
|
chatbot.append((prompt, "[Local Message] 图像生成提示为空白,请在“输入区”输入图像生成提示。"))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 界面更新
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 界面更新
|
||||||
return
|
return
|
||||||
chatbot.append(("您正在调用“图像生成”插件。", "[Local Message] 生成图像, 请先把模型切换至gpt-*。如果中文Prompt效果不理想, 请尝试英文Prompt。正在处理中 ....."))
|
chatbot.append(("您正在调用“图像生成”插件。", "[Local Message] 生成图像, 使用前请切换模型到GPT系列。如果中文Prompt效果不理想, 请尝试英文Prompt。正在处理中 ....."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||||
resolution = plugin_kwargs.get("advanced_arg", '1024x1024')
|
resolution = plugin_kwargs.get("advanced_arg", '1024x1024')
|
||||||
@@ -129,7 +129,7 @@ def 图片生成_DALLE3(prompt, llm_kwargs, plugin_kwargs, chatbot, history, sys
|
|||||||
chatbot.append((prompt, "[Local Message] 图像生成提示为空白,请在“输入区”输入图像生成提示。"))
|
chatbot.append((prompt, "[Local Message] 图像生成提示为空白,请在“输入区”输入图像生成提示。"))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 界面更新
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 界面更新
|
||||||
return
|
return
|
||||||
chatbot.append(("您正在调用“图像生成”插件。", "[Local Message] 生成图像, 请先把模型切换至gpt-*。如果中文Prompt效果不理想, 请尝试英文Prompt。正在处理中 ....."))
|
chatbot.append(("您正在调用“图像生成”插件。", "[Local Message] 生成图像, 使用前请切换模型到GPT系列。如果中文Prompt效果不理想, 请尝试英文Prompt。正在处理中 ....."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||||
resolution_arg = plugin_kwargs.get("advanced_arg", '1024x1024-standard-vivid').lower()
|
resolution_arg = plugin_kwargs.get("advanced_arg", '1024x1024-standard-vivid').lower()
|
||||||
@@ -166,7 +166,7 @@ class ImageEditState(GptAcademicState):
|
|||||||
return confirm, file
|
return confirm, file
|
||||||
|
|
||||||
def lock_plugin(self, chatbot):
|
def lock_plugin(self, chatbot):
|
||||||
chatbot._cookies['lock_plugin'] = 'crazy_functions.图片生成->图片修改_DALLE2'
|
chatbot._cookies['lock_plugin'] = 'crazy_functions.Image_Generate->图片修改_DALLE2'
|
||||||
self.dump_state(chatbot)
|
self.dump_state(chatbot)
|
||||||
|
|
||||||
def unlock_plugin(self, chatbot):
|
def unlock_plugin(self, chatbot):
|
||||||
56
crazy_functions/Image_Generate_Wrap.py
Normal file
56
crazy_functions/Image_Generate_Wrap.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
|
||||||
|
from toolbox import get_conf, update_ui
|
||||||
|
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
|
||||||
|
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGen_Wrap(GptAcademicPluginTemplate):
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def define_arg_selection_menu(self):
|
||||||
|
"""
|
||||||
|
定义插件的二级选项菜单
|
||||||
|
|
||||||
|
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
|
||||||
|
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
|
||||||
|
|
||||||
|
"""
|
||||||
|
gui_definition = {
|
||||||
|
"main_input":
|
||||||
|
ArgProperty(title="输入图片描述", description="需要生成图像的文本描述,尽量使用英文", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||||
|
"model_name":
|
||||||
|
ArgProperty(title="模型", options=["DALLE2", "DALLE3"], default_value="DALLE3", description="无", type="dropdown").model_dump_json(),
|
||||||
|
"resolution":
|
||||||
|
ArgProperty(title="分辨率", options=["256x256(限DALLE2)", "512x512(限DALLE2)", "1024x1024", "1792x1024(限DALLE3)", "1024x1792(限DALLE3)"], default_value="1024x1024", description="无", type="dropdown").model_dump_json(),
|
||||||
|
"quality (仅DALLE3生效)":
|
||||||
|
ArgProperty(title="质量", options=["standard", "hd"], default_value="standard", description="无", type="dropdown").model_dump_json(),
|
||||||
|
"style (仅DALLE3生效)":
|
||||||
|
ArgProperty(title="风格", options=["vivid", "natural"], default_value="vivid", description="无", type="dropdown").model_dump_json(),
|
||||||
|
|
||||||
|
}
|
||||||
|
return gui_definition
|
||||||
|
|
||||||
|
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
"""
|
||||||
|
执行插件
|
||||||
|
"""
|
||||||
|
# 分辨率
|
||||||
|
resolution = plugin_kwargs["resolution"].replace("(限DALLE2)", "").replace("(限DALLE3)", "")
|
||||||
|
|
||||||
|
if plugin_kwargs["model_name"] == "DALLE2":
|
||||||
|
plugin_kwargs["advanced_arg"] = resolution
|
||||||
|
yield from 图片生成_DALLE2(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
|
|
||||||
|
elif plugin_kwargs["model_name"] == "DALLE3":
|
||||||
|
quality = plugin_kwargs["quality (仅DALLE3生效)"]
|
||||||
|
style = plugin_kwargs["style (仅DALLE3生效)"]
|
||||||
|
plugin_kwargs["advanced_arg"] = f"{resolution}-{quality}-{style}"
|
||||||
|
yield from 图片生成_DALLE3(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
|
|
||||||
|
else:
|
||||||
|
chatbot.append([None, "抱歉,找不到该模型"])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
@@ -1,12 +1,109 @@
|
|||||||
from toolbox import CatchException, update_ui, get_conf
|
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from request_llms.bridge_all import model_info
|
|
||||||
import urllib.request
|
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from itertools import zip_longest
|
||||||
from check_proxy import check_proxy
|
from check_proxy import check_proxy
|
||||||
|
from toolbox import CatchException, update_ui, get_conf
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
|
from crazy_functions.prompts.internet import SearchOptimizerPrompt, SearchAcademicOptimizerPrompt
|
||||||
|
|
||||||
|
def search_optimizer(
|
||||||
|
query,
|
||||||
|
proxies,
|
||||||
|
history,
|
||||||
|
llm_kwargs,
|
||||||
|
optimizer=1,
|
||||||
|
categories="general",
|
||||||
|
searxng_url=None,
|
||||||
|
engines=None,
|
||||||
|
):
|
||||||
|
# ------------- < 第1步:尝试进行搜索优化 > -------------
|
||||||
|
# * 增强优化,会尝试结合历史记录进行搜索优化
|
||||||
|
if optimizer == 2:
|
||||||
|
his = " "
|
||||||
|
if len(history) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for i, h in enumerate(history):
|
||||||
|
if i % 2 == 0:
|
||||||
|
his += f"Q: {h}\n"
|
||||||
|
else:
|
||||||
|
his += f"A: {h}\n"
|
||||||
|
if categories == "general":
|
||||||
|
sys_prompt = SearchOptimizerPrompt.format(query=query, history=his, num=4)
|
||||||
|
elif categories == "science":
|
||||||
|
sys_prompt = SearchAcademicOptimizerPrompt.format(query=query, history=his, num=4)
|
||||||
|
else:
|
||||||
|
his = " "
|
||||||
|
if categories == "general":
|
||||||
|
sys_prompt = SearchOptimizerPrompt.format(query=query, history=his, num=3)
|
||||||
|
elif categories == "science":
|
||||||
|
sys_prompt = SearchAcademicOptimizerPrompt.format(query=query, history=his, num=3)
|
||||||
|
|
||||||
|
mutable = ["", time.time(), ""]
|
||||||
|
llm_kwargs["temperature"] = 0.8
|
||||||
|
try:
|
||||||
|
querys_json = predict_no_ui_long_connection(
|
||||||
|
inputs=query,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt=sys_prompt,
|
||||||
|
observe_window=mutable,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
querys_json = "1234"
|
||||||
|
#* 尝试解码优化后的搜索结果
|
||||||
|
querys_json = re.sub(r"```json|```", "", querys_json)
|
||||||
|
try:
|
||||||
|
querys = json.loads(querys_json)
|
||||||
|
except Exception:
|
||||||
|
#* 如果解码失败,降低温度再试一次
|
||||||
|
try:
|
||||||
|
llm_kwargs["temperature"] = 0.4
|
||||||
|
querys_json = predict_no_ui_long_connection(
|
||||||
|
inputs=query,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt=sys_prompt,
|
||||||
|
observe_window=mutable,
|
||||||
|
)
|
||||||
|
querys_json = re.sub(r"```json|```", "", querys_json)
|
||||||
|
querys = json.loads(querys_json)
|
||||||
|
except Exception:
|
||||||
|
#* 如果再次失败,直接返回原始问题
|
||||||
|
querys = [query]
|
||||||
|
links = []
|
||||||
|
success = 0
|
||||||
|
Exceptions = ""
|
||||||
|
for q in querys:
|
||||||
|
try:
|
||||||
|
link = searxng_request(q, proxies, categories, searxng_url, engines=engines)
|
||||||
|
if len(link) > 0:
|
||||||
|
links.append(link[:-5])
|
||||||
|
success += 1
|
||||||
|
except Exception:
|
||||||
|
Exceptions = Exception
|
||||||
|
pass
|
||||||
|
if success == 0:
|
||||||
|
raise ValueError(f"在线搜索失败!\n{Exceptions}")
|
||||||
|
# * 清洗搜索结果,依次放入每组第一,第二个搜索结果,并清洗重复的搜索结果
|
||||||
|
seen_links = set()
|
||||||
|
result = []
|
||||||
|
for tuple in zip_longest(*links, fillvalue=None):
|
||||||
|
for item in tuple:
|
||||||
|
if item is not None:
|
||||||
|
link = item["link"]
|
||||||
|
if link not in seen_links:
|
||||||
|
seen_links.add(link)
|
||||||
|
result.append(item)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_auth_ip():
|
def get_auth_ip():
|
||||||
@@ -15,14 +112,15 @@ def get_auth_ip():
|
|||||||
return '114.114.114.' + str(random.randint(1, 10))
|
return '114.114.114.' + str(random.randint(1, 10))
|
||||||
return ip
|
return ip
|
||||||
|
|
||||||
|
|
||||||
def searxng_request(query, proxies, categories='general', searxng_url=None, engines=None):
|
def searxng_request(query, proxies, categories='general', searxng_url=None, engines=None):
|
||||||
if searxng_url is None:
|
if searxng_url is None:
|
||||||
url = get_conf("SEARXNG_URL")
|
url = get_conf("SEARXNG_URL")
|
||||||
else:
|
else:
|
||||||
url = searxng_url
|
url = searxng_url
|
||||||
|
|
||||||
if engines is None:
|
if engines == "Mixed":
|
||||||
engines = 'bing'
|
engines = None
|
||||||
|
|
||||||
if categories == 'general':
|
if categories == 'general':
|
||||||
params = {
|
params = {
|
||||||
@@ -66,6 +164,7 @@ def searxng_request(query, proxies, categories='general', searxng_url=None, engi
|
|||||||
else:
|
else:
|
||||||
raise ValueError("在线搜索失败,状态码: " + str(response.status_code) + '\t' + response.content.decode('utf-8'))
|
raise ValueError("在线搜索失败,状态码: " + str(response.status_code) + '\t' + response.content.decode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def scrape_text(url, proxies) -> str:
|
def scrape_text(url, proxies) -> str:
|
||||||
"""Scrape text from a webpage
|
"""Scrape text from a webpage
|
||||||
|
|
||||||
@@ -93,9 +192,10 @@ def scrape_text(url, proxies) -> str:
|
|||||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
optimizer_history = history[:-8]
|
||||||
history = [] # 清空历史,以免输入溢出
|
history = [] # 清空历史,以免输入溢出
|
||||||
chatbot.append((f"请结合互联网信息回答以下问题:{txt}", "检索中..."))
|
chatbot.append((f"请结合互联网信息回答以下问题:{txt}", "检索中..."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
@@ -106,16 +206,23 @@ def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
|||||||
categories = plugin_kwargs.get('categories', 'general')
|
categories = plugin_kwargs.get('categories', 'general')
|
||||||
searxng_url = plugin_kwargs.get('searxng_url', None)
|
searxng_url = plugin_kwargs.get('searxng_url', None)
|
||||||
engines = plugin_kwargs.get('engine', None)
|
engines = plugin_kwargs.get('engine', None)
|
||||||
urls = searxng_request(txt, proxies, categories, searxng_url, engines=engines)
|
optimizer = plugin_kwargs.get('optimizer', "关闭")
|
||||||
|
if optimizer == "关闭":
|
||||||
|
urls = searxng_request(txt, proxies, categories, searxng_url, engines=engines)
|
||||||
|
else:
|
||||||
|
urls = search_optimizer(txt, proxies, optimizer_history, llm_kwargs, optimizer, categories, searxng_url, engines)
|
||||||
history = []
|
history = []
|
||||||
if len(urls) == 0:
|
if len(urls) == 0:
|
||||||
chatbot.append((f"结论:{txt}",
|
chatbot.append((f"结论:{txt}",
|
||||||
"[Local Message] 受到限制,无法从searxng获取信息!请尝试更换搜索引擎。"))
|
"[Local Message] 受到限制,无法从searxng获取信息!请尝试更换搜索引擎。"))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
return
|
return
|
||||||
|
|
||||||
# ------------- < 第2步:依次访问网页 > -------------
|
# ------------- < 第2步:依次访问网页 > -------------
|
||||||
max_search_result = 5 # 最多收纳多少个网页的结果
|
max_search_result = 5 # 最多收纳多少个网页的结果
|
||||||
chatbot.append([f"联网检索中 ...", None])
|
if optimizer == "开启(增强)":
|
||||||
|
max_search_result = 8
|
||||||
|
chatbot.append(["联网检索中 ...", None])
|
||||||
for index, url in enumerate(urls[:max_search_result]):
|
for index, url in enumerate(urls[:max_search_result]):
|
||||||
res = scrape_text(url['link'], proxies)
|
res = scrape_text(url['link'], proxies)
|
||||||
prefix = f"第{index}份搜索结果 [源自{url['source'][0]}搜索] ({url['title'][:25]}):"
|
prefix = f"第{index}份搜索结果 [源自{url['source'][0]}搜索] ({url['title'][:25]}):"
|
||||||
@@ -125,18 +232,47 @@ def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, s
|
|||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
# ------------- < 第3步:ChatGPT综合 > -------------
|
# ------------- < 第3步:ChatGPT综合 > -------------
|
||||||
i_say = f"从以上搜索结果中抽取信息,然后回答问题:{txt}"
|
if (optimizer != "开启(增强)"):
|
||||||
i_say, history = input_clipping( # 裁剪输入,从最长的条目开始裁剪,防止爆token
|
i_say = f"从以上搜索结果中抽取信息,然后回答问题:{txt}"
|
||||||
inputs=i_say,
|
i_say, history = input_clipping( # 裁剪输入,从最长的条目开始裁剪,防止爆token
|
||||||
history=history,
|
inputs=i_say,
|
||||||
max_token_limit=min(model_info[llm_kwargs['llm_model']]['max_token']*3//4, 8192)
|
history=history,
|
||||||
)
|
max_token_limit=min(model_info[llm_kwargs['llm_model']]['max_token']*3//4, 8192)
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
)
|
||||||
inputs=i_say, inputs_show_user=i_say,
|
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
inputs=i_say, inputs_show_user=i_say,
|
||||||
sys_prompt="请从给定的若干条搜索结果中抽取信息,对最相关的两个搜索结果进行总结,然后回答问题。"
|
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||||
)
|
sys_prompt="请从给定的若干条搜索结果中抽取信息,对最相关的两个搜索结果进行总结,然后回答问题。"
|
||||||
chatbot[-1] = (i_say, gpt_say)
|
)
|
||||||
history.append(i_say);history.append(gpt_say)
|
chatbot[-1] = (i_say, gpt_say)
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
history.append(i_say);history.append(gpt_say)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
||||||
|
|
||||||
|
#* 或者使用搜索优化器,这样可以保证后续问答能读取到有效的历史记录
|
||||||
|
else:
|
||||||
|
i_say = f"从以上搜索结果中抽取与问题:{txt} 相关的信息:"
|
||||||
|
i_say, history = input_clipping( # 裁剪输入,从最长的条目开始裁剪,防止爆token
|
||||||
|
inputs=i_say,
|
||||||
|
history=history,
|
||||||
|
max_token_limit=min(model_info[llm_kwargs['llm_model']]['max_token']*3//4, 8192)
|
||||||
|
)
|
||||||
|
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=i_say, inputs_show_user=i_say,
|
||||||
|
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||||
|
sys_prompt="请从给定的若干条搜索结果中抽取信息,对最相关的三个搜索结果进行总结"
|
||||||
|
)
|
||||||
|
chatbot[-1] = (i_say, gpt_say)
|
||||||
|
history = []
|
||||||
|
history.append(i_say);history.append(gpt_say)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
|
||||||
|
|
||||||
|
# ------------- < 第4步:根据综合回答问题 > -------------
|
||||||
|
i_say = f"请根据以上搜索结果回答问题:{txt}"
|
||||||
|
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=i_say, inputs_show_user=i_say,
|
||||||
|
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||||
|
sys_prompt="请根据给定的若干条搜索结果回答问题"
|
||||||
|
)
|
||||||
|
chatbot[-1] = (i_say, gpt_say)
|
||||||
|
history.append(i_say);history.append(gpt_say)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -22,11 +22,13 @@ class NetworkGPT_Wrap(GptAcademicPluginTemplate):
|
|||||||
"""
|
"""
|
||||||
gui_definition = {
|
gui_definition = {
|
||||||
"main_input":
|
"main_input":
|
||||||
ArgProperty(title="输入问题", description="待通过互联网检索的问题", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
ArgProperty(title="输入问题", description="待通过互联网检索的问题,会自动读取输入框内容", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||||
"categories":
|
"categories":
|
||||||
ArgProperty(title="搜索分类", options=["网页", "学术论文"], default_value="网页", description="无", type="dropdown").model_dump_json(),
|
ArgProperty(title="搜索分类", options=["网页", "学术论文"], default_value="网页", description="无", type="dropdown").model_dump_json(),
|
||||||
"engine":
|
"engine":
|
||||||
ArgProperty(title="选择搜索引擎", options=["bing", "google", "duckduckgo"], default_value="bing", description="无", type="dropdown").model_dump_json(),
|
ArgProperty(title="选择搜索引擎", options=["Mixed", "bing", "google", "duckduckgo"], default_value="google", description="无", type="dropdown").model_dump_json(),
|
||||||
|
"optimizer":
|
||||||
|
ArgProperty(title="搜索优化", options=["关闭", "开启", "开启(增强)"], default_value="关闭", description="是否使用搜索增强。注意这可能会消耗较多token", type="dropdown").model_dump_json(),
|
||||||
"searxng_url":
|
"searxng_url":
|
||||||
ArgProperty(title="Searxng服务地址", description="输入Searxng的地址", default_value=get_conf("SEARXNG_URL"), type="string").model_dump_json(), # 主输入,自动从输入框同步
|
ArgProperty(title="Searxng服务地址", description="输入Searxng的地址", default_value=get_conf("SEARXNG_URL"), type="string").model_dump_json(), # 主输入,自动从输入框同步
|
||||||
|
|
||||||
@@ -39,6 +41,5 @@ class NetworkGPT_Wrap(GptAcademicPluginTemplate):
|
|||||||
"""
|
"""
|
||||||
if plugin_kwargs["categories"] == "网页": plugin_kwargs["categories"] = "general"
|
if plugin_kwargs["categories"] == "网页": plugin_kwargs["categories"] = "general"
|
||||||
if plugin_kwargs["categories"] == "学术论文": plugin_kwargs["categories"] = "science"
|
if plugin_kwargs["categories"] == "学术论文": plugin_kwargs["categories"] = "science"
|
||||||
|
|
||||||
yield from 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
yield from 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone, check_repeat_upload, map_file_to_sha256
|
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone, check_repeat_upload, map_file_to_sha256
|
||||||
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
|
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
import glob, os, requests, time, json, tarfile
|
import glob, os, requests, time, json, tarfile
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
@@ -136,25 +138,43 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
|
|||||||
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
|
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
|
||||||
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
|
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
|
||||||
|
|
||||||
url_tar = url_.replace('/abs/', '/e-print/')
|
|
||||||
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
|
|
||||||
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
|
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
|
||||||
os.makedirs(translation_dir, exist_ok=True)
|
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
|
||||||
|
|
||||||
# <-------------- download arxiv source file ------------->
|
|
||||||
dst = pj(translation_dir, arxiv_id + '.tar')
|
dst = pj(translation_dir, arxiv_id + '.tar')
|
||||||
if os.path.exists(dst):
|
os.makedirs(translation_dir, exist_ok=True)
|
||||||
yield from update_ui_lastest_msg("调用缓存", chatbot=chatbot, history=history) # 刷新界面
|
# <-------------- download arxiv source file ------------->
|
||||||
|
|
||||||
|
def fix_url_and_download():
|
||||||
|
# for url_tar in [url_.replace('/abs/', '/e-print/'), url_.replace('/abs/', '/src/')]:
|
||||||
|
for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
|
||||||
|
proxies = get_conf('proxies')
|
||||||
|
r = requests.get(url_tar, proxies=proxies)
|
||||||
|
if r.status_code == 200:
|
||||||
|
with open(dst, 'wb+') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if os.path.exists(dst) and allow_cache:
|
||||||
|
yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
success = True
|
||||||
else:
|
else:
|
||||||
yield from update_ui_lastest_msg("开始下载", chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||||
proxies = get_conf('proxies')
|
success = fix_url_and_download()
|
||||||
r = requests.get(url_tar, proxies=proxies)
|
yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
|
||||||
with open(dst, 'wb+') as f:
|
|
||||||
f.write(r.content)
|
|
||||||
|
if not success:
|
||||||
|
yield from update_ui_lastest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
|
||||||
|
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
|
||||||
|
|
||||||
# <-------------- extract file ------------->
|
# <-------------- extract file ------------->
|
||||||
yield from update_ui_lastest_msg("下载完成", chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
from toolbox import extract_archive
|
from toolbox import extract_archive
|
||||||
extract_archive(file_path=dst, dest_dir=extract_dst)
|
try:
|
||||||
|
extract_archive(file_path=dst, dest_dir=extract_dst)
|
||||||
|
except tarfile.ReadError:
|
||||||
|
os.remove(dst)
|
||||||
|
raise tarfile.ReadError(f"论文下载失败")
|
||||||
return extract_dst, arxiv_id
|
return extract_dst, arxiv_id
|
||||||
|
|
||||||
|
|
||||||
@@ -178,7 +198,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
|||||||
|
|
||||||
if response.ok:
|
if response.ok:
|
||||||
pdf_id = response.json()["pdf_id"]
|
pdf_id = response.json()["pdf_id"]
|
||||||
print(f"PDF processing initiated. PDF ID: {pdf_id}")
|
logger.info(f"PDF processing initiated. PDF ID: {pdf_id}")
|
||||||
|
|
||||||
# Step 2: Check processing status
|
# Step 2: Check processing status
|
||||||
while True:
|
while True:
|
||||||
@@ -186,12 +206,12 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
|||||||
conversion_data = conversion_response.json()
|
conversion_data = conversion_response.json()
|
||||||
|
|
||||||
if conversion_data["status"] == "completed":
|
if conversion_data["status"] == "completed":
|
||||||
print("PDF processing completed.")
|
logger.info("PDF processing completed.")
|
||||||
break
|
break
|
||||||
elif conversion_data["status"] == "error":
|
elif conversion_data["status"] == "error":
|
||||||
print("Error occurred during processing.")
|
logger.info("Error occurred during processing.")
|
||||||
else:
|
else:
|
||||||
print(f"Processing status: {conversion_data['status']}")
|
logger.info(f"Processing status: {conversion_data['status']}")
|
||||||
time.sleep(5) # wait for a few seconds before checking again
|
time.sleep(5) # wait for a few seconds before checking again
|
||||||
|
|
||||||
# Step 3: Save results to local files
|
# Step 3: Save results to local files
|
||||||
@@ -206,7 +226,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
|||||||
output_path = os.path.join(output_dir, output_name)
|
output_path = os.path.join(output_dir, output_name)
|
||||||
with open(output_path, "wb") as output_file:
|
with open(output_path, "wb") as output_file:
|
||||||
output_file.write(response.content)
|
output_file.write(response.content)
|
||||||
print(f"tex.zip file saved at: {output_path}")
|
logger.info(f"tex.zip file saved at: {output_path}")
|
||||||
|
|
||||||
import zipfile
|
import zipfile
|
||||||
unzip_dir = os.path.join(output_dir, file_name_wo_dot)
|
unzip_dir = os.path.join(output_dir, file_name_wo_dot)
|
||||||
@@ -216,7 +236,7 @@ def pdf2tex_project(pdf_file_path, plugin_kwargs):
|
|||||||
return unzip_dir
|
return unzip_dir
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Error sending PDF for processing. Status code: {response.status_code}")
|
logger.error(f"Error sending PDF for processing. Status code: {response.status_code}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
|
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
|
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
|
||||||
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
|
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class PaperFileGroup():
|
class PaperFileGroup():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -33,7 +33,7 @@ class PaperFileGroup():
|
|||||||
self.sp_file_index.append(index)
|
self.sp_file_index.append(index)
|
||||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
||||||
|
|
||||||
print('Segmentation: done')
|
logger.info('Segmentation: done')
|
||||||
def merge_result(self):
|
def merge_result(self):
|
||||||
self.file_result = ["" for _ in range(len(self.file_paths))]
|
self.file_result = ["" for _ in range(len(self.file_paths))]
|
||||||
for r, k in zip(self.sp_file_result, self.sp_file_index):
|
for r, k in zip(self.sp_file_result, self.sp_file_index):
|
||||||
@@ -56,7 +56,7 @@ class PaperFileGroup():
|
|||||||
|
|
||||||
def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en', mode='polish'):
|
def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en', mode='polish'):
|
||||||
import time, os, re
|
import time, os, re
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
|
||||||
|
|
||||||
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
||||||
@@ -122,7 +122,7 @@ def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
pfg.write_result()
|
pfg.write_result()
|
||||||
pfg.zip_result()
|
pfg.zip_result()
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
|
|
||||||
# <-------- 整理结果,退出 ---------->
|
# <-------- 整理结果,退出 ---------->
|
||||||
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
|
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from toolbox import update_ui, promote_file_to_downloadzone
|
from toolbox import update_ui, promote_file_to_downloadzone
|
||||||
from toolbox import CatchException, report_exception, write_history_to_file
|
from toolbox import CatchException, report_exception, write_history_to_file
|
||||||
fast_debug = False
|
from loguru import logger
|
||||||
|
|
||||||
class PaperFileGroup():
|
class PaperFileGroup():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -33,11 +33,11 @@ class PaperFileGroup():
|
|||||||
self.sp_file_index.append(index)
|
self.sp_file_index.append(index)
|
||||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
|
||||||
|
|
||||||
print('Segmentation: done')
|
logger.info('Segmentation: done')
|
||||||
|
|
||||||
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
||||||
import time, os, re
|
import time, os, re
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
|
||||||
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
|
||||||
pfg = PaperFileGroup()
|
pfg = PaperFileGroup()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import glob, shutil, os, re, logging
|
import glob, shutil, os, re
|
||||||
|
from loguru import logger
|
||||||
from toolbox import update_ui, trimmed_format_exc, gen_time_str
|
from toolbox import update_ui, trimmed_format_exc, gen_time_str
|
||||||
from toolbox import CatchException, report_exception, get_log_folder
|
from toolbox import CatchException, report_exception, get_log_folder
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
@@ -34,7 +35,7 @@ class PaperFileGroup():
|
|||||||
self.sp_file_contents.append(segment)
|
self.sp_file_contents.append(segment)
|
||||||
self.sp_file_index.append(index)
|
self.sp_file_index.append(index)
|
||||||
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.md")
|
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.md")
|
||||||
logging.info('Segmentation: done')
|
logger.info('Segmentation: done')
|
||||||
|
|
||||||
def merge_result(self):
|
def merge_result(self):
|
||||||
self.file_result = ["" for _ in range(len(self.file_paths))]
|
self.file_result = ["" for _ in range(len(self.file_paths))]
|
||||||
@@ -51,7 +52,7 @@ class PaperFileGroup():
|
|||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
|
||||||
# <-------- 读取Markdown文件,删除其中的所有注释 ---------->
|
# <-------- 读取Markdown文件,删除其中的所有注释 ---------->
|
||||||
pfg = PaperFileGroup()
|
pfg = PaperFileGroup()
|
||||||
@@ -106,7 +107,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
|
|||||||
expected_f_name = plugin_kwargs['markdown_expected_output_path']
|
expected_f_name = plugin_kwargs['markdown_expected_output_path']
|
||||||
shutil.copyfile(output_file, expected_f_name)
|
shutil.copyfile(output_file, expected_f_name)
|
||||||
except:
|
except:
|
||||||
logging.error(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
|
|
||||||
# <-------- 整理结果,退出 ---------->
|
# <-------- 整理结果,退出 ---------->
|
||||||
create_report_file_name = gen_time_str() + f"-chatgpt.md"
|
create_report_file_name = gen_time_str() + f"-chatgpt.md"
|
||||||
@@ -126,7 +127,7 @@ def get_files_from_everything(txt, preference=''):
|
|||||||
proxies = get_conf('proxies')
|
proxies = get_conf('proxies')
|
||||||
# 网络的远程文件
|
# 网络的远程文件
|
||||||
if preference == 'Github':
|
if preference == 'Github':
|
||||||
logging.info('正在从github下载资源 ...')
|
logger.info('正在从github下载资源 ...')
|
||||||
if not txt.endswith('.md'):
|
if not txt.endswith('.md'):
|
||||||
# Make a request to the GitHub API to retrieve the repository information
|
# Make a request to the GitHub API to retrieve the repository information
|
||||||
url = txt.replace("https://github.com/", "https://api.github.com/repos/") + '/readme'
|
url = txt.replace("https://github.com/", "https://api.github.com/repos/") + '/readme'
|
||||||
|
|||||||
92
crazy_functions/Rag_Interface.py
Normal file
92
crazy_functions/Rag_Interface.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||||
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
|
||||||
|
RAG_WORKER_REGISTER = {}
|
||||||
|
MAX_HISTORY_ROUND = 5
|
||||||
|
MAX_CONTEXT_TOKEN_LIMIT = 4096
|
||||||
|
REMEMBER_PREVIEW = 1000
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
|
||||||
|
# import vector store lib
|
||||||
|
VECTOR_STORE_TYPE = "Milvus"
|
||||||
|
if VECTOR_STORE_TYPE == "Milvus":
|
||||||
|
try:
|
||||||
|
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
|
||||||
|
except:
|
||||||
|
VECTOR_STORE_TYPE = "Simple"
|
||||||
|
if VECTOR_STORE_TYPE == "Simple":
|
||||||
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
|
||||||
|
# 1. we retrieve rag worker from global context
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
|
||||||
|
if user_name in RAG_WORKER_REGISTER:
|
||||||
|
rag_worker = RAG_WORKER_REGISTER[user_name]
|
||||||
|
else:
|
||||||
|
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
|
||||||
|
user_name,
|
||||||
|
llm_kwargs,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
auto_load_checkpoint=True)
|
||||||
|
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
|
||||||
|
tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库"
|
||||||
|
if txt == "清空向量数据库":
|
||||||
|
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
rag_worker.purge()
|
||||||
|
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
# 2. clip history to reduce token consumption
|
||||||
|
# 2-1. reduce chat round
|
||||||
|
txt_origin = txt
|
||||||
|
|
||||||
|
if len(history) > MAX_HISTORY_ROUND * 2:
|
||||||
|
history = history[-(MAX_HISTORY_ROUND * 2):]
|
||||||
|
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
|
||||||
|
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
|
||||||
|
|
||||||
|
# 2-2. if input is clipped, add input to vector store before retrieve
|
||||||
|
if input_is_clipped_flag:
|
||||||
|
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
|
# save input to vector store
|
||||||
|
rag_worker.add_text_to_vector_store(txt_origin)
|
||||||
|
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
|
||||||
|
if len(txt_origin) > REMEMBER_PREVIEW:
|
||||||
|
HALF = REMEMBER_PREVIEW//2
|
||||||
|
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
|
||||||
|
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
|
||||||
|
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
i_say = txt_clip
|
||||||
|
else:
|
||||||
|
i_say_to_remember = i_say = txt_clip
|
||||||
|
else:
|
||||||
|
i_say_to_remember = i_say = txt_clip
|
||||||
|
|
||||||
|
# 3. we search vector store and build prompts
|
||||||
|
nodes = rag_worker.retrieve_from_store_with_query(i_say)
|
||||||
|
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
|
||||||
|
|
||||||
|
# 4. it is time to query llms
|
||||||
|
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
|
||||||
|
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=prompt, inputs_show_user=i_say,
|
||||||
|
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
retry_times_at_unknown_error=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. remember what has been asked / answered
|
||||||
|
yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
|
||||||
|
rag_worker.remember_qa(i_say_to_remember, model_say)
|
||||||
|
history.extend([i_say, model_say])
|
||||||
|
|
||||||
|
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面
|
||||||
167
crazy_functions/Social_Helper.py
Normal file
167
crazy_functions/Social_Helper.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import pickle, os, random
|
||||||
|
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
|
||||||
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
|
from crazy_functions.json_fns.select_tool import structure_output, select_tool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
SOCIAL_NETWOK_WORKER_REGISTER = {}
|
||||||
|
|
||||||
|
class SocialNetwork():
|
||||||
|
def __init__(self):
|
||||||
|
self.people = []
|
||||||
|
|
||||||
|
class SaveAndLoad():
|
||||||
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
|
self.user_name = user_name
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
if auto_load_checkpoint:
|
||||||
|
self.social_network = self.load_from_checkpoint(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
self.social_network = SocialNetwork()
|
||||||
|
|
||||||
|
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||||
|
import os, glob
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if not os.path.exists(checkpoint_dir): return False
|
||||||
|
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f:
|
||||||
|
pickle.dump(self.social_network, f)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||||
|
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f:
|
||||||
|
social_network = pickle.load(f)
|
||||||
|
return social_network
|
||||||
|
else:
|
||||||
|
return SocialNetwork()
|
||||||
|
|
||||||
|
|
||||||
|
class Friend(BaseModel):
|
||||||
|
friend_name: str = Field(description="name of a friend")
|
||||||
|
friend_description: str = Field(description="description of a friend (everything about this friend)")
|
||||||
|
friend_relationship: str = Field(description="The relationship with a friend (e.g. friend, family, colleague)")
|
||||||
|
|
||||||
|
class FriendList(BaseModel):
|
||||||
|
friends_list: List[Friend] = Field(description="The list of friends")
|
||||||
|
|
||||||
|
|
||||||
|
class SocialNetworkWorker(SaveAndLoad):
|
||||||
|
def ai_socail_advice(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ai_remove_friend(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ai_list_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ai_add_multi_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
|
||||||
|
friend, err_msg = structure_output(
|
||||||
|
txt=prompt,
|
||||||
|
prompt="根据提示, 解析多个联系人的身份信息\n\n",
|
||||||
|
err_msg=f"不能理解该联系人",
|
||||||
|
run_gpt_fn=run_gpt_fn,
|
||||||
|
pydantic_cls=FriendList
|
||||||
|
)
|
||||||
|
if friend.friends_list:
|
||||||
|
for f in friend.friends_list:
|
||||||
|
self.add_friend(f)
|
||||||
|
msg = f"成功添加{len(friend.friends_list)}个联系人: {str(friend.friends_list)}"
|
||||||
|
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
prompt = txt
|
||||||
|
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
|
||||||
|
self.tools_to_select = {
|
||||||
|
"SocialAdvice":{
|
||||||
|
"explain_to_llm": "如果用户希望获取社交指导,调用SocialAdvice生成一些社交建议",
|
||||||
|
"callback": self.ai_socail_advice,
|
||||||
|
},
|
||||||
|
"AddFriends":{
|
||||||
|
"explain_to_llm": "如果用户给出了联系人,调用AddMultiFriends把联系人添加到数据库",
|
||||||
|
"callback": self.ai_add_multi_friends,
|
||||||
|
},
|
||||||
|
"RemoveFriend":{
|
||||||
|
"explain_to_llm": "如果用户希望移除某个联系人,调用RemoveFriend",
|
||||||
|
"callback": self.ai_remove_friend,
|
||||||
|
},
|
||||||
|
"ListFriends":{
|
||||||
|
"explain_to_llm": "如果用户列举联系人,调用ListFriends",
|
||||||
|
"callback": self.ai_list_friends,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
Explaination = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()])
|
||||||
|
class UserSociaIntention(BaseModel):
|
||||||
|
intention_type: str = Field(
|
||||||
|
description=
|
||||||
|
f"The type of user intention. You must choose from {self.tools_to_select.keys()}.\n\n"
|
||||||
|
f"Explaination:\n{Explaination}",
|
||||||
|
default="SocialAdvice"
|
||||||
|
)
|
||||||
|
pydantic_cls_instance, err_msg = select_tool(
|
||||||
|
prompt=txt,
|
||||||
|
run_gpt_fn=run_gpt_fn,
|
||||||
|
pydantic_cls=UserSociaIntention
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
yield from update_ui_lastest_msg(
|
||||||
|
lastmsg=f"无法理解用户意图 {err_msg}",
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
delay=0
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
intention_type = pydantic_cls_instance.intention_type
|
||||||
|
intention_callback = self.tools_to_select[pydantic_cls_instance.intention_type]['callback']
|
||||||
|
yield from intention_callback(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type)
|
||||||
|
|
||||||
|
|
||||||
|
def add_friend(self, friend):
|
||||||
|
# check whether the friend is already in the social network
|
||||||
|
for f in self.social_network.people:
|
||||||
|
if f.friend_name == friend.friend_name:
|
||||||
|
f.friend_description = friend.friend_description
|
||||||
|
f.friend_relationship = friend.friend_relationship
|
||||||
|
logger.info(f"Repeated friend, update info: {friend}")
|
||||||
|
return
|
||||||
|
logger.info(f"Add a new friend: {friend}")
|
||||||
|
self.social_network.people.append(friend)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
|
||||||
|
# 1. we retrieve worker from global context
|
||||||
|
user_name = chatbot.get_user()
|
||||||
|
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
|
||||||
|
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
|
||||||
|
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
|
||||||
|
else:
|
||||||
|
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
|
||||||
|
user_name,
|
||||||
|
llm_kwargs,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
auto_load_checkpoint=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. save
|
||||||
|
yield from social_network_worker.run(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
|
||||||
|
social_network_worker.save_to_checkpoint(checkpoint_dir)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
@@ -1,13 +1,12 @@
|
|||||||
from toolbox import update_ui, promote_file_to_downloadzone, disable_auto_promotion
|
from toolbox import update_ui, promote_file_to_downloadzone
|
||||||
from toolbox import CatchException, report_exception, write_history_to_file
|
from toolbox import CatchException, report_exception, write_history_to_file
|
||||||
from shared_utils.fastapi_server import validate_path_safety
|
from shared_utils.fastapi_server import validate_path_safety
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
|
||||||
def 解析源代码新(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析源代码新(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
import os, copy
|
import os, copy
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
disable_auto_promotion(chatbot=chatbot)
|
|
||||||
|
|
||||||
summary_batch_isolation = True
|
summary_batch_isolation = True
|
||||||
inputs_array = []
|
inputs_array = []
|
||||||
@@ -24,7 +23,7 @@ def 解析源代码新(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
|
|||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
prefix = "接下来请你逐文件分析下面的工程" if index==0 else ""
|
prefix = "接下来请你逐文件分析下面的工程" if index==0 else ""
|
||||||
i_say = prefix + f'请对下面的程序文件做一个概述文件名是{os.path.relpath(fp, project_folder)},文件代码是 ```{file_content}```'
|
i_say = prefix + f'请对下面的程序文件做一个概述文件名是{os.path.relpath(fp, project_folder)},文件代码是 ```{file_content}```'
|
||||||
i_say_show_user = prefix + f'[{index}/{len(file_manifest)}] 请对下面的程序文件做一个概述: {fp}'
|
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请对下面的程序文件做一个概述: {fp}'
|
||||||
# 装载请求内容
|
# 装载请求内容
|
||||||
inputs_array.append(i_say)
|
inputs_array.append(i_say)
|
||||||
inputs_show_user_array.append(i_say_show_user)
|
inputs_show_user_array.append(i_say_show_user)
|
||||||
138
crazy_functions/SourceCode_Comment.py
Normal file
138
crazy_functions/SourceCode_Comment.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import os, copy, time
|
||||||
|
from toolbox import CatchException, report_exception, update_ui, zip_result, promote_file_to_downloadzone, update_ui_lastest_msg, get_conf, generate_file_link
|
||||||
|
from shared_utils.fastapi_server import validate_path_safety
|
||||||
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
|
||||||
|
from crazy_functions.diagram_fns.file_tree import FileNode
|
||||||
|
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
|
||||||
|
|
||||||
|
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
|
|
||||||
|
summary_batch_isolation = True
|
||||||
|
inputs_array = []
|
||||||
|
inputs_show_user_array = []
|
||||||
|
history_array = []
|
||||||
|
sys_prompt_array = []
|
||||||
|
|
||||||
|
assert len(file_manifest) <= 512, "源文件太多(超过512个), 请缩减输入文件的数量。或者,您也可以选择删除此行警告,并修改代码拆分file_manifest列表,从而实现分批次处理。"
|
||||||
|
|
||||||
|
# 建立文件树
|
||||||
|
file_tree_struct = FileNode("root", build_manifest=True)
|
||||||
|
for file_path in file_manifest:
|
||||||
|
file_tree_struct.add_file(file_path, file_path)
|
||||||
|
|
||||||
|
# <第一步,逐个文件分析,多线程>
|
||||||
|
for index, fp in enumerate(file_manifest):
|
||||||
|
# 读取文件
|
||||||
|
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
|
file_content = f.read()
|
||||||
|
prefix = ""
|
||||||
|
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence, the code is:\n```{file_content}```'
|
||||||
|
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
|
||||||
|
# 装载请求内容
|
||||||
|
MAX_TOKEN_SINGLE_FILE = 2560
|
||||||
|
i_say, _ = input_clipping(inputs=i_say, history=[], max_token_limit=MAX_TOKEN_SINGLE_FILE)
|
||||||
|
inputs_array.append(i_say)
|
||||||
|
inputs_show_user_array.append(i_say_show_user)
|
||||||
|
history_array.append([])
|
||||||
|
sys_prompt_array.append("You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear.")
|
||||||
|
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
|
||||||
|
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array = inputs_array,
|
||||||
|
inputs_show_user_array = inputs_show_user_array,
|
||||||
|
history_array = history_array,
|
||||||
|
sys_prompt_array = sys_prompt_array,
|
||||||
|
llm_kwargs = llm_kwargs,
|
||||||
|
chatbot = chatbot,
|
||||||
|
show_user_at_complete = True
|
||||||
|
)
|
||||||
|
|
||||||
|
# <第二步,逐个文件分析,生成带注释文件>
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
|
||||||
|
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct):
|
||||||
|
pcc = PythonCodeComment(llm_kwargs, language='English')
|
||||||
|
pcc.read_file(path=fp, brief=gpt_say)
|
||||||
|
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
|
||||||
|
file_tree_struct.manifest[fp].revised_path = revised_path
|
||||||
|
file_tree_struct.manifest[fp].revised_content = revised_content
|
||||||
|
# <将结果写回源文件>
|
||||||
|
with open(fp, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(file_tree_struct.manifest[fp].revised_content)
|
||||||
|
# <生成对比html>
|
||||||
|
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
|
||||||
|
html_template = f.read()
|
||||||
|
warp = lambda x: "```python\n\n" + x + "\n\n```"
|
||||||
|
from themes.theme import advanced_css
|
||||||
|
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
|
||||||
|
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
|
||||||
|
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
|
||||||
|
compare_html_path = fp + '.compare.html'
|
||||||
|
file_tree_struct.manifest[fp].compare_html = compare_html_path
|
||||||
|
with open(compare_html_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(html_template)
|
||||||
|
# print('done 1')
|
||||||
|
|
||||||
|
chatbot.append([None, f"正在处理:"])
|
||||||
|
futures = []
|
||||||
|
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
|
||||||
|
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
while True:
|
||||||
|
cnt += 1
|
||||||
|
time.sleep(3)
|
||||||
|
worker_done = [h.done() for h in futures]
|
||||||
|
remain = len(worker_done) - sum(worker_done)
|
||||||
|
|
||||||
|
# <展示已经完成的部分>
|
||||||
|
preview_html_list = []
|
||||||
|
for done, fp in zip(worker_done, file_manifest):
|
||||||
|
if not done: continue
|
||||||
|
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
|
||||||
|
file_links = generate_file_link(preview_html_list)
|
||||||
|
|
||||||
|
yield from update_ui_lastest_msg(
|
||||||
|
f"剩余源文件数量: {remain}.\n\n" +
|
||||||
|
f"已完成的文件: {sum(worker_done)}.\n\n" +
|
||||||
|
file_links +
|
||||||
|
"\n\n" +
|
||||||
|
''.join(['.']*(cnt % 10 + 1)
|
||||||
|
), chatbot=chatbot, history=history, delay=0)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
|
||||||
|
if all(worker_done):
|
||||||
|
executor.shutdown()
|
||||||
|
break
|
||||||
|
|
||||||
|
# <第四步,压缩结果>
|
||||||
|
zip_res = zip_result(project_folder)
|
||||||
|
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
|
||||||
|
|
||||||
|
# <END>
|
||||||
|
chatbot.append((None, "所有源文件均已处理完毕。"))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
history = [] # 清空历史,以免输入溢出
|
||||||
|
import glob, os
|
||||||
|
if os.path.exists(txt):
|
||||||
|
project_folder = txt
|
||||||
|
validate_path_safety(project_folder, chatbot.get_user())
|
||||||
|
else:
|
||||||
|
if txt == "": txt = '空空如也的输入栏'
|
||||||
|
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
return
|
||||||
|
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.py', recursive=True)]
|
||||||
|
if len(file_manifest) == 0:
|
||||||
|
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何python文件: {txt}")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
yield from 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
from crazy_functions.agent_fns.pipe import PluginMultiprocessManager, PipeCom
|
from crazy_functions.agent_fns.pipe import PluginMultiprocessManager, PipeCom
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class EchoDemo(PluginMultiprocessManager):
|
class EchoDemo(PluginMultiprocessManager):
|
||||||
def subprocess_worker(self, child_conn):
|
def subprocess_worker(self, child_conn):
|
||||||
@@ -16,4 +17,4 @@ class EchoDemo(PluginMultiprocessManager):
|
|||||||
elif msg.cmd == "terminate":
|
elif msg.cmd == "terminate":
|
||||||
self.child_conn.send(PipeCom("done", ""))
|
self.child_conn.send(PipeCom("done", ""))
|
||||||
break
|
break
|
||||||
print('[debug] subprocess_worker terminated')
|
logger.info('[debug] subprocess_worker terminated')
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from toolbox import get_log_folder, update_ui, gen_time_str, get_conf, promote_file_to_downloadzone
|
from toolbox import get_log_folder, update_ui, gen_time_str, get_conf, promote_file_to_downloadzone
|
||||||
from crazy_functions.agent_fns.watchdog import WatchDog
|
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||||
|
from loguru import logger
|
||||||
import time, os
|
import time, os
|
||||||
|
|
||||||
class PipeCom:
|
class PipeCom:
|
||||||
@@ -47,7 +48,7 @@ class PluginMultiprocessManager:
|
|||||||
def terminate(self):
|
def terminate(self):
|
||||||
self.p.terminate()
|
self.p.terminate()
|
||||||
self.alive = False
|
self.alive = False
|
||||||
print("[debug] instance terminated")
|
logger.info("[debug] instance terminated")
|
||||||
|
|
||||||
def subprocess_worker(self, child_conn):
|
def subprocess_worker(self, child_conn):
|
||||||
# ⭐⭐ run in subprocess
|
# ⭐⭐ run in subprocess
|
||||||
|
|||||||
393
crazy_functions/agent_fns/python_comment_agent.py
Normal file
393
crazy_functions/agent_fns/python_comment_agent.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
import datetime
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
from textwrap import dedent
|
||||||
|
from toolbox import CatchException, update_ui
|
||||||
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
|
||||||
|
# TODO: 解决缩进问题
|
||||||
|
|
||||||
|
find_function_end_prompt = '''
|
||||||
|
Below is a page of code that you need to read. This page may not yet complete, you job is to split this page to sperate functions, class functions etc.
|
||||||
|
- Provide the line number where the first visible function ends.
|
||||||
|
- Provide the line number where the next visible function begins.
|
||||||
|
- If there are no other functions in this page, you should simply return the line number of the last line.
|
||||||
|
- Only focus on functions declared by `def` keyword. Ignore inline functions. Ignore function calls.
|
||||||
|
|
||||||
|
------------------ Example ------------------
|
||||||
|
INPUT:
|
||||||
|
|
||||||
|
```
|
||||||
|
L0000 |import sys
|
||||||
|
L0001 |import re
|
||||||
|
L0002 |
|
||||||
|
L0003 |def trimmed_format_exc():
|
||||||
|
L0004 | import os
|
||||||
|
L0005 | import traceback
|
||||||
|
L0006 | str = traceback.format_exc()
|
||||||
|
L0007 | current_path = os.getcwd()
|
||||||
|
L0008 | replace_path = "."
|
||||||
|
L0009 | return str.replace(current_path, replace_path)
|
||||||
|
L0010 |
|
||||||
|
L0011 |
|
||||||
|
L0012 |def trimmed_format_exc_markdown():
|
||||||
|
L0013 | ...
|
||||||
|
L0014 | ...
|
||||||
|
```
|
||||||
|
|
||||||
|
OUTPUT:
|
||||||
|
|
||||||
|
```
|
||||||
|
<first_function_end_at>L0009</first_function_end_at>
|
||||||
|
<next_function_begin_from>L0012</next_function_begin_from>
|
||||||
|
```
|
||||||
|
|
||||||
|
------------------ End of Example ------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------ the real INPUT you need to process NOW ------------------
|
||||||
|
```
|
||||||
|
{THE_TAGGED_CODE}
|
||||||
|
```
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
revise_funtion_prompt = '''
|
||||||
|
You need to read the following code, and revise the source code ({FILE_BASENAME}) according to following instructions:
|
||||||
|
1. You should analyze the purpose of the functions (if there are any).
|
||||||
|
2. You need to add docstring for the provided functions (if there are any).
|
||||||
|
|
||||||
|
Be aware:
|
||||||
|
1. You must NOT modify the indent of code.
|
||||||
|
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
|
||||||
|
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
|
||||||
|
|
||||||
|
------------------ Example ------------------
|
||||||
|
INPUT:
|
||||||
|
```
|
||||||
|
L0000 |
|
||||||
|
L0001 |def zip_result(folder):
|
||||||
|
L0002 | t = gen_time_str()
|
||||||
|
L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
|
||||||
|
L0004 | return os.path.join(get_log_folder(), f"result.zip")
|
||||||
|
L0005 |
|
||||||
|
L0006 |
|
||||||
|
```
|
||||||
|
|
||||||
|
OUTPUT:
|
||||||
|
|
||||||
|
<instruction_1_purpose>
|
||||||
|
This function compresses a given folder, and return the path of the resulting `zip` file.
|
||||||
|
</instruction_1_purpose>
|
||||||
|
<instruction_2_revised_code>
|
||||||
|
```
|
||||||
|
def zip_result(folder):
|
||||||
|
"""
|
||||||
|
Compresses the specified folder into a zip file and stores it in the log folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder (str): The path to the folder that needs to be compressed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the created zip file in the log folder.
|
||||||
|
"""
|
||||||
|
t = gen_time_str()
|
||||||
|
zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ Execute the zipping of folder
|
||||||
|
return os.path.join(get_log_folder(), f"result.zip")
|
||||||
|
```
|
||||||
|
</instruction_2_revised_code>
|
||||||
|
------------------ End of Example ------------------
|
||||||
|
|
||||||
|
|
||||||
|
------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
|
||||||
|
```
|
||||||
|
{THE_CODE}
|
||||||
|
```
|
||||||
|
{INDENT_REMINDER}
|
||||||
|
{BRIEF_REMINDER}
|
||||||
|
{HINT_REMINDER}
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PythonCodeComment():
|
||||||
|
|
||||||
|
def __init__(self, llm_kwargs, language) -> None:
|
||||||
|
self.original_content = ""
|
||||||
|
self.full_context = []
|
||||||
|
self.full_context_with_line_no = []
|
||||||
|
self.current_page_start = 0
|
||||||
|
self.page_limit = 100 # 100 lines of code each page
|
||||||
|
self.ignore_limit = 20
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
self.language = language
|
||||||
|
self.path = None
|
||||||
|
self.file_basename = None
|
||||||
|
self.file_brief = ""
|
||||||
|
|
||||||
|
def generate_tagged_code_from_full_context(self):
|
||||||
|
for i, code in enumerate(self.full_context):
|
||||||
|
number = i
|
||||||
|
padded_number = f"{number:04}"
|
||||||
|
result = f"L{padded_number}"
|
||||||
|
self.full_context_with_line_no.append(f"{result} | {code}")
|
||||||
|
return self.full_context_with_line_no
|
||||||
|
|
||||||
|
def read_file(self, path, brief):
|
||||||
|
with open(path, 'r', encoding='utf8') as f:
|
||||||
|
self.full_context = f.readlines()
|
||||||
|
self.original_content = ''.join(self.full_context)
|
||||||
|
self.file_basename = os.path.basename(path)
|
||||||
|
self.file_brief = brief
|
||||||
|
self.full_context_with_line_no = self.generate_tagged_code_from_full_context()
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def find_next_function_begin(self, tagged_code:list, begin_and_end):
|
||||||
|
begin, end = begin_and_end
|
||||||
|
THE_TAGGED_CODE = ''.join(tagged_code)
|
||||||
|
self.llm_kwargs['temperature'] = 0
|
||||||
|
result = predict_no_ui_long_connection(
|
||||||
|
inputs=find_function_end_prompt.format(THE_TAGGED_CODE=THE_TAGGED_CODE),
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt="",
|
||||||
|
observe_window=[],
|
||||||
|
console_slience=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_number(text):
|
||||||
|
# 使用正则表达式匹配模式
|
||||||
|
match = re.search(r'<next_function_begin_from>L(\d+)</next_function_begin_from>', text)
|
||||||
|
if match:
|
||||||
|
# 提取匹配的数字部分并转换为整数
|
||||||
|
return int(match.group(1))
|
||||||
|
return None
|
||||||
|
|
||||||
|
line_no = extract_number(result)
|
||||||
|
if line_no is not None:
|
||||||
|
return line_no
|
||||||
|
else:
|
||||||
|
return end
|
||||||
|
|
||||||
|
def _get_next_window(self):
|
||||||
|
#
|
||||||
|
current_page_start = self.current_page_start
|
||||||
|
|
||||||
|
if self.current_page_start == len(self.full_context) + 1:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
|
# 如果剩余的行数非常少,一鼓作气处理掉
|
||||||
|
if len(self.full_context) - self.current_page_start < self.ignore_limit:
|
||||||
|
future_page_start = len(self.full_context) + 1
|
||||||
|
self.current_page_start = future_page_start
|
||||||
|
return current_page_start, future_page_start
|
||||||
|
|
||||||
|
|
||||||
|
tagged_code = self.full_context_with_line_no[ self.current_page_start: self.current_page_start + self.page_limit]
|
||||||
|
line_no = self.find_next_function_begin(tagged_code, [self.current_page_start, self.current_page_start + self.page_limit])
|
||||||
|
|
||||||
|
if line_no > len(self.full_context) - 5:
|
||||||
|
line_no = len(self.full_context) + 1
|
||||||
|
|
||||||
|
future_page_start = line_no
|
||||||
|
self.current_page_start = future_page_start
|
||||||
|
|
||||||
|
# ! consider eof
|
||||||
|
return current_page_start, future_page_start
|
||||||
|
|
||||||
|
def dedent(self, text):
|
||||||
|
"""Remove any common leading whitespace from every line in `text`.
|
||||||
|
"""
|
||||||
|
# Look for the longest leading string of spaces and tabs common to
|
||||||
|
# all lines.
|
||||||
|
margin = None
|
||||||
|
_whitespace_only_re = re.compile('^[ \t]+$', re.MULTILINE)
|
||||||
|
_leading_whitespace_re = re.compile('(^[ \t]*)(?:[^ \t\n])', re.MULTILINE)
|
||||||
|
text = _whitespace_only_re.sub('', text)
|
||||||
|
indents = _leading_whitespace_re.findall(text)
|
||||||
|
for indent in indents:
|
||||||
|
if margin is None:
|
||||||
|
margin = indent
|
||||||
|
|
||||||
|
# Current line more deeply indented than previous winner:
|
||||||
|
# no change (previous winner is still on top).
|
||||||
|
elif indent.startswith(margin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Current line consistent with and no deeper than previous winner:
|
||||||
|
# it's the new winner.
|
||||||
|
elif margin.startswith(indent):
|
||||||
|
margin = indent
|
||||||
|
|
||||||
|
# Find the largest common whitespace between current line and previous
|
||||||
|
# winner.
|
||||||
|
else:
|
||||||
|
for i, (x, y) in enumerate(zip(margin, indent)):
|
||||||
|
if x != y:
|
||||||
|
margin = margin[:i]
|
||||||
|
break
|
||||||
|
|
||||||
|
# sanity check (testing/debugging only)
|
||||||
|
if 0 and margin:
|
||||||
|
for line in text.split("\n"):
|
||||||
|
assert not line or line.startswith(margin), \
|
||||||
|
"line = %r, margin = %r" % (line, margin)
|
||||||
|
|
||||||
|
if margin:
|
||||||
|
text = re.sub(r'(?m)^' + margin, '', text)
|
||||||
|
return text, len(margin)
|
||||||
|
else:
|
||||||
|
return text, 0
|
||||||
|
|
||||||
|
def get_next_batch(self):
|
||||||
|
current_page_start, future_page_start = self._get_next_window()
|
||||||
|
return ''.join(self.full_context[current_page_start: future_page_start]), current_page_start, future_page_start
|
||||||
|
|
||||||
|
def tag_code(self, fn, hint):
|
||||||
|
code = fn
|
||||||
|
_, n_indent = self.dedent(code)
|
||||||
|
indent_reminder = "" if n_indent == 0 else "(Reminder: as you can see, this piece of code has indent made up with {n_indent} whitespace, please preseve them in the OUTPUT.)"
|
||||||
|
brief_reminder = "" if self.file_brief == "" else f"({self.file_basename} abstract: {self.file_brief})"
|
||||||
|
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
|
||||||
|
self.llm_kwargs['temperature'] = 0
|
||||||
|
result = predict_no_ui_long_connection(
|
||||||
|
inputs=revise_funtion_prompt.format(
|
||||||
|
LANG=self.language,
|
||||||
|
FILE_BASENAME=self.file_basename,
|
||||||
|
THE_CODE=code,
|
||||||
|
INDENT_REMINDER=indent_reminder,
|
||||||
|
BRIEF_REMINDER=brief_reminder,
|
||||||
|
HINT_REMINDER=hint_reminder
|
||||||
|
),
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt="",
|
||||||
|
observe_window=[],
|
||||||
|
console_slience=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_code_block(reply):
|
||||||
|
import re
|
||||||
|
pattern = r"```([\s\S]*?)```" # regex pattern to match code blocks
|
||||||
|
matches = re.findall(pattern, reply) # find all code blocks in text
|
||||||
|
if len(matches) == 1:
|
||||||
|
return matches[0].strip('python') # code block
|
||||||
|
return None
|
||||||
|
|
||||||
|
code_block = get_code_block(result)
|
||||||
|
if code_block is not None:
|
||||||
|
code_block = self.sync_and_patch(original=code, revised=code_block)
|
||||||
|
return code_block
|
||||||
|
else:
|
||||||
|
return code
|
||||||
|
|
||||||
|
def get_markdown_block_in_html(self, html):
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
soup = BeautifulSoup(html, 'lxml')
|
||||||
|
found_list = soup.find_all("div", class_="markdown-body")
|
||||||
|
if found_list:
|
||||||
|
res = found_list[0]
|
||||||
|
return res.prettify()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def sync_and_patch(self, original, revised):
|
||||||
|
"""Ensure the number of pre-string empty lines in revised matches those in original."""
|
||||||
|
|
||||||
|
def count_leading_empty_lines(s, reverse=False):
|
||||||
|
"""Count the number of leading empty lines in a string."""
|
||||||
|
lines = s.split('\n')
|
||||||
|
if reverse: lines = list(reversed(lines))
|
||||||
|
count = 0
|
||||||
|
for line in lines:
|
||||||
|
if line.strip() == '':
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return count
|
||||||
|
|
||||||
|
original_empty_lines = count_leading_empty_lines(original)
|
||||||
|
revised_empty_lines = count_leading_empty_lines(revised)
|
||||||
|
|
||||||
|
if original_empty_lines > revised_empty_lines:
|
||||||
|
additional_lines = '\n' * (original_empty_lines - revised_empty_lines)
|
||||||
|
revised = additional_lines + revised
|
||||||
|
elif original_empty_lines < revised_empty_lines:
|
||||||
|
lines = revised.split('\n')
|
||||||
|
revised = '\n'.join(lines[revised_empty_lines - original_empty_lines:])
|
||||||
|
|
||||||
|
original_empty_lines = count_leading_empty_lines(original, reverse=True)
|
||||||
|
revised_empty_lines = count_leading_empty_lines(revised, reverse=True)
|
||||||
|
|
||||||
|
if original_empty_lines > revised_empty_lines:
|
||||||
|
additional_lines = '\n' * (original_empty_lines - revised_empty_lines)
|
||||||
|
revised = revised + additional_lines
|
||||||
|
elif original_empty_lines < revised_empty_lines:
|
||||||
|
lines = revised.split('\n')
|
||||||
|
revised = '\n'.join(lines[:-(revised_empty_lines - original_empty_lines)])
|
||||||
|
|
||||||
|
return revised
|
||||||
|
|
||||||
|
def begin_comment_source_code(self, chatbot=None, history=None):
|
||||||
|
# from toolbox import update_ui_lastest_msg
|
||||||
|
assert self.path is not None
|
||||||
|
assert '.py' in self.path # must be python source code
|
||||||
|
# write_target = self.path + '.revised.py'
|
||||||
|
|
||||||
|
write_content = ""
|
||||||
|
# with open(self.path + '.revised.py', 'w+', encoding='utf8') as f:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
|
||||||
|
next_batch, line_no_start, line_no_end = self.get_next_batch()
|
||||||
|
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
|
||||||
|
|
||||||
|
hint = None
|
||||||
|
MAX_ATTEMPT = 2
|
||||||
|
for attempt in range(MAX_ATTEMPT):
|
||||||
|
result = self.tag_code(next_batch, hint)
|
||||||
|
try:
|
||||||
|
successful, hint = self.verify_successful(next_batch, result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('ignored exception:\n' + str(e))
|
||||||
|
break
|
||||||
|
if successful:
|
||||||
|
break
|
||||||
|
if attempt == MAX_ATTEMPT - 1:
|
||||||
|
# cannot deal with this, give up
|
||||||
|
result = next_batch
|
||||||
|
break
|
||||||
|
|
||||||
|
# f.write(result)
|
||||||
|
write_content += result
|
||||||
|
except StopIteration:
|
||||||
|
next_batch, line_no_start, line_no_end = [], -1, -1
|
||||||
|
return None, write_content
|
||||||
|
|
||||||
|
def verify_successful(self, original, revised):
|
||||||
|
""" Determine whether the revised code contains every line that already exists
|
||||||
|
"""
|
||||||
|
from crazy_functions.ast_fns.comment_remove import remove_python_comments
|
||||||
|
original = remove_python_comments(original)
|
||||||
|
original_lines = original.split('\n')
|
||||||
|
revised_lines = revised.split('\n')
|
||||||
|
|
||||||
|
for l in original_lines:
|
||||||
|
l = l.strip()
|
||||||
|
if '\'' in l or '\"' in l: continue # ast sometimes toggle " to '
|
||||||
|
found = False
|
||||||
|
for lt in revised_lines:
|
||||||
|
if l in lt:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if not found:
|
||||||
|
return False, l
|
||||||
|
return True, None
|
||||||
45
crazy_functions/agent_fns/python_comment_compare.html
Normal file
45
crazy_functions/agent_fns/python_comment_compare.html
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<style>ADVANCED_CSS</style>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>源文件对比</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
height: 100vh;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
display: flex;
|
||||||
|
width: 95%;
|
||||||
|
height: -webkit-fill-available;
|
||||||
|
}
|
||||||
|
.code-container {
|
||||||
|
flex: 1;
|
||||||
|
margin: 0px;
|
||||||
|
padding: 0px;
|
||||||
|
border: 1px solid #ccc;
|
||||||
|
background-color: #f9f9f9;
|
||||||
|
overflow: auto;
|
||||||
|
}
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-wrap: break-word;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="code-container">
|
||||||
|
REPLACE_CODE_FILE_LEFT
|
||||||
|
</div>
|
||||||
|
<div class="code-container">
|
||||||
|
REPLACE_CODE_FILE_RIGHT
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import threading, time
|
import threading, time
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class WatchDog():
|
class WatchDog():
|
||||||
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
|
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
|
||||||
@@ -13,7 +14,7 @@ class WatchDog():
|
|||||||
while True:
|
while True:
|
||||||
if self.kill_dog: break
|
if self.kill_dog: break
|
||||||
if time.time() - self.last_feed > self.timeout:
|
if time.time() - self.last_feed > self.timeout:
|
||||||
if len(self.msg) > 0: print(self.msg)
|
if len(self.msg) > 0: logger.info(self.msg)
|
||||||
self.bark_fn()
|
self.bark_fn()
|
||||||
break
|
break
|
||||||
time.sleep(self.interval)
|
time.sleep(self.interval)
|
||||||
|
|||||||
46
crazy_functions/ast_fns/comment_remove.py
Normal file
46
crazy_functions/ast_fns/comment_remove.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import ast
|
||||||
|
|
||||||
|
class CommentRemover(ast.NodeTransformer):
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
# 移除函数的文档字符串
|
||||||
|
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||||
|
isinstance(node.body[0].value, ast.Str)):
|
||||||
|
node.body = node.body[1:]
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node):
|
||||||
|
# 移除类的文档字符串
|
||||||
|
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||||
|
isinstance(node.body[0].value, ast.Str)):
|
||||||
|
node.body = node.body[1:]
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Module(self, node):
|
||||||
|
# 移除模块的文档字符串
|
||||||
|
if (node.body and isinstance(node.body[0], ast.Expr) and
|
||||||
|
isinstance(node.body[0].value, ast.Str)):
|
||||||
|
node.body = node.body[1:]
|
||||||
|
self.generic_visit(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def remove_python_comments(source_code):
|
||||||
|
# 解析源代码为 AST
|
||||||
|
tree = ast.parse(source_code)
|
||||||
|
# 移除注释
|
||||||
|
transformer = CommentRemover()
|
||||||
|
tree = transformer.visit(tree)
|
||||||
|
# 将处理后的 AST 转换回源代码
|
||||||
|
return ast.unparse(tree)
|
||||||
|
|
||||||
|
# 示例使用
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("source.py", "r", encoding="utf-8") as f:
|
||||||
|
source_code = f.read()
|
||||||
|
|
||||||
|
cleaned_code = remove_python_comments(source_code)
|
||||||
|
|
||||||
|
with open("cleaned_source.py", "w", encoding="utf-8") as f:
|
||||||
|
f.write(cleaned_code)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
|
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
import datetime, json
|
import datetime, json
|
||||||
|
|
||||||
def fetch_items(list_of_items, batch_size):
|
def fetch_items(list_of_items, batch_size):
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
|
|
||||||
from shared_utils.char_visual_effect import scolling_visual_effect
|
|
||||||
import threading
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import threading
|
||||||
|
from loguru import logger
|
||||||
|
from shared_utils.char_visual_effect import scolling_visual_effect
|
||||||
|
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
|
||||||
|
|
||||||
def input_clipping(inputs, history, max_token_limit):
|
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
|
||||||
"""
|
"""
|
||||||
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
|
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
|
||||||
输入:
|
输入:
|
||||||
@@ -20,17 +20,20 @@ def input_clipping(inputs, history, max_token_limit):
|
|||||||
enc = model_info["gpt-3.5-turbo"]['tokenizer']
|
enc = model_info["gpt-3.5-turbo"]['tokenizer']
|
||||||
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
|
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||||
|
|
||||||
|
|
||||||
mode = 'input-and-history'
|
mode = 'input-and-history'
|
||||||
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
|
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
|
||||||
input_token_num = get_token_num(inputs)
|
input_token_num = get_token_num(inputs)
|
||||||
|
original_input_len = len(inputs)
|
||||||
if input_token_num < max_token_limit//2:
|
if input_token_num < max_token_limit//2:
|
||||||
mode = 'only-history'
|
mode = 'only-history'
|
||||||
max_token_limit = max_token_limit - input_token_num
|
max_token_limit = max_token_limit - input_token_num
|
||||||
|
|
||||||
everything = [inputs] if mode == 'input-and-history' else ['']
|
everything = [inputs] if mode == 'input-and-history' else ['']
|
||||||
everything.extend(history)
|
everything.extend(history)
|
||||||
n_token = get_token_num('\n'.join(everything))
|
full_token_num = n_token = get_token_num('\n'.join(everything))
|
||||||
everything_token = [get_token_num(e) for e in everything]
|
everything_token = [get_token_num(e) for e in everything]
|
||||||
|
everything_token_num = sum(everything_token)
|
||||||
delta = max(everything_token) // 16 # 截断时的颗粒度
|
delta = max(everything_token) // 16 # 截断时的颗粒度
|
||||||
|
|
||||||
while n_token > max_token_limit:
|
while n_token > max_token_limit:
|
||||||
@@ -43,10 +46,24 @@ def input_clipping(inputs, history, max_token_limit):
|
|||||||
|
|
||||||
if mode == 'input-and-history':
|
if mode == 'input-and-history':
|
||||||
inputs = everything[0]
|
inputs = everything[0]
|
||||||
|
full_token_num = everything_token_num
|
||||||
else:
|
else:
|
||||||
pass
|
full_token_num = everything_token_num + input_token_num
|
||||||
|
|
||||||
history = everything[1:]
|
history = everything[1:]
|
||||||
return inputs, history
|
|
||||||
|
flags = {
|
||||||
|
"mode": mode,
|
||||||
|
"original_input_token_num": input_token_num,
|
||||||
|
"original_full_token_num": full_token_num,
|
||||||
|
"original_input_len": original_input_len,
|
||||||
|
"clipped_input_len": len(inputs),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not return_clip_flags:
|
||||||
|
return inputs, history
|
||||||
|
else:
|
||||||
|
return inputs, history, flags
|
||||||
|
|
||||||
def request_gpt_model_in_new_thread_with_ui_alive(
|
def request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
inputs, inputs_show_user, llm_kwargs,
|
inputs, inputs_show_user, llm_kwargs,
|
||||||
@@ -116,7 +133,7 @@ def request_gpt_model_in_new_thread_with_ui_alive(
|
|||||||
except:
|
except:
|
||||||
# 【第三种情况】:其他错误:重试几次
|
# 【第三种情况】:其他错误:重试几次
|
||||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||||
print(tb_str)
|
logger.error(tb_str)
|
||||||
mutable[0] += f"[Local Message] 警告,在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
mutable[0] += f"[Local Message] 警告,在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
||||||
if retry_op > 0:
|
if retry_op > 0:
|
||||||
retry_op -= 1
|
retry_op -= 1
|
||||||
@@ -266,7 +283,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
|||||||
# 【第三种情况】:其他错误
|
# 【第三种情况】:其他错误
|
||||||
if detect_timeout(): raise RuntimeError("检测到程序终止。")
|
if detect_timeout(): raise RuntimeError("检测到程序终止。")
|
||||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||||
print(tb_str)
|
logger.error(tb_str)
|
||||||
gpt_say += f"[Local Message] 警告,线程{index}在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
gpt_say += f"[Local Message] 警告,线程{index}在执行过程中遭遇问题, Traceback:\n\n{tb_str}\n\n"
|
||||||
if len(mutable[index][0]) > 0: gpt_say += "此线程失败前收到的回答:\n\n" + mutable[index][0]
|
if len(mutable[index][0]) > 0: gpt_say += "此线程失败前收到的回答:\n\n" + mutable[index][0]
|
||||||
if retry_op > 0:
|
if retry_op > 0:
|
||||||
@@ -361,7 +378,7 @@ def read_and_clean_pdf_text(fp):
|
|||||||
import fitz, copy
|
import fitz, copy
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from shared_utils.colorful import print亮黄, print亮绿
|
# from shared_utils.colorful import print亮黄, print亮绿
|
||||||
fc = 0 # Index 0 文本
|
fc = 0 # Index 0 文本
|
||||||
fs = 1 # Index 1 字体
|
fs = 1 # Index 1 字体
|
||||||
fb = 2 # Index 2 框框
|
fb = 2 # Index 2 框框
|
||||||
@@ -578,7 +595,7 @@ class nougat_interface():
|
|||||||
def nougat_with_timeout(self, command, cwd, timeout=3600):
|
def nougat_with_timeout(self, command, cwd, timeout=3600):
|
||||||
import subprocess
|
import subprocess
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
logging.info(f'正在执行命令 {command}')
|
logger.info(f'正在执行命令 {command}')
|
||||||
with ProxyNetworkActivate("Nougat_Download"):
|
with ProxyNetworkActivate("Nougat_Download"):
|
||||||
process = subprocess.Popen(command, shell=False, cwd=cwd, env=os.environ)
|
process = subprocess.Popen(command, shell=False, cwd=cwd, env=os.environ)
|
||||||
try:
|
try:
|
||||||
@@ -586,7 +603,7 @@ class nougat_interface():
|
|||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
process.kill()
|
process.kill()
|
||||||
stdout, stderr = process.communicate()
|
stdout, stderr = process.communicate()
|
||||||
print("Process timed out!")
|
logger.error("Process timed out!")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class FileNode:
|
class FileNode:
|
||||||
def __init__(self, name):
|
def __init__(self, name, build_manifest=False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.children = []
|
self.children = []
|
||||||
self.is_leaf = False
|
self.is_leaf = False
|
||||||
@@ -10,6 +11,8 @@ class FileNode:
|
|||||||
self.parenting_ship = []
|
self.parenting_ship = []
|
||||||
self.comment = ""
|
self.comment = ""
|
||||||
self.comment_maxlen_show = 50
|
self.comment_maxlen_show = 50
|
||||||
|
self.build_manifest = build_manifest
|
||||||
|
self.manifest = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_linebreaks_at_spaces(string, interval=10):
|
def add_linebreaks_at_spaces(string, interval=10):
|
||||||
@@ -29,6 +32,7 @@ class FileNode:
|
|||||||
level = 1
|
level = 1
|
||||||
if directory_names == "":
|
if directory_names == "":
|
||||||
new_node = FileNode(file_name)
|
new_node = FileNode(file_name)
|
||||||
|
self.manifest[file_path] = new_node
|
||||||
current_node.children.append(new_node)
|
current_node.children.append(new_node)
|
||||||
new_node.is_leaf = True
|
new_node.is_leaf = True
|
||||||
new_node.comment = self.sanitize_comment(file_comment)
|
new_node.comment = self.sanitize_comment(file_comment)
|
||||||
@@ -50,13 +54,14 @@ class FileNode:
|
|||||||
new_node.level = level - 1
|
new_node.level = level - 1
|
||||||
current_node = new_node
|
current_node = new_node
|
||||||
term = FileNode(file_name)
|
term = FileNode(file_name)
|
||||||
|
self.manifest[file_path] = term
|
||||||
term.level = level
|
term.level = level
|
||||||
term.comment = self.sanitize_comment(file_comment)
|
term.comment = self.sanitize_comment(file_comment)
|
||||||
term.is_leaf = True
|
term.is_leaf = True
|
||||||
current_node.children.append(term)
|
current_node.children.append(term)
|
||||||
|
|
||||||
def print_files_recursively(self, level=0, code="R0"):
|
def print_files_recursively(self, level=0, code="R0"):
|
||||||
print(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
|
logger.info(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
|
||||||
for j, child in enumerate(self.children):
|
for j, child in enumerate(self.children):
|
||||||
child.print_files_recursively(level=level+1, code=code+str(j))
|
child.print_files_recursively(level=level+1, code=code+str(j))
|
||||||
self.parenting_ship.extend(child.parenting_ship)
|
self.parenting_ship.extend(child.parenting_ship)
|
||||||
@@ -119,4 +124,4 @@ if __name__ == "__main__":
|
|||||||
"用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器",
|
"用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器",
|
||||||
"包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类",
|
"包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类",
|
||||||
]
|
]
|
||||||
print(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))
|
logger.info(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))
|
||||||
@@ -92,7 +92,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
|
|||||||
|
|
||||||
def generate_story_image(self, story_paragraph):
|
def generate_story_image(self, story_paragraph):
|
||||||
try:
|
try:
|
||||||
from crazy_functions.图片生成 import gen_image
|
from crazy_functions.Image_Generate import gen_image
|
||||||
prompt_ = predict_no_ui_long_connection(inputs=story_paragraph, llm_kwargs=self.llm_kwargs, history=[], sys_prompt='你需要根据用户给出的小说段落,进行简短的环境描写。要求:80字以内。')
|
prompt_ = predict_no_ui_long_connection(inputs=story_paragraph, llm_kwargs=self.llm_kwargs, history=[], sys_prompt='你需要根据用户给出的小说段落,进行简短的环境描写。要求:80字以内。')
|
||||||
image_url, image_path = gen_image(self.llm_kwargs, prompt_, '512x512', model="dall-e-2", quality='standard', style='natural')
|
image_url, image_path = gen_image(self.llm_kwargs, prompt_, '512x512', model="dall-e-2", quality='standard', style='natural')
|
||||||
return f'<br/><div align="center"><img src="file={image_path}"></div>'
|
return f'<br/><div align="center"><img src="file={image_path}"></div>'
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ class Actor(BaseModel):
|
|||||||
film_names: List[str] = Field(description="list of names of films they starred in")
|
film_names: List[str] = Field(description="list of names of films they starred in")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json, re, logging
|
import json, re
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||||
|
|
||||||
|
|||||||
26
crazy_functions/json_fns/select_tool.py
Normal file
26
crazy_functions/json_fns/select_tool.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError
|
||||||
|
|
||||||
|
def structure_output(txt, prompt, err_msg, run_gpt_fn, pydantic_cls):
|
||||||
|
gpt_json_io = GptJsonIO(pydantic_cls)
|
||||||
|
analyze_res = run_gpt_fn(
|
||||||
|
txt,
|
||||||
|
sys_prompt=prompt + gpt_json_io.format_instructions
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
friend = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn)
|
||||||
|
except JsonStringError as e:
|
||||||
|
return None, err_msg
|
||||||
|
|
||||||
|
err_msg = ""
|
||||||
|
return friend, err_msg
|
||||||
|
|
||||||
|
|
||||||
|
def select_tool(prompt, run_gpt_fn, pydantic_cls):
|
||||||
|
pydantic_cls_instance, err_msg = structure_output(
|
||||||
|
txt=prompt,
|
||||||
|
prompt="根据提示, 分析应该调用哪个工具函数\n\n",
|
||||||
|
err_msg=f"不能理解该联系人",
|
||||||
|
run_gpt_fn=run_gpt_fn,
|
||||||
|
pydantic_cls=pydantic_cls
|
||||||
|
)
|
||||||
|
return pydantic_cls_instance, err_msg
|
||||||
@@ -1,15 +1,17 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
|
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
|
||||||
from toolbox import get_conf, promote_file_to_downloadzone
|
from toolbox import get_conf, promote_file_to_downloadzone
|
||||||
from .latex_toolbox import PRESERVE, TRANSFORM
|
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
|
||||||
from .latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
|
||||||
from .latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
|
from crazy_functions.latex_fns.latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
|
||||||
from .latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
|
from crazy_functions.latex_fns.latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
|
||||||
from .latex_toolbox import find_title_and_abs
|
from crazy_functions.latex_fns.latex_toolbox import find_title_and_abs
|
||||||
from .latex_pickle_io import objdump, objload
|
from crazy_functions.latex_fns.latex_pickle_io import objdump, objload
|
||||||
|
|
||||||
import os, shutil
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
|
|
||||||
@@ -323,7 +325,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
|
|||||||
buggy_lines = [int(l) for l in buggy_lines]
|
buggy_lines = [int(l) for l in buggy_lines]
|
||||||
buggy_lines = sorted(buggy_lines)
|
buggy_lines = sorted(buggy_lines)
|
||||||
buggy_line = buggy_lines[0]-1
|
buggy_line = buggy_lines[0]-1
|
||||||
print("reversing tex line that has errors", buggy_line)
|
logger.warning("reversing tex line that has errors", buggy_line)
|
||||||
|
|
||||||
# 重组,逆转出错的段落
|
# 重组,逆转出错的段落
|
||||||
if buggy_line not in fixed_line:
|
if buggy_line not in fixed_line:
|
||||||
@@ -337,7 +339,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
|
|||||||
|
|
||||||
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
|
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
|
||||||
except:
|
except:
|
||||||
print("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
|
logger.error("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
|
||||||
return False, -1, [-1]
|
return False, -1, [-1]
|
||||||
|
|
||||||
|
|
||||||
@@ -380,7 +382,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
|||||||
|
|
||||||
if mode!='translate_zh':
|
if mode!='translate_zh':
|
||||||
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
|
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
|
||||||
print( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
|
logger.info( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
|
||||||
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
|
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
|
||||||
|
|
||||||
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
|
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
|
||||||
@@ -419,7 +421,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
|
|||||||
shutil.copyfile(concat_pdf, pj(work_folder, '..', 'translation', 'comparison.pdf'))
|
shutil.copyfile(concat_pdf, pj(work_folder, '..', 'translation', 'comparison.pdf'))
|
||||||
promote_file_to_downloadzone(concat_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
|
promote_file_to_downloadzone(concat_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
pass
|
pass
|
||||||
return True # 成功啦
|
return True # 成功啦
|
||||||
else:
|
else:
|
||||||
@@ -465,4 +467,4 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
|
|||||||
promote_file_to_downloadzone(file=res, chatbot=chatbot)
|
promote_file_to_downloadzone(file=res, chatbot=chatbot)
|
||||||
except:
|
except:
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
print('writing html result failed:', trimmed_format_exc())
|
logger.error('writing html result failed:', trimmed_format_exc())
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import os, shutil
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
PRESERVE = 0
|
PRESERVE = 0
|
||||||
TRANSFORM = 1
|
TRANSFORM = 1
|
||||||
@@ -55,7 +57,7 @@ def post_process(root):
|
|||||||
str_stack.append("{")
|
str_stack.append("{")
|
||||||
elif c == "}":
|
elif c == "}":
|
||||||
if len(str_stack) == 1:
|
if len(str_stack) == 1:
|
||||||
print("stack fix")
|
logger.warning("fixing brace error")
|
||||||
return i
|
return i
|
||||||
str_stack.pop(-1)
|
str_stack.pop(-1)
|
||||||
else:
|
else:
|
||||||
@@ -601,7 +603,7 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
|
|||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
process.kill()
|
process.kill()
|
||||||
stdout, stderr = process.communicate()
|
stdout, stderr = process.communicate()
|
||||||
print("Process timed out!")
|
logger.error("Process timed out (compile_latex_with_timeout)!")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -642,6 +644,213 @@ def run_in_subprocess(func):
|
|||||||
|
|
||||||
|
|
||||||
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
||||||
|
try:
|
||||||
|
logger.info("Merging PDFs using _merge_pdfs_ng")
|
||||||
|
_merge_pdfs_ng(pdf1_path, pdf2_path, output_path)
|
||||||
|
except:
|
||||||
|
logger.info("Merging PDFs using _merge_pdfs_legacy")
|
||||||
|
_merge_pdfs_legacy(pdf1_path, pdf2_path, output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
|
||||||
|
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||||
|
from PyPDF2.generic import NameObject, TextStringObject, ArrayObject, FloatObject, NumberObject
|
||||||
|
|
||||||
|
Percent = 1
|
||||||
|
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
|
||||||
|
# Open the first PDF file
|
||||||
|
with open(pdf1_path, "rb") as pdf1_file:
|
||||||
|
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
|
||||||
|
# Open the second PDF file
|
||||||
|
with open(pdf2_path, "rb") as pdf2_file:
|
||||||
|
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
|
||||||
|
# Create a new PDF file to store the merged pages
|
||||||
|
output_writer = PyPDF2.PdfFileWriter()
|
||||||
|
# Determine the number of pages in each PDF file
|
||||||
|
num_pages = max(pdf1_reader.numPages, pdf2_reader.numPages)
|
||||||
|
# Merge the pages from the two PDF files
|
||||||
|
for page_num in range(num_pages):
|
||||||
|
# Add the page from the first PDF file
|
||||||
|
if page_num < pdf1_reader.numPages:
|
||||||
|
page1 = pdf1_reader.getPage(page_num)
|
||||||
|
else:
|
||||||
|
page1 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||||
|
# Add the page from the second PDF file
|
||||||
|
if page_num < pdf2_reader.numPages:
|
||||||
|
page2 = pdf2_reader.getPage(page_num)
|
||||||
|
else:
|
||||||
|
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||||
|
# Create a new empty page with double width
|
||||||
|
new_page = PyPDF2.PageObject.createBlankPage(
|
||||||
|
width=int(
|
||||||
|
int(page1.mediaBox.getWidth())
|
||||||
|
+ int(page2.mediaBox.getWidth()) * Percent
|
||||||
|
),
|
||||||
|
height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
|
||||||
|
)
|
||||||
|
new_page.mergeTranslatedPage(page1, 0, 0)
|
||||||
|
new_page.mergeTranslatedPage(
|
||||||
|
page2,
|
||||||
|
int(
|
||||||
|
int(page1.mediaBox.getWidth())
|
||||||
|
- int(page2.mediaBox.getWidth()) * (1 - Percent)
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if "/Annots" in new_page:
|
||||||
|
annotations = new_page["/Annots"]
|
||||||
|
for i, annot in enumerate(annotations):
|
||||||
|
annot_obj = annot.get_object()
|
||||||
|
|
||||||
|
# 检查注释类型是否是链接(/Link)
|
||||||
|
if annot_obj.get("/Subtype") == "/Link":
|
||||||
|
# 检查是否为内部链接跳转(/GoTo)或外部URI链接(/URI)
|
||||||
|
action = annot_obj.get("/A")
|
||||||
|
if action:
|
||||||
|
|
||||||
|
if "/S" in action and action["/S"] == "/GoTo":
|
||||||
|
# 内部链接:跳转到文档中的某个页面
|
||||||
|
dest = action.get("/D") # 目标页或目标位置
|
||||||
|
# if dest and annot.idnum in page2_annot_id:
|
||||||
|
if dest in pdf2_reader.named_destinations:
|
||||||
|
# 获取原始文件中跳转信息,包括跳转页面
|
||||||
|
destination = pdf2_reader.named_destinations[
|
||||||
|
dest
|
||||||
|
]
|
||||||
|
page_number = (
|
||||||
|
pdf2_reader.get_destination_page_number(
|
||||||
|
destination
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||||
|
# “/D”:[10,'/XYZ',100,100,0]
|
||||||
|
if destination.dest_array[1] == "/XYZ":
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
FloatObject(
|
||||||
|
destination.dest_array[
|
||||||
|
2
|
||||||
|
]
|
||||||
|
+ int(
|
||||||
|
page1.mediaBox.getWidth()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
destination.dest_array[3],
|
||||||
|
destination.dest_array[4],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rect = annot_obj.get("/Rect")
|
||||||
|
# 更新点击坐标
|
||||||
|
rect = ArrayObject(
|
||||||
|
[
|
||||||
|
FloatObject(
|
||||||
|
rect[0]
|
||||||
|
+ int(page1.mediaBox.getWidth())
|
||||||
|
),
|
||||||
|
rect[1],
|
||||||
|
FloatObject(
|
||||||
|
rect[2]
|
||||||
|
+ int(page1.mediaBox.getWidth())
|
||||||
|
),
|
||||||
|
rect[3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
annot_obj.update(
|
||||||
|
{
|
||||||
|
NameObject(
|
||||||
|
"/Rect"
|
||||||
|
): rect # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# if dest and annot.idnum in page1_annot_id:
|
||||||
|
if dest in pdf1_reader.named_destinations:
|
||||||
|
|
||||||
|
# 获取原始文件中跳转信息,包括跳转页面
|
||||||
|
destination = pdf1_reader.named_destinations[
|
||||||
|
dest
|
||||||
|
]
|
||||||
|
page_number = (
|
||||||
|
pdf1_reader.get_destination_page_number(
|
||||||
|
destination
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
|
||||||
|
# “/D”:[10,'/XYZ',100,100,0]
|
||||||
|
if destination.dest_array[1] == "/XYZ":
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
FloatObject(
|
||||||
|
destination.dest_array[
|
||||||
|
2
|
||||||
|
]
|
||||||
|
),
|
||||||
|
destination.dest_array[3],
|
||||||
|
destination.dest_array[4],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
annot_obj["/A"].update(
|
||||||
|
{
|
||||||
|
NameObject("/D"): ArrayObject(
|
||||||
|
[
|
||||||
|
NumberObject(page_number),
|
||||||
|
destination.dest_array[1],
|
||||||
|
]
|
||||||
|
) # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rect = annot_obj.get("/Rect")
|
||||||
|
rect = ArrayObject(
|
||||||
|
[
|
||||||
|
FloatObject(rect[0]),
|
||||||
|
rect[1],
|
||||||
|
FloatObject(rect[2]),
|
||||||
|
rect[3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
annot_obj.update(
|
||||||
|
{
|
||||||
|
NameObject(
|
||||||
|
"/Rect"
|
||||||
|
): rect # 确保键和值是 PdfObject
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif "/S" in action and action["/S"] == "/URI":
|
||||||
|
# 外部链接:跳转到某个URI
|
||||||
|
uri = action.get("/URI")
|
||||||
|
output_writer.addPage(new_page)
|
||||||
|
# Save the merged PDF file
|
||||||
|
with open(output_path, "wb") as output_file:
|
||||||
|
output_writer.write(output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_pdfs_legacy(pdf1_path, pdf2_path, output_path):
|
||||||
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||||
|
|
||||||
Percent = 0.95
|
Percent = 0.95
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time, logging, json, sys, struct
|
import time, json, sys, struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from loguru import logger as logging
|
||||||
from scipy.io.wavfile import WAVE_FORMAT
|
from scipy.io.wavfile import WAVE_FORMAT
|
||||||
|
|
||||||
def write_numpy_to_wave(filename, rate, data, add_header=False):
|
def write_numpy_to_wave(filename, rate, data, add_header=False):
|
||||||
@@ -106,18 +107,14 @@ def is_speaker_speaking(vad, data, sample_rate):
|
|||||||
class AliyunASR():
|
class AliyunASR():
|
||||||
|
|
||||||
def test_on_sentence_begin(self, message, *args):
|
def test_on_sentence_begin(self, message, *args):
|
||||||
# print("test_on_sentence_begin:{}".format(message))
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_on_sentence_end(self, message, *args):
|
def test_on_sentence_end(self, message, *args):
|
||||||
# print("test_on_sentence_end:{}".format(message))
|
|
||||||
message = json.loads(message)
|
message = json.loads(message)
|
||||||
self.parsed_sentence = message['payload']['result']
|
self.parsed_sentence = message['payload']['result']
|
||||||
self.event_on_entence_end.set()
|
self.event_on_entence_end.set()
|
||||||
# print(self.parsed_sentence)
|
|
||||||
|
|
||||||
def test_on_start(self, message, *args):
|
def test_on_start(self, message, *args):
|
||||||
# print("test_on_start:{}".format(message))
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_on_error(self, message, *args):
|
def test_on_error(self, message, *args):
|
||||||
@@ -129,13 +126,11 @@ class AliyunASR():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_on_result_chg(self, message, *args):
|
def test_on_result_chg(self, message, *args):
|
||||||
# print("test_on_chg:{}".format(message))
|
|
||||||
message = json.loads(message)
|
message = json.loads(message)
|
||||||
self.parsed_text = message['payload']['result']
|
self.parsed_text = message['payload']['result']
|
||||||
self.event_on_result_chg.set()
|
self.event_on_result_chg.set()
|
||||||
|
|
||||||
def test_on_completed(self, message, *args):
|
def test_on_completed(self, message, *args):
|
||||||
# print("on_completed:args=>{} message=>{}".format(args, message))
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def audio_convertion_thread(self, uuid):
|
def audio_convertion_thread(self, uuid):
|
||||||
@@ -248,14 +243,14 @@ class AliyunASR():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.do_action_with_exception(request)
|
response = client.do_action_with_exception(request)
|
||||||
print(response)
|
logging.info(response)
|
||||||
jss = json.loads(response)
|
jss = json.loads(response)
|
||||||
if 'Token' in jss and 'Id' in jss['Token']:
|
if 'Token' in jss and 'Id' in jss['Token']:
|
||||||
token = jss['Token']['Id']
|
token = jss['Token']['Id']
|
||||||
expireTime = jss['Token']['ExpireTime']
|
expireTime = jss['Token']['ExpireTime']
|
||||||
print("token = " + token)
|
logging.info("token = " + token)
|
||||||
print("expireTime = " + str(expireTime))
|
logging.info("expireTime = " + str(expireTime))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logging.error(e)
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
|
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
def force_breakdown(txt, limit, get_token_fn):
|
def force_breakdown(txt, limit, get_token_fn):
|
||||||
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
|
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
|
||||||
@@ -76,7 +77,7 @@ def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=F
|
|||||||
remain_txt_to_cut = post
|
remain_txt_to_cut = post
|
||||||
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||||
process = fin_len/total_len
|
process = fin_len/total_len
|
||||||
print(f'正在文本切分 {int(process*100)}%')
|
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||||
if len(remain_txt_to_cut.strip()) == 0:
|
if len(remain_txt_to_cut.strip()) == 0:
|
||||||
break
|
break
|
||||||
return res
|
return res
|
||||||
@@ -119,7 +120,7 @@ if __name__ == '__main__':
|
|||||||
for i in range(5):
|
for i in range(5):
|
||||||
file_content += file_content
|
file_content += file_content
|
||||||
|
|
||||||
print(len(file_content))
|
logger.info(len(file_content))
|
||||||
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
||||||
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)
|
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_
|
|||||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||||
from shared_utils.colorful import *
|
from shared_utils.colorful import *
|
||||||
|
from loguru import logger
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
@@ -93,7 +94,7 @@ def 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwa
|
|||||||
generated_html_files.append(ch.save_file(create_report_file_name))
|
generated_html_files.append(ch.save_file(create_report_file_name))
|
||||||
except:
|
except:
|
||||||
from toolbox import trimmed_format_exc
|
from toolbox import trimmed_format_exc
|
||||||
print('writing html result failed:', trimmed_format_exc())
|
logger.error('writing html result failed:', trimmed_format_exc())
|
||||||
|
|
||||||
# 准备文件的下载
|
# 准备文件的下载
|
||||||
for pdf_path in generated_conclusion_files:
|
for pdf_path in generated_conclusion_files:
|
||||||
|
|||||||
87
crazy_functions/prompts/internet.py
Normal file
87
crazy_functions/prompts/internet.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
SearchOptimizerPrompt="""作为一个网页搜索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高网页检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如:
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 对话背景。
|
||||||
|
A: 当前对话是关于 Nginx 的介绍和在Ubuntu上的使用等。
|
||||||
|
"
|
||||||
|
原问题: 怎么下载
|
||||||
|
检索词: ["Nginx 下载","Ubuntu Nginx","Ubuntu安装Nginx"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 对话背景。
|
||||||
|
A: 当前对话是关于 Nginx 的介绍和使用等。
|
||||||
|
Q: 报错 "no connection"
|
||||||
|
A: 报错"no connection"可能是因为……
|
||||||
|
"
|
||||||
|
原问题: 怎么解决
|
||||||
|
检索词: ["Nginx报错"no connection" 解决","Nginx'no connection'报错 原因","Nginx提示'no connection'"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
|
||||||
|
"
|
||||||
|
原问题: 你知道 Python 么?
|
||||||
|
检索词: ["Python","Python 使用教程。","Python 特点和优势"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 列出Java的三种特点?
|
||||||
|
A: 1. Java 是一种编译型语言。
|
||||||
|
2. Java 是一种面向对象的编程语言。
|
||||||
|
3. Java 是一种跨平台的编程语言。
|
||||||
|
"
|
||||||
|
原问题: 介绍下第2点。
|
||||||
|
检索词: ["Java 面向对象特点","Java 面向对象编程优势。","Java 面向对象编程"]
|
||||||
|
----------------
|
||||||
|
现在有历史记录:
|
||||||
|
"
|
||||||
|
{history}
|
||||||
|
"
|
||||||
|
有其原问题: {query}
|
||||||
|
直接给出最多{num}个检索词,必须以json形式给出,不得有多余字符:
|
||||||
|
"""
|
||||||
|
|
||||||
|
SearchAcademicOptimizerPrompt="""作为一个学术论文搜索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高学术论文检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如:
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 对话背景。
|
||||||
|
A: 当前对话是关于深度学习的介绍和在图像识别中的应用等。
|
||||||
|
"
|
||||||
|
原问题: 怎么下载相关论文
|
||||||
|
检索词: ["深度学习 图像识别 论文下载","图像识别 深度学习 研究论文","深度学习 图像识别 论文资源","Deep Learning Image Recognition Paper Download","Image Recognition Deep Learning Research Paper"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 对话背景。
|
||||||
|
A: 当前对话是关于深度学习的介绍和应用等。
|
||||||
|
Q: 报错 "模型不收敛"
|
||||||
|
A: 报错"模型不收敛"可能是因为……
|
||||||
|
"
|
||||||
|
原问题: 怎么解决
|
||||||
|
检索词: ["深度学习 模型不收敛 解决方案 论文","深度学习 模型不收敛 原因 研究","深度学习 模型不收敛 论文","Deep Learning Model Convergence Issue Solution Paper","Deep Learning Model Convergence Problem Research"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
|
||||||
|
"
|
||||||
|
原问题: 你知道 GAN 么?
|
||||||
|
检索词: ["生成对抗网络 论文","GAN 使用教程 论文","GAN 特点和优势 研究","Generative Adversarial Network Paper","GAN Usage Tutorial Paper"]
|
||||||
|
----------------
|
||||||
|
历史记录:
|
||||||
|
"
|
||||||
|
Q: 列出机器学习的三种应用?
|
||||||
|
A: 1. 机器学习在图像识别中的应用。
|
||||||
|
2. 机器学习在自然语言处理中的应用。
|
||||||
|
3. 机器学习在推荐系统中的应用。
|
||||||
|
"
|
||||||
|
原问题: 介绍下第2点。
|
||||||
|
检索词: ["机器学习 自然语言处理 应用 论文","机器学习 自然语言处理 研究","机器学习 NLP 应用 论文","Machine Learning Natural Language Processing Application Paper","Machine Learning NLP Research"]
|
||||||
|
----------------
|
||||||
|
现在有历史记录:
|
||||||
|
"
|
||||||
|
{history}
|
||||||
|
"
|
||||||
|
有其原问题: {query}
|
||||||
|
直接给出最多{num}个检索词,必须以json形式给出,不得有多余字符:
|
||||||
|
"""
|
||||||
130
crazy_functions/rag_fns/llama_index_worker.py
Normal file
130
crazy_functions/rag_fns/llama_index_worker.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import llama_index
|
||||||
|
import os
|
||||||
|
import atexit
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
||||||
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from llama_index.core.ingestion import run_transformations
|
||||||
|
from llama_index.core import PromptTemplate
|
||||||
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
|
|
||||||
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
|
Now, you have context information as below:
|
||||||
|
---------------------
|
||||||
|
{context_str}
|
||||||
|
---------------------
|
||||||
|
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
|
||||||
|
---------------------
|
||||||
|
{query_str}
|
||||||
|
"""
|
||||||
|
|
||||||
|
QUESTION_ANSWER_RECORD = """\
|
||||||
|
{{
|
||||||
|
"type": "This is a previous conversation with the user",
|
||||||
|
"question": "{question}",
|
||||||
|
"answer": "{answer}",
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLoad():
|
||||||
|
|
||||||
|
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||||
|
import os, glob
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if not os.path.exists(checkpoint_dir): return False
|
||||||
|
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||||
|
logger.info(f'saving vector store to: {checkpoint_dir}')
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||||
|
logger.info('loading checkpoint from disk')
|
||||||
|
from llama_index.core import StorageContext, load_index_from_storage
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
|
||||||
|
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
|
||||||
|
return self.vs_index
|
||||||
|
else:
|
||||||
|
return self.create_new_vs()
|
||||||
|
|
||||||
|
def create_new_vs(self):
|
||||||
|
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model)
|
||||||
|
|
||||||
|
def purge(self):
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
||||||
|
self.vs_index = self.create_new_vs()
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaIndexRagWorker(SaveLoad):
|
||||||
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
|
self.debug_mode = True
|
||||||
|
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||||
|
self.user_name = user_name
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
if auto_load_checkpoint:
|
||||||
|
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
self.vs_index = self.create_new_vs(checkpoint_dir)
|
||||||
|
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
||||||
|
|
||||||
|
def assign_embedding_model(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def inspect_vector_store(self):
|
||||||
|
# This function is for debugging
|
||||||
|
self.vs_index.storage_context.index_store.to_dict()
|
||||||
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
|
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
||||||
|
logger.info('\n++ --------inspect_vector_store begin--------')
|
||||||
|
logger.info(vector_store_preview)
|
||||||
|
logger.info('oo --------inspect_vector_store end--------')
|
||||||
|
return vector_store_preview
|
||||||
|
|
||||||
|
def add_documents_to_vector_store(self, document_list):
|
||||||
|
documents = [Document(text=t) for t in document_list]
|
||||||
|
documents_nodes = run_transformations(
|
||||||
|
documents, # type: ignore
|
||||||
|
self.vs_index._transformations,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
|
|
||||||
|
def add_text_to_vector_store(self, text):
|
||||||
|
node = TextNode(text=text)
|
||||||
|
documents_nodes = run_transformations(
|
||||||
|
[node],
|
||||||
|
self.vs_index._transformations,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
|
|
||||||
|
def remember_qa(self, question, answer):
|
||||||
|
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||||
|
self.add_text_to_vector_store(formatted_str)
|
||||||
|
|
||||||
|
def retrieve_from_store_with_query(self, query):
|
||||||
|
if self.debug_mode: self.inspect_vector_store()
|
||||||
|
retriever = self.vs_index.as_retriever()
|
||||||
|
return retriever.retrieve(query)
|
||||||
|
|
||||||
|
def build_prompt(self, query, nodes):
|
||||||
|
context_str = self.generate_node_array_preview(nodes)
|
||||||
|
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||||
|
|
||||||
|
def generate_node_array_preview(self, nodes):
|
||||||
|
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||||
|
if self.debug_mode: logger.info(buf)
|
||||||
|
return buf
|
||||||
108
crazy_functions/rag_fns/milvus_worker.py
Normal file
108
crazy_functions/rag_fns/milvus_worker.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import llama_index
|
||||||
|
import os
|
||||||
|
import atexit
|
||||||
|
from typing import List
|
||||||
|
from loguru import logger
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
||||||
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from llama_index.core.ingestion import run_transformations
|
||||||
|
from llama_index.core import PromptTemplate
|
||||||
|
from llama_index.core.response_synthesizers import TreeSummarize
|
||||||
|
from llama_index.core import StorageContext
|
||||||
|
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||||
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
|
|
||||||
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
|
Now, you have context information as below:
|
||||||
|
---------------------
|
||||||
|
{context_str}
|
||||||
|
---------------------
|
||||||
|
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
|
||||||
|
---------------------
|
||||||
|
{query_str}
|
||||||
|
"""
|
||||||
|
|
||||||
|
QUESTION_ANSWER_RECORD = """\
|
||||||
|
{{
|
||||||
|
"type": "This is a previous conversation with the user",
|
||||||
|
"question": "{question}",
|
||||||
|
"answer": "{answer}",
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusSaveLoad():
|
||||||
|
|
||||||
|
def does_checkpoint_exist(self, checkpoint_dir=None):
|
||||||
|
import os, glob
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if not os.path.exists(checkpoint_dir): return False
|
||||||
|
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_to_checkpoint(self, checkpoint_dir=None):
|
||||||
|
logger.info(f'saving vector store to: {checkpoint_dir}')
|
||||||
|
# if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
# self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
|
||||||
|
|
||||||
|
def load_from_checkpoint(self, checkpoint_dir=None):
|
||||||
|
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
|
||||||
|
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
|
||||||
|
logger.info('loading checkpoint from disk')
|
||||||
|
from llama_index.core import StorageContext, load_index_from_storage
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
|
||||||
|
try:
|
||||||
|
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
|
||||||
|
return self.vs_index
|
||||||
|
except:
|
||||||
|
return self.create_new_vs(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
return self.create_new_vs(checkpoint_dir)
|
||||||
|
|
||||||
|
def create_new_vs(self, checkpoint_dir, overwrite=False):
|
||||||
|
vector_store = MilvusVectorStore(
|
||||||
|
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
|
||||||
|
dim=self.embed_model.embedding_dimension(),
|
||||||
|
overwrite=overwrite
|
||||||
|
)
|
||||||
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
|
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def purge(self):
|
||||||
|
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
|
||||||
|
|
||||||
|
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
|
||||||
|
|
||||||
|
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
|
||||||
|
self.debug_mode = True
|
||||||
|
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
|
||||||
|
self.user_name = user_name
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
if auto_load_checkpoint:
|
||||||
|
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||||
|
else:
|
||||||
|
self.vs_index = self.create_new_vs(checkpoint_dir)
|
||||||
|
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
||||||
|
|
||||||
|
def inspect_vector_store(self):
|
||||||
|
# This function is for debugging
|
||||||
|
try:
|
||||||
|
self.vs_index.storage_context.index_store.to_dict()
|
||||||
|
docstore = self.vs_index.storage_context.docstore.docs
|
||||||
|
if not docstore.items():
|
||||||
|
raise ValueError("cannot inspect")
|
||||||
|
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
|
||||||
|
except:
|
||||||
|
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
|
||||||
|
vector_store_preview = "\n".join(
|
||||||
|
[f"{node.id_} | {node.text}" for node in dummy_retrieve_res]
|
||||||
|
)
|
||||||
|
logger.info('\n++ --------inspect_vector_store begin--------')
|
||||||
|
logger.info(vector_store_preview)
|
||||||
|
logger.info('oo --------inspect_vector_store end--------')
|
||||||
|
return vector_store_preview
|
||||||
58
crazy_functions/rag_fns/vector_store_index.py
Normal file
58
crazy_functions/rag_fns/vector_store_index.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from llama_index.core import VectorStoreIndex
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
|
from llama_index.core.schema import TransformComponent
|
||||||
|
from llama_index.core.service_context import ServiceContext
|
||||||
|
from llama_index.core.settings import (
|
||||||
|
Settings,
|
||||||
|
callback_manager_from_settings_or_context,
|
||||||
|
transformations_from_settings_or_context,
|
||||||
|
)
|
||||||
|
from llama_index.core.storage.storage_context import StorageContext
|
||||||
|
|
||||||
|
|
||||||
|
class GptacVectorStoreIndex(VectorStoreIndex):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_vector_store(
|
||||||
|
cls,
|
||||||
|
storage_context: Optional[StorageContext] = None,
|
||||||
|
show_progress: bool = False,
|
||||||
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
|
transformations: Optional[List[TransformComponent]] = None,
|
||||||
|
# deprecated
|
||||||
|
service_context: Optional[ServiceContext] = None,
|
||||||
|
embed_model = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Create index from documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents (Optional[Sequence[BaseDocument]]): List of documents to
|
||||||
|
build the index from.
|
||||||
|
|
||||||
|
"""
|
||||||
|
storage_context = storage_context or StorageContext.from_defaults()
|
||||||
|
docstore = storage_context.docstore
|
||||||
|
callback_manager = (
|
||||||
|
callback_manager
|
||||||
|
or callback_manager_from_settings_or_context(Settings, service_context)
|
||||||
|
)
|
||||||
|
transformations = transformations or transformations_from_settings_or_context(
|
||||||
|
Settings, service_context
|
||||||
|
)
|
||||||
|
|
||||||
|
with callback_manager.as_trace("index_construction"):
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
nodes=[],
|
||||||
|
storage_context=storage_context,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
show_progress=show_progress,
|
||||||
|
transformations=transformations,
|
||||||
|
service_context=service_context,
|
||||||
|
embed_model=embed_model,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
# From project chatglm-langchain
|
# From project chatglm-langchain
|
||||||
|
|
||||||
import threading
|
|
||||||
from toolbox import Singleton
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
from toolbox import Singleton
|
||||||
|
from loguru import logger
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import numpy as np
|
|
||||||
from crazy_functions.vector_fns.general_file_loader import load_file
|
from crazy_functions.vector_fns.general_file_loader import load_file
|
||||||
|
|
||||||
embedding_model_dict = {
|
embedding_model_dict = {
|
||||||
@@ -150,17 +151,17 @@ class LocalDocQA:
|
|||||||
failed_files = []
|
failed_files = []
|
||||||
if isinstance(filepath, str):
|
if isinstance(filepath, str):
|
||||||
if not os.path.exists(filepath):
|
if not os.path.exists(filepath):
|
||||||
print("路径不存在")
|
logger.error("路径不存在")
|
||||||
return None
|
return None
|
||||||
elif os.path.isfile(filepath):
|
elif os.path.isfile(filepath):
|
||||||
file = os.path.split(filepath)[-1]
|
file = os.path.split(filepath)[-1]
|
||||||
try:
|
try:
|
||||||
docs = load_file(filepath, SENTENCE_SIZE)
|
docs = load_file(filepath, SENTENCE_SIZE)
|
||||||
print(f"{file} 已成功加载")
|
logger.info(f"{file} 已成功加载")
|
||||||
loaded_files.append(filepath)
|
loaded_files.append(filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
print(f"{file} 未能成功加载")
|
logger.error(f"{file} 未能成功加载")
|
||||||
return None
|
return None
|
||||||
elif os.path.isdir(filepath):
|
elif os.path.isdir(filepath):
|
||||||
docs = []
|
docs = []
|
||||||
@@ -170,23 +171,23 @@ class LocalDocQA:
|
|||||||
docs += load_file(fullfilepath, SENTENCE_SIZE)
|
docs += load_file(fullfilepath, SENTENCE_SIZE)
|
||||||
loaded_files.append(fullfilepath)
|
loaded_files.append(fullfilepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
logger.error(e)
|
||||||
failed_files.append(file)
|
failed_files.append(file)
|
||||||
|
|
||||||
if len(failed_files) > 0:
|
if len(failed_files) > 0:
|
||||||
print("以下文件未能成功加载:")
|
logger.error("以下文件未能成功加载:")
|
||||||
for file in failed_files:
|
for file in failed_files:
|
||||||
print(f"{file}\n")
|
logger.error(f"{file}\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
for file in filepath:
|
for file in filepath:
|
||||||
docs += load_file(file, SENTENCE_SIZE)
|
docs += load_file(file, SENTENCE_SIZE)
|
||||||
print(f"{file} 已成功加载")
|
logger.info(f"{file} 已成功加载")
|
||||||
loaded_files.append(file)
|
loaded_files.append(file)
|
||||||
|
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
print("文件加载完毕,正在生成向量库")
|
logger.info("文件加载完毕,正在生成向量库")
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
try:
|
try:
|
||||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||||
@@ -233,7 +234,7 @@ class LocalDocQA:
|
|||||||
prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)])
|
prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)])
|
||||||
prompt += "\n\n---\n\n"
|
prompt += "\n\n---\n\n"
|
||||||
prompt = prompt.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
|
prompt = prompt.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
|
||||||
# print(prompt)
|
# logger.info(prompt)
|
||||||
response = {"query": query, "source_documents": related_docs_with_score}
|
response = {"query": query, "source_documents": related_docs_with_score}
|
||||||
return response, prompt
|
return response, prompt
|
||||||
|
|
||||||
@@ -262,7 +263,7 @@ def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_co
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
# file_status = "文件未成功加载,请重新上传文件"
|
# file_status = "文件未成功加载,请重新上传文件"
|
||||||
# print(file_status)
|
# logger.info(file_status)
|
||||||
return local_doc_qa, vs_path
|
return local_doc_qa, vs_path
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
@@ -278,7 +279,7 @@ class knowledge_archive_interface():
|
|||||||
if self.text2vec_large_chinese is None:
|
if self.text2vec_large_chinese is None:
|
||||||
# < -------------------预热文本向量化模组--------------- >
|
# < -------------------预热文本向量化模组--------------- >
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
print('Checking Text2vec ...')
|
logger.info('Checking Text2vec ...')
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||||
self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
|
import re, requests, unicodedata, os
|
||||||
from toolbox import update_ui, get_log_folder
|
from toolbox import update_ui, get_log_folder
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from toolbox import CatchException, report_exception, get_conf
|
from toolbox import CatchException, report_exception, get_conf
|
||||||
import re, requests, unicodedata, os
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from loguru import logger
|
||||||
|
|
||||||
def download_arxiv_(url_pdf):
|
def download_arxiv_(url_pdf):
|
||||||
if 'arxiv.org' not in url_pdf:
|
if 'arxiv.org' not in url_pdf:
|
||||||
if ('.' in url_pdf) and ('/' not in url_pdf):
|
if ('.' in url_pdf) and ('/' not in url_pdf):
|
||||||
new_url = 'https://arxiv.org/abs/'+url_pdf
|
new_url = 'https://arxiv.org/abs/'+url_pdf
|
||||||
print('下载编号:', url_pdf, '自动定位:', new_url)
|
logger.info('下载编号:', url_pdf, '自动定位:', new_url)
|
||||||
# download_arxiv_(new_url)
|
# download_arxiv_(new_url)
|
||||||
return download_arxiv_(new_url)
|
return download_arxiv_(new_url)
|
||||||
else:
|
else:
|
||||||
print('不能识别的URL!')
|
logger.info('不能识别的URL!')
|
||||||
return None
|
return None
|
||||||
if 'abs' in url_pdf:
|
if 'abs' in url_pdf:
|
||||||
url_pdf = url_pdf.replace('abs', 'pdf')
|
url_pdf = url_pdf.replace('abs', 'pdf')
|
||||||
@@ -42,15 +44,12 @@ def download_arxiv_(url_pdf):
|
|||||||
requests_pdf_url = url_pdf
|
requests_pdf_url = url_pdf
|
||||||
file_path = download_dir+title_str
|
file_path = download_dir+title_str
|
||||||
|
|
||||||
print('下载中')
|
logger.info('下载中')
|
||||||
proxies = get_conf('proxies')
|
proxies = get_conf('proxies')
|
||||||
r = requests.get(requests_pdf_url, proxies=proxies)
|
r = requests.get(requests_pdf_url, proxies=proxies)
|
||||||
with open(file_path, 'wb+') as f:
|
with open(file_path, 'wb+') as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
print('下载完成')
|
logger.info('下载完成')
|
||||||
|
|
||||||
# print('输出下载命令:','aria2c -o \"%s\" %s'%(title_str,url_pdf))
|
|
||||||
# subprocess.call('aria2c --all-proxy=\"172.18.116.150:11084\" -o \"%s\" %s'%(download_dir+title_str,url_pdf), shell=True)
|
|
||||||
|
|
||||||
x = "%s %s %s.bib" % (paper_id, other_info['year'], other_info['authors'])
|
x = "%s %s %s.bib" % (paper_id, other_info['year'], other_info['authors'])
|
||||||
x = x.replace('?', '?')\
|
x = x.replace('?', '?')\
|
||||||
@@ -63,19 +62,9 @@ def download_arxiv_(url_pdf):
|
|||||||
|
|
||||||
|
|
||||||
def get_name(_url_):
|
def get_name(_url_):
|
||||||
import os
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
print('正在获取文献名!')
|
logger.info('正在获取文献名!')
|
||||||
print(_url_)
|
logger.info(_url_)
|
||||||
|
|
||||||
# arxiv_recall = {}
|
|
||||||
# if os.path.exists('./arxiv_recall.pkl'):
|
|
||||||
# with open('./arxiv_recall.pkl', 'rb') as f:
|
|
||||||
# arxiv_recall = pickle.load(f)
|
|
||||||
|
|
||||||
# if _url_ in arxiv_recall:
|
|
||||||
# print('在缓存中')
|
|
||||||
# return arxiv_recall[_url_]
|
|
||||||
|
|
||||||
proxies = get_conf('proxies')
|
proxies = get_conf('proxies')
|
||||||
res = requests.get(_url_, proxies=proxies)
|
res = requests.get(_url_, proxies=proxies)
|
||||||
@@ -92,7 +81,7 @@ def get_name(_url_):
|
|||||||
other_details['abstract'] = abstract
|
other_details['abstract'] = abstract
|
||||||
except:
|
except:
|
||||||
other_details['year'] = ''
|
other_details['year'] = ''
|
||||||
print('年份获取失败')
|
logger.info('年份获取失败')
|
||||||
|
|
||||||
# get author
|
# get author
|
||||||
try:
|
try:
|
||||||
@@ -101,7 +90,7 @@ def get_name(_url_):
|
|||||||
other_details['authors'] = authors
|
other_details['authors'] = authors
|
||||||
except:
|
except:
|
||||||
other_details['authors'] = ''
|
other_details['authors'] = ''
|
||||||
print('authors获取失败')
|
logger.info('authors获取失败')
|
||||||
|
|
||||||
# get comment
|
# get comment
|
||||||
try:
|
try:
|
||||||
@@ -116,11 +105,11 @@ def get_name(_url_):
|
|||||||
other_details['comment'] = ''
|
other_details['comment'] = ''
|
||||||
except:
|
except:
|
||||||
other_details['comment'] = ''
|
other_details['comment'] = ''
|
||||||
print('年份获取失败')
|
logger.info('年份获取失败')
|
||||||
|
|
||||||
title_str = BeautifulSoup(
|
title_str = BeautifulSoup(
|
||||||
res.text, 'html.parser').find('title').contents[0]
|
res.text, 'html.parser').find('title').contents[0]
|
||||||
print('获取成功:', title_str)
|
logger.info('获取成功:', title_str)
|
||||||
# arxiv_recall[_url_] = (title_str+'.pdf', other_details)
|
# arxiv_recall[_url_] = (title_str+'.pdf', other_details)
|
||||||
# with open('./arxiv_recall.pkl', 'wb') as f:
|
# with open('./arxiv_recall.pkl', 'wb') as f:
|
||||||
# pickle.dump(arxiv_recall, f)
|
# pickle.dump(arxiv_recall, f)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui
|
from toolbox import CatchException, update_ui
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
def 交互功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def 交互功能模板函数(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ Testing:
|
|||||||
|
|
||||||
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, is_the_upload_folder
|
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, is_the_upload_folder
|
||||||
from toolbox import promote_file_to_downloadzone, get_log_folder, update_ui_lastest_msg
|
from toolbox import promote_file_to_downloadzone, get_log_folder, update_ui_lastest_msg
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
||||||
from .crazy_utils import input_clipping, try_install_deps
|
from crazy_functions.crazy_utils import input_clipping, try_install_deps
|
||||||
from crazy_functions.gen_fns.gen_fns_shared import is_function_successfully_generated
|
from crazy_functions.gen_fns.gen_fns_shared import is_function_successfully_generated
|
||||||
from crazy_functions.gen_fns.gen_fns_shared import get_class_name
|
from crazy_functions.gen_fns.gen_fns_shared import get_class_name
|
||||||
from crazy_functions.gen_fns.gen_fns_shared import subprocess_worker
|
from crazy_functions.gen_fns.gen_fns_shared import subprocess_worker
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from toolbox import CatchException, update_ui, gen_time_str
|
from toolbox import CatchException, update_ui, gen_time_str
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from .crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
import copy, json
|
import copy, json
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
|
|||||||
@@ -6,13 +6,14 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import time
|
||||||
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, ProxyNetworkActivate
|
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc, ProxyNetworkActivate
|
||||||
from toolbox import get_conf, select_api_key, update_ui_lastest_msg, Singleton
|
from toolbox import get_conf, select_api_key, update_ui_lastest_msg, Singleton
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_plugin_arg
|
||||||
from crazy_functions.crazy_utils import input_clipping, try_install_deps
|
from crazy_functions.crazy_utils import input_clipping, try_install_deps
|
||||||
from crazy_functions.agent_fns.persistent import GradioMultiuserManagerForPersistentClasses
|
from crazy_functions.agent_fns.persistent import GradioMultiuserManagerForPersistentClasses
|
||||||
from crazy_functions.agent_fns.auto_agent import AutoGenMath
|
from crazy_functions.agent_fns.auto_agent import AutoGenMath
|
||||||
import time
|
from loguru import logger
|
||||||
|
|
||||||
def remove_model_prefix(llm):
|
def remove_model_prefix(llm):
|
||||||
if llm.startswith('api2d-'): llm = llm.replace('api2d-', '')
|
if llm.startswith('api2d-'): llm = llm.replace('api2d-', '')
|
||||||
@@ -80,12 +81,12 @@ def 多智能体终端(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_
|
|||||||
persistent_key = f"{user_uuid}->多智能体终端"
|
persistent_key = f"{user_uuid}->多智能体终端"
|
||||||
if persistent_class_multi_user_manager.already_alive(persistent_key):
|
if persistent_class_multi_user_manager.already_alive(persistent_key):
|
||||||
# 当已经存在一个正在运行的多智能体终端时,直接将用户输入传递给它,而不是再次启动一个新的多智能体终端
|
# 当已经存在一个正在运行的多智能体终端时,直接将用户输入传递给它,而不是再次启动一个新的多智能体终端
|
||||||
print('[debug] feed new user input')
|
logger.info('[debug] feed new user input')
|
||||||
executor = persistent_class_multi_user_manager.get(persistent_key)
|
executor = persistent_class_multi_user_manager.get(persistent_key)
|
||||||
exit_reason = yield from executor.main_process_ui_control(txt, create_or_resume="resume")
|
exit_reason = yield from executor.main_process_ui_control(txt, create_or_resume="resume")
|
||||||
else:
|
else:
|
||||||
# 运行多智能体终端 (首次)
|
# 运行多智能体终端 (首次)
|
||||||
print('[debug] create new executor instance')
|
logger.info('[debug] create new executor instance')
|
||||||
history = []
|
history = []
|
||||||
chatbot.append(["正在启动: 多智能体终端", "插件动态生成, 执行开始, 作者 Microsoft & Binary-Husky."])
|
chatbot.append(["正在启动: 多智能体终端", "插件动态生成, 执行开始, 作者 Microsoft & Binary-Husky."])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
fast_debug = False
|
fast_debug = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, report_exception, select_api_key, update_ui, get_conf
|
from toolbox import CatchException, report_exception, select_api_key, update_ui, get_conf
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone, get_log_folder
|
from toolbox import write_history_to_file, promote_file_to_downloadzone, get_log_folder
|
||||||
|
|
||||||
def split_audio_file(filename, split_duration=1000):
|
def split_audio_file(filename, split_duration=1000):
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
|
from loguru import logger
|
||||||
|
|
||||||
from toolbox import update_ui, promote_file_to_downloadzone, gen_time_str
|
from toolbox import update_ui, promote_file_to_downloadzone, gen_time_str
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from .crazy_utils import read_and_clean_pdf_text
|
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||||
from .crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
file_write_buffer = []
|
file_write_buffer = []
|
||||||
for file_name in file_manifest:
|
for file_name in file_manifest:
|
||||||
print('begin analysis on:', file_name)
|
logger.info('begin analysis on:', file_name)
|
||||||
############################## <第 0 步,切割PDF> ##################################
|
############################## <第 0 步,切割PDF> ##################################
|
||||||
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
||||||
# 的长度必须小于 2500 个 Token
|
# 的长度必须小于 2500 个 Token
|
||||||
@@ -38,7 +40,7 @@ def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot,
|
|||||||
last_iteration_result = paper_meta # 初始值是摘要
|
last_iteration_result = paper_meta # 初始值是摘要
|
||||||
MAX_WORD_TOTAL = 4096 * 0.7
|
MAX_WORD_TOTAL = 4096 * 0.7
|
||||||
n_fragment = len(paper_fragments)
|
n_fragment = len(paper_fragments)
|
||||||
if n_fragment >= 20: print('文章极长,不能达到预期效果')
|
if n_fragment >= 20: logger.warning('文章极长,不能达到预期效果')
|
||||||
for i in range(n_fragment):
|
for i in range(n_fragment):
|
||||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
||||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} Chinese characters: {paper_fragments[i]}"
|
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} Chinese characters: {paper_fragments[i]}"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
from loguru import logger
|
||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
|
|
||||||
fast_debug = False
|
fast_debug = False
|
||||||
@@ -57,7 +58,6 @@ def readPdf(pdfPath):
|
|||||||
layout = device.get_result()
|
layout = device.get_result()
|
||||||
for obj in layout._objs:
|
for obj in layout._objs:
|
||||||
if isinstance(obj, pdfminer.layout.LTTextBoxHorizontal):
|
if isinstance(obj, pdfminer.layout.LTTextBoxHorizontal):
|
||||||
# print(obj.get_text())
|
|
||||||
outTextList.append(obj.get_text())
|
outTextList.append(obj.get_text())
|
||||||
|
|
||||||
return outTextList
|
return outTextList
|
||||||
@@ -66,7 +66,7 @@ def readPdf(pdfPath):
|
|||||||
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
import time, glob, os
|
import time, glob, os
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
print('begin analysis on:', file_manifest)
|
logger.info('begin analysis on:', file_manifest)
|
||||||
for index, fp in enumerate(file_manifest):
|
for index, fp in enumerate(file_manifest):
|
||||||
if ".tex" in fp:
|
if ".tex" in fp:
|
||||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
@@ -77,7 +77,7 @@ def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbo
|
|||||||
|
|
||||||
prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
|
prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
|
||||||
i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
|
i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
|
||||||
i_say_show_user = prefix + f'[{index}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
|
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
|
||||||
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str
|
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str
|
||||||
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
|
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
from .crazy_utils import read_and_clean_pdf_text
|
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||||
from .pdf_fns.parse_pdf import parse_pdf, get_avail_grobid_url, translate_pdf
|
from .pdf_fns.parse_pdf import parse_pdf, get_avail_grobid_url, translate_pdf
|
||||||
from shared_utils.colorful import *
|
from shared_utils.colorful import *
|
||||||
import copy
|
import copy
|
||||||
@@ -60,7 +60,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
# 清空历史,以免输入溢出
|
# 清空历史,以免输入溢出
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
from .crazy_utils import get_files_from_everything
|
from crazy_functions.crazy_utils import get_files_from_everything
|
||||||
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
|
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
|
||||||
if len(file_manifest) > 0:
|
if len(file_manifest) > 0:
|
||||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from loguru import logger
|
||||||
from toolbox import CatchException, update_ui, gen_time_str, promote_file_to_downloadzone
|
from toolbox import CatchException, update_ui, gen_time_str, promote_file_to_downloadzone
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
@@ -34,10 +35,10 @@ def eval_manim(code):
|
|||||||
return f'gpt_log/{time_str}.mp4'
|
return f'gpt_log/{time_str}.mp4'
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
output = e.output.decode()
|
output = e.output.decode()
|
||||||
print(f"Command returned non-zero exit status {e.returncode}: {output}.")
|
logger.error(f"Command returned non-zero exit status {e.returncode}: {output}.")
|
||||||
return f"Evaluating python script failed: {e.output}."
|
return f"Evaluating python script failed: {e.output}."
|
||||||
except:
|
except:
|
||||||
print('generating mp4 failed')
|
logger.error('generating mp4 failed')
|
||||||
return "Generating mp4 failed."
|
return "Generating mp4 failed."
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
|
from loguru import logger
|
||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from .crazy_utils import read_and_clean_pdf_text
|
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
fast_debug = False
|
|
||||||
|
|
||||||
|
|
||||||
def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
import tiktoken
|
logger.info('begin analysis on:', file_name)
|
||||||
print('begin analysis on:', file_name)
|
|
||||||
|
|
||||||
############################## <第 0 步,切割PDF> ##################################
|
############################## <第 0 步,切割PDF> ##################################
|
||||||
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
# 递归地切割PDF文件,每一块(尽量是完整的一个section,比如introduction,experiment等,必要时再进行切割)
|
||||||
@@ -36,7 +35,7 @@ def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
|||||||
last_iteration_result = paper_meta # 初始值是摘要
|
last_iteration_result = paper_meta # 初始值是摘要
|
||||||
MAX_WORD_TOTAL = 4096
|
MAX_WORD_TOTAL = 4096
|
||||||
n_fragment = len(paper_fragments)
|
n_fragment = len(paper_fragments)
|
||||||
if n_fragment >= 20: print('文章极长,不能达到预期效果')
|
if n_fragment >= 20: logger.warning('文章极长,不能达到预期效果')
|
||||||
for i in range(n_fragment):
|
for i in range(n_fragment):
|
||||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
NUM_OF_WORD = MAX_WORD_TOTAL // n_fragment
|
||||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words: {paper_fragments[i]}"
|
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words: {paper_fragments[i]}"
|
||||||
@@ -57,7 +56,7 @@ def 解析PDF(file_name, llm_kwargs, plugin_kwargs, chatbot, history, system_pro
|
|||||||
chatbot.append([i_say_show_user, gpt_say])
|
chatbot.append([i_say_show_user, gpt_say])
|
||||||
|
|
||||||
############################## <第 4 步,设置一个token上限,防止回答时Token溢出> ##################################
|
############################## <第 4 步,设置一个token上限,防止回答时Token溢出> ##################################
|
||||||
from .crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
_, final_results = input_clipping("", final_results, max_token_limit=3200)
|
_, final_results = input_clipping("", final_results, max_token_limit=3200)
|
||||||
yield from update_ui(chatbot=chatbot, history=final_results) # 注意这里的历史记录被替代了
|
yield from update_ui(chatbot=chatbot, history=final_results) # 注意这里的历史记录被替代了
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +1,35 @@
|
|||||||
|
from loguru import logger
|
||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
fast_debug = False
|
|
||||||
|
|
||||||
def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 生成函数注释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
import time, os
|
import time, os
|
||||||
print('begin analysis on:', file_manifest)
|
logger.info('begin analysis on:', file_manifest)
|
||||||
for index, fp in enumerate(file_manifest):
|
for index, fp in enumerate(file_manifest):
|
||||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
|
|
||||||
i_say = f'请对下面的程序文件做一个概述,并对文件中的所有函数生成注释,使用markdown表格输出结果,文件名是{os.path.relpath(fp, project_folder)},文件内容是 ```{file_content}```'
|
i_say = f'请对下面的程序文件做一个概述,并对文件中的所有函数生成注释,使用markdown表格输出结果,文件名是{os.path.relpath(fp, project_folder)},文件内容是 ```{file_content}```'
|
||||||
i_say_show_user = f'[{index}/{len(file_manifest)}] 请对下面的程序文件做一个概述,并对文件中的所有函数生成注释: {os.path.abspath(fp)}'
|
i_say_show_user = f'[{index+1}/{len(file_manifest)}] 请对下面的程序文件做一个概述,并对文件中的所有函数生成注释: {os.path.abspath(fp)}'
|
||||||
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
if not fast_debug:
|
msg = '正常'
|
||||||
msg = '正常'
|
# ** gpt request **
|
||||||
# ** gpt request **
|
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
i_say, i_say_show_user, llm_kwargs, chatbot, history=[], sys_prompt=system_prompt) # 带超时倒计时
|
||||||
i_say, i_say_show_user, llm_kwargs, chatbot, history=[], sys_prompt=system_prompt) # 带超时倒计时
|
|
||||||
|
|
||||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
chatbot[-1] = (i_say_show_user, gpt_say)
|
||||||
history.append(i_say_show_user); history.append(gpt_say)
|
history.append(i_say_show_user); history.append(gpt_say)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
|
||||||
if not fast_debug: time.sleep(2)
|
|
||||||
|
|
||||||
if not fast_debug:
|
|
||||||
res = write_history_to_file(history)
|
|
||||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
|
||||||
chatbot.append(("完成了吗?", res))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
res = write_history_to_file(history)
|
||||||
|
promote_file_to_downloadzone(res, chatbot=chatbot)
|
||||||
|
chatbot.append(("完成了吗?", res))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui, report_exception
|
from toolbox import CatchException, update_ui, report_exception
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from crazy_functions.plugin_template.plugin_class_template import (
|
from crazy_functions.plugin_template.plugin_class_template import (
|
||||||
GptAcademicPluginTemplate,
|
GptAcademicPluginTemplate,
|
||||||
)
|
)
|
||||||
@@ -201,8 +201,7 @@ def 解析历史输入(history, llm_kwargs, file_manifest, chatbot, plugin_kwarg
|
|||||||
MAX_WORD_TOTAL = 4096
|
MAX_WORD_TOTAL = 4096
|
||||||
n_txt = len(txt)
|
n_txt = len(txt)
|
||||||
last_iteration_result = "从以下文本中提取摘要。"
|
last_iteration_result = "从以下文本中提取摘要。"
|
||||||
if n_txt >= 20:
|
|
||||||
print("文章极长,不能达到预期效果")
|
|
||||||
for i in range(n_txt):
|
for i in range(n_txt):
|
||||||
NUM_OF_WORD = MAX_WORD_TOTAL // n_txt
|
NUM_OF_WORD = MAX_WORD_TOTAL // n_txt
|
||||||
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words in Chinese: {txt[i]}"
|
i_say = f"Read this section, recapitulate the content of this section with less than {NUM_OF_WORD} words in Chinese: {txt[i]}"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg, get_log_folder, get_user
|
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg, get_log_folder, get_user
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
||||||
|
from loguru import logger
|
||||||
install_msg ="""
|
install_msg ="""
|
||||||
|
|
||||||
1. python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
|
1. python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
@@ -40,7 +40,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
# from .crazy_utils import try_install_deps
|
# from crazy_functions.crazy_utils import try_install_deps
|
||||||
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
||||||
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
||||||
return
|
return
|
||||||
@@ -60,7 +60,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
# < -------------------预热文本向量化模组--------------- >
|
# < -------------------预热文本向量化模组--------------- >
|
||||||
chatbot.append(['<br/>'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."])
|
chatbot.append(['<br/>'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
print('Checking Text2vec ...')
|
logger.info('Checking Text2vec ...')
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||||
HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
|
||||||
@@ -68,7 +68,7 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
# < -------------------构建知识库--------------- >
|
# < -------------------构建知识库--------------- >
|
||||||
chatbot.append(['<br/>'.join(file_manifest), "正在构建知识库..."])
|
chatbot.append(['<br/>'.join(file_manifest), "正在构建知识库..."])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
print('Establishing knowledge archive ...')
|
logger.info('Establishing knowledge archive ...')
|
||||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||||
kai = knowledge_archive_interface()
|
kai = knowledge_archive_interface()
|
||||||
vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store')
|
vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store')
|
||||||
@@ -93,7 +93,7 @@ def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
chatbot.append(["依赖不足", f"{str(e)}\n\n导入依赖失败。请用以下命令安装" + install_msg])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
# from .crazy_utils import try_install_deps
|
# from crazy_functions.crazy_utils import try_install_deps
|
||||||
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
# try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
|
||||||
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
# yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui
|
from toolbox import CatchException, update_ui
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
@@ -23,8 +23,8 @@ def google(query, proxies):
|
|||||||
item = {'title': title, 'link': link}
|
item = {'title': title, 'link': link}
|
||||||
results.append(item)
|
results.append(item)
|
||||||
|
|
||||||
for r in results:
|
# for r in results:
|
||||||
print(r['link'])
|
# print(r['link'])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def scrape_text(url, proxies) -> str:
|
def scrape_text(url, proxies) -> str:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui
|
from toolbox import CatchException, update_ui
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
@@ -22,8 +22,8 @@ def bing_search(query, proxies=None):
|
|||||||
item = {'title': title, 'link': link}
|
item = {'title': title, 'link': link}
|
||||||
results.append(item)
|
results.append(item)
|
||||||
|
|
||||||
for r in results:
|
# for r in results:
|
||||||
print(r['link'])
|
# print(r['link'])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def parseNotebook(filename, enable_markdown=1):
|
|||||||
|
|
||||||
|
|
||||||
def ipynb解释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def ipynb解释(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
|
||||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||||
enable_markdown = plugin_kwargs.get("advanced_arg", "1")
|
enable_markdown = plugin_kwargs.get("advanced_arg", "1")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from toolbox import CatchException, update_ui, get_conf
|
from toolbox import CatchException, update_ui, get_conf
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
import datetime
|
import datetime
|
||||||
@CatchException
|
@CatchException
|
||||||
def 同时问询(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
def 同时问询(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, get_conf, markdown_convertion
|
from toolbox import CatchException, get_conf, markdown_convertion
|
||||||
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
from crazy_functions.agent_fns.watchdog import WatchDog
|
from crazy_functions.agent_fns.watchdog import WatchDog
|
||||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
from crazy_functions.live_audio.aliyunASR import AliyunASR
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
import threading, time
|
import threading, time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .live_audio.aliyunASR import AliyunASR
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -42,9 +44,9 @@ class AsyncGptTask():
|
|||||||
gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt,
|
gpt_say_partial = predict_no_ui_long_connection(inputs=i_say, llm_kwargs=llm_kwargs, history=history, sys_prompt=sys_prompt,
|
||||||
observe_window=observe_window[index], console_slience=True)
|
observe_window=observe_window[index], console_slience=True)
|
||||||
except ConnectionAbortedError as token_exceed_err:
|
except ConnectionAbortedError as token_exceed_err:
|
||||||
print('至少一个线程任务Token溢出而失败', e)
|
logger.error('至少一个线程任务Token溢出而失败', e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('至少一个线程任务意外失败', e)
|
logger.error('至少一个线程任务意外失败', e)
|
||||||
|
|
||||||
def add_async_gpt_task(self, i_say, chatbot_index, llm_kwargs, history, system_prompt):
|
def add_async_gpt_task(self, i_say, chatbot_index, llm_kwargs, history, system_prompt):
|
||||||
self.observe_future.append([""])
|
self.observe_future.append([""])
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
from toolbox import update_ui
|
from toolbox import update_ui
|
||||||
from toolbox import CatchException, report_exception
|
from toolbox import CatchException, report_exception
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
|
||||||
|
|
||||||
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
def 解析Paper(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
||||||
import time, glob, os
|
import time, glob, os
|
||||||
print('begin analysis on:', file_manifest)
|
|
||||||
for index, fp in enumerate(file_manifest):
|
for index, fp in enumerate(file_manifest):
|
||||||
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
|
|
||||||
prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
|
prefix = "接下来请你逐文件分析下面的论文文件,概括其内容" if index==0 else ""
|
||||||
i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
|
i_say = prefix + f'请对下面的文章片段用中文做一个概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{file_content}```'
|
||||||
i_say_show_user = prefix + f'[{index}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
|
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请对下面的文章片段做一个概述: {os.path.abspath(fp)}'
|
||||||
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
chatbot.append((i_say_show_user, "[Local Message] waiting gpt response."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from toolbox import CatchException, report_exception, promote_file_to_downloadzone
|
from toolbox import CatchException, report_exception, promote_file_to_downloadzone
|
||||||
from toolbox import update_ui, update_ui_lastest_msg, disable_auto_promotion, write_history_to_file
|
from toolbox import update_ui, update_ui_lastest_msg, disable_auto_promotion, write_history_to_file
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
# 此Dockerfile不再维护,请前往docs/GithubAction+JittorLLMs
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# docker build -t gpt-academic-all-capacity -f docs/GithubAction+AllCapacity --network=host --build-arg http_proxy=http://localhost:10881 --build-arg https_proxy=http://localhost:10881 .
|
|
||||||
# docker build -t gpt-academic-all-capacity -f docs/GithubAction+AllCapacityBeta --network=host .
|
|
||||||
# docker run -it --net=host gpt-academic-all-capacity bash
|
|
||||||
|
|
||||||
# 从NVIDIA源,从而支持显卡(检查宿主的nvidia-smi中的cuda版本必须>=11.3)
|
|
||||||
FROM fuqingxu/11.3.1-runtime-ubuntu20.04-with-texlive:latest
|
|
||||||
|
|
||||||
# edge-tts需要的依赖,某些pip包所需的依赖
|
|
||||||
RUN apt update && apt install ffmpeg build-essential -y
|
|
||||||
|
|
||||||
# use python3 as the system default python
|
|
||||||
WORKDIR /gpt
|
|
||||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.8
|
|
||||||
|
|
||||||
# # 非必要步骤,更换pip源 (以下三行,可以删除)
|
|
||||||
# RUN echo '[global]' > /etc/pip.conf && \
|
|
||||||
# echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
|
|
||||||
# echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
|
|
||||||
|
|
||||||
# 下载pytorch
|
|
||||||
RUN python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
|
|
||||||
# 准备pip依赖
|
|
||||||
RUN python3 -m pip install openai numpy arxiv rich
|
|
||||||
RUN python3 -m pip install colorama Markdown pygments pymupdf
|
|
||||||
RUN python3 -m pip install python-docx moviepy pdfminer
|
|
||||||
RUN python3 -m pip install zh_langchain==0.2.1 pypinyin
|
|
||||||
RUN python3 -m pip install rarfile py7zr
|
|
||||||
RUN python3 -m pip install aliyun-python-sdk-core==2.13.3 pyOpenSSL webrtcvad scipy git+https://github.com/aliyun/alibabacloud-nls-python-sdk.git
|
|
||||||
# 下载分支
|
|
||||||
WORKDIR /gpt
|
|
||||||
RUN git clone --depth=1 https://github.com/binary-husky/gpt_academic.git
|
|
||||||
WORKDIR /gpt/gpt_academic
|
|
||||||
RUN git clone --depth=1 https://github.com/OpenLMLab/MOSS.git request_llms/moss
|
|
||||||
|
|
||||||
RUN python3 -m pip install -r requirements.txt
|
|
||||||
RUN python3 -m pip install -r request_llms/requirements_moss.txt
|
|
||||||
RUN python3 -m pip install -r request_llms/requirements_qwen.txt
|
|
||||||
RUN python3 -m pip install -r request_llms/requirements_chatglm.txt
|
|
||||||
RUN python3 -m pip install -r request_llms/requirements_newbing.txt
|
|
||||||
RUN python3 -m pip install nougat-ocr
|
|
||||||
|
|
||||||
|
|
||||||
# 预热Tiktoken模块
|
|
||||||
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
|
||||||
|
|
||||||
# 安装知识库插件的额外依赖
|
|
||||||
RUN apt-get update && apt-get install libgl1 -y
|
|
||||||
RUN pip3 install transformers protobuf langchain sentence-transformers faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk --upgrade
|
|
||||||
RUN pip3 install unstructured[all-docs] --upgrade
|
|
||||||
RUN python3 -c 'from check_proxy import warm_up_vectordb; warm_up_vectordb()'
|
|
||||||
RUN rm -rf /usr/local/lib/python3.8/dist-packages/tests
|
|
||||||
|
|
||||||
|
|
||||||
# COPY .cache /root/.cache
|
|
||||||
# COPY config_private.py config_private.py
|
|
||||||
# 启动
|
|
||||||
CMD ["python3", "-u", "main.py"]
|
|
||||||
25
docs/GithubAction+NoLocal+Latex+Arm
Normal file
25
docs/GithubAction+NoLocal+Latex+Arm
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
|
||||||
|
# - 1 修改 `config.py`
|
||||||
|
# - 2 构建 docker build -t gpt-academic-nolocal-latex -f docs/GithubAction+NoLocal+Latex .
|
||||||
|
# - 3 运行 docker run -v /home/fuqingxu/arxiv_cache:/root/arxiv_cache --rm -it --net=host gpt-academic-nolocal-latex
|
||||||
|
|
||||||
|
FROM menghuan1918/ubuntu_uv_ctex:latest
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
WORKDIR /gpt
|
||||||
|
COPY . .
|
||||||
|
RUN /root/.cargo/bin/uv venv --seed \
|
||||||
|
&& source .venv/bin/activate \
|
||||||
|
&& /root/.cargo/bin/uv pip install openai numpy arxiv rich colorama Markdown pygments pymupdf python-docx pdfminer \
|
||||||
|
&& /root/.cargo/bin/uv pip install -r requirements.txt \
|
||||||
|
&& /root/.cargo/bin/uv clean
|
||||||
|
|
||||||
|
# 对齐python3
|
||||||
|
RUN rm -f /usr/bin/python3 && ln -s /gpt/.venv/bin/python /usr/bin/python3
|
||||||
|
RUN rm -f /usr/bin/python && ln -s /gpt/.venv/bin/python /usr/bin/python
|
||||||
|
|
||||||
|
# 可选步骤,用于预热模块
|
||||||
|
RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
|
||||||
|
|
||||||
|
# 启动
|
||||||
|
CMD ["python3", "-u", "main.py"]
|
||||||
@@ -4,7 +4,7 @@ We currently support fastapi in order to solve sub-path deploy issue.
|
|||||||
|
|
||||||
1. change CUSTOM_PATH setting in `config.py`
|
1. change CUSTOM_PATH setting in `config.py`
|
||||||
|
|
||||||
``` sh
|
```sh
|
||||||
nano config.py
|
nano config.py
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -35,9 +35,8 @@ if __name__ == "__main__":
|
|||||||
main()
|
main()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
3. Go!
|
3. Go!
|
||||||
|
|
||||||
``` sh
|
```sh
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -106,5 +106,24 @@
|
|||||||
"解析PDF_DOC2X_转Latex": "ParsePDF_DOC2X_toLatex",
|
"解析PDF_DOC2X_转Latex": "ParsePDF_DOC2X_toLatex",
|
||||||
"解析PDF_基于DOC2X": "ParsePDF_basedDOC2X",
|
"解析PDF_基于DOC2X": "ParsePDF_basedDOC2X",
|
||||||
"解析PDF_简单拆解": "ParsePDF_simpleDecomposition",
|
"解析PDF_简单拆解": "ParsePDF_simpleDecomposition",
|
||||||
"解析PDF_DOC2X_单文件": "ParsePDF_DOC2X_singleFile"
|
"解析PDF_DOC2X_单文件": "ParsePDF_DOC2X_singleFile",
|
||||||
|
"注释Python项目": "CommentPythonProject",
|
||||||
|
"注释源代码": "CommentSourceCode",
|
||||||
|
"log亮黄": "log_yellow",
|
||||||
|
"log亮绿": "log_green",
|
||||||
|
"log亮红": "log_red",
|
||||||
|
"log亮紫": "log_purple",
|
||||||
|
"log亮蓝": "log_blue",
|
||||||
|
"Rag问答": "RagQA",
|
||||||
|
"sprint红": "sprint_red",
|
||||||
|
"sprint绿": "sprint_green",
|
||||||
|
"sprint黄": "sprint_yellow",
|
||||||
|
"sprint蓝": "sprint_blue",
|
||||||
|
"sprint紫": "sprint_purple",
|
||||||
|
"sprint靛": "sprint_indigo",
|
||||||
|
"sprint亮红": "sprint_bright_red",
|
||||||
|
"sprint亮绿": "sprint_bright_green",
|
||||||
|
"sprint亮黄": "sprint_bright_yellow",
|
||||||
|
"sprint亮蓝": "sprint_bright_blue",
|
||||||
|
"sprint亮紫": "sprint_bright_purple"
|
||||||
}
|
}
|
||||||
93
main.py
93
main.py
@@ -13,24 +13,39 @@ help_menu_description = \
|
|||||||
</br></br>如何语音对话: 请阅读Wiki
|
</br></br>如何语音对话: 请阅读Wiki
|
||||||
</br></br>如何临时更换API_KEY: 在输入区输入临时API_KEY后提交(网页刷新后失效)"""
|
</br></br>如何临时更换API_KEY: 在输入区输入临时API_KEY后提交(网页刷新后失效)"""
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
def enable_log(PATH_LOGGING):
|
def enable_log(PATH_LOGGING):
|
||||||
import logging
|
from shared_utils.logging import setup_logging
|
||||||
admin_log_path = os.path.join(PATH_LOGGING, "admin")
|
setup_logging(PATH_LOGGING)
|
||||||
os.makedirs(admin_log_path, exist_ok=True)
|
|
||||||
log_dir = os.path.join(admin_log_path, "chat_secrets.log")
|
def encode_plugin_info(k, plugin)->str:
|
||||||
try:logging.basicConfig(filename=log_dir, level=logging.INFO, encoding="utf-8", format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
import copy
|
||||||
except:logging.basicConfig(filename=log_dir, level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
from themes.theme import to_cookie_str
|
||||||
# Disable logging output from the 'httpx' logger
|
plugin_ = copy.copy(plugin)
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
plugin_.pop("Function", None)
|
||||||
print(f"所有对话记录将自动保存在本地目录{log_dir}, 请注意自我隐私保护哦!")
|
plugin_.pop("Class", None)
|
||||||
|
plugin_.pop("Button", None)
|
||||||
|
plugin_["Info"] = plugin.get("Info", k)
|
||||||
|
if plugin.get("AdvancedArgs", False):
|
||||||
|
plugin_["Label"] = f"插件[{k}]的高级参数说明:" + plugin.get("ArgsReminder", f"没有提供高级参数功能说明")
|
||||||
|
else:
|
||||||
|
plugin_["Label"] = f"插件[{k}]不需要高级参数。"
|
||||||
|
return to_cookie_str(plugin_)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
if gr.__version__ not in ['3.32.9', '3.32.10', '3.32.11']:
|
if gr.__version__ not in ['3.32.9', '3.32.10', '3.32.11']:
|
||||||
raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.")
|
raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.")
|
||||||
from request_llms.bridge_all import predict
|
|
||||||
|
# 一些基础工具
|
||||||
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
|
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
|
||||||
|
|
||||||
|
# 对话、日志记录
|
||||||
|
enable_log(get_conf("PATH_LOGGING"))
|
||||||
|
|
||||||
|
# 对话句柄
|
||||||
|
from request_llms.bridge_all import predict
|
||||||
|
|
||||||
# 读取配置
|
# 读取配置
|
||||||
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
|
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
|
||||||
CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = get_conf('CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
|
CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = get_conf('CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
|
||||||
@@ -47,8 +62,6 @@ def main():
|
|||||||
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
|
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, assign_user_uuid
|
||||||
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
|
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
|
||||||
|
|
||||||
# 对话、日志记录
|
|
||||||
enable_log(PATH_LOGGING)
|
|
||||||
|
|
||||||
# 一些普通功能模块
|
# 一些普通功能模块
|
||||||
from core_functional import get_core_functions
|
from core_functional import get_core_functions
|
||||||
@@ -98,8 +111,18 @@ def main():
|
|||||||
with gr.Accordion("输入区", open=True, elem_id="input-panel") as area_input_primary:
|
with gr.Accordion("输入区", open=True, elem_id="input-panel") as area_input_primary:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
txt = gr.Textbox(show_label=False, placeholder="Input question here.", elem_id='user_input_main').style(container=False)
|
txt = gr.Textbox(show_label=False, placeholder="Input question here.", elem_id='user_input_main').style(container=False)
|
||||||
with gr.Row():
|
with gr.Row(elem_id="gpt-submit-row"):
|
||||||
submitBtn = gr.Button("提交", elem_id="elem_submit", variant="primary")
|
multiplex_submit_btn = gr.Button("提交", elem_id="elem_submit_visible", variant="primary")
|
||||||
|
multiplex_sel = gr.Dropdown(
|
||||||
|
choices=[
|
||||||
|
"常规对话",
|
||||||
|
"多模型对话",
|
||||||
|
"智能召回 RAG",
|
||||||
|
# "智能上下文",
|
||||||
|
], value="常规对话",
|
||||||
|
interactive=True, label='', show_label=False,
|
||||||
|
elem_classes='normal_mut_select', elem_id="gpt-submit-dropdown").style(container=False)
|
||||||
|
submit_btn = gr.Button("提交", elem_id="elem_submit", variant="primary", visible=False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resetBtn = gr.Button("重置", elem_id="elem_reset", variant="secondary"); resetBtn.style(size="sm")
|
resetBtn = gr.Button("重置", elem_id="elem_reset", variant="secondary"); resetBtn.style(size="sm")
|
||||||
stopBtn = gr.Button("停止", elem_id="elem_stop", variant="secondary"); stopBtn.style(size="sm")
|
stopBtn = gr.Button("停止", elem_id="elem_stop", variant="secondary"); stopBtn.style(size="sm")
|
||||||
@@ -146,7 +169,7 @@ def main():
|
|||||||
if not plugin.get("AsButton", True): dropdown_fn_list.append(k) # 排除已经是按钮的插件
|
if not plugin.get("AsButton", True): dropdown_fn_list.append(k) # 排除已经是按钮的插件
|
||||||
elif plugin.get('AdvancedArgs', False): dropdown_fn_list.append(k) # 对于需要高级参数的插件,亦在下拉菜单中显示
|
elif plugin.get('AdvancedArgs', False): dropdown_fn_list.append(k) # 对于需要高级参数的插件,亦在下拉菜单中显示
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dropdown = gr.Dropdown(dropdown_fn_list, value=r"点击这里搜索插件列表", label="", show_label=False).style(container=False)
|
dropdown = gr.Dropdown(dropdown_fn_list, value=r"点击这里输入「关键词」搜索插件", label="", show_label=False).style(container=False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
plugin_advanced_arg = gr.Textbox(show_label=True, label="高级参数输入区", visible=False, elem_id="advance_arg_input_legacy",
|
plugin_advanced_arg = gr.Textbox(show_label=True, label="高级参数输入区", visible=False, elem_id="advance_arg_input_legacy",
|
||||||
placeholder="这里是特殊函数插件的高级参数输入区").style(container=False)
|
placeholder="这里是特殊函数插件的高级参数输入区").style(container=False)
|
||||||
@@ -163,7 +186,7 @@ def main():
|
|||||||
|
|
||||||
# 浮动菜单定义
|
# 浮动菜单定义
|
||||||
from themes.gui_floating_menu import define_gui_floating_menu
|
from themes.gui_floating_menu import define_gui_floating_menu
|
||||||
area_input_secondary, txt2, area_customize, submitBtn2, resetBtn2, clearBtn2, stopBtn2 = \
|
area_input_secondary, txt2, area_customize, _, resetBtn2, clearBtn2, stopBtn2 = \
|
||||||
define_gui_floating_menu(customize_btns, functional, predefined_btns, cookies, web_cookie_cache)
|
define_gui_floating_menu(customize_btns, functional, predefined_btns, cookies, web_cookie_cache)
|
||||||
|
|
||||||
# 插件二级菜单的实现
|
# 插件二级菜单的实现
|
||||||
@@ -195,11 +218,15 @@ def main():
|
|||||||
input_combo_order = ["cookies", "max_length_sl", "md_dropdown", "txt", "txt2", "top_p", "temperature", "chatbot", "history", "system_prompt", "plugin_advanced_arg"]
|
input_combo_order = ["cookies", "max_length_sl", "md_dropdown", "txt", "txt2", "top_p", "temperature", "chatbot", "history", "system_prompt", "plugin_advanced_arg"]
|
||||||
output_combo = [cookies, chatbot, history, status]
|
output_combo = [cookies, chatbot, history, status]
|
||||||
predict_args = dict(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True)], outputs=output_combo)
|
predict_args = dict(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True)], outputs=output_combo)
|
||||||
|
|
||||||
# 提交按钮、重置按钮
|
# 提交按钮、重置按钮
|
||||||
cancel_handles.append(txt.submit(**predict_args))
|
multiplex_submit_btn.click(
|
||||||
cancel_handles.append(txt2.submit(**predict_args))
|
None, [multiplex_sel], None, _js="""(multiplex_sel)=>multiplex_function_begin(multiplex_sel)""")
|
||||||
cancel_handles.append(submitBtn.click(**predict_args))
|
txt.submit(
|
||||||
cancel_handles.append(submitBtn2.click(**predict_args))
|
None, [multiplex_sel], None, _js="""(multiplex_sel)=>multiplex_function_begin(multiplex_sel)""")
|
||||||
|
multiplex_sel.select(
|
||||||
|
None, [multiplex_sel], None, _js=f"""(multiplex_sel)=>run_multiplex_shift(multiplex_sel)""")
|
||||||
|
cancel_handles.append(submit_btn.click(**predict_args))
|
||||||
resetBtn.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
|
resetBtn.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
|
||||||
resetBtn2.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
|
resetBtn2.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
|
||||||
reset_server_side_args = (lambda history: ([], [], "已重置", json.dumps(history)), [history], [chatbot, history, status, history_cache])
|
reset_server_side_args = (lambda history: ([], [], "已重置", json.dumps(history)), [history], [chatbot, history, status, history_cache])
|
||||||
@@ -208,10 +235,7 @@ def main():
|
|||||||
clearBtn.click(None, None, [txt, txt2], _js=js_code_clear)
|
clearBtn.click(None, None, [txt, txt2], _js=js_code_clear)
|
||||||
clearBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
|
clearBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
|
||||||
if AUTO_CLEAR_TXT:
|
if AUTO_CLEAR_TXT:
|
||||||
submitBtn.click(None, None, [txt, txt2], _js=js_code_clear)
|
submit_btn.click(None, None, [txt, txt2], _js=js_code_clear)
|
||||||
submitBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
|
|
||||||
txt.submit(None, None, [txt, txt2], _js=js_code_clear)
|
|
||||||
txt2.submit(None, None, [txt, txt2], _js=js_code_clear)
|
|
||||||
# 基础功能区的回调函数注册
|
# 基础功能区的回调函数注册
|
||||||
for k in functional:
|
for k in functional:
|
||||||
if ("Visible" in functional[k]) and (not functional[k]["Visible"]): continue
|
if ("Visible" in functional[k]) and (not functional[k]["Visible"]): continue
|
||||||
@@ -224,21 +248,6 @@ def main():
|
|||||||
file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
|
file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
|
||||||
file_upload_2.upload(on_file_uploaded, [file_upload_2, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
|
file_upload_2.upload(on_file_uploaded, [file_upload_2, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
|
||||||
# 函数插件-固定按钮区
|
# 函数插件-固定按钮区
|
||||||
def encode_plugin_info(k, plugin)->str:
|
|
||||||
import copy
|
|
||||||
from themes.theme import to_cookie_str
|
|
||||||
plugin_ = copy.copy(plugin)
|
|
||||||
plugin_.pop("Function", None)
|
|
||||||
plugin_.pop("Class", None)
|
|
||||||
plugin_.pop("Button", None)
|
|
||||||
plugin_["Info"] = plugin.get("Info", k)
|
|
||||||
if plugin.get("AdvancedArgs", False):
|
|
||||||
plugin_["Label"] = f"插件[{k}]的高级参数说明:" + plugin.get("ArgsReminder", f"没有提供高级参数功能说明")
|
|
||||||
else:
|
|
||||||
plugin_["Label"] = f"插件[{k}]不需要高级参数。"
|
|
||||||
return to_cookie_str(plugin_)
|
|
||||||
|
|
||||||
# 插件的注册(前端代码注册)
|
|
||||||
for k in plugins:
|
for k in plugins:
|
||||||
register_advanced_plugin_init_arr += f"""register_plugin_init("{k}","{encode_plugin_info(k, plugins[k])}");"""
|
register_advanced_plugin_init_arr += f"""register_plugin_init("{k}","{encode_plugin_info(k, plugins[k])}");"""
|
||||||
if plugins[k].get("Class", None):
|
if plugins[k].get("Class", None):
|
||||||
@@ -329,9 +338,9 @@ def main():
|
|||||||
# Gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
|
# Gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
|
||||||
def run_delayed_tasks():
|
def run_delayed_tasks():
|
||||||
import threading, webbrowser, time
|
import threading, webbrowser, time
|
||||||
print(f"如果浏览器没有自动打开,请复制并转到以下URL:")
|
logger.info(f"如果浏览器没有自动打开,请复制并转到以下URL:")
|
||||||
if DARK_MODE: print(f"\t「暗色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
if DARK_MODE: logger.info(f"\t「暗色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||||
else: print(f"\t「亮色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
else: logger.info(f"\t「亮色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
|
||||||
|
|
||||||
def auto_updates(): time.sleep(0); auto_update()
|
def auto_updates(): time.sleep(0); auto_update()
|
||||||
def open_browser(): time.sleep(2); webbrowser.open_new_tab(f"http://localhost:{PORT}")
|
def open_browser(): time.sleep(2); webbrowser.open_new_tab(f"http://localhost:{PORT}")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
2. predict_no_ui_long_connection(...)
|
2. predict_no_ui_long_connection(...)
|
||||||
"""
|
"""
|
||||||
import tiktoken, copy, re
|
import tiktoken, copy, re
|
||||||
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from toolbox import get_conf, trimmed_format_exc, apply_gpt_academic_string_mask, read_one_api_model_name
|
from toolbox import get_conf, trimmed_format_exc, apply_gpt_academic_string_mask, read_one_api_model_name
|
||||||
@@ -51,9 +52,9 @@ class LazyloadTiktoken(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def get_encoder(model):
|
def get_encoder(model):
|
||||||
print('正在加载tokenizer,如果是第一次运行,可能需要一点时间下载参数')
|
logger.info('正在加载tokenizer,如果是第一次运行,可能需要一点时间下载参数')
|
||||||
tmp = tiktoken.encoding_for_model(model)
|
tmp = tiktoken.encoding_for_model(model)
|
||||||
print('加载tokenizer完毕')
|
logger.info('加载tokenizer完毕')
|
||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
def encode(self, *args, **kwargs):
|
def encode(self, *args, **kwargs):
|
||||||
@@ -83,7 +84,7 @@ try:
|
|||||||
API_URL = get_conf("API_URL")
|
API_URL = get_conf("API_URL")
|
||||||
if API_URL != "https://api.openai.com/v1/chat/completions":
|
if API_URL != "https://api.openai.com/v1/chat/completions":
|
||||||
openai_endpoint = API_URL
|
openai_endpoint = API_URL
|
||||||
print("警告!API_URL配置选项将被弃用,请更换为API_URL_REDIRECT配置")
|
logger.warning("警告!API_URL配置选项将被弃用,请更换为API_URL_REDIRECT配置")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
# 新版配置
|
# 新版配置
|
||||||
@@ -201,6 +202,16 @@ model_info = {
|
|||||||
"token_cnt": get_token_num_gpt4,
|
"token_cnt": get_token_num_gpt4,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"gpt-4o-mini": {
|
||||||
|
"fn_with_ui": chatgpt_ui,
|
||||||
|
"fn_without_ui": chatgpt_noui,
|
||||||
|
"endpoint": openai_endpoint,
|
||||||
|
"has_multimodal_capacity": True,
|
||||||
|
"max_token": 128000,
|
||||||
|
"tokenizer": tokenizer_gpt4,
|
||||||
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
},
|
||||||
|
|
||||||
"gpt-4o-2024-05-13": {
|
"gpt-4o-2024-05-13": {
|
||||||
"fn_with_ui": chatgpt_ui,
|
"fn_with_ui": chatgpt_ui,
|
||||||
"fn_without_ui": chatgpt_noui,
|
"fn_without_ui": chatgpt_noui,
|
||||||
@@ -238,6 +249,27 @@ model_info = {
|
|||||||
"token_cnt": get_token_num_gpt4,
|
"token_cnt": get_token_num_gpt4,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"o1-preview": {
|
||||||
|
"fn_with_ui": chatgpt_ui,
|
||||||
|
"fn_without_ui": chatgpt_noui,
|
||||||
|
"endpoint": openai_endpoint,
|
||||||
|
"max_token": 128000,
|
||||||
|
"tokenizer": tokenizer_gpt4,
|
||||||
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
"openai_disable_system_prompt": True,
|
||||||
|
"openai_disable_stream": True,
|
||||||
|
},
|
||||||
|
"o1-mini": {
|
||||||
|
"fn_with_ui": chatgpt_ui,
|
||||||
|
"fn_without_ui": chatgpt_noui,
|
||||||
|
"endpoint": openai_endpoint,
|
||||||
|
"max_token": 128000,
|
||||||
|
"tokenizer": tokenizer_gpt4,
|
||||||
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
"openai_disable_system_prompt": True,
|
||||||
|
"openai_disable_stream": True,
|
||||||
|
},
|
||||||
|
|
||||||
"gpt-4-turbo": {
|
"gpt-4-turbo": {
|
||||||
"fn_with_ui": chatgpt_ui,
|
"fn_with_ui": chatgpt_ui,
|
||||||
"fn_without_ui": chatgpt_noui,
|
"fn_without_ui": chatgpt_noui,
|
||||||
@@ -258,7 +290,6 @@ model_info = {
|
|||||||
"token_cnt": get_token_num_gpt4,
|
"token_cnt": get_token_num_gpt4,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
||||||
"gpt-3.5-random": {
|
"gpt-3.5-random": {
|
||||||
"fn_with_ui": chatgpt_ui,
|
"fn_with_ui": chatgpt_ui,
|
||||||
"fn_without_ui": chatgpt_noui,
|
"fn_without_ui": chatgpt_noui,
|
||||||
@@ -398,22 +429,46 @@ model_info = {
|
|||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
|
# Gemini
|
||||||
|
# Note: now gemini-pro is an alias of gemini-1.0-pro.
|
||||||
|
# Warning: gemini-pro-vision has been deprecated.
|
||||||
|
# Support for gemini-pro-vision has been removed.
|
||||||
"gemini-pro": {
|
"gemini-pro": {
|
||||||
"fn_with_ui": genai_ui,
|
"fn_with_ui": genai_ui,
|
||||||
"fn_without_ui": genai_noui,
|
"fn_without_ui": genai_noui,
|
||||||
"endpoint": gemini_endpoint,
|
"endpoint": gemini_endpoint,
|
||||||
|
"has_multimodal_capacity": False,
|
||||||
"max_token": 1024 * 32,
|
"max_token": 1024 * 32,
|
||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
"gemini-pro-vision": {
|
"gemini-1.0-pro": {
|
||||||
"fn_with_ui": genai_ui,
|
"fn_with_ui": genai_ui,
|
||||||
"fn_without_ui": genai_noui,
|
"fn_without_ui": genai_noui,
|
||||||
"endpoint": gemini_endpoint,
|
"endpoint": gemini_endpoint,
|
||||||
|
"has_multimodal_capacity": False,
|
||||||
"max_token": 1024 * 32,
|
"max_token": 1024 * 32,
|
||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
|
"gemini-1.5-pro": {
|
||||||
|
"fn_with_ui": genai_ui,
|
||||||
|
"fn_without_ui": genai_noui,
|
||||||
|
"endpoint": gemini_endpoint,
|
||||||
|
"has_multimodal_capacity": True,
|
||||||
|
"max_token": 1024 * 204800,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
|
"gemini-1.5-flash": {
|
||||||
|
"fn_with_ui": genai_ui,
|
||||||
|
"fn_without_ui": genai_noui,
|
||||||
|
"endpoint": gemini_endpoint,
|
||||||
|
"has_multimodal_capacity": True,
|
||||||
|
"max_token": 1024 * 204800,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
|
|
||||||
# cohere
|
# cohere
|
||||||
"cohere-command-r-plus": {
|
"cohere-command-r-plus": {
|
||||||
@@ -475,7 +530,7 @@ for model in AVAIL_LLM_MODELS:
|
|||||||
|
|
||||||
# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=-
|
||||||
# claude家族
|
# claude家族
|
||||||
claude_models = ["claude-instant-1.2","claude-2.0","claude-2.1","claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229"]
|
claude_models = ["claude-instant-1.2","claude-2.0","claude-2.1","claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229","claude-3-5-sonnet-20240620"]
|
||||||
if any(item in claude_models for item in AVAIL_LLM_MODELS):
|
if any(item in claude_models for item in AVAIL_LLM_MODELS):
|
||||||
from .bridge_claude import predict_no_ui_long_connection as claude_noui
|
from .bridge_claude import predict_no_ui_long_connection as claude_noui
|
||||||
from .bridge_claude import predict as claude_ui
|
from .bridge_claude import predict as claude_ui
|
||||||
@@ -539,6 +594,16 @@ if any(item in claude_models for item in AVAIL_LLM_MODELS):
|
|||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
model_info.update({
|
||||||
|
"claude-3-5-sonnet-20240620": {
|
||||||
|
"fn_with_ui": claude_ui,
|
||||||
|
"fn_without_ui": claude_noui,
|
||||||
|
"endpoint": claude_endpoint,
|
||||||
|
"max_token": 200000,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
},
|
||||||
|
})
|
||||||
if "jittorllms_rwkv" in AVAIL_LLM_MODELS:
|
if "jittorllms_rwkv" in AVAIL_LLM_MODELS:
|
||||||
from .bridge_jittorllms_rwkv import predict_no_ui_long_connection as rwkv_noui
|
from .bridge_jittorllms_rwkv import predict_no_ui_long_connection as rwkv_noui
|
||||||
from .bridge_jittorllms_rwkv import predict as rwkv_ui
|
from .bridge_jittorllms_rwkv import predict as rwkv_ui
|
||||||
@@ -619,7 +684,7 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
||||||
try:
|
try:
|
||||||
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
||||||
@@ -635,7 +700,7 @@ if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 上海AI-LAB书生大模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 上海AI-LAB书生大模型 -=-=-=-=-=-=-
|
||||||
if "internlm" in AVAIL_LLM_MODELS:
|
if "internlm" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
@@ -652,7 +717,7 @@ if "internlm" in AVAIL_LLM_MODELS:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
if "chatglm_onnx" in AVAIL_LLM_MODELS:
|
if "chatglm_onnx" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
from .bridge_chatglmonnx import predict_no_ui_long_connection as chatglm_onnx_noui
|
from .bridge_chatglmonnx import predict_no_ui_long_connection as chatglm_onnx_noui
|
||||||
@@ -668,7 +733,7 @@ if "chatglm_onnx" in AVAIL_LLM_MODELS:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 通义-本地模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 通义-本地模型 -=-=-=-=-=-=-
|
||||||
if "qwen-local" in AVAIL_LLM_MODELS:
|
if "qwen-local" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
@@ -686,7 +751,7 @@ if "qwen-local" in AVAIL_LLM_MODELS:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 通义-在线模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 通义-在线模型 -=-=-=-=-=-=-
|
||||||
if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-max" in AVAIL_LLM_MODELS: # zhipuai
|
if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-max" in AVAIL_LLM_MODELS: # zhipuai
|
||||||
try:
|
try:
|
||||||
@@ -722,7 +787,7 @@ if "qwen-turbo" in AVAIL_LLM_MODELS or "qwen-plus" in AVAIL_LLM_MODELS or "qwen-
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 零一万物模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 零一万物模型 -=-=-=-=-=-=-
|
||||||
yi_models = ["yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview"]
|
yi_models = ["yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview"]
|
||||||
if any(item in yi_models for item in AVAIL_LLM_MODELS):
|
if any(item in yi_models for item in AVAIL_LLM_MODELS):
|
||||||
@@ -802,7 +867,7 @@ if any(item in yi_models for item in AVAIL_LLM_MODELS):
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 讯飞星火认知大模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 讯飞星火认知大模型 -=-=-=-=-=-=-
|
||||||
if "spark" in AVAIL_LLM_MODELS:
|
if "spark" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
@@ -820,7 +885,7 @@ if "spark" in AVAIL_LLM_MODELS:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||||
try:
|
try:
|
||||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||||
@@ -837,8 +902,8 @@ if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
if any(x in AVAIL_LLM_MODELS for x in ("sparkv3", "sparkv3.5", "sparkv4")): # 讯飞星火认知大模型
|
||||||
try:
|
try:
|
||||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||||
from .bridge_spark import predict as spark_ui
|
from .bridge_spark import predict as spark_ui
|
||||||
@@ -872,7 +937,7 @@ if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
if "llama2" in AVAIL_LLM_MODELS: # llama2
|
if "llama2" in AVAIL_LLM_MODELS: # llama2
|
||||||
try:
|
try:
|
||||||
from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui
|
from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui
|
||||||
@@ -888,7 +953,7 @@ if "llama2" in AVAIL_LLM_MODELS: # llama2
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 智谱 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 智谱 -=-=-=-=-=-=-
|
||||||
if "zhipuai" in AVAIL_LLM_MODELS: # zhipuai 是glm-4的别名,向后兼容配置
|
if "zhipuai" in AVAIL_LLM_MODELS: # zhipuai 是glm-4的别名,向后兼容配置
|
||||||
try:
|
try:
|
||||||
@@ -903,7 +968,7 @@ if "zhipuai" in AVAIL_LLM_MODELS: # zhipuai 是glm-4的别名,向后兼容
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 幻方-深度求索大模型 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 幻方-深度求索大模型 -=-=-=-=-=-=-
|
||||||
if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
||||||
try:
|
try:
|
||||||
@@ -920,7 +985,7 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
|
||||||
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
@@ -948,7 +1013,7 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- one-api 对齐支持 -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- one-api 对齐支持 -=-=-=-=-=-=-
|
||||||
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
|
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
|
||||||
# 为了更灵活地接入one-api多模型管理界面,设计了此接口,例子:AVAIL_LLM_MODELS = ["one-api-mixtral-8x7b(max_token=6666)"]
|
# 为了更灵活地接入one-api多模型管理界面,设计了此接口,例子:AVAIL_LLM_MODELS = ["one-api-mixtral-8x7b(max_token=6666)"]
|
||||||
@@ -961,7 +1026,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("one-api-")]:
|
|||||||
# 如果是已知模型,则尝试获取其信息
|
# 如果是已知模型,则尝试获取其信息
|
||||||
original_model_info = model_info.get(origin_model_name.replace("one-api-", "", 1), None)
|
original_model_info = model_info.get(origin_model_name.replace("one-api-", "", 1), None)
|
||||||
except:
|
except:
|
||||||
print(f"one-api模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
logger.error(f"one-api模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||||
continue
|
continue
|
||||||
this_model_info = {
|
this_model_info = {
|
||||||
"fn_with_ui": chatgpt_ui,
|
"fn_with_ui": chatgpt_ui,
|
||||||
@@ -992,7 +1057,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("vllm-")]:
|
|||||||
try:
|
try:
|
||||||
_, max_token_tmp = read_one_api_model_name(model)
|
_, max_token_tmp = read_one_api_model_name(model)
|
||||||
except:
|
except:
|
||||||
print(f"vllm模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
logger.error(f"vllm模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||||
continue
|
continue
|
||||||
model_info.update({
|
model_info.update({
|
||||||
model: {
|
model: {
|
||||||
@@ -1019,7 +1084,7 @@ for model in [m for m in AVAIL_LLM_MODELS if m.startswith("ollama-")]:
|
|||||||
try:
|
try:
|
||||||
_, max_token_tmp = read_one_api_model_name(model)
|
_, max_token_tmp = read_one_api_model_name(model)
|
||||||
except:
|
except:
|
||||||
print(f"ollama模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
logger.error(f"ollama模型 {model} 的 max_token 配置不是整数,请检查配置文件。")
|
||||||
continue
|
continue
|
||||||
model_info.update({
|
model_info.update({
|
||||||
model: {
|
model: {
|
||||||
@@ -1055,6 +1120,24 @@ if len(AZURE_CFG_ARRAY) > 0:
|
|||||||
if azure_model_name not in AVAIL_LLM_MODELS:
|
if azure_model_name not in AVAIL_LLM_MODELS:
|
||||||
AVAIL_LLM_MODELS += [azure_model_name]
|
AVAIL_LLM_MODELS += [azure_model_name]
|
||||||
|
|
||||||
|
# -=-=-=-=-=-=- Openrouter模型对齐支持 -=-=-=-=-=-=-
|
||||||
|
# 为了更灵活地接入Openrouter路由,设计了此接口
|
||||||
|
for model in [m for m in AVAIL_LLM_MODELS if m.startswith("openrouter-")]:
|
||||||
|
from request_llms.bridge_openrouter import predict_no_ui_long_connection as openrouter_noui
|
||||||
|
from request_llms.bridge_openrouter import predict as openrouter_ui
|
||||||
|
model_info.update({
|
||||||
|
model: {
|
||||||
|
"fn_with_ui": openrouter_ui,
|
||||||
|
"fn_without_ui": openrouter_noui,
|
||||||
|
# 以下参数参考gpt-4o-mini的配置, 请根据实际情况修改
|
||||||
|
"endpoint": openai_endpoint,
|
||||||
|
"has_multimodal_capacity": True,
|
||||||
|
"max_token": 128000,
|
||||||
|
"tokenizer": tokenizer_gpt4,
|
||||||
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
|
# -=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=--=-=-=-=-=-=-=-=
|
||||||
# -=-=-=-=-=-=-=-=-=- ☝️ 以上是模型路由 -=-=-=-=-=-=-=-=-=
|
# -=-=-=-=-=-=-=-=-=- ☝️ 以上是模型路由 -=-=-=-=-=-=-=-=-=
|
||||||
@@ -1200,5 +1283,5 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot,
|
|||||||
if additional_fn: # 根据基础功能区 ModelOverride 参数调整模型类型
|
if additional_fn: # 根据基础功能区 ModelOverride 参数调整模型类型
|
||||||
llm_kwargs, additional_fn, method = execute_model_override(llm_kwargs, additional_fn, method)
|
llm_kwargs, additional_fn, method = execute_model_override(llm_kwargs, additional_fn, method)
|
||||||
|
|
||||||
|
# 更新一下llm_kwargs的参数,否则会出现参数不匹配的问题
|
||||||
yield from method(inputs, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, stream, additional_fn)
|
yield from method(inputs, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, stream, additional_fn)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class GetGLM3Handle(LocalLLMHandle):
|
|||||||
|
|
||||||
def load_model_and_tokenizer(self):
|
def load_model_and_tokenizer(self):
|
||||||
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
|
||||||
import os, glob
|
import os, glob
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@@ -45,15 +45,13 @@ class GetGLM3Handle(LocalLLMHandle):
|
|||||||
chatglm_model = AutoModel.from_pretrained(
|
chatglm_model = AutoModel.from_pretrained(
|
||||||
pretrained_model_name_or_path=_model_name_,
|
pretrained_model_name_or_path=_model_name_,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
device="cuda",
|
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
||||||
load_in_4bit=True,
|
|
||||||
)
|
)
|
||||||
elif LOCAL_MODEL_QUANT == "INT8": # INT8
|
elif LOCAL_MODEL_QUANT == "INT8": # INT8
|
||||||
chatglm_model = AutoModel.from_pretrained(
|
chatglm_model = AutoModel.from_pretrained(
|
||||||
pretrained_model_name_or_path=_model_name_,
|
pretrained_model_name_or_path=_model_name_,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
device="cuda",
|
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||||
load_in_8bit=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chatglm_model = AutoModel.from_pretrained(
|
chatglm_model = AutoModel.from_pretrained(
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
|
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from loguru import logger
|
||||||
|
from toolbox import update_ui, get_conf
|
||||||
|
from multiprocessing import Process, Pipe
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import importlib
|
import importlib
|
||||||
from toolbox import update_ui, get_conf
|
|
||||||
from multiprocessing import Process, Pipe
|
|
||||||
|
|
||||||
load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
|
load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
|
||||||
|
|
||||||
@@ -78,7 +79,7 @@ class GetGLMFTHandle(Process):
|
|||||||
config.pre_seq_len = model_args['pre_seq_len']
|
config.pre_seq_len = model_args['pre_seq_len']
|
||||||
config.prefix_projection = model_args['prefix_projection']
|
config.prefix_projection = model_args['prefix_projection']
|
||||||
|
|
||||||
print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
|
logger.info(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
|
||||||
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
|
||||||
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
|
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
|
||||||
new_prefix_state_dict = {}
|
new_prefix_state_dict = {}
|
||||||
@@ -88,7 +89,7 @@ class GetGLMFTHandle(Process):
|
|||||||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
|
|
||||||
if model_args['quantization_bit'] is not None and model_args['quantization_bit'] != 0:
|
if model_args['quantization_bit'] is not None and model_args['quantization_bit'] != 0:
|
||||||
print(f"Quantized to {model_args['quantization_bit']} bit")
|
logger.info(f"Quantized to {model_args['quantization_bit']} bit")
|
||||||
model = model.quantize(model_args['quantization_bit'])
|
model = model.quantize(model_args['quantization_bit'])
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
if model_args['pre_seq_len'] is not None:
|
if model_args['pre_seq_len'] is not None:
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
||||||
@@ -133,21 +134,32 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
|||||||
observe_window = None:
|
observe_window = None:
|
||||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||||
"""
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
|
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
|
||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=False
|
# make a POST request to the API endpoint, stream=False
|
||||||
from .bridge_all import model_info
|
|
||||||
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
|
||||||
except requests.exceptions.ReadTimeout as e:
|
except requests.exceptions.ReadTimeout as e:
|
||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
return gpt_replying_buffer
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ''
|
result = ''
|
||||||
@@ -190,10 +202,13 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
|||||||
if (time.time()-observe_window[1]) > watch_dog_patience:
|
if (time.time()-observe_window[1]) > watch_dog_patience:
|
||||||
raise RuntimeError("用户取消了程序。")
|
raise RuntimeError("用户取消了程序。")
|
||||||
else: raise RuntimeError("意外Json结构:"+delta)
|
else: raise RuntimeError("意外Json结构:"+delta)
|
||||||
if json_data and json_data['finish_reason'] == 'content_filter':
|
|
||||||
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
|
finish_reason = json_data.get('finish_reason', None) if json_data else None
|
||||||
if json_data and json_data['finish_reason'] == 'length':
|
if finish_reason == 'content_filter':
|
||||||
|
raise RuntimeError("由于提问含不合规内容被过滤。")
|
||||||
|
if finish_reason == 'length':
|
||||||
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -208,7 +223,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||||
"""
|
"""
|
||||||
from .bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
if is_any_api_key(inputs):
|
if is_any_api_key(inputs):
|
||||||
chatbot._cookies['api_key'] = inputs
|
chatbot._cookies['api_key'] = inputs
|
||||||
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
||||||
@@ -237,6 +252,10 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot.append((_inputs, ""))
|
chatbot.append((_inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||||
|
|
||||||
|
# 禁用stream的特殊模型处理
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
# check mis-behavior
|
# check mis-behavior
|
||||||
if is_the_upload_folder(user_input):
|
if is_the_upload_folder(user_input):
|
||||||
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
||||||
@@ -270,7 +289,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=True
|
# make a POST request to the API endpoint, stream=True
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
|
||||||
except:
|
except:
|
||||||
retry += 1
|
retry += 1
|
||||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||||
@@ -278,10 +297,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
|
||||||
gpt_replying_buffer = ""
|
|
||||||
|
|
||||||
is_head_of_the_stream = True
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
|
||||||
|
return
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
gpt_replying_buffer = ""
|
||||||
|
is_head_of_the_stream = True
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -317,7 +341,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||||
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
# logging.info(f'[response] {gpt_replying_buffer}')
|
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
@@ -343,12 +366,24 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
return
|
return
|
||||||
|
return # return from stream-branch
|
||||||
|
|
||||||
|
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
||||||
|
try:
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
except Exception as e:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||||
from .bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||||
if "reduce the length" in error_msg:
|
if "reduce the length" in error_msg:
|
||||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||||
@@ -381,6 +416,8 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
"""
|
"""
|
||||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||||
"""
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
if not is_any_api_key(llm_kwargs['api_key']):
|
if not is_any_api_key(llm_kwargs['api_key']):
|
||||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||||
|
|
||||||
@@ -409,10 +446,16 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
else:
|
else:
|
||||||
enable_multimodal_capacity = False
|
enable_multimodal_capacity = False
|
||||||
|
|
||||||
|
conversation_cnt = len(history) // 2
|
||||||
|
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)
|
||||||
|
|
||||||
|
if openai_disable_system_prompt:
|
||||||
|
messages = [{"role": "user", "content": system_prompt}]
|
||||||
|
else:
|
||||||
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
if not enable_multimodal_capacity:
|
if not enable_multimodal_capacity:
|
||||||
# 不使用多模态能力
|
# 不使用多模态能力
|
||||||
conversation_cnt = len(history) // 2
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2*conversation_cnt, 2):
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
what_i_have_asked = {}
|
what_i_have_asked = {}
|
||||||
@@ -434,8 +477,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
messages.append(what_i_ask_now)
|
messages.append(what_i_ask_now)
|
||||||
else:
|
else:
|
||||||
# 多模态能力
|
# 多模态能力
|
||||||
conversation_cnt = len(history) // 2
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2*conversation_cnt, 2):
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
what_i_have_asked = {}
|
what_i_have_asked = {}
|
||||||
@@ -486,7 +527,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
"gpt-3.5-turbo-16k-0613",
|
"gpt-3.5-turbo-16k-0613",
|
||||||
"gpt-3.5-turbo-0301",
|
"gpt-3.5-turbo-0301",
|
||||||
])
|
])
|
||||||
logging.info("Random select model:" + model)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -496,10 +536,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
"n": 1,
|
"n": 1,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
|
|
||||||
except:
|
|
||||||
print('输入中可能存在乱码。')
|
|
||||||
return headers,payload
|
return headers,payload
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,15 @@
|
|||||||
2. predict_no_ui_long_connection:支持多线程
|
2. predict_no_ui_long_connection:支持多线程
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
import requests
|
import requests
|
||||||
import base64
|
import base64
|
||||||
import os
|
|
||||||
import glob
|
import glob
|
||||||
|
from loguru import logger
|
||||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, \
|
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, \
|
||||||
update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files
|
update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files, log_chat
|
||||||
|
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
|
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
|
||||||
@@ -100,7 +100,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
raw_input = inputs
|
raw_input = inputs
|
||||||
logging.info(f'[raw_input] {raw_input}')
|
|
||||||
def make_media_input(inputs, image_paths):
|
def make_media_input(inputs, image_paths):
|
||||||
for image_path in image_paths:
|
for image_path in image_paths:
|
||||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
||||||
@@ -185,7 +184,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
lastmsg = chatbot[-1][-1] + f"\n\n\n\n「{llm_kwargs['llm_model']}调用结束,该模型不具备上下文对话能力,如需追问,请及时切换模型。」"
|
lastmsg = chatbot[-1][-1] + f"\n\n\n\n「{llm_kwargs['llm_model']}调用结束,该模型不具备上下文对话能力,如需追问,请及时切换模型。」"
|
||||||
yield from update_ui_lastest_msg(lastmsg, chatbot, history, delay=1)
|
yield from update_ui_lastest_msg(lastmsg, chatbot, history, delay=1)
|
||||||
logging.info(f'[response] {gpt_replying_buffer}')
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||||
@@ -210,7 +209,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key)
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key=""):
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg, api_key=""):
|
||||||
@@ -301,10 +300,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths):
|
|||||||
"presence_penalty": 0,
|
"presence_penalty": 0,
|
||||||
"frequency_penalty": 0,
|
"frequency_penalty": 0,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
print(f" {llm_kwargs['llm_model']} : {inputs[:100]} ..........")
|
|
||||||
except:
|
|
||||||
print('输入中可能存在乱码。')
|
|
||||||
return headers, payload, api_key
|
return headers, payload, api_key
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,281 +0,0 @@
|
|||||||
# 借鉴了 https://github.com/GaiZhenbiao/ChuanhuChatGPT 项目
|
|
||||||
|
|
||||||
"""
|
|
||||||
该文件中主要包含三个函数
|
|
||||||
|
|
||||||
不具备多线程能力的函数:
|
|
||||||
1. predict: 正常对话时使用,具备完备的交互功能,不可多线程
|
|
||||||
|
|
||||||
具备多线程调用能力的函数
|
|
||||||
2. predict_no_ui_long_connection:支持多线程
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import gradio as gr
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
import requests
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
|
||||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc
|
|
||||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG = \
|
|
||||||
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG')
|
|
||||||
|
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
|
||||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
|
||||||
|
|
||||||
def get_full_error(chunk, stream_response):
|
|
||||||
"""
|
|
||||||
获取完整的从Openai返回的报错
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk += next(stream_response)
|
|
||||||
except:
|
|
||||||
break
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
|
|
||||||
"""
|
|
||||||
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
|
||||||
inputs:
|
|
||||||
是本次问询的输入
|
|
||||||
sys_prompt:
|
|
||||||
系统静默prompt
|
|
||||||
llm_kwargs:
|
|
||||||
chatGPT的内部调优参数
|
|
||||||
history:
|
|
||||||
是之前的对话列表
|
|
||||||
observe_window = None:
|
|
||||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
|
||||||
"""
|
|
||||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
|
||||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
|
||||||
retry = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# make a POST request to the API endpoint, stream=False
|
|
||||||
from .bridge_all import model_info
|
|
||||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
|
||||||
except requests.exceptions.ReadTimeout as e:
|
|
||||||
retry += 1
|
|
||||||
traceback.print_exc()
|
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
|
||||||
result = ''
|
|
||||||
while True:
|
|
||||||
try: chunk = next(stream_response).decode()
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
chunk = next(stream_response).decode() # 失败了,重试一次?再失败就没办法了。
|
|
||||||
if len(chunk)==0: continue
|
|
||||||
if not chunk.startswith('data:'):
|
|
||||||
error_msg = get_full_error(chunk.encode('utf8'), stream_response).decode()
|
|
||||||
if "reduce the length" in error_msg:
|
|
||||||
raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("OpenAI拒绝了请求:" + error_msg)
|
|
||||||
if ('data: [DONE]' in chunk): break # api2d 正常完成
|
|
||||||
json_data = json.loads(chunk.lstrip('data:'))['choices'][0]
|
|
||||||
delta = json_data["delta"]
|
|
||||||
if len(delta) == 0: break
|
|
||||||
if "role" in delta: continue
|
|
||||||
if "content" in delta:
|
|
||||||
result += delta["content"]
|
|
||||||
if not console_slience: print(delta["content"], end='')
|
|
||||||
if observe_window is not None:
|
|
||||||
# 观测窗,把已经获取的数据显示出去
|
|
||||||
if len(observe_window) >= 1: observe_window[0] += delta["content"]
|
|
||||||
# 看门狗,如果超过期限没有喂狗,则终止
|
|
||||||
if len(observe_window) >= 2:
|
|
||||||
if (time.time()-observe_window[1]) > watch_dog_patience:
|
|
||||||
raise RuntimeError("用户取消了程序。")
|
|
||||||
else: raise RuntimeError("意外Json结构:"+delta)
|
|
||||||
if json_data['finish_reason'] == 'content_filter':
|
|
||||||
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
|
|
||||||
if json_data['finish_reason'] == 'length':
|
|
||||||
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
|
||||||
"""
|
|
||||||
发送至chatGPT,流式获取输出。
|
|
||||||
用于基础的对话功能。
|
|
||||||
inputs 是本次问询的输入
|
|
||||||
top_p, temperature是chatGPT的内部调优参数
|
|
||||||
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
|
||||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
|
||||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
|
||||||
"""
|
|
||||||
if additional_fn is not None:
|
|
||||||
from core_functional import handle_core_functionality
|
|
||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
|
||||||
|
|
||||||
raw_input = inputs
|
|
||||||
logging.info(f'[raw_input] {raw_input}')
|
|
||||||
chatbot.append((inputs, ""))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
|
||||||
|
|
||||||
try:
|
|
||||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
|
||||||
except RuntimeError as e:
|
|
||||||
chatbot[-1] = (inputs, f"您提供的api-key不满足要求,不包含任何可用于{llm_kwargs['llm_model']}的api-key。您可能选择了错误的模型或请求源。")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
history.append(inputs); history.append("")
|
|
||||||
|
|
||||||
retry = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# make a POST request to the API endpoint, stream=True
|
|
||||||
from .bridge_all import model_info
|
|
||||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
|
||||||
except:
|
|
||||||
retry += 1
|
|
||||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
|
||||||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
|
||||||
|
|
||||||
gpt_replying_buffer = ""
|
|
||||||
|
|
||||||
is_head_of_the_stream = True
|
|
||||||
if stream:
|
|
||||||
stream_response = response.iter_lines()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = next(stream_response)
|
|
||||||
except StopIteration:
|
|
||||||
# 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
|
|
||||||
chunk_decoded = chunk.decode()
|
|
||||||
error_msg = chunk_decoded
|
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="非Openai官方接口返回了错误:" + chunk.decode()) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# print(chunk.decode()[6:])
|
|
||||||
if is_head_of_the_stream and (r'"object":"error"' not in chunk.decode()):
|
|
||||||
# 数据流的第一帧不携带content
|
|
||||||
is_head_of_the_stream = False; continue
|
|
||||||
|
|
||||||
if chunk:
|
|
||||||
try:
|
|
||||||
chunk_decoded = chunk.decode()
|
|
||||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
|
||||||
if 'data: [DONE]' in chunk_decoded:
|
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
|
||||||
logging.info(f'[response] {gpt_replying_buffer}')
|
|
||||||
break
|
|
||||||
# 处理数据流的主体
|
|
||||||
chunkjson = json.loads(chunk_decoded[6:])
|
|
||||||
status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
|
|
||||||
delta = chunkjson['choices'][0]["delta"]
|
|
||||||
if "content" in delta:
|
|
||||||
gpt_replying_buffer = gpt_replying_buffer + delta["content"]
|
|
||||||
history[-1] = gpt_replying_buffer
|
|
||||||
chatbot[-1] = (history[-2], history[-1])
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
|
|
||||||
except Exception as e:
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
|
|
||||||
chunk = get_full_error(chunk, stream_response)
|
|
||||||
chunk_decoded = chunk.decode()
|
|
||||||
error_msg = chunk_decoded
|
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
|
||||||
print(error_msg)
|
|
||||||
return
|
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
|
||||||
from .bridge_all import model_info
|
|
||||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
|
||||||
if "reduce the length" in error_msg:
|
|
||||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
|
||||||
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
|
|
||||||
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
|
|
||||||
# history = [] # 清除历史
|
|
||||||
elif "does not exist" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
|
|
||||||
elif "Incorrect API key" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务. " + openai_website)
|
|
||||||
elif "exceeded your current quota" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务." + openai_website)
|
|
||||||
elif "account is not active" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Your account is not active. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
|
||||||
elif "associated with a deactivated account" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] You are associated with a deactivated account. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
|
||||||
elif "bad forward key" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
|
||||||
elif "Not enough point" in error_msg:
|
|
||||||
chatbot[-1] = (chatbot[-1][0], "[Local Message] Not enough point. API2D账户点数不足.")
|
|
||||||
else:
|
|
||||||
from toolbox import regular_txt_to_markdown
|
|
||||||
tb_str = '```\n' + trimmed_format_exc() + '```'
|
|
||||||
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
|
|
||||||
return chatbot, history
|
|
||||||
|
|
||||||
def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
|
|
||||||
"""
|
|
||||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
|
||||||
"""
|
|
||||||
if not is_any_api_key(llm_kwargs['api_key']):
|
|
||||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
conversation_cnt = len(history) // 2
|
|
||||||
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
|
||||||
if conversation_cnt:
|
|
||||||
for index in range(0, 2*conversation_cnt, 2):
|
|
||||||
what_i_have_asked = {}
|
|
||||||
what_i_have_asked["role"] = "user"
|
|
||||||
what_i_have_asked["content"] = history[index]
|
|
||||||
what_gpt_answer = {}
|
|
||||||
what_gpt_answer["role"] = "assistant"
|
|
||||||
what_gpt_answer["content"] = history[index+1]
|
|
||||||
if what_i_have_asked["content"] != "":
|
|
||||||
if what_gpt_answer["content"] == "": continue
|
|
||||||
if what_gpt_answer["content"] == timeout_bot_msg: continue
|
|
||||||
messages.append(what_i_have_asked)
|
|
||||||
messages.append(what_gpt_answer)
|
|
||||||
else:
|
|
||||||
messages[-1]['content'] = what_gpt_answer['content']
|
|
||||||
|
|
||||||
what_i_ask_now = {}
|
|
||||||
what_i_ask_now["role"] = "user"
|
|
||||||
what_i_ask_now["content"] = inputs
|
|
||||||
messages.append(what_i_ask_now)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": llm_kwargs['llm_model'].strip('api2d-'),
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": llm_kwargs['temperature'], # 1.0,
|
|
||||||
"top_p": llm_kwargs['top_p'], # 1.0,
|
|
||||||
"n": 1,
|
|
||||||
"stream": stream,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
|
|
||||||
except:
|
|
||||||
print('输入中可能存在乱码。')
|
|
||||||
return headers,payload
|
|
||||||
|
|
||||||
|
|
||||||
@@ -9,15 +9,16 @@
|
|||||||
具备多线程调用能力的函数
|
具备多线程调用能力的函数
|
||||||
2. predict_no_ui_long_connection:支持多线程
|
2. predict_no_ui_long_connection:支持多线程
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
from loguru import logger
|
||||||
from toolbox import get_conf, update_ui, trimmed_format_exc, encode_image, every_image_file_in_path, log_chat
|
from toolbox import get_conf, update_ui, trimmed_format_exc, encode_image, every_image_file_in_path, log_chat
|
||||||
|
|
||||||
picture_system_prompt = "\n当回复图像时,必须说明正在回复哪张图像。所有图像仅在最后一个问题中提供,即使它们在历史记录中被提及。请使用'这是第X张图像:'的格式来指明您正在描述的是哪张图像。"
|
picture_system_prompt = "\n当回复图像时,必须说明正在回复哪张图像。所有图像仅在最后一个问题中提供,即使它们在历史记录中被提及。请使用'这是第X张图像:'的格式来指明您正在描述的是哪张图像。"
|
||||||
Claude_3_Models = ["claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229"]
|
Claude_3_Models = ["claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620"]
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
@@ -101,7 +102,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ''
|
result = ''
|
||||||
while True:
|
while True:
|
||||||
@@ -116,12 +117,11 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
if need_to_pass:
|
if need_to_pass:
|
||||||
pass
|
pass
|
||||||
elif is_last_chunk:
|
elif is_last_chunk:
|
||||||
# logging.info(f'[response] {result}')
|
# logger.info(f'[response] {result}')
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
||||||
result += chunkjson['delta']['text']
|
result += chunkjson['delta']['text']
|
||||||
print(chunkjson['delta']['text'], end='')
|
|
||||||
if observe_window is not None:
|
if observe_window is not None:
|
||||||
# 观测窗,把已经获取的数据显示出去
|
# 观测窗,把已经获取的数据显示出去
|
||||||
if len(observe_window) >= 1:
|
if len(observe_window) >= 1:
|
||||||
@@ -134,7 +134,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
chunk = get_full_error(chunk, stream_response)
|
chunk = get_full_error(chunk, stream_response)
|
||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -200,7 +200,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
gpt_replying_buffer = ""
|
gpt_replying_buffer = ""
|
||||||
|
|
||||||
@@ -217,7 +217,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
pass
|
pass
|
||||||
elif is_last_chunk:
|
elif is_last_chunk:
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
# logging.info(f'[response] {gpt_replying_buffer}')
|
# logger.info(f'[response] {gpt_replying_buffer}')
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
if chunkjson and chunkjson['type'] == 'content_block_delta':
|
||||||
@@ -230,7 +230,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
chunk = get_full_error(chunk, stream_response)
|
chunk = get_full_error(chunk, stream_response)
|
||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
|
|
||||||
def multiple_picture_types(image_paths):
|
def multiple_picture_types(image_paths):
|
||||||
|
|||||||
@@ -13,11 +13,9 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
import importlib
|
from loguru import logger
|
||||||
import random
|
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
@@ -98,7 +96,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
|||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ''
|
result = ''
|
||||||
@@ -153,7 +151,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
raw_input = inputs
|
raw_input = inputs
|
||||||
# logging.info(f'[raw_input] {raw_input}')
|
# logger.info(f'[raw_input] {raw_input}')
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||||
|
|
||||||
@@ -237,7 +235,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
model_name = "deepseek-coder-6.7b-instruct"
|
model_name = "deepseek-coder-6.7b-instruct"
|
||||||
cmd_to_install = "未知" # "`pip install -r request_llms/requirements_qwen.txt`"
|
cmd_to_install = "未知" # "`pip install -r request_llms/requirements_qwen.txt`"
|
||||||
|
|
||||||
import os
|
|
||||||
from toolbox import ProxyNetworkActivate
|
from toolbox import ProxyNetworkActivate
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
|
from request_llms.local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
def download_huggingface_model(model_name, max_retry, local_dir):
|
def download_huggingface_model(model_name, max_retry, local_dir):
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@@ -15,7 +16,7 @@ def download_huggingface_model(model_name, max_retry, local_dir):
|
|||||||
snapshot_download(repo_id=model_name, local_dir=local_dir, resume_download=True)
|
snapshot_download(repo_id=model_name, local_dir=local_dir, resume_download=True)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'\n\n下载失败,重试第{i}次中...\n\n')
|
logger.error(f'\n\n下载失败,重试第{i}次中...\n\n')
|
||||||
return local_dir
|
return local_dir
|
||||||
# ------------------------------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------------------------------
|
||||||
# 🔌💻 Local Model
|
# 🔌💻 Local Model
|
||||||
@@ -112,7 +113,6 @@ class GetCoderLMHandle(LocalLLMHandle):
|
|||||||
generated_text = ""
|
generated_text = ""
|
||||||
for new_text in self._streamer:
|
for new_text in self._streamer:
|
||||||
generated_text += new_text
|
generated_text += new_text
|
||||||
# print(generated_text)
|
|
||||||
yield generated_text
|
yield generated_text
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ import os
|
|||||||
import time
|
import time
|
||||||
from request_llms.com_google import GoogleChatInit
|
from request_llms.com_google import GoogleChatInit
|
||||||
from toolbox import ChatBotWithCookies
|
from toolbox import ChatBotWithCookies
|
||||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat
|
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat, encode_image
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||||
|
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], sys_prompt:str="", observe_window:list=[],
|
||||||
console_slience=False):
|
console_slience:bool=False):
|
||||||
# 检查API_KEY
|
# 检查API_KEY
|
||||||
if get_conf("GEMINI_API_KEY") == "":
|
if get_conf("GEMINI_API_KEY") == "":
|
||||||
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
||||||
@@ -44,9 +44,20 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
||||||
return gpt_replying_buffer
|
return gpt_replying_buffer
|
||||||
|
|
||||||
|
def make_media_input(inputs, image_paths):
|
||||||
|
image_base64_array = []
|
||||||
|
for image_path in image_paths:
|
||||||
|
path = os.path.abspath(image_path)
|
||||||
|
inputs = inputs + f'<br/><br/><div align="center"><img src="file={path}"></div>'
|
||||||
|
base64 = encode_image(path)
|
||||||
|
image_base64_array.append(base64)
|
||||||
|
return inputs, image_base64_array
|
||||||
|
|
||||||
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
|
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
|
||||||
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
|
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
|
||||||
|
|
||||||
|
from .bridge_all import model_info
|
||||||
|
|
||||||
# 检查API_KEY
|
# 检查API_KEY
|
||||||
if get_conf("GEMINI_API_KEY") == "":
|
if get_conf("GEMINI_API_KEY") == "":
|
||||||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||||
@@ -57,18 +68,17 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
from core_functional import handle_core_functionality
|
from core_functional import handle_core_functionality
|
||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
if "vision" in llm_kwargs["llm_model"]:
|
# multimodal capacity
|
||||||
have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
|
# inspired by codes in bridge_chatgpt
|
||||||
if not have_recent_file:
|
has_multimodal_capacity = model_info[llm_kwargs['llm_model']].get('has_multimodal_capacity', False)
|
||||||
chatbot.append((inputs, "没有检测到任何近期上传的图像文件,请上传jpg格式的图片,此外,请注意拓展名需要小写"))
|
if has_multimodal_capacity:
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待图片") # 刷新界面
|
has_recent_image_upload, image_paths = have_any_recent_upload_image_files(chatbot, pop=True)
|
||||||
return
|
else:
|
||||||
def make_media_input(inputs, image_paths):
|
has_recent_image_upload, image_paths = False, []
|
||||||
for image_path in image_paths:
|
if has_recent_image_upload:
|
||||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
inputs, image_base64_array = make_media_input(inputs, image_paths)
|
||||||
return inputs
|
else:
|
||||||
if have_recent_file:
|
inputs, image_base64_array = inputs, []
|
||||||
inputs = make_media_input(inputs, image_paths)
|
|
||||||
|
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -76,7 +86,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
|
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
retry += 1
|
retry += 1
|
||||||
@@ -112,7 +122,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
llm_kwargs = {'llm_model': 'gemini-pro'}
|
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ class GetInternlmHandle(LocalLLMHandle):
|
|||||||
|
|
||||||
def llm_stream_generator(self, **kwargs):
|
def llm_stream_generator(self, **kwargs):
|
||||||
import torch
|
import torch
|
||||||
import logging
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from loguru import logger as logging
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
||||||
|
|
||||||
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
||||||
@@ -119,7 +119,7 @@ class GetInternlmHandle(LocalLLMHandle):
|
|||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
|
|
||||||
from toolbox import get_conf, update_ui, log_chat
|
from toolbox import get_conf, update_ui, log_chat
|
||||||
from toolbox import ChatBotWithCookies
|
from toolbox import ChatBotWithCookies
|
||||||
|
|||||||
@@ -13,11 +13,11 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
import importlib
|
import importlib
|
||||||
import random
|
import random
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
@@ -81,7 +81,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ''
|
result = ''
|
||||||
@@ -96,7 +96,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
try:
|
try:
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
logging.info(f'[response] {result}')
|
logger.info(f'[response] {result}')
|
||||||
break
|
break
|
||||||
result += chunkjson['message']["content"]
|
result += chunkjson['message']["content"]
|
||||||
if not console_slience: print(chunkjson['message']["content"], end='')
|
if not console_slience: print(chunkjson['message']["content"], end='')
|
||||||
@@ -112,7 +112,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
chunk = get_full_error(chunk, stream_response)
|
chunk = get_full_error(chunk, stream_response)
|
||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -134,7 +134,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
raw_input = inputs
|
raw_input = inputs
|
||||||
logging.info(f'[raw_input] {raw_input}')
|
logger.info(f'[raw_input] {raw_input}')
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
try:
|
try:
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
logging.info(f'[response] {gpt_replying_buffer}')
|
logger.info(f'[response] {gpt_replying_buffer}')
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
try:
|
try:
|
||||||
@@ -202,7 +202,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||||
@@ -265,8 +265,5 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"options": options,
|
"options": options,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
|
|
||||||
except:
|
|
||||||
print('输入中可能存在乱码。')
|
|
||||||
return headers,payload
|
return headers,payload
|
||||||
|
|||||||
541
request_llms/bridge_openrouter.py
Normal file
541
request_llms/bridge_openrouter.py
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
"""
|
||||||
|
该文件中主要包含三个函数
|
||||||
|
|
||||||
|
不具备多线程能力的函数:
|
||||||
|
1. predict: 正常对话时使用,具备完备的交互功能,不可多线程
|
||||||
|
|
||||||
|
具备多线程调用能力的函数
|
||||||
|
2. predict_no_ui_long_connection:支持多线程
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import requests
|
||||||
|
import random
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
|
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
||||||
|
from toolbox import trimmed_format_exc, is_the_upload_folder, read_one_api_model_name, log_chat
|
||||||
|
from toolbox import ChatBotWithCookies, have_any_recent_upload_image_files, encode_image
|
||||||
|
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
|
||||||
|
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
|
||||||
|
|
||||||
|
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||||
|
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||||
|
|
||||||
|
def get_full_error(chunk, stream_response):
|
||||||
|
"""
|
||||||
|
获取完整的从Openai返回的报错
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk += next(stream_response)
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def make_multimodal_input(inputs, image_paths):
|
||||||
|
image_base64_array = []
|
||||||
|
for image_path in image_paths:
|
||||||
|
path = os.path.abspath(image_path)
|
||||||
|
base64 = encode_image(path)
|
||||||
|
inputs = inputs + f'<br/><br/><div align="center"><img src="file={path}" base64="{base64}"></div>'
|
||||||
|
image_base64_array.append(base64)
|
||||||
|
return inputs, image_base64_array
|
||||||
|
|
||||||
|
def reverse_base64_from_input(inputs):
|
||||||
|
# 定义一个正则表达式来匹配 Base64 字符串(假设格式为 base64="<Base64编码>")
|
||||||
|
# pattern = re.compile(r'base64="([^"]+)"></div>')
|
||||||
|
pattern = re.compile(r'<br/><br/><div align="center"><img[^<>]+base64="([^"]+)"></div>')
|
||||||
|
# 使用 findall 方法查找所有匹配的 Base64 字符串
|
||||||
|
base64_strings = pattern.findall(inputs)
|
||||||
|
# 返回反转后的 Base64 字符串列表
|
||||||
|
return base64_strings
|
||||||
|
|
||||||
|
def contain_base64(inputs):
|
||||||
|
base64_strings = reverse_base64_from_input(inputs)
|
||||||
|
return len(base64_strings) > 0
|
||||||
|
|
||||||
|
def append_image_if_contain_base64(inputs):
|
||||||
|
if not contain_base64(inputs):
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
image_base64_array = reverse_base64_from_input(inputs)
|
||||||
|
pattern = re.compile(r'<br/><br/><div align="center"><img[^><]+></div>')
|
||||||
|
inputs = re.sub(pattern, '', inputs)
|
||||||
|
res = []
|
||||||
|
res.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": inputs
|
||||||
|
})
|
||||||
|
for image_base64 in image_base64_array:
|
||||||
|
res.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return res
|
||||||
|
|
||||||
|
def remove_image_if_contain_base64(inputs):
|
||||||
|
if not contain_base64(inputs):
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
pattern = re.compile(r'<br/><br/><div align="center"><img[^><]+></div>')
|
||||||
|
inputs = re.sub(pattern, '', inputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def decode_chunk(chunk):
|
||||||
|
# 提前读取一些信息 (用于判断异常)
|
||||||
|
chunk_decoded = chunk.decode()
|
||||||
|
chunkjson = None
|
||||||
|
has_choices = False
|
||||||
|
choice_valid = False
|
||||||
|
has_content = False
|
||||||
|
has_role = False
|
||||||
|
try:
|
||||||
|
chunkjson = json.loads(chunk_decoded[6:])
|
||||||
|
has_choices = 'choices' in chunkjson
|
||||||
|
if has_choices: choice_valid = (len(chunkjson['choices']) > 0)
|
||||||
|
if has_choices and choice_valid: has_content = ("content" in chunkjson['choices'][0]["delta"])
|
||||||
|
if has_content: has_content = (chunkjson['choices'][0]["delta"]["content"] is not None)
|
||||||
|
if has_choices and choice_valid: has_role = "role" in chunkjson['choices'][0]["delta"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return chunk_decoded, chunkjson, has_choices, choice_valid, has_content, has_role
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def verify_endpoint(endpoint):
|
||||||
|
"""
|
||||||
|
检查endpoint是否可用
|
||||||
|
"""
|
||||||
|
if "你亲手写的api名称" in endpoint:
|
||||||
|
raise ValueError("Endpoint不正确, 请检查AZURE_ENDPOINT的配置! 当前的Endpoint为:" + endpoint)
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], sys_prompt:str="", observe_window:list=None, console_slience:bool=False):
|
||||||
|
"""
|
||||||
|
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
||||||
|
inputs:
|
||||||
|
是本次问询的输入
|
||||||
|
sys_prompt:
|
||||||
|
系统静默prompt
|
||||||
|
llm_kwargs:
|
||||||
|
chatGPT的内部调优参数
|
||||||
|
history:
|
||||||
|
是之前的对话列表
|
||||||
|
observe_window = None:
|
||||||
|
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
|
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||||
|
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
|
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
|
||||||
|
retry = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# make a POST request to the API endpoint, stream=False
|
||||||
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
|
||||||
|
except requests.exceptions.ReadTimeout as e:
|
||||||
|
retry += 1
|
||||||
|
traceback.print_exc()
|
||||||
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
if MAX_RETRY!=0: logger.error(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
return gpt_replying_buffer
|
||||||
|
|
||||||
|
stream_response = response.iter_lines()
|
||||||
|
result = ''
|
||||||
|
json_data = None
|
||||||
|
while True:
|
||||||
|
try: chunk = next(stream_response)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
||||||
|
chunk_decoded, chunkjson, has_choices, choice_valid, has_content, has_role = decode_chunk(chunk)
|
||||||
|
if len(chunk_decoded)==0: continue
|
||||||
|
if not chunk_decoded.startswith('data:'):
|
||||||
|
error_msg = get_full_error(chunk, stream_response).decode()
|
||||||
|
if "reduce the length" in error_msg:
|
||||||
|
raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
|
||||||
|
elif """type":"upstream_error","param":"307""" in error_msg:
|
||||||
|
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("OpenAI拒绝了请求:" + error_msg)
|
||||||
|
if ('data: [DONE]' in chunk_decoded): break # api2d 正常完成
|
||||||
|
# 提前读取一些信息 (用于判断异常)
|
||||||
|
if (has_choices and not choice_valid) or ('OPENROUTER PROCESSING' in chunk_decoded):
|
||||||
|
# 一些垃圾第三方接口的出现这样的错误,openrouter的特殊处理
|
||||||
|
continue
|
||||||
|
json_data = chunkjson['choices'][0]
|
||||||
|
delta = json_data["delta"]
|
||||||
|
if len(delta) == 0: break
|
||||||
|
if (not has_content) and has_role: continue
|
||||||
|
if (not has_content) and (not has_role): continue # raise RuntimeError("发现不标准的第三方接口:"+delta)
|
||||||
|
if has_content: # has_role = True/False
|
||||||
|
result += delta["content"]
|
||||||
|
if not console_slience: print(delta["content"], end='')
|
||||||
|
if observe_window is not None:
|
||||||
|
# 观测窗,把已经获取的数据显示出去
|
||||||
|
if len(observe_window) >= 1:
|
||||||
|
observe_window[0] += delta["content"]
|
||||||
|
# 看门狗,如果超过期限没有喂狗,则终止
|
||||||
|
if len(observe_window) >= 2:
|
||||||
|
if (time.time()-observe_window[1]) > watch_dog_patience:
|
||||||
|
raise RuntimeError("用户取消了程序。")
|
||||||
|
else: raise RuntimeError("意外Json结构:"+delta)
|
||||||
|
if json_data and json_data['finish_reason'] == 'content_filter':
|
||||||
|
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
|
||||||
|
if json_data and json_data['finish_reason'] == 'length':
|
||||||
|
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
|
||||||
|
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
|
||||||
|
"""
|
||||||
|
发送至chatGPT,流式获取输出。
|
||||||
|
用于基础的对话功能。
|
||||||
|
inputs 是本次问询的输入
|
||||||
|
top_p, temperature是chatGPT的内部调优参数
|
||||||
|
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
||||||
|
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||||
|
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
if is_any_api_key(inputs):
|
||||||
|
chatbot._cookies['api_key'] = inputs
|
||||||
|
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
|
||||||
|
return
|
||||||
|
elif not is_any_api_key(chatbot._cookies['api_key']):
|
||||||
|
chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。"))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
user_input = inputs
|
||||||
|
if additional_fn is not None:
|
||||||
|
from core_functional import handle_core_functionality
|
||||||
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
|
# 多模态模型
|
||||||
|
has_multimodal_capacity = model_info[llm_kwargs['llm_model']].get('has_multimodal_capacity', False)
|
||||||
|
if has_multimodal_capacity:
|
||||||
|
has_recent_image_upload, image_paths = have_any_recent_upload_image_files(chatbot, pop=True)
|
||||||
|
else:
|
||||||
|
has_recent_image_upload, image_paths = False, []
|
||||||
|
if has_recent_image_upload:
|
||||||
|
_inputs, image_base64_array = make_multimodal_input(inputs, image_paths)
|
||||||
|
else:
|
||||||
|
_inputs, image_base64_array = inputs, []
|
||||||
|
chatbot.append((_inputs, ""))
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||||
|
|
||||||
|
# 禁用stream的特殊模型处理
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
|
# check mis-behavior
|
||||||
|
if is_the_upload_folder(user_input):
|
||||||
|
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity, stream)
|
||||||
|
except RuntimeError as e:
|
||||||
|
chatbot[-1] = (inputs, f"您提供的api-key不满足要求,不包含任何可用于{llm_kwargs['llm_model']}的api-key。您可能选择了错误的模型或请求源。")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查endpoint是否合法
|
||||||
|
try:
|
||||||
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
|
except:
|
||||||
|
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||||
|
chatbot[-1] = (inputs, tb_str)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Endpoint不满足要求") # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加入历史
|
||||||
|
if has_recent_image_upload:
|
||||||
|
history.extend([_inputs, ""])
|
||||||
|
else:
|
||||||
|
history.extend([inputs, ""])
|
||||||
|
|
||||||
|
retry = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# make a POST request to the API endpoint, stream=True
|
||||||
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
|
||||||
|
except:
|
||||||
|
retry += 1
|
||||||
|
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||||
|
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||||
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
|
||||||
|
return
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
gpt_replying_buffer = ""
|
||||||
|
is_head_of_the_stream = True
|
||||||
|
stream_response = response.iter_lines()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = next(stream_response)
|
||||||
|
except StopIteration:
|
||||||
|
# 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
|
||||||
|
chunk_decoded = chunk.decode()
|
||||||
|
error_msg = chunk_decoded
|
||||||
|
# 首先排除一个one-api没有done数据包的第三方Bug情形
|
||||||
|
if len(gpt_replying_buffer.strip()) > 0 and len(error_msg) == 0:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="检测到有缺陷的非OpenAI官方接口,建议选择更稳定的接口。")
|
||||||
|
break
|
||||||
|
# 其他情况,直接返回报错
|
||||||
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="非OpenAI官方接口返回了错误:" + chunk.decode()) # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
|
# 提前读取一些信息 (用于判断异常)
|
||||||
|
chunk_decoded, chunkjson, has_choices, choice_valid, has_content, has_role = decode_chunk(chunk)
|
||||||
|
|
||||||
|
if is_head_of_the_stream and (r'"object":"error"' not in chunk_decoded) and (r"content" not in chunk_decoded):
|
||||||
|
# 数据流的第一帧不携带content
|
||||||
|
is_head_of_the_stream = False; continue
|
||||||
|
|
||||||
|
if chunk:
|
||||||
|
try:
|
||||||
|
if (has_choices and not choice_valid) or ('OPENROUTER PROCESSING' in chunk_decoded):
|
||||||
|
# 一些垃圾第三方接口的出现这样的错误, 或者OPENROUTER的特殊处理,因为OPENROUTER的数据流未连接到模型时会出现OPENROUTER PROCESSING
|
||||||
|
continue
|
||||||
|
if ('data: [DONE]' not in chunk_decoded) and len(chunk_decoded) > 0 and (chunkjson is None):
|
||||||
|
# 传递进来一些奇怪的东西
|
||||||
|
raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}')
|
||||||
|
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||||
|
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
||||||
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
|
break
|
||||||
|
# 处理数据流的主体
|
||||||
|
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||||
|
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
||||||
|
if has_content:
|
||||||
|
# 正常情况
|
||||||
|
gpt_replying_buffer = gpt_replying_buffer + chunkjson['choices'][0]["delta"]["content"]
|
||||||
|
elif has_role:
|
||||||
|
# 一些第三方接口的出现这样的错误,兼容一下吧
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 至此已经超出了正常接口应该进入的范围,一些垃圾第三方接口会出现这样的错误
|
||||||
|
if chunkjson['choices'][0]["delta"]["content"] is None: continue # 一些垃圾第三方接口出现这样的错误,兼容一下吧
|
||||||
|
gpt_replying_buffer = gpt_replying_buffer + chunkjson['choices'][0]["delta"]["content"]
|
||||||
|
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
|
||||||
|
except Exception as e:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
|
||||||
|
chunk = get_full_error(chunk, stream_response)
|
||||||
|
chunk_decoded = chunk.decode()
|
||||||
|
error_msg = chunk_decoded
|
||||||
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
|
||||||
|
logger.error(error_msg)
|
||||||
|
return
|
||||||
|
return # return from stream-branch
|
||||||
|
|
||||||
|
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
||||||
|
try:
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
except Exception as e:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面
|
||||||
|
|
||||||
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||||
|
if "reduce the length" in error_msg:
|
||||||
|
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||||
|
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
|
||||||
|
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
|
||||||
|
elif "does not exist" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
|
||||||
|
elif "Incorrect API key" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务. " + openai_website)
|
||||||
|
elif "exceeded your current quota" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务." + openai_website)
|
||||||
|
elif "account is not active" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] Your account is not active. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
||||||
|
elif "associated with a deactivated account" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] You are associated with a deactivated account. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
||||||
|
elif "API key has been deactivated" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] API key has been deactivated. OpenAI以账户失效为由, 拒绝服务." + openai_website)
|
||||||
|
elif "bad forward key" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
||||||
|
elif "Not enough point" in error_msg:
|
||||||
|
chatbot[-1] = (chatbot[-1][0], "[Local Message] Not enough point. API2D账户点数不足.")
|
||||||
|
else:
|
||||||
|
from toolbox import regular_txt_to_markdown
|
||||||
|
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||||
|
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
|
||||||
|
return chatbot, history
|
||||||
|
|
||||||
|
def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:str, image_base64_array:list=[], has_multimodal_capacity:bool=False, stream:bool=True):
|
||||||
|
"""
|
||||||
|
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
|
if not is_any_api_key(llm_kwargs['api_key']):
|
||||||
|
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||||
|
|
||||||
|
if llm_kwargs['llm_model'].startswith('vllm-'):
|
||||||
|
api_key = 'no-api-key'
|
||||||
|
else:
|
||||||
|
api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
if API_ORG.startswith('org-'): headers.update({"OpenAI-Organization": API_ORG})
|
||||||
|
if llm_kwargs['llm_model'].startswith('azure-'):
|
||||||
|
headers.update({"api-key": api_key})
|
||||||
|
if llm_kwargs['llm_model'] in AZURE_CFG_ARRAY.keys():
|
||||||
|
azure_api_key_unshared = AZURE_CFG_ARRAY[llm_kwargs['llm_model']]["AZURE_API_KEY"]
|
||||||
|
headers.update({"api-key": azure_api_key_unshared})
|
||||||
|
|
||||||
|
if has_multimodal_capacity:
|
||||||
|
# 当以下条件满足时,启用多模态能力:
|
||||||
|
# 1. 模型本身是多模态模型(has_multimodal_capacity)
|
||||||
|
# 2. 输入包含图像(len(image_base64_array) > 0)
|
||||||
|
# 3. 历史输入包含图像( any([contain_base64(h) for h in history]) )
|
||||||
|
enable_multimodal_capacity = (len(image_base64_array) > 0) or any([contain_base64(h) for h in history])
|
||||||
|
else:
|
||||||
|
enable_multimodal_capacity = False
|
||||||
|
|
||||||
|
conversation_cnt = len(history) // 2
|
||||||
|
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)
|
||||||
|
|
||||||
|
if openai_disable_system_prompt:
|
||||||
|
messages = [{"role": "user", "content": system_prompt}]
|
||||||
|
else:
|
||||||
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
if not enable_multimodal_capacity:
|
||||||
|
# 不使用多模态能力
|
||||||
|
if conversation_cnt:
|
||||||
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
|
what_i_have_asked = {}
|
||||||
|
what_i_have_asked["role"] = "user"
|
||||||
|
what_i_have_asked["content"] = remove_image_if_contain_base64(history[index])
|
||||||
|
what_gpt_answer = {}
|
||||||
|
what_gpt_answer["role"] = "assistant"
|
||||||
|
what_gpt_answer["content"] = remove_image_if_contain_base64(history[index+1])
|
||||||
|
if what_i_have_asked["content"] != "":
|
||||||
|
if what_gpt_answer["content"] == "": continue
|
||||||
|
if what_gpt_answer["content"] == timeout_bot_msg: continue
|
||||||
|
messages.append(what_i_have_asked)
|
||||||
|
messages.append(what_gpt_answer)
|
||||||
|
else:
|
||||||
|
messages[-1]['content'] = what_gpt_answer['content']
|
||||||
|
what_i_ask_now = {}
|
||||||
|
what_i_ask_now["role"] = "user"
|
||||||
|
what_i_ask_now["content"] = inputs
|
||||||
|
messages.append(what_i_ask_now)
|
||||||
|
else:
|
||||||
|
# 多模态能力
|
||||||
|
if conversation_cnt:
|
||||||
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
|
what_i_have_asked = {}
|
||||||
|
what_i_have_asked["role"] = "user"
|
||||||
|
what_i_have_asked["content"] = append_image_if_contain_base64(history[index])
|
||||||
|
what_gpt_answer = {}
|
||||||
|
what_gpt_answer["role"] = "assistant"
|
||||||
|
what_gpt_answer["content"] = append_image_if_contain_base64(history[index+1])
|
||||||
|
if what_i_have_asked["content"] != "":
|
||||||
|
if what_gpt_answer["content"] == "": continue
|
||||||
|
if what_gpt_answer["content"] == timeout_bot_msg: continue
|
||||||
|
messages.append(what_i_have_asked)
|
||||||
|
messages.append(what_gpt_answer)
|
||||||
|
else:
|
||||||
|
messages[-1]['content'] = what_gpt_answer['content']
|
||||||
|
what_i_ask_now = {}
|
||||||
|
what_i_ask_now["role"] = "user"
|
||||||
|
what_i_ask_now["content"] = []
|
||||||
|
what_i_ask_now["content"].append({
|
||||||
|
"type": "text",
|
||||||
|
"text": inputs
|
||||||
|
})
|
||||||
|
for image_base64 in image_base64_array:
|
||||||
|
what_i_ask_now["content"].append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
messages.append(what_i_ask_now)
|
||||||
|
|
||||||
|
|
||||||
|
model = llm_kwargs['llm_model']
|
||||||
|
if llm_kwargs['llm_model'].startswith('api2d-'):
|
||||||
|
model = llm_kwargs['llm_model'][len('api2d-'):]
|
||||||
|
if llm_kwargs['llm_model'].startswith('one-api-'):
|
||||||
|
model = llm_kwargs['llm_model'][len('one-api-'):]
|
||||||
|
model, _ = read_one_api_model_name(model)
|
||||||
|
if llm_kwargs['llm_model'].startswith('vllm-'):
|
||||||
|
model = llm_kwargs['llm_model'][len('vllm-'):]
|
||||||
|
model, _ = read_one_api_model_name(model)
|
||||||
|
if llm_kwargs['llm_model'].startswith('openrouter-'):
|
||||||
|
model = llm_kwargs['llm_model'][len('openrouter-'):]
|
||||||
|
model= read_one_api_model_name(model)
|
||||||
|
if model == "gpt-3.5-random": # 随机选择, 绕过openai访问频率限制
|
||||||
|
model = random.choice([
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"gpt-3.5-turbo-1106",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
"gpt-3.5-turbo-0301",
|
||||||
|
])
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": llm_kwargs['temperature'], # 1.0,
|
||||||
|
"top_p": llm_kwargs['top_p'], # 1.0,
|
||||||
|
"n": 1,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers,payload
|
||||||
|
|
||||||
|
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import importlib
|
||||||
|
|
||||||
from .bridge_newbingfree import preprocess_newbing_out, preprocess_newbing_out_simple
|
from .bridge_newbingfree import preprocess_newbing_out, preprocess_newbing_out_simple
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
from toolbox import update_ui, get_conf, trimmed_format_exc
|
from toolbox import update_ui, get_conf, trimmed_format_exc
|
||||||
import threading
|
from loguru import logger as logging
|
||||||
import importlib
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
import asyncio
|
|
||||||
|
|
||||||
load_message = "正在加载Claude组件,请稍候..."
|
load_message = "正在加载Claude组件,请稍候..."
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import json
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import websockets
|
import websockets
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
from toolbox import get_conf, encode_image, get_pictures_list, to_markdown_tabs
|
from toolbox import get_conf, update_ui, encode_image, get_pictures_list, to_markdown_tabs
|
||||||
|
|
||||||
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
||||||
|
|
||||||
@@ -112,6 +112,14 @@ def html_local_img(__file, layout="left", max_width=None, max_height=None, md=Tr
|
|||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_base64_from_input(inputs):
|
||||||
|
pattern = re.compile(r'<br/><br/><div align="center"><img[^<>]+base64="([^"]+)"></div>')
|
||||||
|
base64_strings = pattern.findall(inputs)
|
||||||
|
return base64_strings
|
||||||
|
|
||||||
|
def contain_base64(inputs):
|
||||||
|
base64_strings = reverse_base64_from_input(inputs)
|
||||||
|
return len(base64_strings) > 0
|
||||||
|
|
||||||
class GoogleChatInit:
|
class GoogleChatInit:
|
||||||
def __init__(self, llm_kwargs):
|
def __init__(self, llm_kwargs):
|
||||||
@@ -119,9 +127,9 @@ class GoogleChatInit:
|
|||||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
||||||
self.url_gemini = endpoint + "/%m:streamGenerateContent?key=%k"
|
self.url_gemini = endpoint + "/%m:streamGenerateContent?key=%k"
|
||||||
|
|
||||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
def generate_chat(self, inputs, llm_kwargs, history, system_prompt, image_base64_array:list=[], has_multimodal_capacity:bool=False):
|
||||||
headers, payload = self.generate_message_payload(
|
headers, payload = self.generate_message_payload(
|
||||||
inputs, llm_kwargs, history, system_prompt
|
inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity
|
||||||
)
|
)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=self.url_gemini,
|
url=self.url_gemini,
|
||||||
@@ -133,13 +141,16 @@ class GoogleChatInit:
|
|||||||
)
|
)
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
|
|
||||||
def __conversation_user(self, user_input, llm_kwargs):
|
def __conversation_user(self, user_input, llm_kwargs, enable_multimodal_capacity):
|
||||||
what_i_have_asked = {"role": "user", "parts": []}
|
what_i_have_asked = {"role": "user", "parts": []}
|
||||||
if "vision" not in self.url_gemini:
|
from .bridge_all import model_info
|
||||||
|
|
||||||
|
if enable_multimodal_capacity:
|
||||||
|
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
||||||
|
else:
|
||||||
input_ = user_input
|
input_ = user_input
|
||||||
encode_img = []
|
encode_img = []
|
||||||
else:
|
|
||||||
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
|
||||||
what_i_have_asked["parts"].append({"text": input_})
|
what_i_have_asked["parts"].append({"text": input_})
|
||||||
if encode_img:
|
if encode_img:
|
||||||
for data in encode_img:
|
for data in encode_img:
|
||||||
@@ -153,12 +164,12 @@ class GoogleChatInit:
|
|||||||
)
|
)
|
||||||
return what_i_have_asked
|
return what_i_have_asked
|
||||||
|
|
||||||
def __conversation_history(self, history, llm_kwargs):
|
def __conversation_history(self, history, llm_kwargs, enable_multimodal_capacity):
|
||||||
messages = []
|
messages = []
|
||||||
conversation_cnt = len(history) // 2
|
conversation_cnt = len(history) // 2
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2 * conversation_cnt, 2):
|
for index in range(0, 2 * conversation_cnt, 2):
|
||||||
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
|
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs, enable_multimodal_capacity)
|
||||||
what_gpt_answer = {
|
what_gpt_answer = {
|
||||||
"role": "model",
|
"role": "model",
|
||||||
"parts": [{"text": history[index + 1]}],
|
"parts": [{"text": history[index + 1]}],
|
||||||
@@ -168,7 +179,7 @@ class GoogleChatInit:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
def generate_message_payload(
|
def generate_message_payload(
|
||||||
self, inputs, llm_kwargs, history, system_prompt
|
self, inputs, llm_kwargs, history, system_prompt, image_base64_array:list=[], has_multimodal_capacity:bool=False
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
messages = [
|
messages = [
|
||||||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||||
@@ -179,25 +190,31 @@ class GoogleChatInit:
|
|||||||
"%m", llm_kwargs["llm_model"]
|
"%m", llm_kwargs["llm_model"]
|
||||||
).replace("%k", get_conf("GEMINI_API_KEY"))
|
).replace("%k", get_conf("GEMINI_API_KEY"))
|
||||||
header = {"Content-Type": "application/json"}
|
header = {"Content-Type": "application/json"}
|
||||||
if "vision" not in self.url_gemini: # 不是vision 才处理history
|
|
||||||
|
if has_multimodal_capacity:
|
||||||
|
enable_multimodal_capacity = (len(image_base64_array) > 0) or any([contain_base64(h) for h in history])
|
||||||
|
else:
|
||||||
|
enable_multimodal_capacity = False
|
||||||
|
|
||||||
|
if not enable_multimodal_capacity:
|
||||||
messages.extend(
|
messages.extend(
|
||||||
self.__conversation_history(history, llm_kwargs)
|
self.__conversation_history(history, llm_kwargs, enable_multimodal_capacity)
|
||||||
) # 处理 history
|
) # 处理 history
|
||||||
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
|
|
||||||
|
messages.append(self.__conversation_user(inputs, llm_kwargs, enable_multimodal_capacity)) # 处理用户对话
|
||||||
payload = {
|
payload = {
|
||||||
"contents": messages,
|
"contents": messages,
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
# "maxOutputTokens": 800,
|
# "maxOutputTokens": llm_kwargs.get("max_token", 1024),
|
||||||
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
|
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
|
||||||
"temperature": llm_kwargs.get("temperature", 1),
|
"temperature": llm_kwargs.get("temperature", 1),
|
||||||
"topP": llm_kwargs.get("top_p", 0.8),
|
"topP": llm_kwargs.get("top_p", 0.8),
|
||||||
"topK": 10,
|
"topK": 10,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return header, payload
|
return header, payload
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
google = GoogleChatInit()
|
google = GoogleChatInit()
|
||||||
# print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
|
|
||||||
# gootle.input_encode_handle('123123[123123](./123123), ')
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
import threading
|
import threading
|
||||||
import logging
|
|
||||||
|
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from toolbox import get_conf
|
|
||||||
import threading
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
from toolbox import get_conf
|
||||||
|
from loguru import logger as logging
|
||||||
|
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
||||||
#os.environ['VOLC_ACCESSKEY'] = ''
|
#os.environ['VOLC_ACCESSKEY'] = ''
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
from toolbox import get_conf, get_pictures_list, encode_image
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
from urllib.parse import urlparse
|
|
||||||
import ssl
|
import ssl
|
||||||
|
import websocket
|
||||||
|
import threading
|
||||||
|
from toolbox import get_conf, get_pictures_list, encode_image
|
||||||
|
from loguru import logger
|
||||||
|
from urllib.parse import urlparse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import mktime
|
from time import mktime
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
import websocket
|
|
||||||
import threading, time
|
|
||||||
|
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
||||||
|
|
||||||
@@ -104,7 +105,7 @@ class SparkRequestInstance():
|
|||||||
if llm_kwargs['most_recent_uploaded'].get('path'):
|
if llm_kwargs['most_recent_uploaded'].get('path'):
|
||||||
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
|
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
|
||||||
if len(file_manifest) > 0:
|
if len(file_manifest) > 0:
|
||||||
print('正在使用讯飞图片理解API')
|
logger.info('正在使用讯飞图片理解API')
|
||||||
gpt_url = self.gpt_url_img
|
gpt_url = self.gpt_url_img
|
||||||
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
|
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
|
||||||
websocket.enableTrace(False)
|
websocket.enableTrace(False)
|
||||||
@@ -123,7 +124,7 @@ class SparkRequestInstance():
|
|||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
code = data['header']['code']
|
code = data['header']['code']
|
||||||
if code != 0:
|
if code != 0:
|
||||||
print(f'请求错误: {code}, {data}')
|
logger.error(f'请求错误: {code}, {data}')
|
||||||
self.result_buf += str(data)
|
self.result_buf += str(data)
|
||||||
ws.close()
|
ws.close()
|
||||||
self.time_to_exit_event.set()
|
self.time_to_exit_event.set()
|
||||||
@@ -140,7 +141,7 @@ class SparkRequestInstance():
|
|||||||
|
|
||||||
# 收到websocket错误的处理
|
# 收到websocket错误的处理
|
||||||
def on_error(ws, error):
|
def on_error(ws, error):
|
||||||
print("error:", error)
|
logger.error("error:", error)
|
||||||
self.time_to_exit_event.set()
|
self.time_to_exit_event.set()
|
||||||
|
|
||||||
# 收到websocket关闭的处理
|
# 收到websocket关闭的处理
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# @Descr : 兼容最新的智谱Ai
|
# @Descr : 兼容最新的智谱Ai
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
from toolbox import get_conf, encode_image, get_pictures_list
|
from toolbox import get_conf, encode_image, get_pictures_list
|
||||||
import logging, os, requests
|
import requests
|
||||||
import json
|
import json
|
||||||
class TaichuChatInit:
|
class TaichuChatInit:
|
||||||
def __init__(self): ...
|
def __init__(self): ...
|
||||||
@@ -43,7 +43,8 @@ class TaichuChatInit:
|
|||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response.encoding = 'utf-8'
|
response.encoding = 'utf-8'
|
||||||
for line in response.iter_lines(decode_unicode=True):
|
for line in response.iter_lines(decode_unicode=True):
|
||||||
delta = json.loads(line)['choices'][0]['text']
|
try: delta = json.loads(line)['data']['content']
|
||||||
|
except: delta = json.loads(line)['choices'][0]['text']
|
||||||
results += delta
|
results += delta
|
||||||
yield delta, results
|
yield delta, results
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -5,7 +5,8 @@
|
|||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
from toolbox import get_conf, encode_image, get_pictures_list
|
from toolbox import get_conf, encode_image, get_pictures_list
|
||||||
import logging, os
|
from loguru import logger
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def input_encode_handler(inputs:str, llm_kwargs:dict):
|
def input_encode_handler(inputs:str, llm_kwargs:dict):
|
||||||
@@ -24,7 +25,7 @@ class ZhipuChatInit:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
ZHIPUAI_API_KEY, ZHIPUAI_MODEL = get_conf("ZHIPUAI_API_KEY", "ZHIPUAI_MODEL")
|
ZHIPUAI_API_KEY, ZHIPUAI_MODEL = get_conf("ZHIPUAI_API_KEY", "ZHIPUAI_MODEL")
|
||||||
if len(ZHIPUAI_MODEL) > 0:
|
if len(ZHIPUAI_MODEL) > 0:
|
||||||
logging.error('ZHIPUAI_MODEL 配置项选项已经弃用,请在LLM_MODEL中配置')
|
logger.error('ZHIPUAI_MODEL 配置项选项已经弃用,请在LLM_MODEL中配置')
|
||||||
self.zhipu_bro = ZhipuAI(api_key=ZHIPUAI_API_KEY)
|
self.zhipu_bro = ZhipuAI(api_key=ZHIPUAI_API_KEY)
|
||||||
self.model = ''
|
self.model = ''
|
||||||
|
|
||||||
@@ -37,8 +38,7 @@ class ZhipuChatInit:
|
|||||||
what_i_have_asked['content'].append({"type": 'text', "text": user_input})
|
what_i_have_asked['content'].append({"type": 'text', "text": user_input})
|
||||||
if encode_img:
|
if encode_img:
|
||||||
if len(encode_img) > 1:
|
if len(encode_img) > 1:
|
||||||
logging.warning("glm-4v只支持一张图片,将只取第一张图片进行处理")
|
logger.warning("glm-4v只支持一张图片,将只取第一张图片进行处理")
|
||||||
print("glm-4v只支持一张图片,将只取第一张图片进行处理")
|
|
||||||
img_d = {"type": "image_url",
|
img_d = {"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": encode_img[0]['data']
|
"url": encode_img[0]['data']
|
||||||
|
|||||||
40
request_llms/embed_models/bridge_all_embed.py
Normal file
40
request_llms/embed_models/bridge_all_embed.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import tiktoken, copy, re
|
||||||
|
from functools import lru_cache
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from toolbox import get_conf, trimmed_format_exc, apply_gpt_academic_string_mask, read_one_api_model_name
|
||||||
|
|
||||||
|
# Endpoint 重定向
|
||||||
|
API_URL_REDIRECT, AZURE_ENDPOINT, AZURE_ENGINE = get_conf("API_URL_REDIRECT", "AZURE_ENDPOINT", "AZURE_ENGINE")
|
||||||
|
openai_endpoint = "https://api.openai.com/v1/chat/completions"
|
||||||
|
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
|
||||||
|
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
|
||||||
|
|
||||||
|
|
||||||
|
if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint]
|
||||||
|
|
||||||
|
openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings")
|
||||||
|
|
||||||
|
from .openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
|
embed_model_info = {
|
||||||
|
# text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536
|
||||||
|
"text-embedding-3-small": {
|
||||||
|
"embed_class": OpenAiEmbeddingModel,
|
||||||
|
"embed_endpoint": openai_embed_endpoint,
|
||||||
|
"embed_dimension": 1536,
|
||||||
|
},
|
||||||
|
|
||||||
|
# text-embedding-3-large Most capable embedding model for both english and non-english tasks | 3,072
|
||||||
|
"text-embedding-3-large": {
|
||||||
|
"embed_class": OpenAiEmbeddingModel,
|
||||||
|
"embed_endpoint": openai_embed_endpoint,
|
||||||
|
"embed_dimension": 3072,
|
||||||
|
},
|
||||||
|
|
||||||
|
# text-embedding-ada-002 Most capable 2nd generation embedding model, replacing 16 first generation models | 1,536
|
||||||
|
"text-embedding-ada-002": {
|
||||||
|
"embed_class": OpenAiEmbeddingModel,
|
||||||
|
"embed_endpoint": openai_embed_endpoint,
|
||||||
|
"embed_dimension": 1536,
|
||||||
|
},
|
||||||
|
}
|
||||||
85
request_llms/embed_models/openai_embed.py
Normal file
85
request_llms/embed_models/openai_embed.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
from openai import OpenAI
|
||||||
|
from toolbox import get_conf
|
||||||
|
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate
|
||||||
|
from shared_utils.key_pattern_manager import select_api_key_for_embed_models
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def mean_agg(embeddings):
|
||||||
|
"""Mean aggregation for embeddings."""
|
||||||
|
return np.array(embeddings).mean(axis=0).tolist()
|
||||||
|
|
||||||
|
class EmbeddingModel():
|
||||||
|
|
||||||
|
def get_agg_embedding_from_queries(
|
||||||
|
self,
|
||||||
|
queries: List[str],
|
||||||
|
agg_fn = None,
|
||||||
|
):
|
||||||
|
"""Get aggregated embedding from multiple queries."""
|
||||||
|
query_embeddings = [self.get_query_embedding(query) for query in queries]
|
||||||
|
agg_fn = agg_fn or mean_agg
|
||||||
|
return agg_fn(query_embeddings)
|
||||||
|
|
||||||
|
def get_text_embedding_batch(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
show_progress: bool = False,
|
||||||
|
):
|
||||||
|
return self.compute_embedding(texts, batch_mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiEmbeddingModel(EmbeddingModel):
|
||||||
|
|
||||||
|
def __init__(self, llm_kwargs:dict=None):
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
|
||||||
|
def get_query_embedding(self, query: str):
|
||||||
|
return self.compute_embedding(query)
|
||||||
|
|
||||||
|
def compute_embedding(self, text="这是要计算嵌入的文本", llm_kwargs:dict=None, batch_mode=False):
|
||||||
|
from .bridge_all_embed import embed_model_info
|
||||||
|
|
||||||
|
# load kwargs
|
||||||
|
if llm_kwargs is None:
|
||||||
|
llm_kwargs = self.llm_kwargs
|
||||||
|
if llm_kwargs is None:
|
||||||
|
raise RuntimeError("llm_kwargs is not provided!")
|
||||||
|
|
||||||
|
# setup api and req url
|
||||||
|
api_key = select_api_key_for_embed_models(llm_kwargs['api_key'], llm_kwargs['embed_model'])
|
||||||
|
embed_model = llm_kwargs['embed_model']
|
||||||
|
base_url = embed_model_info[llm_kwargs['embed_model']]['embed_endpoint'].replace('embeddings', '')
|
||||||
|
|
||||||
|
# send and compute
|
||||||
|
with ProxyNetworkActivate("Connect_OpenAI_Embedding"):
|
||||||
|
self.oai_client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
if batch_mode:
|
||||||
|
input = text
|
||||||
|
assert isinstance(text, list)
|
||||||
|
else:
|
||||||
|
input = [text]
|
||||||
|
assert isinstance(text, str)
|
||||||
|
res = self.oai_client.embeddings.create(input=input, model=embed_model)
|
||||||
|
|
||||||
|
# parse result
|
||||||
|
if batch_mode:
|
||||||
|
embedding = [d.embedding for d in res.data]
|
||||||
|
else:
|
||||||
|
embedding = res.data[0].embedding
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def embedding_dimension(self, llm_kwargs=None):
|
||||||
|
# load kwargs
|
||||||
|
if llm_kwargs is None:
|
||||||
|
llm_kwargs = self.llm_kwargs
|
||||||
|
if llm_kwargs is None:
|
||||||
|
raise RuntimeError("llm_kwargs is not provided!")
|
||||||
|
|
||||||
|
from .bridge_all_embed import embed_model_info
|
||||||
|
return embed_model_info[llm_kwargs['embed_model']]['embed_dimension']
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pass
|
||||||
@@ -5,6 +5,7 @@ from toolbox import ChatBotWithCookies
|
|||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from request_llms.queued_pipe import create_queue_pipe
|
from request_llms.queued_pipe import create_queue_pipe
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
class ThreadLock(object):
|
class ThreadLock(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -51,7 +52,7 @@ def reset_tqdm_output():
|
|||||||
getattr(sys.stdout, 'flush', lambda: None)()
|
getattr(sys.stdout, 'flush', lambda: None)()
|
||||||
|
|
||||||
def fp_write(s):
|
def fp_write(s):
|
||||||
print(s)
|
logger.info(s)
|
||||||
last_len = [0]
|
last_len = [0]
|
||||||
|
|
||||||
def print_status(s):
|
def print_status(s):
|
||||||
@@ -199,7 +200,7 @@ class LocalLLMHandle(Process):
|
|||||||
if res.startswith(self.std_tag):
|
if res.startswith(self.std_tag):
|
||||||
new_output = res[len(self.std_tag):]
|
new_output = res[len(self.std_tag):]
|
||||||
std_out = std_out[:std_out_clip_len]
|
std_out = std_out[:std_out_clip_len]
|
||||||
print(new_output, end='')
|
logger.info(new_output, end='')
|
||||||
std_out = new_output + std_out
|
std_out = new_output + std_out
|
||||||
yield self.std_tag + '\n```\n' + std_out + '\n```\n'
|
yield self.std_tag + '\n```\n' + std_out + '\n```\n'
|
||||||
elif res == '[Finish]':
|
elif res == '[Finish]':
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
# config_private.py放自己的秘密如API和代理网址
|
# config_private.py放自己的秘密如API和代理网址
|
||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||||
@@ -44,7 +44,7 @@ def decode_chunk(chunk):
|
|||||||
try:
|
try:
|
||||||
chunk = json.loads(chunk[6:])
|
chunk = json.loads(chunk[6:])
|
||||||
except:
|
except:
|
||||||
respose = "API_ERROR"
|
respose = ""
|
||||||
finish_reason = chunk
|
finish_reason = chunk
|
||||||
# 错误处理部分
|
# 错误处理部分
|
||||||
if "error" in chunk:
|
if "error" in chunk:
|
||||||
@@ -106,10 +106,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
|
|||||||
"stream": True,
|
"stream": True,
|
||||||
"max_tokens": max_output_token,
|
"max_tokens": max_output_token,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
print(f" {model} : {conversation_cnt} : {input[:100]} ..........")
|
|
||||||
except:
|
|
||||||
print("输入中可能存在乱码。")
|
|
||||||
return headers, playload
|
return headers, playload
|
||||||
|
|
||||||
|
|
||||||
@@ -196,14 +193,17 @@ def get_predict_function(
|
|||||||
if retry > MAX_RETRY:
|
if retry > MAX_RETRY:
|
||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
if MAX_RETRY != 0:
|
if MAX_RETRY != 0:
|
||||||
print(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
|
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ""
|
result = ""
|
||||||
|
finish_reason = ""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = next(stream_response)
|
chunk = next(stream_response)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
if result == "":
|
||||||
|
raise RuntimeError(f"获得空的回复,可能原因:{finish_reason}")
|
||||||
break
|
break
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
||||||
@@ -216,18 +216,17 @@ def get_predict_function(
|
|||||||
):
|
):
|
||||||
chunk = get_full_error(chunk, stream_response)
|
chunk = get_full_error(chunk, stream_response)
|
||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
print(chunk_decoded)
|
logger.error(chunk_decoded)
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"API异常,请检测终端输出。可能的原因是:{finish_reason}"
|
f"API异常,请检测终端输出。可能的原因是:{finish_reason}"
|
||||||
)
|
)
|
||||||
if chunk:
|
if chunk:
|
||||||
try:
|
try:
|
||||||
if finish_reason == "stop":
|
if finish_reason == "stop":
|
||||||
logging.info(f"[response] {result}")
|
if not console_slience:
|
||||||
|
print(f"[response] {result}")
|
||||||
break
|
break
|
||||||
result += response_text
|
result += response_text
|
||||||
if not console_slience:
|
|
||||||
print(response_text, end="")
|
|
||||||
if observe_window is not None:
|
if observe_window is not None:
|
||||||
# 观测窗,把已经获取的数据显示出去
|
# 观测窗,把已经获取的数据显示出去
|
||||||
if len(observe_window) >= 1:
|
if len(observe_window) >= 1:
|
||||||
@@ -240,7 +239,7 @@ def get_predict_function(
|
|||||||
chunk = get_full_error(chunk, stream_response)
|
chunk = get_full_error(chunk, stream_response)
|
||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -273,7 +272,7 @@ def get_predict_function(
|
|||||||
inputs, history = handle_core_functionality(
|
inputs, history = handle_core_functionality(
|
||||||
additional_fn, inputs, history, chatbot
|
additional_fn, inputs, history, chatbot
|
||||||
)
|
)
|
||||||
logging.info(f"[raw_input] {inputs}")
|
logger.info(f"[raw_input] {inputs}")
|
||||||
chatbot.append((inputs, ""))
|
chatbot.append((inputs, ""))
|
||||||
yield from update_ui(
|
yield from update_ui(
|
||||||
chatbot=chatbot, history=history, msg="等待响应"
|
chatbot=chatbot, history=history, msg="等待响应"
|
||||||
@@ -351,6 +350,10 @@ def get_predict_function(
|
|||||||
response_text, finish_reason = decode_chunk(chunk)
|
response_text, finish_reason = decode_chunk(chunk)
|
||||||
# 返回的数据流第一次为空,继续等待
|
# 返回的数据流第一次为空,继续等待
|
||||||
if response_text == "" and finish_reason != "False":
|
if response_text == "" and finish_reason != "False":
|
||||||
|
status_text = f"finish_reason: {finish_reason}"
|
||||||
|
yield from update_ui(
|
||||||
|
chatbot=chatbot, history=history, msg=status_text
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
if chunk:
|
if chunk:
|
||||||
try:
|
try:
|
||||||
@@ -369,11 +372,11 @@ def get_predict_function(
|
|||||||
history=history,
|
history=history,
|
||||||
msg="API异常:" + chunk_decoded,
|
msg="API异常:" + chunk_decoded,
|
||||||
) # 刷新界面
|
) # 刷新界面
|
||||||
print(chunk_decoded)
|
logger.error(chunk_decoded)
|
||||||
return
|
return
|
||||||
|
|
||||||
if finish_reason == "stop":
|
if finish_reason == "stop":
|
||||||
logging.info(f"[response] {gpt_replying_buffer}")
|
logger.info(f"[response] {gpt_replying_buffer}")
|
||||||
break
|
break
|
||||||
status_text = f"finish_reason: {finish_reason}"
|
status_text = f"finish_reason: {finish_reason}"
|
||||||
gpt_replying_buffer += response_text
|
gpt_replying_buffer += response_text
|
||||||
@@ -396,7 +399,7 @@ def get_predict_function(
|
|||||||
yield from update_ui(
|
yield from update_ui(
|
||||||
chatbot=chatbot, history=history, msg="Json异常" + chunk_decoded
|
chatbot=chatbot, history=history, msg="Json异常" + chunk_decoded
|
||||||
) # 刷新界面
|
) # 刷新界面
|
||||||
print(chunk_decoded)
|
logger.error(chunk_decoded)
|
||||||
return
|
return
|
||||||
|
|
||||||
return predict_no_ui_long_connection, predict
|
return predict_no_ui_long_connection, predict
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ https://public.agent-matrix.com/publish/gradio-3.32.10-py3-none-any.whl
|
|||||||
fastapi==0.110
|
fastapi==0.110
|
||||||
gradio-client==0.8
|
gradio-client==0.8
|
||||||
pypdf2==2.12.1
|
pypdf2==2.12.1
|
||||||
|
httpx<=0.25.2
|
||||||
zhipuai==2.0.1
|
zhipuai==2.0.1
|
||||||
tiktoken>=0.3.3
|
tiktoken>=0.3.3
|
||||||
requests[socks]
|
requests[socks]
|
||||||
pydantic==2.5.2
|
pydantic==2.9.2
|
||||||
protobuf==3.18
|
protobuf==3.20
|
||||||
transformers>=4.27.1
|
transformers>=4.27.1,<4.42
|
||||||
scipdf_parser>=0.52
|
scipdf_parser>=0.52
|
||||||
|
spacy==3.7.4
|
||||||
anthropic>=0.18.1
|
anthropic>=0.18.1
|
||||||
python-markdown-math
|
python-markdown-math
|
||||||
pymdown-extensions
|
pymdown-extensions
|
||||||
@@ -27,6 +29,18 @@ edge-tts
|
|||||||
pymupdf
|
pymupdf
|
||||||
openai
|
openai
|
||||||
rjsmin
|
rjsmin
|
||||||
|
loguru
|
||||||
arxiv
|
arxiv
|
||||||
numpy
|
numpy
|
||||||
rich
|
rich
|
||||||
|
|
||||||
|
|
||||||
|
llama-index-core==0.10.68
|
||||||
|
llama-index-legacy==0.9.48
|
||||||
|
llama-index-readers-file==0.1.33
|
||||||
|
llama-index-readers-llama-parse==0.1.6
|
||||||
|
llama-index-embeddings-azure-openai==0.1.10
|
||||||
|
llama-index-embeddings-openai==0.1.10
|
||||||
|
llama-parse==0.4.9
|
||||||
|
mdit-py-plugins>=0.3.3
|
||||||
|
linkify-it-py==2.0.3
|
||||||
@@ -2,6 +2,8 @@ import markdown
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pymdownx.superfences import fence_code_format
|
from pymdownx.superfences import fence_code_format
|
||||||
@@ -227,14 +229,14 @@ def fix_dollar_sticking_bug(txt):
|
|||||||
|
|
||||||
if single_stack_height > 0:
|
if single_stack_height > 0:
|
||||||
if txt[:(index+1)].find('\n') > 0 or txt[:(index+1)].find('<td>') > 0 or txt[:(index+1)].find('</td>') > 0:
|
if txt[:(index+1)].find('\n') > 0 or txt[:(index+1)].find('<td>') > 0 or txt[:(index+1)].find('</td>') > 0:
|
||||||
print('公式之中出现了异常 (Unexpect element in equation)')
|
logger.error('公式之中出现了异常 (Unexpect element in equation)')
|
||||||
single_stack_height = 0
|
single_stack_height = 0
|
||||||
txt_result += ' $'
|
txt_result += ' $'
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if double_stack_height > 0:
|
if double_stack_height > 0:
|
||||||
if txt[:(index+1)].find('\n\n') > 0:
|
if txt[:(index+1)].find('\n\n') > 0:
|
||||||
print('公式之中出现了异常 (Unexpect element in equation)')
|
logger.error('公式之中出现了异常 (Unexpect element in equation)')
|
||||||
double_stack_height = 0
|
double_stack_height = 0
|
||||||
txt_result += '$$'
|
txt_result += '$$'
|
||||||
continue
|
continue
|
||||||
@@ -253,13 +255,13 @@ def fix_dollar_sticking_bug(txt):
|
|||||||
txt = txt[(index+2):]
|
txt = txt[(index+2):]
|
||||||
else:
|
else:
|
||||||
if double_stack_height != 0:
|
if double_stack_height != 0:
|
||||||
# print(txt[:(index)])
|
# logger.info(txt[:(index)])
|
||||||
print('发现异常嵌套公式')
|
logger.info('发现异常嵌套公式')
|
||||||
if single_stack_height == 0:
|
if single_stack_height == 0:
|
||||||
single_stack_height = 1
|
single_stack_height = 1
|
||||||
else:
|
else:
|
||||||
single_stack_height = 0
|
single_stack_height = 0
|
||||||
# print(txt[:(index)])
|
# logger.info(txt[:(index)])
|
||||||
txt_result += txt[:(index+1)]
|
txt_result += txt[:(index+1)]
|
||||||
txt = txt[(index+1):]
|
txt = txt[(index+1):]
|
||||||
break
|
break
|
||||||
@@ -271,7 +273,7 @@ def markdown_convertion_for_file(txt):
|
|||||||
"""
|
"""
|
||||||
from themes.theme import advanced_css
|
from themes.theme import advanced_css
|
||||||
pre = f"""
|
pre = f"""
|
||||||
<!DOCTYPE html><head><meta charset="utf-8"><title>PDF文档翻译</title><style>{advanced_css}</style></head>
|
<!DOCTYPE html><head><meta charset="utf-8"><title>GPT-Academic输出文档</title><style>{advanced_css}</style></head>
|
||||||
<body>
|
<body>
|
||||||
<div class="test_temp1" style="width:10%; height: 500px; float:left;"></div>
|
<div class="test_temp1" style="width:10%; height: 500px; float:left;"></div>
|
||||||
<div class="test_temp2" style="width:80%;padding: 40px;float:left;padding-left: 20px;padding-right: 20px;box-shadow: rgba(0, 0, 0, 0.2) 0px 0px 8px 8px;border-radius: 10px;">
|
<div class="test_temp2" style="width:80%;padding: 40px;float:left;padding-left: 20px;padding-right: 20px;box-shadow: rgba(0, 0, 0, 0.2) 0px 0px 8px 8px;border-radius: 10px;">
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user