diff --git a/.gitignore b/.gitignore index be959f73..be33f58c 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ test.* temp.* objdump* *.min.*.js +TODO \ No newline at end of file diff --git a/TODO b/TODO deleted file mode 100644 index 4ab3721b..00000000 --- a/TODO +++ /dev/null @@ -1 +0,0 @@ -RAG忘了触发保存了! \ No newline at end of file diff --git a/crazy_functions/Rag_Interface.py b/crazy_functions/Rag_Interface.py index 9e1d9075..0de42b0c 100644 --- a/crazy_functions/Rag_Interface.py +++ b/crazy_functions/Rag_Interface.py @@ -1,7 +1,14 @@ from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive -from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker + +VECTOR_STORE_TYPE = "Milvus" + +if VECTOR_STORE_TYPE == "Simple": + from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker +if VECTOR_STORE_TYPE == "Milvus": + from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker + RAG_WORKER_REGISTER = {} @@ -14,16 +21,25 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u # 1. we retrieve rag worker from global context user_name = chatbot.get_user() + checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag') if user_name in RAG_WORKER_REGISTER: rag_worker = RAG_WORKER_REGISTER[user_name] else: rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker( user_name, llm_kwargs, - checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag'), + checkpoint_dir=checkpoint_dir, auto_load_checkpoint=True) + current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}" + tip = "提示:输入“清空向量数据库”可以清空RAG向量数据库" + if txt == "清空向量数据库": + chatbot.append([txt, f'正在清空 ({current_context}) ...']) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + rag_worker.purge() + yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面 + return - chatbot.append([txt, '正在召回知识 ...']) + chatbot.append([txt, f'正在召回知识 ({current_context}) ...']) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 2. clip history to reduce token consumption @@ -68,8 +84,8 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u ) # 5. remember what has been asked / answered - yield from update_ui_lastest_msg(model_say + '

' + '对话记忆中, 请稍等 ...', chatbot, history, delay=0.5) # 刷新界面 + yield from update_ui_lastest_msg(model_say + '

' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面 rag_worker.remember_qa(i_say_to_remember, model_say) history.extend([i_say, model_say]) - yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0) # 刷新界面 + yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面 diff --git a/crazy_functions/rag_fns/llama_index_worker.py b/crazy_functions/rag_fns/llama_index_worker.py index de1ef38d..761d6943 100644 --- a/crazy_functions/rag_fns/llama_index_worker.py +++ b/crazy_functions/rag_fns/llama_index_worker.py @@ -1,4 +1,7 @@ import llama_index +import os +import atexit +from typing import List from llama_index.core import Document from llama_index.core.schema import TextNode from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel @@ -38,6 +41,7 @@ class SaveLoad(): return True def save_to_checkpoint(self, checkpoint_dir=None): + print(f'saving vector store to: {checkpoint_dir}') if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) @@ -65,7 +69,8 @@ class LlamaIndexRagWorker(SaveLoad): if auto_load_checkpoint: self.vs_index = self.load_from_checkpoint(checkpoint_dir) else: - self.vs_index = self.create_new_vs() + self.vs_index = self.create_new_vs(checkpoint_dir) + atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) def assign_embedding_model(self): pass @@ -117,6 +122,3 @@ class LlamaIndexRagWorker(SaveLoad): buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) if self.debug_mode: print(buf) return buf - - - diff --git a/crazy_functions/rag_fns/milvus_worker.py b/crazy_functions/rag_fns/milvus_worker.py new file mode 100644 index 00000000..4b5b0ad9 --- /dev/null +++ b/crazy_functions/rag_fns/milvus_worker.py @@ -0,0 +1,107 @@ +import llama_index +import os +import atexit +from typing import List +from llama_index.core import Document +from llama_index.core.schema import TextNode +from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel +from shared_utils.connect_void_terminal import get_chat_default_kwargs +from llama_index.core import VectorStoreIndex, SimpleDirectoryReader +from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex +from llama_index.core.ingestion import run_transformations +from llama_index.core import PromptTemplate +from llama_index.core.response_synthesizers import TreeSummarize +from llama_index.core import StorageContext +from llama_index.vector_stores.milvus import MilvusVectorStore +from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker + +DEFAULT_QUERY_GENERATION_PROMPT = """\ +Now, you have context information as below: +--------------------- +{context_str} +--------------------- +Answer the user request below (use the context information if necessary, otherwise you can ignore them): +--------------------- +{query_str} +""" + +QUESTION_ANSWER_RECORD = """\ +{{ + "type": "This is a previous conversation with the user", + "question": "{question}", + "answer": "{answer}", +}} +""" + + +class MilvusSaveLoad(): + + def does_checkpoint_exist(self, checkpoint_dir=None): + import os, glob + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if not os.path.exists(checkpoint_dir): return False + if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False + return True + + def save_to_checkpoint(self, checkpoint_dir=None): + print(f'saving vector store to: {checkpoint_dir}') + # if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + # self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) + + def load_from_checkpoint(self, checkpoint_dir=None): + if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir + if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir): + print('loading checkpoint from disk') + from llama_index.core import StorageContext, load_index_from_storage + storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir) + try: + self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model) + return self.vs_index + except: + return self.create_new_vs(checkpoint_dir) + else: + return self.create_new_vs(checkpoint_dir) + + def create_new_vs(self, checkpoint_dir, overwrite=False): + vector_store = MilvusVectorStore( + uri=os.path.join(checkpoint_dir, "milvus_demo.db"), + dim=self.embed_model.embedding_dimension(), + overwrite=overwrite + ) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model) + return index + + def purge(self): + self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True) + +class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker): + + def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None: + self.debug_mode = True + self.embed_model = OpenAiEmbeddingModel(llm_kwargs) + self.user_name = user_name + self.checkpoint_dir = checkpoint_dir + if auto_load_checkpoint: + self.vs_index = self.load_from_checkpoint(checkpoint_dir) + else: + self.vs_index = self.create_new_vs(checkpoint_dir) + atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir)) + + def inspect_vector_store(self): + # This function is for debugging + try: + self.vs_index.storage_context.index_store.to_dict() + docstore = self.vs_index.storage_context.docstore.docs + if not docstore.items(): + raise ValueError("cannot inspect") + vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ]) + except: + dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ') + vector_store_preview = "\n".join( + [f"{node.id_} | {node.text}" for node in dummy_retrieve_res] + ) + print('\n++ --------inspect_vector_store begin--------') + print(vector_store_preview) + print('oo --------inspect_vector_store end--------') + return vector_store_preview diff --git a/request_llms/embed_models/openai_embed.py b/request_llms/embed_models/openai_embed.py index c559e1c3..9d565173 100644 --- a/request_llms/embed_models/openai_embed.py +++ b/request_llms/embed_models/openai_embed.py @@ -71,7 +71,13 @@ class OpenAiEmbeddingModel(EmbeddingModel): embedding = res.data[0].embedding return embedding - def embedding_dimension(self, llm_kwargs): + def embedding_dimension(self, llm_kwargs=None): + # load kwargs + if llm_kwargs is None: + llm_kwargs = self.llm_kwargs + if llm_kwargs is None: + raise RuntimeError("llm_kwargs is not provided!") + from .bridge_all_embed import embed_model_info return embed_model_info[llm_kwargs['embed_model']]['embed_dimension'] diff --git a/requirements.txt b/requirements.txt index 757df774..99c841e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,9 @@ tiktoken>=0.3.3 requests[socks] pydantic==2.5.2 llama-index==0.10.47 -protobuf==3.18 +llama-index-vector-stores-milvus==0.1.16 +pymilvus==2.4.2 +protobuf==3.20 transformers>=4.27.1,<4.42 scipdf_parser>=0.52 anthropic>=0.18.1 diff --git a/toolbox.py b/toolbox.py index 6b2f4c10..900cf234 100644 --- a/toolbox.py +++ b/toolbox.py @@ -178,7 +178,7 @@ def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): # yield cookies, chatbot_gr, history, msg -def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1): # 刷新界面 +def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面 """ 刷新用户界面 """ @@ -186,7 +186,7 @@ def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, chatbot.append(["update_ui_last_msg", lastmsg]) chatbot[-1] = list(chatbot[-1]) chatbot[-1][-1] = lastmsg - yield from update_ui(chatbot=chatbot, history=history) + yield from update_ui(chatbot=chatbot, history=history, msg=msg) time.sleep(delay)