fix: return 参数数量 及 返回类型考虑 (#2129)
This commit is contained in:
@@ -1,16 +1,13 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
# config_private.py放自己的秘密如API和代理网址
|
||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
||||
from toolbox import (
|
||||
get_conf,
|
||||
update_ui,
|
||||
is_the_upload_folder,
|
||||
)
|
||||
from toolbox import get_conf, is_the_upload_folder, update_ui
|
||||
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf(
|
||||
"proxies", "TIMEOUT_SECONDS", "MAX_RETRY"
|
||||
@@ -39,27 +36,35 @@ def decode_chunk(chunk):
|
||||
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
|
||||
"""
|
||||
chunk = chunk.decode()
|
||||
respose = ""
|
||||
response = ""
|
||||
reasoning_content = ""
|
||||
finish_reason = "False"
|
||||
|
||||
# 考虑返回类型是 text/json 和 text/event-stream 两种
|
||||
if chunk.startswith("data: "):
|
||||
chunk = chunk[6:]
|
||||
else:
|
||||
chunk = chunk
|
||||
|
||||
try:
|
||||
chunk = json.loads(chunk[6:])
|
||||
chunk = json.loads(chunk)
|
||||
except:
|
||||
respose = ""
|
||||
response = ""
|
||||
finish_reason = chunk
|
||||
|
||||
# 错误处理部分
|
||||
if "error" in chunk:
|
||||
respose = "API_ERROR"
|
||||
response = "API_ERROR"
|
||||
try:
|
||||
chunk = json.loads(chunk)
|
||||
finish_reason = chunk["error"]["code"]
|
||||
except:
|
||||
finish_reason = "API_ERROR"
|
||||
return respose, finish_reason
|
||||
return response, reasoning_content, finish_reason
|
||||
|
||||
try:
|
||||
if chunk["choices"][0]["delta"]["content"] is not None:
|
||||
respose = chunk["choices"][0]["delta"]["content"]
|
||||
response = chunk["choices"][0]["delta"]["content"]
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
@@ -71,7 +76,7 @@ def decode_chunk(chunk):
|
||||
finish_reason = chunk["choices"][0]["finish_reason"]
|
||||
except:
|
||||
pass
|
||||
return respose, reasoning_content, finish_reason
|
||||
return response, reasoning_content, finish_reason
|
||||
|
||||
|
||||
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
|
||||
@@ -106,7 +111,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
|
||||
what_i_ask_now["role"] = "user"
|
||||
what_i_ask_now["content"] = input
|
||||
messages.append(what_i_ask_now)
|
||||
playload = {
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
@@ -114,7 +119,7 @@ def generate_message(input, model, key, history, max_output_token, system_prompt
|
||||
"max_tokens": max_output_token,
|
||||
}
|
||||
|
||||
return headers, playload
|
||||
return headers, payload
|
||||
|
||||
|
||||
def get_predict_function(
|
||||
@@ -141,7 +146,7 @@ def get_predict_function(
|
||||
history=[],
|
||||
sys_prompt="",
|
||||
observe_window=None,
|
||||
console_slience=False,
|
||||
console_silence=False,
|
||||
):
|
||||
"""
|
||||
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
||||
@@ -162,7 +167,7 @@ def get_predict_function(
|
||||
raise RuntimeError(f"APIKEY为空,请检查配置文件的{APIKEY}")
|
||||
if inputs == "":
|
||||
inputs = "你好👋"
|
||||
headers, playload = generate_message(
|
||||
headers, payload = generate_message(
|
||||
input=inputs,
|
||||
model=llm_kwargs["llm_model"],
|
||||
key=APIKEY,
|
||||
@@ -182,7 +187,7 @@ def get_predict_function(
|
||||
endpoint,
|
||||
headers=headers,
|
||||
proxies=None if disable_proxy else proxies,
|
||||
json=playload,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TIMEOUT_SECONDS,
|
||||
)
|
||||
@@ -198,7 +203,7 @@ def get_predict_function(
|
||||
result = ""
|
||||
finish_reason = ""
|
||||
if reasoning:
|
||||
resoning_buffer = ""
|
||||
reasoning_buffer = ""
|
||||
|
||||
stream_response = response.iter_lines()
|
||||
while True:
|
||||
@@ -226,12 +231,12 @@ def get_predict_function(
|
||||
if chunk:
|
||||
try:
|
||||
if finish_reason == "stop":
|
||||
if not console_slience:
|
||||
if not console_silence:
|
||||
print(f"[response] {result}")
|
||||
break
|
||||
result += response_text
|
||||
if reasoning:
|
||||
resoning_buffer += reasoning_content
|
||||
reasoning_buffer += reasoning_content
|
||||
if observe_window is not None:
|
||||
# 观测窗,把已经获取的数据显示出去
|
||||
if len(observe_window) >= 1:
|
||||
@@ -247,9 +252,9 @@ def get_predict_function(
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError("Json解析不合常规")
|
||||
if reasoning:
|
||||
# reasoning 的部分加上框 (>)
|
||||
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + \
|
||||
'\n\n' + result
|
||||
return f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
|
||||
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in reasoning_buffer.split('\n')])}
|
||||
</div>\n\n''' + result
|
||||
return result
|
||||
|
||||
def predict(
|
||||
@@ -268,7 +273,7 @@ def get_predict_function(
|
||||
inputs 是本次问询的输入
|
||||
top_p, temperature是chatGPT的内部调优参数
|
||||
history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
|
||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||
chatbot 为WebUI中显示的对话列表,修改它,然后yield出去,可以直接修改对话界面内容
|
||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||
"""
|
||||
from .bridge_all import model_info
|
||||
@@ -299,7 +304,7 @@ def get_predict_function(
|
||||
) # 刷新界面
|
||||
time.sleep(2)
|
||||
|
||||
headers, playload = generate_message(
|
||||
headers, payload = generate_message(
|
||||
input=inputs,
|
||||
model=llm_kwargs["llm_model"],
|
||||
key=APIKEY,
|
||||
@@ -321,7 +326,7 @@ def get_predict_function(
|
||||
endpoint,
|
||||
headers=headers,
|
||||
proxies=None if disable_proxy else proxies,
|
||||
json=playload,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=TIMEOUT_SECONDS,
|
||||
)
|
||||
@@ -367,7 +372,7 @@ def get_predict_function(
|
||||
chunk_decoded = chunk.decode()
|
||||
chatbot[-1] = (
|
||||
chatbot[-1][0],
|
||||
"[Local Message] {finish_reason},获得以下报错信息:\n"
|
||||
f"[Local Message] {finish_reason},获得以下报错信息:\n"
|
||||
+ chunk_decoded,
|
||||
)
|
||||
yield from update_ui(
|
||||
@@ -385,7 +390,9 @@ def get_predict_function(
|
||||
if reasoning:
|
||||
gpt_replying_buffer += response_text
|
||||
gpt_reasoning_buffer += reasoning_content
|
||||
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
|
||||
history[-1] = f'''<div style="padding: 1em; line-height: 1.5; text-wrap: wrap; opacity: 0.8">
|
||||
{''.join([f'<p style="margin: 1.25em 0;">{line}</p>' for line in gpt_reasoning_buffer.split('\n')])}
|
||||
</div>\n\n''' + gpt_replying_buffer
|
||||
else:
|
||||
gpt_replying_buffer += response_text
|
||||
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
||||
|
||||
Reference in New Issue
Block a user