修改变量命名,整理配置清单

This commit is contained in:
qingxu fu
2023-08-30 16:00:27 +08:00
parent a208782049
commit 89de49f31e
4 changed files with 28 additions and 14 deletions

View File

@@ -63,9 +63,9 @@ class GetGLMFTHandle(Process):
# if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息')
# with open(conf, 'r', encoding='utf8') as f:
# model_args = json.loads(f.read())
ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT')
assert os.path.exists(ChatGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json")
CHATGLM_PTUNING_CHECKPOINT, = get_conf('CHATGLM_PTUNING_CHECKPOINT')
assert os.path.exists(CHATGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点"
conf = os.path.join(CHATGLM_PTUNING_CHECKPOINT, "config.json")
with open(conf, 'r', encoding='utf8') as f:
model_args = json.loads(f.read())
if 'model_name_or_path' not in model_args:
@@ -78,9 +78,9 @@ class GetGLMFTHandle(Process):
config.pre_seq_len = model_args['pre_seq_len']
config.prefix_projection = model_args['prefix_projection']
print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}")
print(f"Loading prefix_encoder weight from {CHATGLM_PTUNING_CHECKPOINT}")
model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
prefix_state_dict = torch.load(os.path.join(CHATGLM_PTUNING_CHECKPOINT, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):