Merge branch 'rag' into frontier
This commit is contained in:
@@ -407,22 +407,46 @@ model_info = {
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
# Gemini
|
||||
# Note: now gemini-pro is an alias of gemini-1.0-pro.
|
||||
# Warning: gemini-pro-vision has been deprecated.
|
||||
# Support for gemini-pro-vision has been removed.
|
||||
"gemini-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": False,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-pro-vision": {
|
||||
"gemini-1.0-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": False,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-1.5-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": True,
|
||||
"max_token": 1024 * 204800,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": gemini_endpoint,
|
||||
"has_multimodal_capacity": True,
|
||||
"max_token": 1024 * 204800,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
|
||||
# cohere
|
||||
"cohere-command-r-plus": {
|
||||
@@ -857,7 +881,7 @@ if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
if "sparkv3" in AVAIL_LLM_MODELS or "sparkv3.5" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
|
||||
if any(x in AVAIL_LLM_MODELS for x in ("sparkv3", "sparkv3.5", "sparkv4")): # 讯飞星火认知大模型
|
||||
try:
|
||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||
from .bridge_spark import predict as spark_ui
|
||||
|
||||
@@ -8,15 +8,15 @@ import os
|
||||
import time
|
||||
from request_llms.com_google import GoogleChatInit
|
||||
from toolbox import ChatBotWithCookies
|
||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat
|
||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc, log_chat, encode_image
|
||||
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||
|
||||
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
||||
console_slience=False):
|
||||
def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], sys_prompt:str="", observe_window:list=[],
|
||||
console_slience:bool=False):
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
||||
@@ -44,9 +44,20 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
||||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
||||
return gpt_replying_buffer
|
||||
|
||||
def make_media_input(inputs, image_paths):
|
||||
image_base64_array = []
|
||||
for image_path in image_paths:
|
||||
path = os.path.abspath(image_path)
|
||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={path}"></div>'
|
||||
base64 = encode_image(path)
|
||||
image_base64_array.append(base64)
|
||||
return inputs, image_base64_array
|
||||
|
||||
def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWithCookies,
|
||||
history:list=[], system_prompt:str='', stream:bool=True, additional_fn:str=None):
|
||||
|
||||
from .bridge_all import model_info
|
||||
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||
@@ -57,18 +68,17 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
from core_functional import handle_core_functionality
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
if "vision" in llm_kwargs["llm_model"]:
|
||||
have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
|
||||
if not have_recent_file:
|
||||
chatbot.append((inputs, "没有检测到任何近期上传的图像文件,请上传jpg格式的图片,此外,请注意拓展名需要小写"))
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待图片") # 刷新界面
|
||||
return
|
||||
def make_media_input(inputs, image_paths):
|
||||
for image_path in image_paths:
|
||||
inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
|
||||
return inputs
|
||||
if have_recent_file:
|
||||
inputs = make_media_input(inputs, image_paths)
|
||||
# multimodal capacity
|
||||
# inspired by codes in bridge_chatgpt
|
||||
has_multimodal_capacity = model_info[llm_kwargs['llm_model']].get('has_multimodal_capacity', False)
|
||||
if has_multimodal_capacity:
|
||||
has_recent_image_upload, image_paths = have_any_recent_upload_image_files(chatbot, pop=True)
|
||||
else:
|
||||
has_recent_image_upload, image_paths = False, []
|
||||
if has_recent_image_upload:
|
||||
inputs, image_base64_array = make_media_input(inputs, image_paths)
|
||||
else:
|
||||
inputs, image_base64_array = inputs, []
|
||||
|
||||
chatbot.append((inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
@@ -76,7 +86,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity)
|
||||
break
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
@@ -112,7 +122,6 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import re
|
||||
import requests
|
||||
from typing import List, Dict, Tuple
|
||||
from toolbox import get_conf, encode_image, get_pictures_list, to_markdown_tabs
|
||||
from toolbox import get_conf, update_ui, encode_image, get_pictures_list, to_markdown_tabs
|
||||
|
||||
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
||||
|
||||
@@ -112,6 +112,14 @@ def html_local_img(__file, layout="left", max_width=None, max_height=None, md=Tr
|
||||
return a
|
||||
|
||||
|
||||
def reverse_base64_from_input(inputs):
|
||||
pattern = re.compile(r'<br/><br/><div align="center"><img[^<>]+base64="([^"]+)"></div>')
|
||||
base64_strings = pattern.findall(inputs)
|
||||
return base64_strings
|
||||
|
||||
def contain_base64(inputs):
|
||||
base64_strings = reverse_base64_from_input(inputs)
|
||||
return len(base64_strings) > 0
|
||||
|
||||
class GoogleChatInit:
|
||||
def __init__(self, llm_kwargs):
|
||||
@@ -119,9 +127,9 @@ class GoogleChatInit:
|
||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
||||
self.url_gemini = endpoint + "/%m:streamGenerateContent?key=%k"
|
||||
|
||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt, image_base64_array:list=[], has_multimodal_capacity:bool=False):
|
||||
headers, payload = self.generate_message_payload(
|
||||
inputs, llm_kwargs, history, system_prompt
|
||||
inputs, llm_kwargs, history, system_prompt, image_base64_array, has_multimodal_capacity
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.url_gemini,
|
||||
@@ -133,13 +141,16 @@ class GoogleChatInit:
|
||||
)
|
||||
return response.iter_lines()
|
||||
|
||||
def __conversation_user(self, user_input, llm_kwargs):
|
||||
def __conversation_user(self, user_input, llm_kwargs, enable_multimodal_capacity):
|
||||
what_i_have_asked = {"role": "user", "parts": []}
|
||||
if "vision" not in self.url_gemini:
|
||||
from .bridge_all import model_info
|
||||
|
||||
if enable_multimodal_capacity:
|
||||
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
||||
else:
|
||||
input_ = user_input
|
||||
encode_img = []
|
||||
else:
|
||||
input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
|
||||
|
||||
what_i_have_asked["parts"].append({"text": input_})
|
||||
if encode_img:
|
||||
for data in encode_img:
|
||||
@@ -153,12 +164,12 @@ class GoogleChatInit:
|
||||
)
|
||||
return what_i_have_asked
|
||||
|
||||
def __conversation_history(self, history, llm_kwargs):
|
||||
def __conversation_history(self, history, llm_kwargs, enable_multimodal_capacity):
|
||||
messages = []
|
||||
conversation_cnt = len(history) // 2
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2 * conversation_cnt, 2):
|
||||
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
|
||||
what_i_have_asked = self.__conversation_user(history[index], llm_kwargs, enable_multimodal_capacity)
|
||||
what_gpt_answer = {
|
||||
"role": "model",
|
||||
"parts": [{"text": history[index + 1]}],
|
||||
@@ -168,7 +179,7 @@ class GoogleChatInit:
|
||||
return messages
|
||||
|
||||
def generate_message_payload(
|
||||
self, inputs, llm_kwargs, history, system_prompt
|
||||
self, inputs, llm_kwargs, history, system_prompt, image_base64_array:list=[], has_multimodal_capacity:bool=False
|
||||
) -> Tuple[Dict, Dict]:
|
||||
messages = [
|
||||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||
@@ -179,21 +190,29 @@ class GoogleChatInit:
|
||||
"%m", llm_kwargs["llm_model"]
|
||||
).replace("%k", get_conf("GEMINI_API_KEY"))
|
||||
header = {"Content-Type": "application/json"}
|
||||
if "vision" not in self.url_gemini: # 不是vision 才处理history
|
||||
|
||||
if has_multimodal_capacity:
|
||||
enable_multimodal_capacity = (len(image_base64_array) > 0) or any([contain_base64(h) for h in history])
|
||||
else:
|
||||
enable_multimodal_capacity = False
|
||||
|
||||
if not enable_multimodal_capacity:
|
||||
messages.extend(
|
||||
self.__conversation_history(history, llm_kwargs)
|
||||
self.__conversation_history(history, llm_kwargs, enable_multimodal_capacity)
|
||||
) # 处理 history
|
||||
messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
|
||||
|
||||
messages.append(self.__conversation_user(inputs, llm_kwargs, enable_multimodal_capacity)) # 处理用户对话
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
# "maxOutputTokens": 800,
|
||||
# "maxOutputTokens": llm_kwargs.get("max_token", 1024),
|
||||
"stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
|
||||
"temperature": llm_kwargs.get("temperature", 1),
|
||||
"topP": llm_kwargs.get("top_p", 0.8),
|
||||
"topK": 10,
|
||||
},
|
||||
}
|
||||
|
||||
return header, payload
|
||||
|
||||
|
||||
|
||||
40
request_llms/embed_models/bridge_all_embed.py
Normal file
40
request_llms/embed_models/bridge_all_embed.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import tiktoken, copy, re
|
||||
from functools import lru_cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from toolbox import get_conf, trimmed_format_exc, apply_gpt_academic_string_mask, read_one_api_model_name
|
||||
|
||||
# Endpoint 重定向
|
||||
API_URL_REDIRECT, AZURE_ENDPOINT, AZURE_ENGINE = get_conf("API_URL_REDIRECT", "AZURE_ENDPOINT", "AZURE_ENGINE")
|
||||
openai_endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
|
||||
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
|
||||
|
||||
|
||||
if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint]
|
||||
|
||||
openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings")
|
||||
|
||||
from .openai_embed import OpenAiEmbeddingModel
|
||||
|
||||
embed_model_info = {
|
||||
# text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536
|
||||
"text-embedding-3-small": {
|
||||
"embed_class": OpenAiEmbeddingModel,
|
||||
"embed_endpoint": openai_embed_endpoint,
|
||||
"embed_dimension": 1536,
|
||||
},
|
||||
|
||||
# text-embedding-3-large Most capable embedding model for both english and non-english tasks | 3,072
|
||||
"text-embedding-3-large": {
|
||||
"embed_class": OpenAiEmbeddingModel,
|
||||
"embed_endpoint": openai_embed_endpoint,
|
||||
"embed_dimension": 3072,
|
||||
},
|
||||
|
||||
# text-embedding-ada-002 Most capable 2nd generation embedding model, replacing 16 first generation models | 1,536
|
||||
"text-embedding-ada-002": {
|
||||
"embed_class": OpenAiEmbeddingModel,
|
||||
"embed_endpoint": openai_embed_endpoint,
|
||||
"embed_dimension": 1536,
|
||||
},
|
||||
}
|
||||
79
request_llms/embed_models/openai_embed.py
Normal file
79
request_llms/embed_models/openai_embed.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from openai import OpenAI
|
||||
from toolbox import get_conf
|
||||
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate
|
||||
from shared_utils.key_pattern_manager import select_api_key_for_embed_models
|
||||
from typing import List, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
def mean_agg(embeddings):
|
||||
"""Mean aggregation for embeddings."""
|
||||
return np.array(embeddings).mean(axis=0).tolist()
|
||||
|
||||
class EmbeddingModel():
|
||||
|
||||
def get_agg_embedding_from_queries(
|
||||
self,
|
||||
queries: List[str],
|
||||
agg_fn = None,
|
||||
):
|
||||
"""Get aggregated embedding from multiple queries."""
|
||||
query_embeddings = [self.get_query_embedding(query) for query in queries]
|
||||
agg_fn = agg_fn or mean_agg
|
||||
return agg_fn(query_embeddings)
|
||||
|
||||
def get_text_embedding_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
show_progress: bool = False,
|
||||
):
|
||||
return self.compute_embedding(texts, batch_mode=True)
|
||||
|
||||
|
||||
class OpenAiEmbeddingModel(EmbeddingModel):
|
||||
|
||||
def __init__(self, llm_kwargs:dict=None):
|
||||
self.llm_kwargs = llm_kwargs
|
||||
|
||||
def get_query_embedding(self, query: str):
|
||||
return self.compute_embedding(query)
|
||||
|
||||
def compute_embedding(self, text="这是要计算嵌入的文本", llm_kwargs:dict=None, batch_mode=False):
|
||||
from .bridge_all_embed import embed_model_info
|
||||
|
||||
# load kwargs
|
||||
if llm_kwargs is None:
|
||||
llm_kwargs = self.llm_kwargs
|
||||
if llm_kwargs is None:
|
||||
raise RuntimeError("llm_kwargs is not provided!")
|
||||
|
||||
# setup api and req url
|
||||
api_key = select_api_key_for_embed_models(llm_kwargs['api_key'], llm_kwargs['embed_model'])
|
||||
embed_model = llm_kwargs['embed_model']
|
||||
base_url = embed_model_info[llm_kwargs['embed_model']]['embed_endpoint'].replace('embeddings', '')
|
||||
|
||||
# send and compute
|
||||
with ProxyNetworkActivate("Connect_OpenAI_Embedding"):
|
||||
self.oai_client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
if batch_mode:
|
||||
input = text
|
||||
assert isinstance(text, list)
|
||||
else:
|
||||
input = [text]
|
||||
assert isinstance(text, str)
|
||||
res = self.oai_client.embeddings.create(input=input, model=embed_model)
|
||||
|
||||
# parse result
|
||||
if batch_mode:
|
||||
embedding = [d.embedding for d in res.data]
|
||||
else:
|
||||
embedding = res.data[0].embedding
|
||||
return embedding
|
||||
|
||||
def embedding_dimension(self, llm_kwargs):
|
||||
from .bridge_all_embed import embed_model_info
|
||||
return embed_model_info[llm_kwargs['embed_model']]['embed_dimension']
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
Reference in New Issue
Block a user