fix: support o1 models
This commit is contained in:
@@ -133,22 +133,33 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
||||
observe_window = None:
|
||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
||||
|
||||
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||
else: stream = True
|
||||
|
||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=False
|
||||
from .bridge_all import model_info
|
||||
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
||||
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
|
||||
except requests.exceptions.ReadTimeout as e:
|
||||
retry += 1
|
||||
traceback.print_exc()
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||
|
||||
if not stream:
|
||||
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||
chunkjson = json.loads(response.content.decode())
|
||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||
return gpt_replying_buffer
|
||||
|
||||
stream_response = response.iter_lines()
|
||||
result = ''
|
||||
json_data = None
|
||||
@@ -208,7 +219,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||
"""
|
||||
from .bridge_all import model_info
|
||||
from request_llms.bridge_all import model_info
|
||||
if is_any_api_key(inputs):
|
||||
chatbot._cookies['api_key'] = inputs
|
||||
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
||||
@@ -237,6 +248,10 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chatbot.append((_inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||
|
||||
# 禁用stream的特殊模型处理
|
||||
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||
else: stream = True
|
||||
|
||||
# check mis-behavior
|
||||
if is_the_upload_folder(user_input):
|
||||
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
||||
@@ -270,7 +285,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
try:
|
||||
# make a POST request to the API endpoint, stream=True
|
||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
||||
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
|
||||
except:
|
||||
retry += 1
|
||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||
@@ -278,10 +293,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
|
||||
gpt_replying_buffer = ""
|
||||
|
||||
is_head_of_the_stream = True
|
||||
if not stream:
|
||||
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
|
||||
return
|
||||
|
||||
if stream:
|
||||
gpt_replying_buffer = ""
|
||||
is_head_of_the_stream = True
|
||||
stream_response = response.iter_lines()
|
||||
while True:
|
||||
try:
|
||||
@@ -343,12 +363,24 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
chunk_decoded = chunk.decode()
|
||||
error_msg = chunk_decoded
|
||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
|
||||
print(error_msg)
|
||||
return
|
||||
return # return from stream-branch
|
||||
|
||||
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
||||
try:
|
||||
chunkjson = json.loads(response.content.decode())
|
||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||
history[-1] = gpt_replying_buffer
|
||||
chatbot[-1] = (history[-2], history[-1])
|
||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||
except Exception as e:
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面
|
||||
|
||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||
from .bridge_all import model_info
|
||||
from request_llms.bridge_all import model_info
|
||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||
if "reduce the length" in error_msg:
|
||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||
@@ -381,6 +413,8 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
"""
|
||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||
"""
|
||||
from request_llms.bridge_all import model_info
|
||||
|
||||
if not is_any_api_key(llm_kwargs['api_key']):
|
||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||
|
||||
@@ -409,10 +443,16 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
else:
|
||||
enable_multimodal_capacity = False
|
||||
|
||||
conversation_cnt = len(history) // 2
|
||||
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)
|
||||
|
||||
if openai_disable_system_prompt:
|
||||
messages = []
|
||||
else:
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
if not enable_multimodal_capacity:
|
||||
# 不使用多模态能力
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
@@ -434,8 +474,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
||||
messages.append(what_i_ask_now)
|
||||
else:
|
||||
# 多模态能力
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
|
||||
Reference in New Issue
Block a user