diff --git a/config.py b/config.py index f3c9ecb0..383549c2 100644 --- a/config.py +++ b/config.py @@ -43,7 +43,7 @@ AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-p # AVAIL_LLM_MODELS = [ # "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-flash", # "qianfan", "deepseekcoder", -# "spark", "sparkv2", "sparkv3", "sparkv3.5", +# "spark", "sparkv2", "sparkv3", "sparkv3.5", "sparkv4", # "qwen-turbo", "qwen-plus", "qwen-max", "qwen-local", # "moonshot-v1-128k", "moonshot-v1-32k", "moonshot-v1-8k", # "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0125", "gpt-4o-2024-05-13" diff --git a/main.py b/main.py index 81acc4d7..b4c424b0 100644 --- a/main.py +++ b/main.py @@ -40,7 +40,7 @@ def encode_plugin_info(k, plugin)->str: def main(): import gradio as gr - if gr.__version__ not in ['3.32.9', '3.32.10']: + if gr.__version__ not in ['3.32.9', '3.32.10', '3.32.11']: raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.") from request_llms.bridge_all import predict from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith diff --git a/request_llms/bridge_all.py b/request_llms/bridge_all.py index 92416ad6..1d408943 100644 --- a/request_llms/bridge_all.py +++ b/request_llms/bridge_all.py @@ -860,6 +860,15 @@ if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞 "max_token": 4096, "tokenizer": tokenizer_gpt35, "token_cnt": get_token_num_gpt35, + }, + "sparkv4":{ + "fn_with_ui": spark_ui, + "fn_without_ui": spark_noui, + "can_multi_thread": True, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, } }) except: diff --git a/request_llms/com_sparkapi.py b/request_llms/com_sparkapi.py index 359e407a..a9196002 100644 --- a/request_llms/com_sparkapi.py +++ b/request_llms/com_sparkapi.py @@ -67,6 +67,7 @@ class SparkRequestInstance(): self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat" self.gpt_url_v35 = "wss://spark-api.xf-yun.com/v3.5/chat" self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image" + self.gpt_url_v4 = "wss://spark-api.xf-yun.com/v4.0/chat" self.time_to_yield_event = threading.Event() self.time_to_exit_event = threading.Event() @@ -94,6 +95,8 @@ class SparkRequestInstance(): gpt_url = self.gpt_url_v3 elif llm_kwargs['llm_model'] == 'sparkv3.5': gpt_url = self.gpt_url_v35 + elif llm_kwargs['llm_model'] == 'sparkv4': + gpt_url = self.gpt_url_v4 else: gpt_url = self.gpt_url file_manifest = [] @@ -194,6 +197,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest) "sparkv2": "generalv2", "sparkv3": "generalv3", "sparkv3.5": "generalv3.5", + "sparkv4": "4.0Ultra" } domains_select = domains[llm_kwargs['llm_model']] if file_manifest: domains_select = 'image' diff --git a/request_llms/oai_std_model_template.py b/request_llms/oai_std_model_template.py index 648dbe41..1d649af9 100644 --- a/request_llms/oai_std_model_template.py +++ b/request_llms/oai_std_model_template.py @@ -44,7 +44,8 @@ def decode_chunk(chunk): try: chunk = json.loads(chunk[6:]) except: - finish_reason = "JSON_ERROR" + respose = "API_ERROR" + finish_reason = chunk # 错误处理部分 if "error" in chunk: respose = "API_ERROR" diff --git a/shared_utils/fastapi_server.py b/shared_utils/fastapi_server.py index 17101392..45363d8d 100644 --- a/shared_utils/fastapi_server.py +++ b/shared_utils/fastapi_server.py @@ -159,6 +159,15 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS return "越权访问!" return await endpoint(path_or_url, request) + from fastapi import Request, status + from fastapi.responses import FileResponse, RedirectResponse + @gradio_app.get("/academic_logout") + async def logout(): + response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND) + response.delete_cookie('access-token') + response.delete_cookie('access-token-unsecure') + return response + # --- --- enable TTS (text-to-speech) functionality --- --- TTS_TYPE = get_conf("TTS_TYPE") if TTS_TYPE != "DISABLE": @@ -236,6 +245,7 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS response = await call_next(request) return response + # --- --- uvicorn.Config --- --- ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE diff --git a/docs/test_markdown_format.py b/tests/test_markdown_format.py similarity index 100% rename from docs/test_markdown_format.py rename to tests/test_markdown_format.py