Merge branch 'master' into frontier

This commit is contained in:
binary-husky
2024-11-10 13:41:44 +00:00
5 changed files with 63 additions and 16 deletions

View File

@@ -6,12 +6,16 @@ class SafeUnpickler(pickle.Unpickler):
def get_safe_classes(self):
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
from numpy.core.multiarray import scalar
from numpy import dtype
# 定义允许的安全类
safe_classes = {
# 在这里添加其他安全的类
'LatexPaperFileGroup': LatexPaperFileGroup,
'LatexPaperSplit': LatexPaperSplit,
'LinkedListNode': LinkedListNode,
'scalar': scalar,
'dtype': dtype,
}
return safe_classes
@@ -22,8 +26,6 @@ class SafeUnpickler(pickle.Unpickler):
for class_name in self.safe_classes.keys():
if (class_name in f'{module}.{name}'):
match_class_name = class_name
if module == 'numpy' or module.startswith('numpy.'):
return super().find_class(module, name)
if match_class_name is not None:
return self.safe_classes[match_class_name]
# 如果尝试加载未授权的类,则抛出异常

View File

@@ -385,6 +385,14 @@ model_info = {
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"glm-4-plus":{
"fn_with_ui": zhipu_ui,
"fn_without_ui": zhipu_noui,
"endpoint": None,
"max_token": 10124 * 8,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
# api_2d (此后不需要在此处添加api2d的接口了因为下面的代码会自动添加)
"api2d-gpt-4": {

View File

@@ -138,7 +138,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
app_block.is_sagemaker = False
gradio_app = App.create_app(app_block)
for route in list(gradio_app.router.routes):
if route.path == "/proxy={url_path:path}":
gradio_app.router.routes.remove(route)
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
if len(AUTHENTICATION) > 0:
dependencies = []
@@ -154,9 +156,13 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
async def file(path_or_url: str, request: fastapi.Request):
if len(AUTHENTICATION) > 0:
if not _authorize_user(path_or_url, request, gradio_app):
return "越权访问!"
if not _authorize_user(path_or_url, request, gradio_app):
return "越权访问!"
stripped = path_or_url.lstrip().lower()
if stripped.startswith("https://") or stripped.startswith("http://"):
return "账户密码授权模式下, 禁止链接!"
if '../' in stripped:
return "非法路径!"
return await endpoint(path_or_url, request)
from fastapi import Request, status
@@ -167,6 +173,26 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
response.delete_cookie('access-token')
response.delete_cookie('access-token-unsecure')
return response
else:
dependencies = []
endpoint = None
for route in list(gradio_app.router.routes):
if route.path == "/file/{path:path}":
gradio_app.router.routes.remove(route)
if route.path == "/file={path_or_url:path}":
dependencies = route.dependencies
endpoint = route.endpoint
gradio_app.router.routes.remove(route)
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
async def file(path_or_url: str, request: fastapi.Request):
stripped = path_or_url.lstrip().lower()
if stripped.startswith("https://") or stripped.startswith("http://"):
return "账户密码授权模式下, 禁止链接!"
if '../' in stripped:
return "非法路径!"
return await endpoint(path_or_url, request)
# --- --- enable TTS (text-to-speech) functionality --- ---
TTS_TYPE = get_conf("TTS_TYPE")

View File

@@ -104,17 +104,27 @@ def extract_archive(file_path, dest_dir):
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
elif file_extension in [".tar", ".gz", ".bz2"]:
with tarfile.open(file_path, "r:*") as tarobj:
# 清理提取路径,移除任何不安全的元素
for member in tarobj.getmembers():
member_path = os.path.normpath(member.name)
full_path = os.path.join(dest_dir, member_path)
full_path = os.path.abspath(full_path)
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
raise Exception(f"Attempted Path Traversal in {member.name}")
try:
with tarfile.open(file_path, "r:*") as tarobj:
# 清理提取路径,移除任何不安全的元素
for member in tarobj.getmembers():
member_path = os.path.normpath(member.name)
full_path = os.path.join(dest_dir, member_path)
full_path = os.path.abspath(full_path)
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
raise Exception(f"Attempted Path Traversal in {member.name}")
tarobj.extractall(path=dest_dir)
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
tarobj.extractall(path=dest_dir)
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
except tarfile.ReadError as e:
if file_extension == ".gz":
# 一些特别奇葩的项目是一个gz文件里面不是tar只有一个tex文件
import gzip
with gzip.open(file_path, 'rb') as f_in:
with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
f_out.write(f_in.read())
else:
raise e
# 第三方库需要预先pip install rarfile
# 此外Windows上还需要安装winrar软件配置其Path环境变量如"C:\Program Files\WinRAR"才可以

View File

@@ -14,6 +14,7 @@ openai_regex = re.compile(
r"sk-[a-zA-Z0-9_-]{92}$|" +
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
r"sk-proj-[a-zA-Z0-9_-]{156}$|"+ #新版apikey位数不匹配故修改此正则表达式
r"sess-[a-zA-Z0-9]{40}$"
)
def is_openai_api_key(key):