From 710a65522c32dcfe10a42fde9bbd554c0dd653c0 Mon Sep 17 00:00:00 2001 From: binary-husky Date: Mon, 2 Sep 2024 15:55:06 +0000 Subject: [PATCH] add social worker (proto) --- TODO | 1 + crazy_functions/Social_Helper.py | 65 ++++++++++++++++++++++++++++++++ tests/test_social_helper.py | 11 ++++++ 3 files changed, 77 insertions(+) create mode 100644 TODO create mode 100644 crazy_functions/Social_Helper.py create mode 100644 tests/test_social_helper.py diff --git a/TODO b/TODO new file mode 100644 index 00000000..4ab3721b --- /dev/null +++ b/TODO @@ -0,0 +1 @@ +RAG忘了触发保存了! \ No newline at end of file diff --git a/crazy_functions/Social_Helper.py b/crazy_functions/Social_Helper.py new file mode 100644 index 00000000..9627f9c0 --- /dev/null +++ b/crazy_functions/Social_Helper.py @@ -0,0 +1,65 @@ +from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg +from crazy_functions.crazy_utils import input_clipping +from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive +import pickle, os + +SOCIAL_NETWOK_WORKER_REGISTER = {} + +class SocialNetwork(): + def __init__(self): + self.people = [] + +class SocialNetworkWorker(): + def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: + self.user_name = user_name + self.checkpoint_dir = checkpoint_dir + if auto_load_checkpoint: + self.social_network = self.load_from_checkpoint(checkpoint_dir) + else: + self.social_network = SocialNetwork() + + def does_checkpoint_exist(self, checkpoint_dir=None): + import os, glob + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if not os.path.exists(checkpoint_dir): return False + if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False + return True + + def save_to_checkpoint(self, checkpoint_dir=None): + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f: + pickle.dump(self.social_network, f) + return + + def load_from_checkpoint(self, checkpoint_dir=None): + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir): + with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f: + social_network = pickle.load(f) + return social_network + else: + return SocialNetwork() + + +@CatchException +def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, num_day=5): + + # 1. we retrieve worker from global context + user_name = chatbot.get_user() + checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag') + if user_name in SOCIAL_NETWOK_WORKER_REGISTER: + social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] + else: + social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker( + user_name, + llm_kwargs, + checkpoint_dir=checkpoint_dir, + auto_load_checkpoint=True + ) + + # 2. save + social_network_worker.social_network.people.append("张三") + social_network_worker.save_to_checkpoint(checkpoint_dir) + chatbot.append(["good", "work"]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + diff --git a/tests/test_social_helper.py b/tests/test_social_helper.py new file mode 100644 index 00000000..fd61a5da --- /dev/null +++ b/tests/test_social_helper.py @@ -0,0 +1,11 @@ +""" +对项目中的各个插件进行测试。运行方法:直接运行 python tests/test_plugins.py +""" + +import init_test +import os, sys + + +if __name__ == "__main__": + from test_utils import plugin_test + plugin_test(plugin='crazy_functions.Social_Helper->I人助手', main_input="|")