From 09fd22091afdc0a77eaa9d8a4328e466b9c5c939 Mon Sep 17 00:00:00 2001 From: binary-husky Date: Sat, 21 Sep 2024 14:41:36 +0000 Subject: [PATCH] fix: console output --- crazy_functions/Social_Helper.py | 116 ++++++++++++++++-------- crazy_functions/json_fns/select_tool.py | 26 ++++++ request_llms/bridge_chatgpt.py | 2 +- request_llms/bridge_cohere.py | 2 +- request_llms/bridge_ollama.py | 2 +- request_llms/oai_std_model_template.py | 2 +- tests/test_social_helper.py | 15 ++- 7 files changed, 120 insertions(+), 45 deletions(-) create mode 100644 crazy_functions/json_fns/select_tool.py diff --git a/crazy_functions/Social_Helper.py b/crazy_functions/Social_Helper.py index 718e7125..c4aaec23 100644 --- a/crazy_functions/Social_Helper.py +++ b/crazy_functions/Social_Helper.py @@ -2,10 +2,11 @@ import pickle, os, random from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive -from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError from request_llms.bridge_all import predict_no_ui_long_connection +from crazy_functions.json_fns.select_tool import structure_output, select_tool from pydantic import BaseModel, Field from loguru import logger +from typing import List SOCIAL_NETWOK_WORKER_REGISTER = {} @@ -51,49 +52,84 @@ class Friend(BaseModel): friend_description: str = Field(description="description of a friend (everything about this friend)") friend_relationship: str = Field(description="The relationship with a friend (e.g. friend, family, colleague)") +class FriendList(BaseModel): + friends_list: List[Friend] = Field(description="The list of friends") -def structure_output(txt, prompt, err_msg, run_gpt_fn, pydantic_cls): - gpt_json_io = GptJsonIO(pydantic_cls) - analyze_res = run_gpt_fn( - txt, - sys_prompt=prompt + gpt_json_io.format_instructions - ) - try: - friend:Friend = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn) - except JsonStringError as e: - return None, err_msg - - err_msg = "" - return friend, err_msg - class SocialNetworkWorker(SaveAndLoad): - def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): - run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[]) - # adding friend: 🧑‍🤝‍🧑 - if txt.startswith("add-friend"): - friend, err_msg = structure_output( - txt=txt, - prompt="根据提示, 解析一个联系人的身份信息\n\n", - err_msg=f"不能理解该联系人", - run_gpt_fn=run_gpt_fn, - pydantic_cls=Friend - ) - if friend: - self.add_friend(friend) - else: - yield from update_ui_lastest_msg(lastmsg=err_msg, chatbot=chatbot, history=history, delay=0) + def ai_socail_advice(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type): + pass + + def ai_remove_friend(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type): + pass + + def ai_list_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type): + pass + + def ai_add_multi_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type): + friend, err_msg = structure_output( + txt=prompt, + prompt="根据提示, 解析多个联系人的身份信息\n\n", + err_msg=f"不能理解该联系人", + run_gpt_fn=run_gpt_fn, + pydantic_cls=FriendList + ) + if friend.friends_list: + for f in friend.friends_list: + self.add_friend(f) + msg = f"成功添加{len(friend.friends_list)}个联系人: {str(friend.friends_list)}" + yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0) + + + def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request): + prompt = txt + run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[]) + self.tools_to_select = { + "SocialAdvice":{ + "explain_to_llm": "如果用户希望获取社交指导,调用SocialAdvice生成一些社交建议", + "callback": self.ai_socail_advice, + }, + "AddFriends":{ + "explain_to_llm": "如果用户给出了联系人,调用AddMultiFriends把联系人添加到数据库", + "callback": self.ai_add_multi_friends, + }, + "RemoveFriend":{ + "explain_to_llm": "如果用户希望移除某个联系人,调用RemoveFriend", + "callback": self.ai_remove_friend, + }, + "ListFriends":{ + "explain_to_llm": "如果用户列举联系人,调用ListFriends", + "callback": self.ai_list_friends, + } + } + + try: + Explaination = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()]) + class UserSociaIntention(BaseModel): + intention_type: str = Field( + description= + f"The type of user intention. You must choose from {self.tools_to_select.keys()}.\n\n" + f"Explaination:\n{Explaination}", + default="SocialAdvice" + ) + pydantic_cls_instance, err_msg = select_tool( + prompt=txt, + run_gpt_fn=run_gpt_fn, + pydantic_cls=UserSociaIntention + ) + except Exception as e: + yield from update_ui_lastest_msg( + lastmsg=f"无法理解用户意图 {err_msg}", + chatbot=chatbot, + history=history, + delay=0 + ) + return + + intention_type = pydantic_cls_instance.intention_type + intention_callback = self.tools_to_select[pydantic_cls_instance.intention_type]['callback'] + yield from intention_callback(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type) - # learn friend info: 🧑‍🤝‍🧑 - if txt.startswith("give-advice"): - # randomly select a friend - if len(self.social_network.people) == 0: - yield from update_ui_lastest_msg(lastmsg="没有联系人", chatbot=chatbot, history=history, delay=0) - return - else: - # randomly select a friend - friend = random.choice(self.social_network.people) - yield from update_ui_lastest_msg(lastmsg=f"给你一个建议: {friend.friend_description}", chatbot=chatbot, history=history, delay=0) def add_friend(self, friend): # check whether the friend is already in the social network diff --git a/crazy_functions/json_fns/select_tool.py b/crazy_functions/json_fns/select_tool.py new file mode 100644 index 00000000..5ed61301 --- /dev/null +++ b/crazy_functions/json_fns/select_tool.py @@ -0,0 +1,26 @@ +from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError + +def structure_output(txt, prompt, err_msg, run_gpt_fn, pydantic_cls): + gpt_json_io = GptJsonIO(pydantic_cls) + analyze_res = run_gpt_fn( + txt, + sys_prompt=prompt + gpt_json_io.format_instructions + ) + try: + friend = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn) + except JsonStringError as e: + return None, err_msg + + err_msg = "" + return friend, err_msg + + +def select_tool(prompt, run_gpt_fn, pydantic_cls): + pydantic_cls_instance, err_msg = structure_output( + txt=prompt, + prompt="根据提示, 分析应该调用哪个工具函数\n\n", + err_msg=f"不能理解该联系人", + run_gpt_fn=run_gpt_fn, + pydantic_cls=pydantic_cls + ) + return pydantic_cls_instance, err_msg \ No newline at end of file diff --git a/request_llms/bridge_chatgpt.py b/request_llms/bridge_chatgpt.py index 763897a1..d4cf1ef5 100644 --- a/request_llms/bridge_chatgpt.py +++ b/request_llms/bridge_chatgpt.py @@ -192,7 +192,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], if (not has_content) and (not has_role): continue # raise RuntimeError("发现不标准的第三方接口:"+delta) if has_content: # has_role = True/False result += delta["content"] - if not console_slience: logger.info(delta["content"], end='') + if not console_slience: print(delta["content"], end='') if observe_window is not None: # 观测窗,把已经获取的数据显示出去 if len(observe_window) >= 1: diff --git a/request_llms/bridge_cohere.py b/request_llms/bridge_cohere.py index 64941145..f5ab5070 100644 --- a/request_llms/bridge_cohere.py +++ b/request_llms/bridge_cohere.py @@ -111,7 +111,7 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[], if chunkjson['event_type'] == 'stream-start': continue if chunkjson['event_type'] == 'text-generation': result += chunkjson["text"] - if not console_slience: logger.info(chunkjson["text"], end='') + if not console_slience: print(chunkjson["text"], end='') if observe_window is not None: # 观测窗,把已经获取的数据显示出去 if len(observe_window) >= 1: diff --git a/request_llms/bridge_ollama.py b/request_llms/bridge_ollama.py index 90744fa6..9a2fb97f 100644 --- a/request_llms/bridge_ollama.py +++ b/request_llms/bridge_ollama.py @@ -99,7 +99,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", logger.info(f'[response] {result}') break result += chunkjson['message']["content"] - if not console_slience: logger.info(chunkjson['message']["content"], end='') + if not console_slience: print(chunkjson['message']["content"], end='') if observe_window is not None: # 观测窗,把已经获取的数据显示出去 if len(observe_window) >= 1: diff --git a/request_llms/oai_std_model_template.py b/request_llms/oai_std_model_template.py index 285ca38d..a19f05bf 100644 --- a/request_llms/oai_std_model_template.py +++ b/request_llms/oai_std_model_template.py @@ -224,7 +224,7 @@ def get_predict_function( try: if finish_reason == "stop": if not console_slience: - logger.info(f"[response] {result}") + print(f"[response] {result}") break result += response_text if observe_window is not None: diff --git a/tests/test_social_helper.py b/tests/test_social_helper.py index 9c74cebf..a633cd4e 100644 --- a/tests/test_social_helper.py +++ b/tests/test_social_helper.py @@ -8,4 +8,17 @@ import os, sys if __name__ == "__main__": from test_utils import plugin_test - plugin_test(plugin='crazy_functions.Social_Helper->I人助手', main_input="add-friend:a,我的师兄") + plugin_test( + plugin='crazy_functions.Social_Helper->I人助手', + main_input=""" +添加联系人: +艾德·史塔克:我的养父,他是临冬城的公爵。 +凯特琳·史塔克:我的养母,她对我态度冷淡,因为我是私生子。 +罗柏·史塔克:我的哥哥,他是北境的继承人。 +艾莉亚·史塔克:我的妹妹,她和我关系亲密,性格独立坚强。 +珊莎·史塔克:我的妹妹,她梦想成为一位淑女。 +布兰·史塔克:我的弟弟,他有预知未来的能力。 +瑞肯·史塔克:我的弟弟,他是个天真无邪的小孩。 +山姆威尔·塔利:我的朋友,他在守夜人军团中与我并肩作战。 +伊格瑞特:我的恋人,她是野人中的一员。 + """)