允许加入ChatGLM微调模型

This commit is contained in:
binary-husky
2023-07-10 03:17:09 +08:00
parent 7ce4192c52
commit c010d50716
4 changed files with 287 additions and 1 deletions

View File

@@ -69,3 +69,57 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
return
def 启动微调(arguments):
"""
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
llm_kwargs gpt模型参数如温度和top_p等一般原样传递下去就行
plugin_kwargs 插件模型的参数
chatbot 聊天显示框的句柄,用于显示给用户
history 聊天历史,前情提要
system_prompt 给gpt的静默提醒
web_port 当前软件运行的端口号
"""
history = [] # 清空历史,以免输入溢出
import subprocess
PRE_SEQ_LEN = 128
LR = 2e-2
NUM_GPUS = 1
JSON_FILE = 't_code.json'
tune_work_path = '/home/hmp/ChatGLM2-6B/ptuning'
command = f"torchrun --standalone --nnodes=1 --nproc-per-node={NUM_GPUS} main.py \
--do_train \
--train_file AdvertiseGen/{JSON_FILE} \
--validation_file AdvertiseGen/{JSON_FILE} \
--preprocessing_num_workers 20 \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path THUDM/chatglm2-6b \
--output_dir output/clothgen-chatglm2-6b-pt-{PRE_SEQ_LEN}-{LR} \
--overwrite_output_dir \
--max_source_length 256 \
--max_target_length 256 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 100 \
--logging_steps 10 \
--save_steps 20 \
--learning_rate {LR} \
--pre_seq_len {PRE_SEQ_LEN} \
--quantization_bit 4"
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tune_work_path)
try:
stdout, stderr = process.communicate(timeout=3600*5)
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
print("Process timed out!")
return False
return