add security patch
This commit is contained in:
@@ -138,7 +138,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
||||
app_block.is_sagemaker = False
|
||||
|
||||
gradio_app = App.create_app(app_block)
|
||||
|
||||
for route in list(gradio_app.router.routes):
|
||||
if route.path == "/proxy={url_path:path}":
|
||||
gradio_app.router.routes.remove(route)
|
||||
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
||||
if len(AUTHENTICATION) > 0:
|
||||
dependencies = []
|
||||
@@ -154,9 +156,13 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
||||
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
||||
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
||||
async def file(path_or_url: str, request: fastapi.Request):
|
||||
if len(AUTHENTICATION) > 0:
|
||||
if not _authorize_user(path_or_url, request, gradio_app):
|
||||
return "越权访问!"
|
||||
if not _authorize_user(path_or_url, request, gradio_app):
|
||||
return "越权访问!"
|
||||
stripped = path_or_url.lstrip().lower()
|
||||
if stripped.startswith("https://") or stripped.startswith("http://"):
|
||||
return "账户密码授权模式下, 禁止链接!"
|
||||
if '../' in stripped:
|
||||
return "非法路径!"
|
||||
return await endpoint(path_or_url, request)
|
||||
|
||||
from fastapi import Request, status
|
||||
@@ -167,6 +173,26 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
||||
response.delete_cookie('access-token')
|
||||
response.delete_cookie('access-token-unsecure')
|
||||
return response
|
||||
else:
|
||||
dependencies = []
|
||||
endpoint = None
|
||||
for route in list(gradio_app.router.routes):
|
||||
if route.path == "/file/{path:path}":
|
||||
gradio_app.router.routes.remove(route)
|
||||
if route.path == "/file={path_or_url:path}":
|
||||
dependencies = route.dependencies
|
||||
endpoint = route.endpoint
|
||||
gradio_app.router.routes.remove(route)
|
||||
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
|
||||
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
||||
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
||||
async def file(path_or_url: str, request: fastapi.Request):
|
||||
stripped = path_or_url.lstrip().lower()
|
||||
if stripped.startswith("https://") or stripped.startswith("http://"):
|
||||
return "账户密码授权模式下, 禁止链接!"
|
||||
if '../' in stripped:
|
||||
return "非法路径!"
|
||||
return await endpoint(path_or_url, request)
|
||||
|
||||
# --- --- enable TTS (text-to-speech) functionality --- ---
|
||||
TTS_TYPE = get_conf("TTS_TYPE")
|
||||
|
||||
Reference in New Issue
Block a user