This commit is contained in:
lbykkkk
2024-12-01 17:35:57 +08:00
parent cf51d4b205
commit b3aef6b393
13 changed files with 398 additions and 234 deletions

View File

@@ -1,20 +1,14 @@
import llama_index
import os
import atexit
import os
from typing import List
from loguru import logger
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from loguru import logger
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
@@ -65,17 +59,19 @@ class MilvusSaveLoad():
def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context,
embed_model=self.embed_model)
return index
def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
@@ -96,7 +92,7 @@ class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
vector_store_preview = "\n".join([f"{_id} | {tn.text}" for _id, tn in docstore.items()])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(