Address import LlamaIndexRagWorker problem

This commit is contained in:
lbykkkk
2024-10-13 18:01:33 +00:00
parent cdfe38d296
commit bbcdd9aa71

View File

@@ -1,9 +1,12 @@
import os, glob
import os,glob
from typing import List
from toolbox import report_exception
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from llama_index.core import Document
from shared_utils.fastapi_server import validate_path_safety
from toolbox import report_exception
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
@@ -12,44 +15,14 @@ MAX_HISTORY_ROUND = 5
MAX_CONTEXT_TOKEN_LIMIT = 4096
REMEMBER_PREVIEW = 1000
# import vector store lib
VECTOR_STORE_TYPE = "Milvus"
if VECTOR_STORE_TYPE == "Milvus":
try:
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
except:
VECTOR_STORE_TYPE = "Simple"
if VECTOR_STORE_TYPE == "Simple":
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
@CatchException
def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker):
"""
Handles document uploads by extracting text and adding it to the vector store.
Args:
files (List[str]): List of file paths to process.
llm_kwargs: Language model keyword arguments.
plugin_kwargs: Plugin keyword arguments.
chatbot: Chatbot instance.
history: Chat history.
system_prompt: System prompt.
user_request: User request.
"""
from llama_index.core import Document
from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format
user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
user_name,
llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True
)
for file_path in files:
try:
@@ -73,10 +46,19 @@ def handle_document_upload(files: List[str], llm_kwargs, plugin_kwargs, chatbot,
@CatchException
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
# import vector store lib
VECTOR_STORE_TYPE = "Milvus"
if VECTOR_STORE_TYPE == "Milvus":
try:
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
except:
VECTOR_STORE_TYPE = "Simple"
if VECTOR_STORE_TYPE == "Simple":
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
# 1. we retrieve rag worker from global context
user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
@@ -100,7 +82,7 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
chatbot.append([txt, f'正在处理上传的文档 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
yield from handle_document_upload(file_paths, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request, rag_worker)
return
elif txt == "清空向量数据库":
@@ -145,7 +127,6 @@ def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, u
# 6. Search vector store and build prompts
nodes = rag_worker.retrieve_from_store_with_query(i_say)
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
# 7. Query language model
if len(chatbot) != 0:
chatbot.pop(-1) # Pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`