feat: add support for R1 model and display CoT

This commit is contained in:
memset0
2025-01-24 14:43:49 +08:00
parent 1213ef19e5
commit a1f7ae5b55
2 changed files with 77 additions and 53 deletions

View File

@@ -1071,18 +1071,18 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
except: except:
logger.error(trimmed_format_exc()) logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=- # -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS: if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
try: try:
deepseekapi_noui, deepseekapi_ui = get_predict_function( deepseekapi_noui, deepseekapi_ui = get_predict_function(
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
) )
model_info.update({ model_info.update({
"deepseek-chat":{ "deepseek-chat":{
"fn_with_ui": deepseekapi_ui, "fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui, "fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint, "endpoint": deepseekapi_endpoint,
"can_multi_thread": True, "can_multi_thread": True,
"max_token": 32000, "max_token": 64000,
"tokenizer": tokenizer_gpt35, "tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35, "token_cnt": get_token_num_gpt35,
}, },
@@ -1095,6 +1095,16 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
"tokenizer": tokenizer_gpt35, "tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35, "token_cnt": get_token_num_gpt35,
}, },
"deepseek-reasoner":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
"enable_reasoning": True
},
}) })
except: except:
logger.error(trimmed_format_exc()) logger.error(trimmed_format_exc())

View File

@@ -36,11 +36,12 @@ def get_full_error(chunk, stream_response):
def decode_chunk(chunk): def decode_chunk(chunk):
""" """
用于解读"content""finish_reason"的内容 用于解读"content""finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
""" """
chunk = chunk.decode() chunk = chunk.decode()
respose = "" respose = ""
finish_reason = "False" finish_reason = "False"
reasoning_content = ""
try: try:
chunk = json.loads(chunk[6:]) chunk = json.loads(chunk[6:])
except: except:
@@ -57,14 +58,20 @@ def decode_chunk(chunk):
return respose, finish_reason return respose, finish_reason
try: try:
respose = chunk["choices"][0]["delta"]["content"] if chunk["choices"][0]["delta"]["content"] is not None:
respose = chunk["choices"][0]["delta"]["content"]
except: except:
pass pass
try: try:
finish_reason = chunk["choices"][0]["finish_reason"] finish_reason = chunk["choices"][0]["finish_reason"]
except: except:
pass pass
return respose, finish_reason try:
if chunk["choices"][0]["delta"]["reasoning_content"] is not None:
reasoning_content = chunk["choices"][0]["delta"]["reasoning_content"]
except:
pass
return respose, finish_reason, reasoning_content
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature): def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
@@ -163,29 +170,23 @@ def get_predict_function(
system_prompt=sys_prompt, system_prompt=sys_prompt,
temperature=llm_kwargs["temperature"], temperature=llm_kwargs["temperature"],
) )
from .bridge_all import model_info
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
retry = 0 retry = 0
while True: while True:
try: try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"] endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy: response = requests.post(
response = requests.post( endpoint,
endpoint, headers=headers,
headers=headers, proxies=None if disable_proxy else proxies,
proxies=proxies, json=playload,
json=playload, stream=True,
stream=True, timeout=TIMEOUT_SECONDS,
timeout=TIMEOUT_SECONDS, )
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break break
except: except:
retry += 1 retry += 1
@@ -194,10 +195,13 @@ def get_predict_function(
raise TimeoutError raise TimeoutError
if MAX_RETRY != 0: if MAX_RETRY != 0:
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……") logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
stream_response = response.iter_lines()
result = "" result = ""
finish_reason = "" finish_reason = ""
if reasoning:
resoning_buffer = ""
stream_response = response.iter_lines()
while True: while True:
try: try:
chunk = next(stream_response) chunk = next(stream_response)
@@ -207,9 +211,12 @@ def get_predict_function(
break break
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。 chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk) if reasoning:
response_text, finish_reason, reasoning_content = decode_chunk(chunk)
else:
response_text, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待 # 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False": if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
continue continue
if response_text == "API_ERROR" and ( if response_text == "API_ERROR" and (
finish_reason != "False" or finish_reason != "stop" finish_reason != "False" or finish_reason != "stop"
@@ -227,6 +234,8 @@ def get_predict_function(
print(f"[response] {result}") print(f"[response] {result}")
break break
result += response_text result += response_text
if reasoning:
resoning_buffer += reasoning_content
if observe_window is not None: if observe_window is not None:
# 观测窗,把已经获取的数据显示出去 # 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1: if len(observe_window) >= 1:
@@ -241,6 +250,8 @@ def get_predict_function(
error_msg = chunk_decoded error_msg = chunk_decoded
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError("Json解析不合常规") raise RuntimeError("Json解析不合常规")
if reasoning:
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + '\n\n' + result
return result return result
def predict( def predict(
@@ -298,32 +309,25 @@ def get_predict_function(
system_prompt=system_prompt, system_prompt=system_prompt,
temperature=llm_kwargs["temperature"], temperature=llm_kwargs["temperature"],
) )
from .bridge_all import model_info
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
history.append(inputs) history.append(inputs)
history.append("") history.append("")
retry = 0 retry = 0
while True: while True:
try: try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"] endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy: response = requests.post(
response = requests.post( endpoint,
endpoint, headers=headers,
headers=headers, proxies=None if disable_proxy else proxies,
proxies=proxies, json=playload,
json=playload, stream=True,
stream=True, timeout=TIMEOUT_SECONDS,
timeout=TIMEOUT_SECONDS, )
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break break
except: except:
retry += 1 retry += 1
@@ -338,6 +342,8 @@ def get_predict_function(
raise TimeoutError raise TimeoutError
gpt_replying_buffer = "" gpt_replying_buffer = ""
if reasoning:
gpt_reasoning_buffer = ""
stream_response = response.iter_lines() stream_response = response.iter_lines()
while True: while True:
@@ -347,9 +353,12 @@ def get_predict_function(
break break
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。 chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk) if reasoning:
response_text, finish_reason, reasoning_content = decode_chunk(chunk)
else:
response_text, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待 # 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False": if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
status_text = f"finish_reason: {finish_reason}" status_text = f"finish_reason: {finish_reason}"
yield from update_ui( yield from update_ui(
chatbot=chatbot, history=history, msg=status_text chatbot=chatbot, history=history, msg=status_text
@@ -379,9 +388,14 @@ def get_predict_function(
logger.info(f"[response] {gpt_replying_buffer}") logger.info(f"[response] {gpt_replying_buffer}")
break break
status_text = f"finish_reason: {finish_reason}" status_text = f"finish_reason: {finish_reason}"
gpt_replying_buffer += response_text if reasoning:
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出 gpt_replying_buffer += response_text
history[-1] = gpt_replying_buffer gpt_reasoning_buffer += reasoning_content
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
else:
gpt_replying_buffer += response_text
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1]) chatbot[-1] = (history[-2], history[-1])
yield from update_ui( yield from update_ui(
chatbot=chatbot, history=history, msg=status_text chatbot=chatbot, history=history, msg=status_text