修正internlm输入设备bug

This commit is contained in:
qingxu fu
2023-11-11 23:22:50 +08:00
parent f75e39dc27
commit 2d91e438d6
2 changed files with 17 additions and 19 deletions

View File

@@ -94,8 +94,9 @@ class GetInternlmHandle(LocalLLMHandle):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
device = get_conf('LOCAL_MODEL_DEVICE')
for k, v in inputs.items():
inputs[k] = v.cuda()
inputs[k] = v.to(device)
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None: