tiktoken做lazyload处理

This commit is contained in:
Your Name
2023-04-19 14:27:34 +08:00
parent 28aa6d1dc0
commit b0409b929b
12 changed files with 83 additions and 35 deletions

View File

@@ -9,7 +9,7 @@
2. predict_no_ui_long_connection在实验过程中发现调用predict_no_ui处理长文档时和openai的连接容易断掉这个函数用stream的方式解决这个问题同样支持多线程
"""
import tiktoken
from functools import wraps, lru_cache
from concurrent.futures import ThreadPoolExecutor
from .bridge_chatgpt import predict_no_ui_long_connection as chatgpt_noui
@@ -18,13 +18,31 @@ from .bridge_chatgpt import predict as chatgpt_ui
from .bridge_chatglm import predict_no_ui_long_connection as chatglm_noui
from .bridge_chatglm import predict as chatglm_ui
from .bridge_tgui import predict_no_ui_long_connection as tgui_noui
from .bridge_tgui import predict as tgui_ui
# from .bridge_tgui import predict_no_ui_long_connection as tgui_noui
# from .bridge_tgui import predict as tgui_ui
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
get_token_num_gpt35 = lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=()))
get_token_num_gpt4 = lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=()))
class LazyloadTiktoken(object):
def __init__(self, model):
self.model = model
@staticmethod
@lru_cache(maxsize=128)
def get_encoder(model):
print('正在加载tokenizer如果是第一次运行可能需要一点时间下载参数')
tmp = tiktoken.encoding_for_model(model)
print('加载tokenizer完毕')
return tmp
def encode(self, *args, **kwargs):
encoder = self.get_encoder(self.model)
return encoder.encode(*args, **kwargs)
tokenizer_gpt35 = LazyloadTiktoken("gpt-3.5-turbo")
tokenizer_gpt4 = LazyloadTiktoken("gpt-4")
get_token_num_gpt35 = lambda txt: len(tokenizer_gpt35.encode(txt, disallowed_special=()))
get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=()))
model_info = {
# openai
@@ -33,7 +51,7 @@ model_info = {
"fn_without_ui": chatgpt_noui,
"endpoint": "https://api.openai.com/v1/chat/completions",
"max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -42,7 +60,7 @@ model_info = {
"fn_without_ui": chatgpt_noui,
"endpoint": "https://api.openai.com/v1/chat/completions",
"max_token": 8192,
"tokenizer": tiktoken.encoding_for_model("gpt-4"),
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
},
@@ -52,7 +70,7 @@ model_info = {
"fn_without_ui": chatgpt_noui,
"endpoint": "https://openai.api2d.net/v1/chat/completions",
"max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},
@@ -61,7 +79,7 @@ model_info = {
"fn_without_ui": chatgpt_noui,
"endpoint": "https://openai.api2d.net/v1/chat/completions",
"max_token": 8192,
"tokenizer": tiktoken.encoding_for_model("gpt-4"),
"tokenizer": tokenizer_gpt4,
"token_cnt": get_token_num_gpt4,
},
@@ -71,7 +89,7 @@ model_info = {
"fn_without_ui": chatglm_noui,
"endpoint": None,
"max_token": 1024,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
},

View File

@@ -5,6 +5,8 @@ import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe
load_message = "ChatGLM尚未加载加载需要一段时间。注意取决于`config.py`的配置ChatGLM消耗大量的内存CPU或显存GPU也许会导致低配计算机卡死 ……"
#################################################################################
class GetGLMHandle(Process):
def __init__(self):
@@ -12,13 +14,26 @@ class GetGLMHandle(Process):
self.parent, self.child = Pipe()
self.chatglm_model = None
self.chatglm_tokenizer = None
self.info = ""
self.success = True
self.check_dependency()
self.start()
print('初始化')
def check_dependency(self):
try:
import sentencepiece
self.info = "依赖检测通过"
self.success = True
except:
self.info = "缺少ChatGLM的依赖如果要使用ChatGLM除了基础的pip依赖以外您还需要运行`pip install -r request_llm/requirements_chatglm.txt`安装ChatGLM的依赖。"
self.success = False
def ready(self):
return self.chatglm_model is not None
def run(self):
# 第一次运行,加载参数
retry = 0
while True:
try:
if self.chatglm_model is None:
@@ -33,7 +48,12 @@ class GetGLMHandle(Process):
else:
break
except:
pass
retry += 1
if retry > 3:
self.child.send('[Local Message] Call ChatGLM fail 不能正常加载ChatGLM的参数。')
raise RuntimeError("不能正常加载ChatGLM的参数")
# 进入任务等待状态
while True:
kwargs = self.child.recv()
try:
@@ -64,7 +84,11 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
global glm_handle
if glm_handle is None:
glm_handle = GetGLMHandle()
observe_window[0] = "ChatGLM尚未加载加载需要一段时间。注意取决于`config.py`的配置ChatGLM消耗大量的内存CPU或显存GPU也许会导致低配计算机卡死 ……"
observe_window[0] = load_message + "\n\n" + glm_handle.info
if not glm_handle.success:
error = glm_handle.info
glm_handle = None
raise RuntimeError(error)
# chatglm 没有 sys_prompt 接口因此把prompt加入 history
history_feedin = []
@@ -93,8 +117,11 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
global glm_handle
if glm_handle is None:
glm_handle = GetGLMHandle()
chatbot[-1] = (inputs, "ChatGLM尚未加载加载需要一段时间。注意取决于`config.py`的配置ChatGLM消耗大量的内存CPU或显存GPU也许会导致低配计算机卡死 ……")
chatbot[-1] = (inputs, load_message + "\n\n" + glm_handle.info)
yield from update_ui(chatbot=chatbot, history=[])
if not glm_handle.success:
glm_handle = None
return
if additional_fn is not None:
import core_functional