适配星火大模型图片理解 增加上传图片view
This commit is contained in:
@@ -16,28 +16,13 @@ import base64
|
||||
import os
|
||||
import glob
|
||||
|
||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, update_ui_lastest_msg, get_max_token
|
||||
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, is_the_upload_folder, update_ui_lastest_msg, get_max_token, encode_image, have_any_recent_upload_image_files
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG, AZURE_CFG_ARRAY = \
|
||||
get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG', 'AZURE_CFG_ARRAY')
|
||||
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||
|
||||
def have_any_recent_upload_image_files(chatbot):
|
||||
_5min = 5 * 60
|
||||
if chatbot is None: return False, None # chatbot is None
|
||||
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
||||
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
|
||||
if time.time() - most_recent_uploaded["time"] < _5min:
|
||||
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
||||
path = most_recent_uploaded['path']
|
||||
file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
|
||||
file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
|
||||
file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
|
||||
if len(file_manifest) == 0: return False, None
|
||||
return True, file_manifest # most_recent_uploaded is new
|
||||
else:
|
||||
return False, None # most_recent_uploaded is too old
|
||||
|
||||
def report_invalid_key(key):
|
||||
if get_conf("BLOCK_INVALID_APIKEY"):
|
||||
@@ -258,10 +243,6 @@ def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg,
|
||||
chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
|
||||
return chatbot, history
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
def generate_payload(inputs, llm_kwargs, history, system_prompt, image_paths):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from toolbox import get_conf
|
||||
from toolbox import get_conf, get_pictures_list, encode_image
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
@@ -65,6 +65,7 @@ class SparkRequestInstance():
|
||||
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image"
|
||||
|
||||
self.time_to_yield_event = threading.Event()
|
||||
self.time_to_exit_event = threading.Event()
|
||||
@@ -92,7 +93,11 @@ class SparkRequestInstance():
|
||||
gpt_url = self.gpt_url_v3
|
||||
else:
|
||||
gpt_url = self.gpt_url
|
||||
|
||||
file_manifest = []
|
||||
if llm_kwargs.get('most_recent_uploaded'):
|
||||
if llm_kwargs['most_recent_uploaded'].get('path'):
|
||||
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
|
||||
gpt_url = self.gpt_url_img
|
||||
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
@@ -101,9 +106,8 @@ class SparkRequestInstance():
|
||||
def on_open(ws):
|
||||
import _thread as thread
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(ws.appid, *ws.all_args))
|
||||
data = json.dumps(gen_params(ws.appid, *ws.all_args, file_manifest))
|
||||
ws.send(data)
|
||||
|
||||
# 收到websocket消息的处理
|
||||
@@ -142,9 +146,18 @@ class SparkRequestInstance():
|
||||
ws.all_args = (inputs, llm_kwargs, history, system_prompt)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
|
||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest):
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages = []
|
||||
if file_manifest:
|
||||
base64_images = []
|
||||
for image_path in file_manifest:
|
||||
base64_images.append(encode_image(image_path))
|
||||
for img_s in base64_images:
|
||||
if img_s not in str(messages):
|
||||
messages.append({"role": "user", "content": img_s, "content_type": "image"})
|
||||
else:
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2*conversation_cnt, 2):
|
||||
what_i_have_asked = {}
|
||||
@@ -167,7 +180,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
|
||||
return messages
|
||||
|
||||
|
||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
@@ -176,6 +189,8 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||
"sparkv2": "generalv2",
|
||||
"sparkv3": "generalv3",
|
||||
}
|
||||
domains_select = domains[llm_kwargs['llm_model']]
|
||||
if file_manifest: domains_select = 'image'
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
@@ -183,7 +198,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domains[llm_kwargs['llm_model']],
|
||||
"domain": domains_select,
|
||||
"temperature": llm_kwargs["temperature"],
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 4096,
|
||||
@@ -192,7 +207,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user