diff --git a/shared_utils/fastapi_server.py b/shared_utils/fastapi_server.py index 6c9b1d1c..2993c987 100644 --- a/shared_utils/fastapi_server.py +++ b/shared_utils/fastapi_server.py @@ -138,7 +138,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS app_block.is_sagemaker = False gradio_app = App.create_app(app_block) - + for route in list(gradio_app.router.routes): + if route.path == "/proxy={url_path:path}": + gradio_app.router.routes.remove(route) # --- --- replace gradio endpoint to forbid access to sensitive files --- --- if len(AUTHENTICATION) > 0: dependencies = [] @@ -154,9 +156,13 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies) @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies) async def file(path_or_url: str, request: fastapi.Request): - if len(AUTHENTICATION) > 0: - if not _authorize_user(path_or_url, request, gradio_app): - return "越权访问!" + if not _authorize_user(path_or_url, request, gradio_app): + return "越权访问!" + stripped = path_or_url.lstrip().lower() + if stripped.startswith("https://") or stripped.startswith("http://"): + return "账户密码授权模式下, 禁止链接!" + if '../' in stripped: + return "非法路径!" return await endpoint(path_or_url, request) from fastapi import Request, status @@ -167,6 +173,26 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS response.delete_cookie('access-token') response.delete_cookie('access-token-unsecure') return response + else: + dependencies = [] + endpoint = None + for route in list(gradio_app.router.routes): + if route.path == "/file/{path:path}": + gradio_app.router.routes.remove(route) + if route.path == "/file={path_or_url:path}": + dependencies = route.dependencies + endpoint = route.endpoint + gradio_app.router.routes.remove(route) + @gradio_app.get("/file/{path:path}", dependencies=dependencies) + @gradio_app.head("/file={path_or_url:path}", dependencies=dependencies) + @gradio_app.get("/file={path_or_url:path}", dependencies=dependencies) + async def file(path_or_url: str, request: fastapi.Request): + stripped = path_or_url.lstrip().lower() + if stripped.startswith("https://") or stripped.startswith("http://"): + return "账户密码授权模式下, 禁止链接!" + if '../' in stripped: + return "非法路径!" + return await endpoint(path_or_url, request) # --- --- enable TTS (text-to-speech) functionality --- --- TTS_TYPE = get_conf("TTS_TYPE")