feat: add support for R1 model and display CoT
This commit is contained in:
@@ -1071,18 +1071,18 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
|
|||||||
except:
|
except:
|
||||||
logger.error(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
|
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
|
||||||
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
|
||||||
try:
|
try:
|
||||||
deepseekapi_noui, deepseekapi_ui = get_predict_function(
|
deepseekapi_noui, deepseekapi_ui = get_predict_function(
|
||||||
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
|
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
|
||||||
)
|
)
|
||||||
model_info.update({
|
model_info.update({
|
||||||
"deepseek-chat":{
|
"deepseek-chat":{
|
||||||
"fn_with_ui": deepseekapi_ui,
|
"fn_with_ui": deepseekapi_ui,
|
||||||
"fn_without_ui": deepseekapi_noui,
|
"fn_without_ui": deepseekapi_noui,
|
||||||
"endpoint": deepseekapi_endpoint,
|
"endpoint": deepseekapi_endpoint,
|
||||||
"can_multi_thread": True,
|
"can_multi_thread": True,
|
||||||
"max_token": 32000,
|
"max_token": 64000,
|
||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
@@ -1095,6 +1095,16 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
|
|||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
|
"deepseek-reasoner":{
|
||||||
|
"fn_with_ui": deepseekapi_ui,
|
||||||
|
"fn_without_ui": deepseekapi_noui,
|
||||||
|
"endpoint": deepseekapi_endpoint,
|
||||||
|
"can_multi_thread": True,
|
||||||
|
"max_token": 64000,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
"enable_reasoning": True
|
||||||
|
},
|
||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
logger.error(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
|
|||||||
@@ -36,11 +36,12 @@ def get_full_error(chunk, stream_response):
|
|||||||
|
|
||||||
def decode_chunk(chunk):
|
def decode_chunk(chunk):
|
||||||
"""
|
"""
|
||||||
用于解读"content"和"finish_reason"的内容
|
用于解读"content"和"finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
|
||||||
"""
|
"""
|
||||||
chunk = chunk.decode()
|
chunk = chunk.decode()
|
||||||
respose = ""
|
respose = ""
|
||||||
finish_reason = "False"
|
finish_reason = "False"
|
||||||
|
reasoning_content = ""
|
||||||
try:
|
try:
|
||||||
chunk = json.loads(chunk[6:])
|
chunk = json.loads(chunk[6:])
|
||||||
except:
|
except:
|
||||||
@@ -57,14 +58,20 @@ def decode_chunk(chunk):
|
|||||||
return respose, finish_reason
|
return respose, finish_reason
|
||||||
|
|
||||||
try:
|
try:
|
||||||
respose = chunk["choices"][0]["delta"]["content"]
|
if chunk["choices"][0]["delta"]["content"] is not None:
|
||||||
|
respose = chunk["choices"][0]["delta"]["content"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
finish_reason = chunk["choices"][0]["finish_reason"]
|
finish_reason = chunk["choices"][0]["finish_reason"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return respose, finish_reason
|
try:
|
||||||
|
if chunk["choices"][0]["delta"]["reasoning_content"] is not None:
|
||||||
|
reasoning_content = chunk["choices"][0]["delta"]["reasoning_content"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return respose, finish_reason, reasoning_content
|
||||||
|
|
||||||
|
|
||||||
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
|
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
|
||||||
@@ -163,29 +170,23 @@ def get_predict_function(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
temperature=llm_kwargs["temperature"],
|
temperature=llm_kwargs["temperature"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .bridge_all import model_info
|
||||||
|
|
||||||
|
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
|
||||||
|
|
||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
from .bridge_all import model_info
|
|
||||||
|
|
||||||
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
|
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
|
||||||
if not disable_proxy:
|
response = requests.post(
|
||||||
response = requests.post(
|
endpoint,
|
||||||
endpoint,
|
headers=headers,
|
||||||
headers=headers,
|
proxies=None if disable_proxy else proxies,
|
||||||
proxies=proxies,
|
json=playload,
|
||||||
json=playload,
|
stream=True,
|
||||||
stream=True,
|
timeout=TIMEOUT_SECONDS,
|
||||||
timeout=TIMEOUT_SECONDS,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = requests.post(
|
|
||||||
endpoint,
|
|
||||||
headers=headers,
|
|
||||||
json=playload,
|
|
||||||
stream=True,
|
|
||||||
timeout=TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
retry += 1
|
retry += 1
|
||||||
@@ -194,10 +195,13 @@ def get_predict_function(
|
|||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
if MAX_RETRY != 0:
|
if MAX_RETRY != 0:
|
||||||
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
|
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
|
||||||
result = ""
|
result = ""
|
||||||
finish_reason = ""
|
finish_reason = ""
|
||||||
|
if reasoning:
|
||||||
|
resoning_buffer = ""
|
||||||
|
|
||||||
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = next(stream_response)
|
chunk = next(stream_response)
|
||||||
@@ -207,9 +211,12 @@ def get_predict_function(
|
|||||||
break
|
break
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
||||||
response_text, finish_reason = decode_chunk(chunk)
|
if reasoning:
|
||||||
|
response_text, finish_reason, reasoning_content = decode_chunk(chunk)
|
||||||
|
else:
|
||||||
|
response_text, finish_reason = decode_chunk(chunk)
|
||||||
# 返回的数据流第一次为空,继续等待
|
# 返回的数据流第一次为空,继续等待
|
||||||
if response_text == "" and finish_reason != "False":
|
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
|
||||||
continue
|
continue
|
||||||
if response_text == "API_ERROR" and (
|
if response_text == "API_ERROR" and (
|
||||||
finish_reason != "False" or finish_reason != "stop"
|
finish_reason != "False" or finish_reason != "stop"
|
||||||
@@ -227,6 +234,8 @@ def get_predict_function(
|
|||||||
print(f"[response] {result}")
|
print(f"[response] {result}")
|
||||||
break
|
break
|
||||||
result += response_text
|
result += response_text
|
||||||
|
if reasoning:
|
||||||
|
resoning_buffer += reasoning_content
|
||||||
if observe_window is not None:
|
if observe_window is not None:
|
||||||
# 观测窗,把已经获取的数据显示出去
|
# 观测窗,把已经获取的数据显示出去
|
||||||
if len(observe_window) >= 1:
|
if len(observe_window) >= 1:
|
||||||
@@ -241,6 +250,8 @@ def get_predict_function(
|
|||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
raise RuntimeError("Json解析不合常规")
|
raise RuntimeError("Json解析不合常规")
|
||||||
|
if reasoning:
|
||||||
|
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + '\n\n' + result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -298,32 +309,25 @@ def get_predict_function(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
temperature=llm_kwargs["temperature"],
|
temperature=llm_kwargs["temperature"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .bridge_all import model_info
|
||||||
|
|
||||||
|
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
|
||||||
|
|
||||||
history.append(inputs)
|
history.append(inputs)
|
||||||
history.append("")
|
history.append("")
|
||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
from .bridge_all import model_info
|
|
||||||
|
|
||||||
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
|
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
|
||||||
if not disable_proxy:
|
response = requests.post(
|
||||||
response = requests.post(
|
endpoint,
|
||||||
endpoint,
|
headers=headers,
|
||||||
headers=headers,
|
proxies=None if disable_proxy else proxies,
|
||||||
proxies=proxies,
|
json=playload,
|
||||||
json=playload,
|
stream=True,
|
||||||
stream=True,
|
timeout=TIMEOUT_SECONDS,
|
||||||
timeout=TIMEOUT_SECONDS,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = requests.post(
|
|
||||||
endpoint,
|
|
||||||
headers=headers,
|
|
||||||
json=playload,
|
|
||||||
stream=True,
|
|
||||||
timeout=TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
retry += 1
|
retry += 1
|
||||||
@@ -338,6 +342,8 @@ def get_predict_function(
|
|||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
|
|
||||||
gpt_replying_buffer = ""
|
gpt_replying_buffer = ""
|
||||||
|
if reasoning:
|
||||||
|
gpt_reasoning_buffer = ""
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
@@ -347,9 +353,12 @@ def get_predict_function(
|
|||||||
break
|
break
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
|
||||||
response_text, finish_reason = decode_chunk(chunk)
|
if reasoning:
|
||||||
|
response_text, finish_reason, reasoning_content = decode_chunk(chunk)
|
||||||
|
else:
|
||||||
|
response_text, finish_reason = decode_chunk(chunk)
|
||||||
# 返回的数据流第一次为空,继续等待
|
# 返回的数据流第一次为空,继续等待
|
||||||
if response_text == "" and finish_reason != "False":
|
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
|
||||||
status_text = f"finish_reason: {finish_reason}"
|
status_text = f"finish_reason: {finish_reason}"
|
||||||
yield from update_ui(
|
yield from update_ui(
|
||||||
chatbot=chatbot, history=history, msg=status_text
|
chatbot=chatbot, history=history, msg=status_text
|
||||||
@@ -379,9 +388,14 @@ def get_predict_function(
|
|||||||
logger.info(f"[response] {gpt_replying_buffer}")
|
logger.info(f"[response] {gpt_replying_buffer}")
|
||||||
break
|
break
|
||||||
status_text = f"finish_reason: {finish_reason}"
|
status_text = f"finish_reason: {finish_reason}"
|
||||||
gpt_replying_buffer += response_text
|
if reasoning:
|
||||||
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
gpt_replying_buffer += response_text
|
||||||
history[-1] = gpt_replying_buffer
|
gpt_reasoning_buffer += reasoning_content
|
||||||
|
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
|
||||||
|
else:
|
||||||
|
gpt_replying_buffer += response_text
|
||||||
|
# 如果这里抛出异常,一般是文本过长,详情见get_full_error的输出
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
chatbot[-1] = (history[-2], history[-1])
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
yield from update_ui(
|
yield from update_ui(
|
||||||
chatbot=chatbot, history=history, msg=status_text
|
chatbot=chatbot, history=history, msg=status_text
|
||||||
|
|||||||
Reference in New Issue
Block a user