Add GLM INT8

This commit is contained in:
binary-husky
2023-07-24 18:19:57 +08:00
parent dd4ba0ea22
commit e93b6fa3a6
3 changed files with 14 additions and 9 deletions

View File

@@ -37,19 +37,23 @@ class GetGLMHandle(Process):
# 子进程执行
# 第一次运行,加载参数
retry = 0
pretrained_model_name_or_path = "THUDM/chatglm2-6b"
LOCAL_MODEL_QUANT = get_conf('LOCAL_MODEL_QUANT')
if LOCAL_MODEL_QUANT and len(LOCAL_MODEL_QUANT) > 0 and LOCAL_MODEL_QUANT[0] == "INT4":
pretrained_model_name_or_path = "THUDM/chatglm2-6b-int4"
LOCAL_MODEL_QUANT, device = get_conf('LOCAL_MODEL_QUANT', 'LOCAL_MODEL_DEVICE')
if LOCAL_MODEL_QUANT == "INT4": # INT4
_model_name_ = "THUDM/chatglm2-6b-int4"
elif LOCAL_MODEL_QUANT == "INT8": # INT8
_model_name_ = "THUDM/chatglm2-6b-int8"
else:
_model_name_ = "THUDM/chatglm2-6b" # FP16
while True:
try:
if self.chatglm_model is None:
self.chatglm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
device, = get_conf('LOCAL_MODEL_DEVICE')
self.chatglm_tokenizer = AutoTokenizer.from_pretrained(_model_name_, trust_remote_code=True)
if device=='cpu':
self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).float()
self.chatglm_model = AutoModel.from_pretrained(_model_name_, trust_remote_code=True).float()
else:
self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).half().cuda()
self.chatglm_model = AutoModel.from_pretrained(_model_name_, trust_remote_code=True).half().cuda()
self.chatglm_model = self.chatglm_model.eval()
break
else: