From 725f60fba3bc981043140ae6a8f444e32a20a87f Mon Sep 17 00:00:00 2001 From: binary-husky Date: Tue, 3 Jun 2025 00:51:18 +0800 Subject: [PATCH] add context clip policy --- config.py | 11 ++ shared_utils/context_clip_policy.py | 296 ++++++++++++++++++++++++++++ toolbox.py | 70 ++----- 3 files changed, 319 insertions(+), 58 deletions(-) create mode 100644 shared_utils/context_clip_policy.py diff --git a/config.py b/config.py index 90e7fe7e..4af8ffd3 100644 --- a/config.py +++ b/config.py @@ -354,6 +354,17 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran # 在互联网搜索组件中,负责将搜索结果整理成干净的Markdown JINA_API_KEY = "" + +# 是否自动裁剪上下文长度(是否启动,默认不启动) +AUTO_CONTEXT_CLIP_ENABLE = False +# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪) +AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN = 30*1000 +# 无条件丢弃x以上的轮数 +AUTO_CONTEXT_MAX_ROUND = 64 +# 在裁剪上下文时,倒数第x次对话能“最多”保留的上下文token的比例占 AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN 的多少 +AUTO_CONTEXT_MAX_CLIP_RATIO = [0.80, 0.60, 0.45, 0.25, 0.20, 0.18, 0.16, 0.14, 0.12, 0.10, 0.08, 0.07, 0.06, 0.05, 0.04, 0.03, 0.02, 0.01] + + """ --------------- 配置关联关系说明 --------------- diff --git a/shared_utils/context_clip_policy.py b/shared_utils/context_clip_policy.py new file mode 100644 index 00000000..02b852a9 --- /dev/null +++ b/shared_utils/context_clip_policy.py @@ -0,0 +1,296 @@ +import copy +from shared_utils.config_loader import get_conf + +def get_token_num(txt, tokenizer): + return len(tokenizer.encode(txt, disallowed_special=())) + +def get_model_info(): + from request_llms.bridge_all import model_info + return model_info + +def clip_history(inputs, history, tokenizer, max_token_limit): + """ + reduce the length of history by clipping. + this function search for the longest entries to clip, little by little, + until the number of token of history is reduced under threshold. + + 通过裁剪来缩短历史记录的长度。 + 此函数逐渐地搜索最长的条目进行剪辑, + 直到历史记录的标记数量降低到阈值以下。 + + 被动触发裁剪 + """ + import numpy as np + + input_token_num = get_token_num(inputs) + + if max_token_limit < 5000: + output_token_expect = 256 # 4k & 2k models + elif max_token_limit < 9000: + output_token_expect = 512 # 8k models + else: + output_token_expect = 1024 # 16k & 32k models + + if input_token_num < max_token_limit * 3 / 4: + # 当输入部分的token占比小于限制的3/4时,裁剪时 + # 1. 把input的余量留出来 + max_token_limit = max_token_limit - input_token_num + # 2. 把输出用的余量留出来 + max_token_limit = max_token_limit - output_token_expect + # 3. 如果余量太小了,直接清除历史 + if max_token_limit < output_token_expect: + history = [] + return history + else: + # 当输入部分的token占比 > 限制的3/4时,直接清除历史 + history = [] + return history + + everything = [""] + everything.extend(history) + n_token = get_token_num("\n".join(everything)) + everything_token = [get_token_num(e) for e in everything] + + # 截断时的颗粒度 + delta = max(everything_token) // 16 + + while n_token > max_token_limit: + where = np.argmax(everything_token) + encoded = tokenizer.encode(everything[where], disallowed_special=()) + clipped_encoded = encoded[: len(encoded) - delta] + everything[where] = tokenizer.decode(clipped_encoded)[ + :-1 + ] # -1 to remove the may-be illegal char + everything_token[where] = get_token_num(everything[where]) + n_token = get_token_num("\n".join(everything)) + + history = everything[1:] + return history + + + +def auto_context_clip_each_message(current, history): + """ + clip_history 是被动触发的 + + 主动触发裁剪 + """ + context = history + [current] + trigger_clip_token_len = get_conf('AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN') + model_info = get_model_info() + tokenizer = model_info['gpt-4']['tokenizer'] + # 只保留最近的128条记录,无论token长度,防止计算token时计算过长的时间 + max_round = get_conf('AUTO_CONTEXT_MAX_ROUND') + char_len = sum([len(h) for h in context]) + if char_len < trigger_clip_token_len*2: + # 不需要裁剪 + history = context[:-1] + current = context[-1] + return current, history + if len(context) > max_round: + context = context[-max_round:] + # 计算各个历史记录的token长度 + context_token_num = [get_token_num(h, tokenizer) for h in context] + context_token_num_old = copy.copy(context_token_num) + total_token_num = total_token_num_old = sum(context_token_num) + if total_token_num < trigger_clip_token_len: + # 不需要裁剪 + history = context[:-1] + current = context[-1] + return current, history + clip_token_len = trigger_clip_token_len * 0.85 + # 越长越先被裁,越靠后越先被裁 + max_clip_ratio: list[float] = get_conf('AUTO_CONTEXT_MAX_CLIP_RATIO') + max_clip_ratio = list(reversed(max_clip_ratio)) + if len(context) > len(max_clip_ratio): + # give up the oldest context + context = context[-len(max_clip_ratio):] + context_token_num = context_token_num[-len(max_clip_ratio):] + if len(context) < len(max_clip_ratio): + # match the length of two array + max_clip_ratio = max_clip_ratio[-len(context):] + + # compute rank + clip_prior_weight = [(token_num/clip_token_len + (len(context) - index)*0.1) for index, token_num in enumerate(context_token_num)] + # print('clip_prior_weight', clip_prior_weight) + # get sorted index of context_token_num, from largest to smallest + sorted_index = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight[k], reverse=True) + + # pre compute space yield + for index in sorted_index: + print('index', index, f'current total {total_token_num}, target {clip_token_len}') + if total_token_num < clip_token_len: + # no need to clip + break + # clip room left + clip_room_left = total_token_num - clip_token_len + # get the clip ratio + allowed_token_num_this_entry = max_clip_ratio[index] * clip_token_len + if context_token_num[index] < allowed_token_num_this_entry: + print('index', index, '[allowed] before', context_token_num[index], 'allowed', allowed_token_num_this_entry) + continue + + token_to_clip = context_token_num[index] - allowed_token_num_this_entry + if token_to_clip*0.85 > clip_room_left: + print('index', index, '[careful clip] token_to_clip', token_to_clip, 'clip_room_left', clip_room_left) + token_to_clip = clip_room_left + + token_percent_to_clip = token_to_clip / context_token_num[index] + char_percent_to_clip = token_percent_to_clip + text_this_entry = context[index] + char_num_to_clip = int(len(text_this_entry) * char_percent_to_clip) + if char_num_to_clip < 500: + # 如果裁剪的字符数小于500,则不裁剪 + print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry) + continue + char_num_to_clip += 200 # 稍微多加一点 + char_to_preseve = len(text_this_entry) - char_num_to_clip + _half = int(char_to_preseve / 2) + # 前半 + ... (content clipped because token overflows) ... + 后半 + text_this_entry_clip = text_this_entry[:_half] + \ + " ... (content clipped because token overflows) ... " \ + + text_this_entry[-_half:] + context[index] = text_this_entry_clip + post_clip_token_cnt = get_token_num(text_this_entry_clip, tokenizer) + print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry, 'after', post_clip_token_cnt) + context_token_num[index] = post_clip_token_cnt + total_token_num = sum(context_token_num) + context_token_num_final = [get_token_num(h, tokenizer) for h in context] + print('context_token_num_old', context_token_num_old) + print('context_token_num_final', context_token_num_final) + print('token change from', total_token_num_old, 'to', sum(context_token_num_final), 'target', clip_token_len) + history = context[:-1] + current = context[-1] + return current, history + + + +def auto_context_clip_search_optimal(current, history, promote_latest_long_message=False): + """ + current: 当前消息 + history: 历史消息列表 + promote_latest_long_message: 是否特别提高最后一条长message的权重,避免过度裁剪 + + 主动触发裁剪 + """ + context = history + [current] + trigger_clip_token_len = get_conf('AUTO_CONTEXT_CLIP_TRIGGER_TOKEN_LEN') + model_info = get_model_info() + tokenizer = model_info['gpt-4']['tokenizer'] + # 只保留最近的128条记录,无论token长度,防止计算token时计算过长的时间 + max_round = get_conf('AUTO_CONTEXT_MAX_ROUND') + char_len = sum([len(h) for h in context]) + if char_len < trigger_clip_token_len: + # 不需要裁剪 + history = context[:-1] + current = context[-1] + return current, history + if len(context) > max_round: + context = context[-max_round:] + # 计算各个历史记录的token长度 + context_token_num = [get_token_num(h, tokenizer) for h in context] + context_token_num_old = copy.copy(context_token_num) + total_token_num = total_token_num_old = sum(context_token_num) + if total_token_num < trigger_clip_token_len: + # 不需要裁剪 + history = context[:-1] + current = context[-1] + return current, history + clip_token_len = trigger_clip_token_len * 0.90 + max_clip_ratio: list[float] = get_conf('AUTO_CONTEXT_MAX_CLIP_RATIO') + max_clip_ratio = list(reversed(max_clip_ratio)) + if len(context) > len(max_clip_ratio): + # give up the oldest context + context = context[-len(max_clip_ratio):] + context_token_num = context_token_num[-len(max_clip_ratio):] + if len(context) < len(max_clip_ratio): + # match the length of two array + max_clip_ratio = max_clip_ratio[-len(context):] + + _scale = _scale_init = 1.25 + token_percent_arr = [(token_num/clip_token_len) for index, token_num in enumerate(context_token_num)] + + # promote last long message, avoid clipping it too much + if promote_latest_long_message: + promote_weight_constant = 1.6 + promote_index = -1 + threshold = 0.50 + for index, token_percent in enumerate(token_percent_arr): + if token_percent > threshold: + promote_index = index + if promote_index >= 0: + max_clip_ratio[promote_index] = promote_weight_constant + + max_clip_ratio_arr = max_clip_ratio + step = 0.05 + for i in range(int(_scale_init / step) - 1): + _take = 0 + for max_clip, token_r in zip(max_clip_ratio_arr, token_percent_arr): + _take += min(max_clip * _scale, token_r) + if _take < 1.0: + break + _scale -= 0.05 + + # print('optimal scale', _scale) + # print([_scale * max_clip for max_clip in max_clip_ratio_arr]) + # print([token_r for token_r in token_percent_arr]) + # print([min(token_r, _scale * max_clip) for token_r, max_clip in zip(token_percent_arr, max_clip_ratio_arr)]) + eps = 0.05 + max_clip_ratio = [_scale * max_clip + eps for max_clip in max_clip_ratio_arr] + + # compute rank + # clip_prior_weight_old = [(token_num/clip_token_len + (len(context) - index)*0.1) for index, token_num in enumerate(context_token_num)] + clip_prior_weight = [ token_r / max_clip for max_clip, token_r in zip(max_clip_ratio_arr, token_percent_arr)] + + # sorted_index_old = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight_old[k], reverse=True) + # print('sorted_index_old', sorted_index_old) + sorted_index = sorted(range(len(context_token_num)), key=lambda k: clip_prior_weight[k], reverse=True) + # print('sorted_index', sorted_index) + + # pre compute space yield + for index in sorted_index: + # print('index', index, f'current total {total_token_num}, target {clip_token_len}') + if total_token_num < clip_token_len: + # no need to clip + break + # clip room left + clip_room_left = total_token_num - clip_token_len + # get the clip ratio + allowed_token_num_this_entry = max_clip_ratio[index] * clip_token_len + if context_token_num[index] < allowed_token_num_this_entry: + # print('index', index, '[allowed] before', context_token_num[index], 'allowed', allowed_token_num_this_entry) + continue + + token_to_clip = context_token_num[index] - allowed_token_num_this_entry + if token_to_clip*0.85 > clip_room_left: + # print('index', index, '[careful clip] token_to_clip', token_to_clip, 'clip_room_left', clip_room_left) + token_to_clip = clip_room_left + + token_percent_to_clip = token_to_clip / context_token_num[index] + char_percent_to_clip = token_percent_to_clip + text_this_entry = context[index] + char_num_to_clip = int(len(text_this_entry) * char_percent_to_clip) + if char_num_to_clip < 500: + # 如果裁剪的字符数小于500,则不裁剪 + # print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry) + continue + eps = 200 + char_num_to_clip = char_num_to_clip + eps # 稍微多加一点 + char_to_preseve = len(text_this_entry) - char_num_to_clip + _half = int(char_to_preseve / 2) + # 前半 + ... (content clipped because token overflows) ... + 后半 + text_this_entry_clip = text_this_entry[:_half] + \ + " ... (content clipped because token overflows) ... " \ + + text_this_entry[-_half:] + context[index] = text_this_entry_clip + post_clip_token_cnt = get_token_num(text_this_entry_clip, tokenizer) + # print('index', index, 'before', context_token_num[index], 'allowed', allowed_token_num_this_entry, 'after', post_clip_token_cnt) + context_token_num[index] = post_clip_token_cnt + total_token_num = sum(context_token_num) + context_token_num_final = [get_token_num(h, tokenizer) for h in context] + # print('context_token_num_old', context_token_num_old) + # print('context_token_num_final', context_token_num_final) + # print('token change from', total_token_num_old, 'to', sum(context_token_num_final), 'target', clip_token_len) + history = context[:-1] + current = context[-1] + return current, history diff --git a/toolbox.py b/toolbox.py index 3f2a85c5..3f0ca5cf 100644 --- a/toolbox.py +++ b/toolbox.py @@ -37,6 +37,9 @@ from shared_utils.handle_upload import html_local_file from shared_utils.handle_upload import html_local_img from shared_utils.handle_upload import file_manifest_filter_type from shared_utils.handle_upload import extract_archive +from shared_utils.context_clip_policy import clip_history +from shared_utils.context_clip_policy import auto_context_clip_each_message +from shared_utils.context_clip_policy import auto_context_clip_search_optimal from typing import List pj = os.path.join default_user_name = "default_user" @@ -133,6 +136,9 @@ def ArgsGeneralWrapper(f): if len(args) == 0: # 插件通道 yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request) else: # 对话通道,或者基础功能通道 + # 基础对话通道,或者基础功能通道 + if get_conf('AUTO_CONTEXT_CLIP_ENABLE'): + txt_passon, history = auto_context_clip(txt_passon, history) yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args) else: # 处理少数情况下的特殊插件的锁定状态 @@ -712,66 +718,14 @@ def run_gradio_in_subpath(demo, auth, port, custom_path): app = gr.mount_gradio_app(app, demo, path=custom_path) uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth - -def clip_history(inputs, history, tokenizer, max_token_limit): - """ - reduce the length of history by clipping. - this function search for the longest entries to clip, little by little, - until the number of token of history is reduced under threshold. - 通过裁剪来缩短历史记录的长度。 - 此函数逐渐地搜索最长的条目进行剪辑, - 直到历史记录的标记数量降低到阈值以下。 - """ - import numpy as np - from request_llms.bridge_all import model_info - - def get_token_num(txt): - return len(tokenizer.encode(txt, disallowed_special=())) - - input_token_num = get_token_num(inputs) - - if max_token_limit < 5000: - output_token_expect = 256 # 4k & 2k models - elif max_token_limit < 9000: - output_token_expect = 512 # 8k models +def auto_context_clip(current, history, policy='search_optimal'): + if policy == 'each_message': + return auto_context_clip_each_message(current, history) + elif policy == 'search_optimal': + return auto_context_clip_search_optimal(current, history) else: - output_token_expect = 1024 # 16k & 32k models + raise RuntimeError(f"未知的自动上下文裁剪策略: {policy}。") - if input_token_num < max_token_limit * 3 / 4: - # 当输入部分的token占比小于限制的3/4时,裁剪时 - # 1. 把input的余量留出来 - max_token_limit = max_token_limit - input_token_num - # 2. 把输出用的余量留出来 - max_token_limit = max_token_limit - output_token_expect - # 3. 如果余量太小了,直接清除历史 - if max_token_limit < output_token_expect: - history = [] - return history - else: - # 当输入部分的token占比 > 限制的3/4时,直接清除历史 - history = [] - return history - - everything = [""] - everything.extend(history) - n_token = get_token_num("\n".join(everything)) - everything_token = [get_token_num(e) for e in everything] - - # 截断时的颗粒度 - delta = max(everything_token) // 16 - - while n_token > max_token_limit: - where = np.argmax(everything_token) - encoded = tokenizer.encode(everything[where], disallowed_special=()) - clipped_encoded = encoded[: len(encoded) - delta] - everything[where] = tokenizer.decode(clipped_encoded)[ - :-1 - ] # -1 to remove the may-be illegal char - everything_token[where] = get_token_num(everything[where]) - n_token = get_token_num("\n".join(everything)) - - history = everything[1:] - return history """