add milvus vector store

This commit is contained in:
binary-husky
2024-09-08 15:19:03 +00:00
parent e4e00b713f
commit 8b91d2ac0a
6 changed files with 128 additions and 8 deletions

9
TODO
View File

@@ -1 +1,10 @@
RAG忘了触发保存了 RAG忘了触发保存了
刘博寅: 用llama index 实现 RAG 文档向量化
RAG代码参考
crazy_functions/rag_fns/llama_index_worker.py
crazy_functions/rag_fns/milvus_worker.py
crazy_functions/rag_fns/vector_store_index.py
读取文件的代码参考使用glob
crazy_functions/SourceCode_Analyse.py

View File

@@ -1,7 +1,14 @@
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
VECTOR_STORE_TYPE = "Milvus"
if VECTOR_STORE_TYPE == "Simple":
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
if VECTOR_STORE_TYPE == "Milvus":
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
RAG_WORKER_REGISTER = {} RAG_WORKER_REGISTER = {}

View File

@@ -1,4 +1,7 @@
import llama_index import llama_index
import os
import atexit
from typing import List
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.schema import TextNode from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
@@ -38,6 +41,7 @@ class SaveLoad():
return True return True
def save_to_checkpoint(self, checkpoint_dir=None): def save_to_checkpoint(self, checkpoint_dir=None):
print(f'saving vector store to: {checkpoint_dir}')
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir) self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
@@ -65,7 +69,8 @@ class LlamaIndexRagWorker(SaveLoad):
if auto_load_checkpoint: if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir) self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else: else:
self.vs_index = self.create_new_vs() self.vs_index = self.create_new_vs(checkpoint_dir)
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
def assign_embedding_model(self): def assign_embedding_model(self):
pass pass
@@ -117,6 +122,3 @@ class LlamaIndexRagWorker(SaveLoad):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)])) buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: print(buf) if self.debug_mode: print(buf)
return buf return buf

View File

@@ -0,0 +1,94 @@
import llama_index
import os
import atexit
from typing import List
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
{context_str}
---------------------
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
---------------------
{query_str}
"""
QUESTION_ANSWER_RECORD = """\
{{
"type": "This is a previous conversation with the user",
"question": "{question}",
"answer": "{answer}",
}}
"""
class MilvusSaveLoad():
def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
return True
def save_to_checkpoint(self, checkpoint_dir=None):
print(f'saving vector store to: {checkpoint_dir}')
# if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
# self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
print('loading checkpoint from disk')
from llama_index.core import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
try:
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
return self.vs_index
except:
return self.create_new_vs(checkpoint_dir)
else:
return self.create_new_vs(checkpoint_dir)
def create_new_vs(self, checkpoint_dir):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension()
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
return index
class MilvusRagWorker(LlamaIndexRagWorker):
def inspect_vector_store(self):
# This function is for debugging
try:
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(
[f"{node.id_} | {node.text}" for node in dummy_retrieve_res]
)
print('\n++ --------inspect_vector_store begin--------')
print(vector_store_preview)
print('oo --------inspect_vector_store end--------')
return vector_store_preview

View File

@@ -71,7 +71,13 @@ class OpenAiEmbeddingModel(EmbeddingModel):
embedding = res.data[0].embedding embedding = res.data[0].embedding
return embedding return embedding
def embedding_dimension(self, llm_kwargs): def embedding_dimension(self, llm_kwargs=None):
# load kwargs
if llm_kwargs is None:
llm_kwargs = self.llm_kwargs
if llm_kwargs is None:
raise RuntimeError("llm_kwargs is not provided!")
from .bridge_all_embed import embed_model_info from .bridge_all_embed import embed_model_info
return embed_model_info[llm_kwargs['embed_model']]['embed_dimension'] return embed_model_info[llm_kwargs['embed_model']]['embed_dimension']

View File

@@ -7,7 +7,9 @@ tiktoken>=0.3.3
requests[socks] requests[socks]
pydantic==2.5.2 pydantic==2.5.2
llama-index==0.10.47 llama-index==0.10.47
protobuf==3.18 llama-index-vector-stores-milvus==0.1.16
pymilvus==2.4.2
protobuf==3.20
transformers>=4.27.1,<4.42 transformers>=4.27.1,<4.42
scipdf_parser>=0.52 scipdf_parser>=0.52
anthropic>=0.18.1 anthropic>=0.18.1