Merge branch 'master' of https://github.com/memset0/gpt_academic into memset0-master

This commit is contained in:
binary-husky
2025-01-30 22:03:31 +08:00
2 changed files with 71 additions and 53 deletions

View File

@@ -1090,18 +1090,18 @@ if "deepseekcoder" in AVAIL_LLM_MODELS: # deepseekcoder
except:
logger.error(trimmed_format_exc())
# -=-=-=-=-=-=- 幻方-深度求索大模型在线API -=-=-=-=-=-=-
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS or "deepseek-reasoner" in AVAIL_LLM_MODELS:
try:
deepseekapi_noui, deepseekapi_ui = get_predict_function(
api_key_conf_name="DEEPSEEK_API_KEY", max_output_token=4096, disable_proxy=False
)
)
model_info.update({
"deepseek-chat":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 32000,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -1114,6 +1114,16 @@ if "deepseek-chat" in AVAIL_LLM_MODELS or "deepseek-coder" in AVAIL_LLM_MODELS:
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
"deepseek-reasoner":{
"fn_with_ui": deepseekapi_ui,
"fn_without_ui": deepseekapi_noui,
"endpoint": deepseekapi_endpoint,
"can_multi_thread": True,
"max_token": 64000,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
"enable_reasoning": True
},
})
except:
logger.error(trimmed_format_exc())

View File

@@ -36,10 +36,11 @@ def get_full_error(chunk, stream_response):
def decode_chunk(chunk):
"""
用于解读"content""finish_reason"的内容
用于解读"content""finish_reason"的内容(如果支持思维链也会返回"reasoning_content"内容)
"""
chunk = chunk.decode()
respose = ""
reasoning_content = ""
finish_reason = "False"
try:
chunk = json.loads(chunk[6:])
@@ -57,14 +58,20 @@ def decode_chunk(chunk):
return respose, finish_reason
try:
respose = chunk["choices"][0]["delta"]["content"]
if chunk["choices"][0]["delta"]["content"] is not None:
respose = chunk["choices"][0]["delta"]["content"]
except:
pass
try:
if chunk["choices"][0]["delta"]["reasoning_content"] is not None:
reasoning_content = chunk["choices"][0]["delta"]["reasoning_content"]
except:
pass
try:
finish_reason = chunk["choices"][0]["finish_reason"]
except:
pass
return respose, finish_reason
return respose, reasoning_content, finish_reason
def generate_message(input, model, key, history, max_output_token, system_prompt, temperature):
@@ -163,29 +170,23 @@ def get_predict_function(
system_prompt=sys_prompt,
temperature=llm_kwargs["temperature"],
)
from .bridge_all import model_info
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
retry = 0
while True:
try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
@@ -194,10 +195,13 @@ def get_predict_function(
raise TimeoutError
if MAX_RETRY != 0:
logger.error(f"请求超时,正在重试 ({retry}/{MAX_RETRY}) ……")
stream_response = response.iter_lines()
result = ""
finish_reason = ""
if reasoning:
resoning_buffer = ""
stream_response = response.iter_lines()
while True:
try:
chunk = next(stream_response)
@@ -207,9 +211,9 @@ def get_predict_function(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk)
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False":
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
continue
if response_text == "API_ERROR" and (
finish_reason != "False" or finish_reason != "stop"
@@ -227,6 +231,8 @@ def get_predict_function(
print(f"[response] {result}")
break
result += response_text
if reasoning:
resoning_buffer += reasoning_content
if observe_window is not None:
# 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1:
@@ -241,6 +247,8 @@ def get_predict_function(
error_msg = chunk_decoded
logger.error(error_msg)
raise RuntimeError("Json解析不合常规")
if reasoning:
return '\n'.join(map(lambda x: '> ' + x, resoning_buffer.split('\n'))) + '\n\n' + result
return result
def predict(
@@ -298,32 +306,25 @@ def get_predict_function(
system_prompt=system_prompt,
temperature=llm_kwargs["temperature"],
)
from .bridge_all import model_info
reasoning = model_info[llm_kwargs['llm_model']].get('enable_reasoning', False)
history.append(inputs)
history.append("")
retry = 0
while True:
try:
from .bridge_all import model_info
endpoint = model_info[llm_kwargs["llm_model"]]["endpoint"]
if not disable_proxy:
response = requests.post(
endpoint,
headers=headers,
proxies=proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
else:
response = requests.post(
endpoint,
headers=headers,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
response = requests.post(
endpoint,
headers=headers,
proxies=None if disable_proxy else proxies,
json=playload,
stream=True,
timeout=TIMEOUT_SECONDS,
)
break
except:
retry += 1
@@ -338,6 +339,8 @@ def get_predict_function(
raise TimeoutError
gpt_replying_buffer = ""
if reasoning:
gpt_reasoning_buffer = ""
stream_response = response.iter_lines()
while True:
@@ -347,9 +350,9 @@ def get_predict_function(
break
except requests.exceptions.ConnectionError:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
response_text, finish_reason = decode_chunk(chunk)
response_text, reasoning_content, finish_reason = decode_chunk(chunk)
# 返回的数据流第一次为空,继续等待
if response_text == "" and finish_reason != "False":
if response_text == "" and (reasoning == False or reasoning_content == "") and finish_reason != "False":
status_text = f"finish_reason: {finish_reason}"
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text
@@ -379,9 +382,14 @@ def get_predict_function(
logger.info(f"[response] {gpt_replying_buffer}")
break
status_text = f"finish_reason: {finish_reason}"
gpt_replying_buffer += response_text
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
history[-1] = gpt_replying_buffer
if reasoning:
gpt_replying_buffer += response_text
gpt_reasoning_buffer += reasoning_content
history[-1] = '\n'.join(map(lambda x: '> ' + x, gpt_reasoning_buffer.split('\n'))) + '\n\n' + gpt_replying_buffer
else:
gpt_replying_buffer += response_text
# 如果这里抛出异常一般是文本过长详情见get_full_error的输出
history[-1] = gpt_replying_buffer
chatbot[-1] = (history[-2], history[-1])
yield from update_ui(
chatbot=chatbot, history=history, msg=status_text