Compare commits
16 Commits
chat_log_n
...
boyin_summ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61676d0536 | ||
|
|
df2ef7940c | ||
|
|
c10f2b45e5 | ||
|
|
7e2ede2d12 | ||
|
|
ec10e2a3ac | ||
|
|
7474d43433 | ||
|
|
83489f9acf | ||
|
|
36e50d490d | ||
|
|
9172337695 | ||
|
|
5dab7b2290 | ||
|
|
89dc6c7265 | ||
|
|
21111d3bd0 | ||
|
|
701018f48c | ||
|
|
8733c4e1e9 | ||
|
|
8498ddf6bf | ||
|
|
3c3293818d |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -160,6 +160,4 @@ test.*
|
|||||||
temp.*
|
temp.*
|
||||||
objdump*
|
objdump*
|
||||||
*.min.*.js
|
*.min.*.js
|
||||||
TODO
|
TODO
|
||||||
experimental_mods
|
|
||||||
search_results
|
|
||||||
@@ -21,13 +21,13 @@ def get_crazy_functions():
|
|||||||
from crazy_functions.询问多个大语言模型 import 同时问询
|
from crazy_functions.询问多个大语言模型 import 同时问询
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
|
||||||
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
|
||||||
from crazy_functions.总结word文档 import 总结word文档
|
|
||||||
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
|
||||||
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
from crazy_functions.Conversation_To_File import 载入对话历史存档
|
||||||
from crazy_functions.Conversation_To_File import 对话历史存档
|
from crazy_functions.Conversation_To_File import 对话历史存档
|
||||||
from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap
|
from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap
|
||||||
from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录
|
from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录
|
||||||
from crazy_functions.辅助功能 import 清除缓存
|
from crazy_functions.辅助功能 import 清除缓存
|
||||||
|
from crazy_functions.批量文件询问 import 批量文件询问
|
||||||
from crazy_functions.Markdown_Translate import Markdown英译中
|
from crazy_functions.Markdown_Translate import Markdown英译中
|
||||||
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
|
||||||
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
from crazy_functions.PDF_Translate import 批量翻译PDF文档
|
||||||
@@ -110,12 +110,13 @@ def get_crazy_functions():
|
|||||||
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
|
||||||
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
"Class": Arxiv_Localize, # 新一代插件需要注册Class
|
||||||
},
|
},
|
||||||
"批量总结Word文档": {
|
"批量文件询问": {
|
||||||
"Group": "学术",
|
"Group": "学术",
|
||||||
"Color": "stop",
|
"Color": "stop",
|
||||||
"AsButton": False,
|
"AsButton": False,
|
||||||
"Info": "批量总结word文档 | 输入参数为路径",
|
"AdvancedArgs": True,
|
||||||
"Function": HotReload(总结word文档),
|
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
|
||||||
|
"Function": HotReload(批量文件询问),
|
||||||
},
|
},
|
||||||
"解析整个Matlab项目": {
|
"解析整个Matlab项目": {
|
||||||
"Group": "编程",
|
"Group": "编程",
|
||||||
|
|||||||
450
crazy_functions/doc_fns/batch_file_query_doc.py
Normal file
450
crazy_functions/doc_fns/batch_file_query_doc.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from docx import Document
|
||||||
|
from docx.enum.style import WD_STYLE_TYPE
|
||||||
|
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||||
|
from docx.oxml.ns import qn
|
||||||
|
from docx.shared import Inches, Cm
|
||||||
|
from docx.shared import Pt, RGBColor, Inches
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentFormatter(ABC):
|
||||||
|
"""文档格式化基类,定义文档格式化的基本接口"""
|
||||||
|
|
||||||
|
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
|
||||||
|
self.final_summary = final_summary
|
||||||
|
self.file_summaries_map = file_summaries_map
|
||||||
|
self.failed_files = failed_files
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
"""格式化失败文件列表"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
"""格式化文件总结内容"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_document(self) -> str:
|
||||||
|
"""创建完整文档"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WordFormatter(DocumentFormatter):
|
||||||
|
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.doc = Document()
|
||||||
|
self._setup_document()
|
||||||
|
self._create_styles()
|
||||||
|
# 初始化三级标题编号系统
|
||||||
|
self.numbers = {
|
||||||
|
1: 0, # 一级标题编号
|
||||||
|
2: 0, # 二级标题编号
|
||||||
|
3: 0 # 三级标题编号
|
||||||
|
}
|
||||||
|
|
||||||
|
def _setup_document(self):
|
||||||
|
"""设置文档基本格式,包括页面设置和页眉"""
|
||||||
|
sections = self.doc.sections
|
||||||
|
for section in sections:
|
||||||
|
# 设置页面大小为A4
|
||||||
|
section.page_width = Cm(21)
|
||||||
|
section.page_height = Cm(29.7)
|
||||||
|
# 设置页边距
|
||||||
|
section.top_margin = Cm(3.7) # 上边距37mm
|
||||||
|
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||||
|
section.left_margin = Cm(2.8) # 左边距28mm
|
||||||
|
section.right_margin = Cm(2.6) # 右边距26mm
|
||||||
|
# 设置页眉页脚距离
|
||||||
|
section.header_distance = Cm(2.0)
|
||||||
|
section.footer_distance = Cm(2.0)
|
||||||
|
|
||||||
|
# 添加页眉
|
||||||
|
header = section.header
|
||||||
|
header_para = header.paragraphs[0]
|
||||||
|
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
|
||||||
|
header_run = header_para.add_run("该文档由GPT-academic生成")
|
||||||
|
header_run.font.name = '仿宋'
|
||||||
|
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
header_run.font.size = Pt(9)
|
||||||
|
|
||||||
|
def _create_styles(self):
|
||||||
|
"""创建文档样式"""
|
||||||
|
# 创建正文样式
|
||||||
|
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
style.font.name = '仿宋'
|
||||||
|
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
style.font.size = Pt(14)
|
||||||
|
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
style.paragraph_format.space_after = Pt(0)
|
||||||
|
style.paragraph_format.first_line_indent = Pt(28)
|
||||||
|
|
||||||
|
# 创建各级标题样式
|
||||||
|
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
|
||||||
|
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
|
||||||
|
|
||||||
|
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
|
||||||
|
"""创建标题样式"""
|
||||||
|
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
style.font.name = font_name
|
||||||
|
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||||
|
style.font.size = Pt(font_size)
|
||||||
|
style.font.bold = True
|
||||||
|
style.paragraph_format.alignment = alignment
|
||||||
|
style.paragraph_format.space_before = Pt(12)
|
||||||
|
style.paragraph_format.space_after = Pt(12)
|
||||||
|
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
return style
|
||||||
|
|
||||||
|
def _get_heading_number(self, level: int) -> str:
|
||||||
|
"""
|
||||||
|
生成标题编号
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: 标题级别 (0-3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化的标题编号
|
||||||
|
"""
|
||||||
|
if level == 0: # 主标题不需要编号
|
||||||
|
return ""
|
||||||
|
|
||||||
|
self.numbers[level] += 1 # 增加当前级别的编号
|
||||||
|
|
||||||
|
# 重置下级标题编号
|
||||||
|
for i in range(level + 1, 4):
|
||||||
|
self.numbers[i] = 0
|
||||||
|
|
||||||
|
# 根据级别返回不同格式的编号
|
||||||
|
if level == 1:
|
||||||
|
return f"{self.numbers[1]}. "
|
||||||
|
elif level == 2:
|
||||||
|
return f"{self.numbers[1]}.{self.numbers[2]} "
|
||||||
|
elif level == 3:
|
||||||
|
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _add_heading(self, text: str, level: int):
|
||||||
|
"""
|
||||||
|
添加带编号的标题
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 标题文本
|
||||||
|
level: 标题级别 (0-3)
|
||||||
|
"""
|
||||||
|
style_map = {
|
||||||
|
0: 'Title_Custom',
|
||||||
|
1: 'Heading1_Custom',
|
||||||
|
2: 'Heading2_Custom',
|
||||||
|
3: 'Heading3_Custom'
|
||||||
|
}
|
||||||
|
|
||||||
|
number = self._get_heading_number(level)
|
||||||
|
paragraph = self.doc.add_paragraph(style=style_map[level])
|
||||||
|
|
||||||
|
if number:
|
||||||
|
number_run = paragraph.add_run(number)
|
||||||
|
font_size = 22 if level == 1 else (18 if level == 2 else 16)
|
||||||
|
self._get_run_style(number_run, '黑体', font_size, True)
|
||||||
|
|
||||||
|
text_run = paragraph.add_run(text)
|
||||||
|
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
|
||||||
|
self._get_run_style(text_run, '黑体', font_size, True)
|
||||||
|
|
||||||
|
# 主标题添加日期
|
||||||
|
if level == 0:
|
||||||
|
date_paragraph = self.doc.add_paragraph()
|
||||||
|
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||||
|
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||||
|
self._get_run_style(date_run, '仿宋', 16, False)
|
||||||
|
|
||||||
|
return paragraph
|
||||||
|
|
||||||
|
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
|
||||||
|
"""设置文本运行对象的样式"""
|
||||||
|
run.font.name = font_name
|
||||||
|
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
|
||||||
|
run.font.size = Pt(font_size)
|
||||||
|
run.font.bold = bold
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
"""格式化失败文件列表"""
|
||||||
|
result = []
|
||||||
|
if not self.failed_files:
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
result.append("处理失败文件:")
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
result.append(f"• {os.path.basename(fp)}: {reason}")
|
||||||
|
|
||||||
|
self._add_heading("处理失败文件", 1)
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
self._add_content(f"• {os.path.basename(fp)}: {reason}", indent=False)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
def _add_content(self, text: str, indent: bool = True):
|
||||||
|
"""添加正文内容"""
|
||||||
|
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
|
||||||
|
if not indent:
|
||||||
|
paragraph.paragraph_format.first_line_indent = Pt(0)
|
||||||
|
return paragraph
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
"""
|
||||||
|
格式化文件总结内容,确保正确的标题层级
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化后的文件总结字符串
|
||||||
|
|
||||||
|
标题层级规则:
|
||||||
|
1. 一级标题为"各文件详细总结"
|
||||||
|
2. 如果文件有目录路径:
|
||||||
|
- 目录路径作为二级标题 (2.1, 2.2 等)
|
||||||
|
- 该目录下所有文件作为三级标题 (2.1.1, 2.1.2 等)
|
||||||
|
3. 如果文件没有目录路径:
|
||||||
|
- 文件直接作为二级标题 (2.1, 2.2 等)
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
# 首先对文件路径进行分组整理
|
||||||
|
file_groups = {}
|
||||||
|
for path in sorted(self.file_summaries_map.keys()):
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path not in file_groups:
|
||||||
|
file_groups[dir_path] = []
|
||||||
|
file_groups[dir_path].append(path)
|
||||||
|
|
||||||
|
# 处理没有目录的文件
|
||||||
|
root_files = file_groups.get("", [])
|
||||||
|
if root_files:
|
||||||
|
for path in sorted(root_files):
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
result.append(f"\n📄 {file_name}")
|
||||||
|
result.append(self.file_summaries_map[path])
|
||||||
|
# 无目录的文件作为二级标题
|
||||||
|
self._add_heading(f"📄 {file_name}", 2)
|
||||||
|
self._add_content(self.file_summaries_map[path])
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 处理有目录的文件
|
||||||
|
for dir_path in sorted(file_groups.keys()):
|
||||||
|
if dir_path == "": # 跳过已处理的根目录文件
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 添加目录作为二级标题
|
||||||
|
result.append(f"\n📁 {dir_path}")
|
||||||
|
self._add_heading(f"📁 {dir_path}", 2)
|
||||||
|
|
||||||
|
# 该目录下的所有文件作为三级标题
|
||||||
|
for path in sorted(file_groups[dir_path]):
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
result.append(f"\n📄 {file_name}")
|
||||||
|
result.append(self.file_summaries_map[path])
|
||||||
|
|
||||||
|
# 添加文件名作为三级标题
|
||||||
|
self._add_heading(f"📄 {file_name}", 3)
|
||||||
|
self._add_content(self.file_summaries_map[path])
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
return "\n".join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def create_document(self):
|
||||||
|
"""创建完整Word文档并返回文档对象"""
|
||||||
|
# 重置所有编号
|
||||||
|
for level in self.numbers:
|
||||||
|
self.numbers[level] = 0
|
||||||
|
|
||||||
|
# 添加主标题
|
||||||
|
self._add_heading("文档总结报告", 0)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 添加总体摘要
|
||||||
|
self._add_heading("总体摘要", 1)
|
||||||
|
self._add_content(self.final_summary)
|
||||||
|
self.doc.add_paragraph()
|
||||||
|
|
||||||
|
# 添加失败文件列表(如果有)
|
||||||
|
if self.failed_files:
|
||||||
|
self.format_failed_files()
|
||||||
|
|
||||||
|
# 添加文件详细总结
|
||||||
|
self._add_heading("各文件详细总结", 1)
|
||||||
|
self.format_file_summaries()
|
||||||
|
|
||||||
|
return self.doc
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownFormatter(DocumentFormatter):
|
||||||
|
"""Markdown格式文档生成器"""
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
if not self.failed_files:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_text = ["\n## ⚠️ 处理失败的文件"]
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
|
||||||
|
formatted_text.append("\n---")
|
||||||
|
return "\n".join(formatted_text)
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
formatted_text = []
|
||||||
|
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||||
|
current_dir = ""
|
||||||
|
|
||||||
|
for path in sorted_paths:
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path != current_dir:
|
||||||
|
if dir_path:
|
||||||
|
formatted_text.append(f"\n## 📁 {dir_path}")
|
||||||
|
current_dir = dir_path
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
formatted_text.append(f"\n### 📄 {file_name}")
|
||||||
|
formatted_text.append(self.file_summaries_map[path])
|
||||||
|
formatted_text.append("\n---")
|
||||||
|
|
||||||
|
return "\n".join(formatted_text)
|
||||||
|
|
||||||
|
def create_document(self) -> str:
|
||||||
|
document = [
|
||||||
|
"# 📑 文档总结报告",
|
||||||
|
"\n## 总体摘要",
|
||||||
|
self.final_summary
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.failed_files:
|
||||||
|
document.append(self.format_failed_files())
|
||||||
|
|
||||||
|
document.extend([
|
||||||
|
"\n# 📚 各文件详细总结",
|
||||||
|
self.format_file_summaries()
|
||||||
|
])
|
||||||
|
|
||||||
|
return "\n".join(document)
|
||||||
|
|
||||||
|
|
||||||
|
class HtmlFormatter(DocumentFormatter):
|
||||||
|
"""HTML格式文档生成器"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.css_styles = """
|
||||||
|
body {
|
||||||
|
font-family: "Microsoft YaHei", Arial, sans-serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
max-width: 1000px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
color: #2c3e50;
|
||||||
|
border-bottom: 2px solid #eee;
|
||||||
|
padding-bottom: 10px;
|
||||||
|
font-size: 24px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
color: #34495e;
|
||||||
|
margin-top: 30px;
|
||||||
|
font-size: 20px;
|
||||||
|
border-left: 4px solid #3498db;
|
||||||
|
padding-left: 10px;
|
||||||
|
}
|
||||||
|
h3 {
|
||||||
|
color: #2c3e50;
|
||||||
|
font-size: 18px;
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
.summary {
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 5px;
|
||||||
|
margin: 20px 0;
|
||||||
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
.details {
|
||||||
|
margin-top: 40px;
|
||||||
|
}
|
||||||
|
.failed-files {
|
||||||
|
background-color: #fff3f3;
|
||||||
|
padding: 15px;
|
||||||
|
border-left: 4px solid #e74c3c;
|
||||||
|
margin: 20px 0;
|
||||||
|
}
|
||||||
|
.file-summary {
|
||||||
|
background-color: #fff;
|
||||||
|
padding: 15px;
|
||||||
|
margin: 15px 0;
|
||||||
|
border-radius: 4px;
|
||||||
|
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_failed_files(self) -> str:
|
||||||
|
if not self.failed_files:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
failed_files_html = ['<div class="failed-files">']
|
||||||
|
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
|
||||||
|
failed_files_html.append("<ul>")
|
||||||
|
for fp, reason in self.failed_files:
|
||||||
|
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
|
||||||
|
failed_files_html.append("</ul></div>")
|
||||||
|
return "\n".join(failed_files_html)
|
||||||
|
|
||||||
|
def format_file_summaries(self) -> str:
|
||||||
|
formatted_html = []
|
||||||
|
sorted_paths = sorted(self.file_summaries_map.keys())
|
||||||
|
current_dir = ""
|
||||||
|
|
||||||
|
for path in sorted_paths:
|
||||||
|
dir_path = os.path.dirname(path)
|
||||||
|
if dir_path != current_dir:
|
||||||
|
if dir_path:
|
||||||
|
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
|
||||||
|
current_dir = dir_path
|
||||||
|
|
||||||
|
file_name = os.path.basename(path)
|
||||||
|
formatted_html.append('<div class="file-summary">')
|
||||||
|
formatted_html.append(f'<h3>📄 {file_name}</h3>')
|
||||||
|
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
|
||||||
|
formatted_html.append('</div>')
|
||||||
|
|
||||||
|
return "\n".join(formatted_html)
|
||||||
|
|
||||||
|
def create_document(self) -> str:
|
||||||
|
return f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset='utf-8'>
|
||||||
|
<title>文档总结报告</title>
|
||||||
|
<style>{self.css_styles}</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>📑 文档总结报告</h1>
|
||||||
|
<h2>总体摘要</h2>
|
||||||
|
<div class="summary">{self.final_summary}</div>
|
||||||
|
{self.format_failed_files()}
|
||||||
|
<div class="details">
|
||||||
|
<h2>📚 各文件详细总结</h2>
|
||||||
|
{self.format_file_summaries()}
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -6,16 +6,12 @@ class SafeUnpickler(pickle.Unpickler):
|
|||||||
def get_safe_classes(self):
|
def get_safe_classes(self):
|
||||||
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
|
||||||
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
|
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
|
||||||
from numpy.core.multiarray import scalar
|
|
||||||
from numpy import dtype
|
|
||||||
# 定义允许的安全类
|
# 定义允许的安全类
|
||||||
safe_classes = {
|
safe_classes = {
|
||||||
# 在这里添加其他安全的类
|
# 在这里添加其他安全的类
|
||||||
'LatexPaperFileGroup': LatexPaperFileGroup,
|
'LatexPaperFileGroup': LatexPaperFileGroup,
|
||||||
'LatexPaperSplit': LatexPaperSplit,
|
'LatexPaperSplit': LatexPaperSplit,
|
||||||
'LinkedListNode': LinkedListNode,
|
'LinkedListNode': LinkedListNode,
|
||||||
'scalar': scalar,
|
|
||||||
'dtype': dtype,
|
|
||||||
}
|
}
|
||||||
return safe_classes
|
return safe_classes
|
||||||
|
|
||||||
@@ -26,6 +22,8 @@ class SafeUnpickler(pickle.Unpickler):
|
|||||||
for class_name in self.safe_classes.keys():
|
for class_name in self.safe_classes.keys():
|
||||||
if (class_name in f'{module}.{name}'):
|
if (class_name in f'{module}.{name}'):
|
||||||
match_class_name = class_name
|
match_class_name = class_name
|
||||||
|
if module == 'numpy' or module.startswith('numpy.'):
|
||||||
|
return super().find_class(module, name)
|
||||||
if match_class_name is not None:
|
if match_class_name is not None:
|
||||||
return self.safe_classes[match_class_name]
|
return self.safe_classes[match_class_name]
|
||||||
# 如果尝试加载未授权的类,则抛出异常
|
# 如果尝试加载未授权的类,则抛出异常
|
||||||
|
|||||||
@@ -1,17 +1,13 @@
|
|||||||
import llama_index
|
|
||||||
import os
|
|
||||||
import atexit
|
import atexit
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.schema import TextNode
|
|
||||||
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
|
||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
|
||||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
|
|
||||||
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
|
||||||
from llama_index.core.ingestion import run_transformations
|
from llama_index.core.ingestion import run_transformations
|
||||||
from llama_index.core import PromptTemplate
|
from llama_index.core.schema import TextNode
|
||||||
from llama_index.core.response_synthesizers import TreeSummarize
|
|
||||||
|
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
|
||||||
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
DEFAULT_QUERY_GENERATION_PROMPT = """\
|
||||||
Now, you have context information as below:
|
Now, you have context information as below:
|
||||||
@@ -63,7 +59,7 @@ class SaveLoad():
|
|||||||
def purge(self):
|
def purge(self):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
|
||||||
self.vs_index = self.create_new_vs()
|
self.vs_index = self.create_new_vs(self.checkpoint_dir)
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexRagWorker(SaveLoad):
|
class LlamaIndexRagWorker(SaveLoad):
|
||||||
@@ -75,7 +71,7 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
if auto_load_checkpoint:
|
if auto_load_checkpoint:
|
||||||
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
|
||||||
else:
|
else:
|
||||||
self.vs_index = self.create_new_vs(checkpoint_dir)
|
self.vs_index = self.create_new_vs()
|
||||||
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
|
||||||
|
|
||||||
def assign_embedding_model(self):
|
def assign_embedding_model(self):
|
||||||
@@ -91,40 +87,52 @@ class LlamaIndexRagWorker(SaveLoad):
|
|||||||
logger.info('oo --------inspect_vector_store end--------')
|
logger.info('oo --------inspect_vector_store end--------')
|
||||||
return vector_store_preview
|
return vector_store_preview
|
||||||
|
|
||||||
def add_documents_to_vector_store(self, document_list):
|
def add_documents_to_vector_store(self, document_list: List[Document]):
|
||||||
documents = [Document(text=t) for t in document_list]
|
"""
|
||||||
|
Adds a list of Document objects to the vector store after processing.
|
||||||
|
"""
|
||||||
|
documents = document_list
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
documents, # type: ignore
|
documents, # type: ignore
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
def add_text_to_vector_store(self, text):
|
def add_text_to_vector_store(self, text: str):
|
||||||
node = TextNode(text=text)
|
node = TextNode(text=text)
|
||||||
documents_nodes = run_transformations(
|
documents_nodes = run_transformations(
|
||||||
[node],
|
[node],
|
||||||
self.vs_index._transformations,
|
self.vs_index._transformations,
|
||||||
show_progress=True
|
show_progress=True
|
||||||
)
|
)
|
||||||
self.vs_index.insert_nodes(documents_nodes)
|
self.vs_index.insert_nodes(documents_nodes)
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
|
|
||||||
def remember_qa(self, question, answer):
|
def remember_qa(self, question, answer):
|
||||||
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
|
||||||
self.add_text_to_vector_store(formatted_str)
|
self.add_text_to_vector_store(formatted_str)
|
||||||
|
|
||||||
def retrieve_from_store_with_query(self, query):
|
def retrieve_from_store_with_query(self, query):
|
||||||
if self.debug_mode: self.inspect_vector_store()
|
if self.debug_mode:
|
||||||
|
self.inspect_vector_store()
|
||||||
retriever = self.vs_index.as_retriever()
|
retriever = self.vs_index.as_retriever()
|
||||||
return retriever.retrieve(query)
|
return retriever.retrieve(query)
|
||||||
|
|
||||||
def build_prompt(self, query, nodes):
|
def build_prompt(self, query, nodes):
|
||||||
context_str = self.generate_node_array_preview(nodes)
|
context_str = self.generate_node_array_preview(nodes)
|
||||||
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
|
||||||
|
|
||||||
def generate_node_array_preview(self, nodes):
|
def generate_node_array_preview(self, nodes):
|
||||||
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
|
||||||
if self.debug_mode: logger.info(buf)
|
if self.debug_mode: logger.info(buf)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
def purge_vector_store(self):
|
||||||
|
"""
|
||||||
|
Purges the current vector store and creates a new one.
|
||||||
|
"""
|
||||||
|
self.purge()
|
||||||
45
crazy_functions/rag_fns/rag_file_support.py
Normal file
45
crazy_functions/rag_fns/rag_file_support.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
|
||||||
|
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
|
||||||
|
|
||||||
|
def read_docx_doc(file_path):
|
||||||
|
if file_path.split(".")[-1] == "docx":
|
||||||
|
from docx import Document
|
||||||
|
doc = Document(file_path)
|
||||||
|
file_content = "\n".join([para.text for para in doc.paragraphs])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import win32com.client
|
||||||
|
word = win32com.client.Dispatch("Word.Application")
|
||||||
|
word.visible = False
|
||||||
|
# 打开文件
|
||||||
|
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
|
||||||
|
# file_content = doc.Content.Text
|
||||||
|
doc = word.ActiveDocument
|
||||||
|
file_content = doc.Range().Text
|
||||||
|
doc.Close()
|
||||||
|
word.Quit()
|
||||||
|
except:
|
||||||
|
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
||||||
|
return file_content
|
||||||
|
|
||||||
|
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
|
||||||
|
import os
|
||||||
|
|
||||||
|
def extract_text(file_path):
|
||||||
|
_, ext = os.path.splitext(file_path.lower())
|
||||||
|
|
||||||
|
# 使用 SimpleDirectoryReader 处理它支持的文件格式
|
||||||
|
if ext in ['.docx', '.doc']:
|
||||||
|
return read_docx_doc(file_path)
|
||||||
|
try:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[file_path])
|
||||||
|
documents = reader.load_data()
|
||||||
|
if len(documents) > 0:
|
||||||
|
return documents[0].text
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
from toolbox import update_ui
|
|
||||||
from toolbox import CatchException, report_exception
|
|
||||||
from toolbox import write_history_to_file, promote_file_to_downloadzone
|
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
|
||||||
fast_debug = False
|
|
||||||
|
|
||||||
|
|
||||||
def 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
|
|
||||||
import time, os
|
|
||||||
# pip install python-docx 用于docx格式,跨平台
|
|
||||||
# pip install pywin32 用于doc格式,仅支持Win平台
|
|
||||||
for index, fp in enumerate(file_manifest):
|
|
||||||
if fp.split(".")[-1] == "docx":
|
|
||||||
from docx import Document
|
|
||||||
doc = Document(fp)
|
|
||||||
file_content = "\n".join([para.text for para in doc.paragraphs])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import win32com.client
|
|
||||||
word = win32com.client.Dispatch("Word.Application")
|
|
||||||
word.visible = False
|
|
||||||
# 打开文件
|
|
||||||
doc = word.Documents.Open(os.getcwd() + '/' + fp)
|
|
||||||
# file_content = doc.Content.Text
|
|
||||||
doc = word.ActiveDocument
|
|
||||||
file_content = doc.Range().Text
|
|
||||||
doc.Close()
|
|
||||||
word.Quit()
|
|
||||||
except:
|
|
||||||
raise RuntimeError('请先将.doc文档转换为.docx文档。')
|
|
||||||
|
|
||||||
# private_upload里面的文件名在解压zip后容易出现乱码(rar和7z格式正常),故可以只分析文章内容,不输入文件名
|
|
||||||
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
|
||||||
from request_llms.bridge_all import model_info
|
|
||||||
max_token = model_info[llm_kwargs['llm_model']]['max_token']
|
|
||||||
TOKEN_LIMIT_PER_FRAGMENT = max_token * 3 // 4
|
|
||||||
paper_fragments = breakdown_text_to_satisfy_token_limit(txt=file_content, limit=TOKEN_LIMIT_PER_FRAGMENT, llm_model=llm_kwargs['llm_model'])
|
|
||||||
this_paper_history = []
|
|
||||||
for i, paper_frag in enumerate(paper_fragments):
|
|
||||||
i_say = f'请对下面的文章片段用中文做概述,文件名是{os.path.relpath(fp, project_folder)},文章内容是 ```{paper_frag}```'
|
|
||||||
i_say_show_user = f'请对下面的文章片段做概述: {os.path.abspath(fp)}的第{i+1}/{len(paper_fragments)}个片段。'
|
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
|
||||||
inputs=i_say,
|
|
||||||
inputs_show_user=i_say_show_user,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
chatbot=chatbot,
|
|
||||||
history=[],
|
|
||||||
sys_prompt="总结文章。"
|
|
||||||
)
|
|
||||||
|
|
||||||
chatbot[-1] = (i_say_show_user, gpt_say)
|
|
||||||
history.extend([i_say_show_user,gpt_say])
|
|
||||||
this_paper_history.extend([i_say_show_user,gpt_say])
|
|
||||||
|
|
||||||
# 已经对该文章的所有片段总结完毕,如果文章被切分了,
|
|
||||||
if len(paper_fragments) > 1:
|
|
||||||
i_say = f"根据以上的对话,总结文章{os.path.abspath(fp)}的主要内容。"
|
|
||||||
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
|
||||||
inputs=i_say,
|
|
||||||
inputs_show_user=i_say,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
chatbot=chatbot,
|
|
||||||
history=this_paper_history,
|
|
||||||
sys_prompt="总结文章。"
|
|
||||||
)
|
|
||||||
|
|
||||||
history.extend([i_say,gpt_say])
|
|
||||||
this_paper_history.extend([i_say,gpt_say])
|
|
||||||
|
|
||||||
res = write_history_to_file(history)
|
|
||||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
|
||||||
chatbot.append(("完成了吗?", res))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
res = write_history_to_file(history)
|
|
||||||
promote_file_to_downloadzone(res, chatbot=chatbot)
|
|
||||||
chatbot.append(("所有文件都总结完成了吗?", res))
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
|
|
||||||
@CatchException
|
|
||||||
def 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
|
|
||||||
import glob, os
|
|
||||||
|
|
||||||
# 基本信息:功能、贡献者
|
|
||||||
chatbot.append([
|
|
||||||
"函数插件功能?",
|
|
||||||
"批量总结Word文档。函数插件贡献者: JasonGuo1。注意, 如果是.doc文件, 请先转化为.docx格式。"])
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
|
|
||||||
# 尝试导入依赖,如果缺少依赖,则给出安装建议
|
|
||||||
try:
|
|
||||||
from docx import Document
|
|
||||||
except:
|
|
||||||
report_exception(chatbot, history,
|
|
||||||
a=f"解析项目: {txt}",
|
|
||||||
b=f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade python-docx pywin32```。")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 清空历史,以免输入溢出
|
|
||||||
history = []
|
|
||||||
|
|
||||||
# 检测输入参数,如没有给定输入参数,直接退出
|
|
||||||
if os.path.exists(txt):
|
|
||||||
project_folder = txt
|
|
||||||
else:
|
|
||||||
if txt == "": txt = '空空如也的输入栏'
|
|
||||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到本地项目或无权访问: {txt}")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 搜索需要处理的文件清单
|
|
||||||
if txt.endswith('.docx') or txt.endswith('.doc'):
|
|
||||||
file_manifest = [txt]
|
|
||||||
else:
|
|
||||||
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.docx', recursive=True)] + \
|
|
||||||
[f for f in glob.glob(f'{project_folder}/**/*.doc', recursive=True)]
|
|
||||||
|
|
||||||
# 如果没找到任何文件
|
|
||||||
if len(file_manifest) == 0:
|
|
||||||
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.docx或doc文件: {txt}")
|
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
|
||||||
return
|
|
||||||
|
|
||||||
# 开始正式执行任务
|
|
||||||
yield from 解析docx(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
|
||||||
496
crazy_functions/批量文件询问.py
Normal file
496
crazy_functions/批量文件询问.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Tuple, Dict, Generator
|
||||||
|
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
|
||||||
|
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import extract_text
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
from toolbox import update_ui, CatchException, report_exception
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileFragment:
|
||||||
|
"""文件片段数据类,用于组织处理单元"""
|
||||||
|
file_path: str
|
||||||
|
content: str
|
||||||
|
rel_path: str
|
||||||
|
fragment_index: int
|
||||||
|
total_fragments: int
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDocumentSummarizer:
|
||||||
|
"""优化的文档总结器 - 批处理版本"""
|
||||||
|
|
||||||
|
def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str):
|
||||||
|
"""初始化总结器"""
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
self.plugin_kwargs = plugin_kwargs
|
||||||
|
self.chatbot = chatbot
|
||||||
|
self.history = history
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.failed_files = []
|
||||||
|
self.file_summaries_map = {}
|
||||||
|
|
||||||
|
def _get_token_limit(self) -> int:
|
||||||
|
"""获取模型token限制"""
|
||||||
|
max_token = model_info[self.llm_kwargs['llm_model']]['max_token']
|
||||||
|
return max_token * 3 // 4
|
||||||
|
|
||||||
|
def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]:
|
||||||
|
"""创建批处理输入"""
|
||||||
|
inputs_array = []
|
||||||
|
inputs_show_user_array = []
|
||||||
|
history_array = []
|
||||||
|
|
||||||
|
for frag in fragments:
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)},'
|
||||||
|
f'用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||||
|
f'文件内容是 ```{frag.content}```')
|
||||||
|
i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})')
|
||||||
|
else:
|
||||||
|
i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)},'
|
||||||
|
f'内容是 ```{frag.content}```')
|
||||||
|
i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})'
|
||||||
|
|
||||||
|
inputs_array.append(i_say)
|
||||||
|
inputs_show_user_array.append(i_say_show_user)
|
||||||
|
history_array.append([])
|
||||||
|
|
||||||
|
return inputs_array, inputs_show_user_array, history_array
|
||||||
|
|
||||||
|
def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]:
|
||||||
|
"""包装了超时控制的文件处理函数"""
|
||||||
|
|
||||||
|
def timeout_handler():
|
||||||
|
thread = threading.current_thread()
|
||||||
|
if hasattr(thread, '_timeout_occurred'):
|
||||||
|
thread._timeout_occurred = True
|
||||||
|
|
||||||
|
# 设置超时标记
|
||||||
|
thread = threading.current_thread()
|
||||||
|
thread._timeout_occurred = False
|
||||||
|
|
||||||
|
# 设置超时定时器
|
||||||
|
timer = threading.Timer(self.watch_dog_patience, timeout_handler)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fp, project_folder = file_info
|
||||||
|
fragments = []
|
||||||
|
|
||||||
|
# 定期检查是否超时
|
||||||
|
def check_timeout():
|
||||||
|
if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred:
|
||||||
|
raise TimeoutError("处理超时")
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "检查文件大小"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 文件大小检查
|
||||||
|
if os.path.getsize(fp) > self.max_file_size:
|
||||||
|
self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB"))
|
||||||
|
mutable_status[2] = "文件过大"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "提取文件内容"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
|
||||||
|
# 提取内容
|
||||||
|
content = extract_text(fp)
|
||||||
|
if content is None:
|
||||||
|
self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏"))
|
||||||
|
mutable_status[2] = "格式不支持"
|
||||||
|
return fragments
|
||||||
|
elif not content.strip():
|
||||||
|
self.failed_files.append((fp, "文件内容为空"))
|
||||||
|
mutable_status[2] = "内容为空"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 更新状态
|
||||||
|
mutable_status[0] = "分割文本"
|
||||||
|
mutable_status[1] = time.time()
|
||||||
|
|
||||||
|
# 分割文本
|
||||||
|
try:
|
||||||
|
paper_fragments = breakdown_text_to_satisfy_token_limit(
|
||||||
|
txt=content,
|
||||||
|
limit=self._get_token_limit(),
|
||||||
|
llm_model=self.llm_kwargs['llm_model']
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.failed_files.append((fp, f"文本分割失败:{str(e)}"))
|
||||||
|
mutable_status[2] = "分割失败"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
check_timeout()
|
||||||
|
|
||||||
|
# 处理片段
|
||||||
|
rel_path = os.path.relpath(fp, project_folder)
|
||||||
|
for i, frag in enumerate(paper_fragments):
|
||||||
|
if frag.strip():
|
||||||
|
fragments.append(FileFragment(
|
||||||
|
file_path=fp,
|
||||||
|
content=frag,
|
||||||
|
rel_path=rel_path,
|
||||||
|
fragment_index=i,
|
||||||
|
total_fragments=len(paper_fragments)
|
||||||
|
))
|
||||||
|
|
||||||
|
mutable_status[2] = "处理完成"
|
||||||
|
return fragments
|
||||||
|
|
||||||
|
except TimeoutError as e:
|
||||||
|
self.failed_files.append((fp, "处理超时"))
|
||||||
|
mutable_status[2] = "处理超时"
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
self.failed_files.append((fp, f"处理失败:{str(e)}"))
|
||||||
|
mutable_status[2] = "处理异常"
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
timer.cancel()
|
||||||
|
|
||||||
|
def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Generator, List
|
||||||
|
"""并行准备所有文件的处理片段"""
|
||||||
|
all_fragments = []
|
||||||
|
total_files = len(file_paths)
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
self.refresh_interval = 0.2 # UI刷新间隔
|
||||||
|
self.watch_dog_patience = 5 # 看门狗超时时间
|
||||||
|
self.max_file_size = 10 * 1024 * 1024 # 10MB限制
|
||||||
|
self.max_workers = min(32, len(file_paths)) # 最多32个线程
|
||||||
|
|
||||||
|
# 创建有超时控制的线程池
|
||||||
|
executor = ThreadPoolExecutor(max_workers=self.max_workers)
|
||||||
|
|
||||||
|
# 用于跨线程状态传递的可变列表 - 增加文件名信息
|
||||||
|
mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths]
|
||||||
|
|
||||||
|
# 创建文件处理任务
|
||||||
|
file_infos = [(fp, project_folder) for fp in file_paths]
|
||||||
|
|
||||||
|
# 提交所有任务,使用带超时控制的处理函数
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
self._process_single_file_with_timeout,
|
||||||
|
file_info,
|
||||||
|
mutable_status_array[i]
|
||||||
|
) for i, file_info in enumerate(file_infos)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 更新UI的计数器
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 监控任务执行
|
||||||
|
while True:
|
||||||
|
time.sleep(self.refresh_interval)
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
# 检查任务完成状态
|
||||||
|
worker_done = [f.done() for f in futures]
|
||||||
|
|
||||||
|
# 更新状态显示
|
||||||
|
status_str = ""
|
||||||
|
for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array):
|
||||||
|
# 获取文件名(去掉路径)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
if worker_done[i]:
|
||||||
|
status_str += f"文件 {file_name}: {desc}\n"
|
||||||
|
else:
|
||||||
|
status_str += f"文件 {file_name}: {status} {desc}\n"
|
||||||
|
|
||||||
|
# 更新UI
|
||||||
|
self.chatbot[-1] = [
|
||||||
|
"处理进度",
|
||||||
|
f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1)
|
||||||
|
]
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 检查是否所有任务完成
|
||||||
|
if all(worker_done):
|
||||||
|
break
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 确保线程池正确关闭
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
processed_files = 0
|
||||||
|
for future in futures:
|
||||||
|
try:
|
||||||
|
fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果
|
||||||
|
all_fragments.extend(fragments)
|
||||||
|
processed_files += 1
|
||||||
|
except concurrent.futures.TimeoutError:
|
||||||
|
# 处理获取结果超时
|
||||||
|
file_index = futures.index(future)
|
||||||
|
self.failed_files.append((file_paths[file_index], "结果获取超时"))
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其他异常
|
||||||
|
file_index = futures.index(future)
|
||||||
|
self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 最终进度更新
|
||||||
|
self.chatbot.append([
|
||||||
|
"文件处理完成",
|
||||||
|
f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件"
|
||||||
|
])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
return all_fragments
|
||||||
|
|
||||||
|
def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator:
|
||||||
|
"""批量处理文件片段"""
|
||||||
|
from collections import defaultdict
|
||||||
|
batch_size = 64 # 每批处理的片段数
|
||||||
|
max_retries = 3 # 最大重试次数
|
||||||
|
retry_delay = 5 # 重试延迟(秒)
|
||||||
|
results = defaultdict(list)
|
||||||
|
|
||||||
|
# 按批次处理
|
||||||
|
for i in range(0, len(fragments), batch_size):
|
||||||
|
batch = fragments[i:i + batch_size]
|
||||||
|
|
||||||
|
inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch)
|
||||||
|
sys_prompt_array = ["请总结以下内容:"] * len(batch)
|
||||||
|
|
||||||
|
# 添加重试机制
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=inputs_array,
|
||||||
|
inputs_show_user_array=inputs_show_user_array,
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=history_array,
|
||||||
|
sys_prompt_array=sys_prompt_array,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
for j, frag in enumerate(batch):
|
||||||
|
summary = response_collection[j * 2 + 1]
|
||||||
|
if summary and summary.strip():
|
||||||
|
results[frag.rel_path].append({
|
||||||
|
'index': frag.fragment_index,
|
||||||
|
'summary': summary,
|
||||||
|
'total': frag.total_fragments
|
||||||
|
})
|
||||||
|
break # 成功处理,跳出重试循环
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry == max_retries - 1: # 最后一次重试失败
|
||||||
|
for frag in batch:
|
||||||
|
self.failed_files.append((frag.file_path, f"处理失败:{str(e)}"))
|
||||||
|
else:
|
||||||
|
yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)]))
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate_final_summary_request(self) -> Tuple[List, List, List]:
|
||||||
|
"""准备最终总结请求"""
|
||||||
|
if not self.file_summaries_map:
|
||||||
|
return (["无可用的文件总结"], ["生成最终总结"], [[]])
|
||||||
|
|
||||||
|
summaries = list(self.file_summaries_map.values())
|
||||||
|
if all(not summary for summary in summaries):
|
||||||
|
return (["所有文件处理均失败"], ["生成最终总结"], [[]])
|
||||||
|
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg']
|
||||||
|
else:
|
||||||
|
i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。"
|
||||||
|
|
||||||
|
return ([i_say], [i_say], [summaries])
|
||||||
|
|
||||||
|
def process_files(self, project_folder: str, file_paths: List[str]) -> Generator:
|
||||||
|
"""处理所有文件"""
|
||||||
|
total_files = len(file_paths)
|
||||||
|
self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 1. 准备所有文件片段
|
||||||
|
# 在 process_files 函数中:
|
||||||
|
fragments = yield from self.prepare_fragments(project_folder, file_paths)
|
||||||
|
if not fragments:
|
||||||
|
self.chatbot.append(["处理失败", "没有可处理的文件内容"])
|
||||||
|
return "没有可处理的文件内容"
|
||||||
|
|
||||||
|
# 2. 批量处理所有文件片段
|
||||||
|
self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_summaries = yield from self._process_fragments_batch(fragments)
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"])
|
||||||
|
return "处理过程发生错误"
|
||||||
|
|
||||||
|
# 3. 为每个文件生成整体总结
|
||||||
|
self.chatbot.append(["生成总结", "正在汇总文件内容..."])
|
||||||
|
yield from update_ui(chatbot=self.chatbot, history=self.history)
|
||||||
|
|
||||||
|
# 处理每个文件的总结
|
||||||
|
for rel_path, summaries in file_summaries.items():
|
||||||
|
if len(summaries) > 1: # 多片段文件需要生成整体总结
|
||||||
|
sorted_summaries = sorted(summaries, key=lambda x: x['index'])
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
|
||||||
|
i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}:'
|
||||||
|
else:
|
||||||
|
i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。"
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary_texts = [s['summary'] for s in sorted_summaries]
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=[i_say],
|
||||||
|
inputs_show_user_array=[f"生成 {rel_path} 的处理结果"],
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=[summary_texts],
|
||||||
|
sys_prompt_array=["你是一个优秀的助手,"],
|
||||||
|
)
|
||||||
|
self.file_summaries_map[rel_path] = response_collection[1]
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"])
|
||||||
|
self.file_summaries_map[rel_path] = "总结生成失败"
|
||||||
|
else: # 单片段文件直接使用其唯一的总结
|
||||||
|
self.file_summaries_map[rel_path] = summaries[0]['summary']
|
||||||
|
|
||||||
|
# 4. 生成最终总结
|
||||||
|
if total_files ==1:
|
||||||
|
return "文件数为1,此时不调用总结模块"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# 收集所有文件的总结用于生成最终总结
|
||||||
|
file_summaries_for_final = []
|
||||||
|
for rel_path, summary in self.file_summaries_map.items():
|
||||||
|
file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}")
|
||||||
|
|
||||||
|
if self.plugin_kwargs.get("advanced_arg"):
|
||||||
|
final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" +
|
||||||
|
self.plugin_kwargs['advanced_arg'])
|
||||||
|
else:
|
||||||
|
final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。"
|
||||||
|
|
||||||
|
response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
|
||||||
|
inputs_array=[final_summary_prompt],
|
||||||
|
inputs_show_user_array=["生成最终总结报告"],
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
chatbot=self.chatbot,
|
||||||
|
history_array=[file_summaries_for_final],
|
||||||
|
sys_prompt_array=["总结所有文件内容。"],
|
||||||
|
max_workers=1
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_collection[1] if len(response_collection) > 1 else "生成总结失败"
|
||||||
|
except Exception as e:
|
||||||
|
self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"])
|
||||||
|
return "生成总结失败"
|
||||||
|
|
||||||
|
def save_results(self, final_summary: str):
|
||||||
|
"""保存结果到文件"""
|
||||||
|
from toolbox import promote_file_to_downloadzone, write_history_to_file
|
||||||
|
from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter
|
||||||
|
import os
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# 创建各种格式化器
|
||||||
|
md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files)
|
||||||
|
|
||||||
|
result_files = []
|
||||||
|
|
||||||
|
# 保存 Markdown
|
||||||
|
md_content = md_formatter.create_document()
|
||||||
|
result_file_md = write_history_to_file(
|
||||||
|
history=[md_content], # 直接传入内容列表
|
||||||
|
file_basename=f"文档总结_{timestamp}.md"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_md)
|
||||||
|
|
||||||
|
# 保存 HTML
|
||||||
|
html_content = html_formatter.create_document()
|
||||||
|
result_file_html = write_history_to_file(
|
||||||
|
history=[html_content],
|
||||||
|
file_basename=f"文档总结_{timestamp}.html"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_html)
|
||||||
|
|
||||||
|
# 保存 Word
|
||||||
|
doc = word_formatter.create_document()
|
||||||
|
# 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录
|
||||||
|
result_file_docx = os.path.join(
|
||||||
|
os.path.dirname(result_file_md),
|
||||||
|
f"文档总结_{timestamp}.docx"
|
||||||
|
)
|
||||||
|
doc.save(result_file_docx)
|
||||||
|
result_files.append(result_file_docx)
|
||||||
|
|
||||||
|
# 添加到下载区
|
||||||
|
for file in result_files:
|
||||||
|
promote_file_to_downloadzone(file, chatbot=self.chatbot)
|
||||||
|
|
||||||
|
self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"])
|
||||||
|
@CatchException
|
||||||
|
def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, user_request: str):
|
||||||
|
"""主函数 - 优化版本"""
|
||||||
|
# 初始化
|
||||||
|
import glob
|
||||||
|
import re
|
||||||
|
from crazy_functions.rag_fns.rag_file_support import supports_format
|
||||||
|
from toolbox import report_exception
|
||||||
|
|
||||||
|
summarizer = BatchDocumentSummarizer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
|
||||||
|
chatbot.append(["函数插件功能", f"作者:lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 验证输入路径
|
||||||
|
if not os.path.exists(txt):
|
||||||
|
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取文件列表
|
||||||
|
project_folder = txt
|
||||||
|
extract_folder = next((d for d in glob.glob(f'{project_folder}/*')
|
||||||
|
if os.path.isdir(d) and d.endswith('.extract')), project_folder)
|
||||||
|
|
||||||
|
exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$'
|
||||||
|
file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True)
|
||||||
|
if os.path.isfile(f) and not re.search(exclude_patterns, f)]
|
||||||
|
|
||||||
|
if not file_manifest:
|
||||||
|
report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理所有文件并生成总结
|
||||||
|
final_summary = yield from summarizer.process_files(project_folder, file_manifest)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
summarizer.save_results(final_summary)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -385,14 +385,6 @@ model_info = {
|
|||||||
"tokenizer": tokenizer_gpt35,
|
"tokenizer": tokenizer_gpt35,
|
||||||
"token_cnt": get_token_num_gpt35,
|
"token_cnt": get_token_num_gpt35,
|
||||||
},
|
},
|
||||||
"glm-4-plus":{
|
|
||||||
"fn_with_ui": zhipu_ui,
|
|
||||||
"fn_without_ui": zhipu_noui,
|
|
||||||
"endpoint": None,
|
|
||||||
"max_token": 10124 * 8,
|
|
||||||
"tokenizer": tokenizer_gpt35,
|
|
||||||
"token_cnt": get_token_num_gpt35,
|
|
||||||
},
|
|
||||||
|
|
||||||
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
|
# api_2d (此后不需要在此处添加api2d的接口了,因为下面的代码会自动添加)
|
||||||
"api2d-gpt-4": {
|
"api2d-gpt-4": {
|
||||||
|
|||||||
@@ -341,7 +341,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||||
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||||
@@ -375,7 +375,7 @@ def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
|||||||
try:
|
try:
|
||||||
chunkjson = json.loads(response.content.decode())
|
chunkjson = json.loads(response.content.decode())
|
||||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
history[-1] = gpt_replying_buffer
|
history[-1] = gpt_replying_buffer
|
||||||
chatbot[-1] = (history[-2], history[-1])
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
lastmsg = chatbot[-1][-1] + f"\n\n\n\n「{llm_kwargs['llm_model']}调用结束,该模型不具备上下文对话能力,如需追问,请及时切换模型。」"
|
lastmsg = chatbot[-1][-1] + f"\n\n\n\n「{llm_kwargs['llm_model']}调用结束,该模型不具备上下文对话能力,如需追问,请及时切换模型。」"
|
||||||
yield from update_ui_lastest_msg(lastmsg, chatbot, history, delay=1)
|
yield from update_ui_lastest_msg(lastmsg, chatbot, history, delay=1)
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
if need_to_pass:
|
if need_to_pass:
|
||||||
pass
|
pass
|
||||||
elif is_last_chunk:
|
elif is_last_chunk:
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
# logger.info(f'[response] {gpt_replying_buffer}')
|
# logger.info(f'[response] {gpt_replying_buffer}')
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot[-1] = (history[-2], history[-1])
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
|
||||||
if chunkjson['event_type'] == 'stream-end':
|
if chunkjson['event_type'] == 'stream-end':
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
history[-1] = gpt_replying_buffer
|
history[-1] = gpt_replying_buffer
|
||||||
chatbot[-1] = (history[-2], history[-1])
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="正常") # 刷新界面
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
|
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
|
||||||
chatbot[-1] = (inputs, gpt_replying_buffer)
|
chatbot[-1] = (inputs, gpt_replying_buffer)
|
||||||
history[-1] = gpt_replying_buffer
|
history[-1] = gpt_replying_buffer
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
if error_match:
|
if error_match:
|
||||||
history = history[-2] # 错误的不纳入对话
|
history = history[-2] # 错误的不纳入对话
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
history = history[:-2]
|
history = history[:-2]
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
break
|
break
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_bro_result, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_bro_result)
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
||||||
console_slience=False):
|
console_slience=False):
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||||
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
if ('data: [DONE]' in chunk_decoded) or (len(chunkjson['choices'][0]["delta"]) == 0):
|
||||||
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
# 判定为数据流的结束,gpt_replying_buffer也写完了
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
break
|
break
|
||||||
# 处理数据流的主体
|
# 处理数据流的主体
|
||||||
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
status_text = f"finish_reason: {chunkjson['choices'][0].get('finish_reason', 'null')}"
|
||||||
@@ -371,7 +371,7 @@ def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
|||||||
try:
|
try:
|
||||||
chunkjson = json.loads(response.content.decode())
|
chunkjson = json.loads(response.content.decode())
|
||||||
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
history[-1] = gpt_replying_buffer
|
history[-1] = gpt_replying_buffer
|
||||||
chatbot[-1] = (history[-2], history[-1])
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
chatbot[-1] = (inputs, response)
|
chatbot[-1] = (inputs, response)
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response)
|
||||||
# 总结输出
|
# 总结输出
|
||||||
if response == f"[Local Message] 等待{model_name}响应中 ...":
|
if response == f"[Local Message] 等待{model_name}响应中 ...":
|
||||||
response = f"[Local Message] {model_name}响应异常 ..."
|
response = f"[Local Message] {model_name}响应异常 ..."
|
||||||
|
|||||||
@@ -68,5 +68,5 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot[-1] = [inputs, response]
|
chatbot[-1] = [inputs, response]
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
history.extend([inputs, response])
|
history.extend([inputs, response])
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response)
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -97,5 +97,5 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot[-1] = [inputs, response]
|
chatbot[-1] = [inputs, response]
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
history.extend([inputs, response])
|
history.extend([inputs, response])
|
||||||
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response, user_name=chatbot.get_user())
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=response)
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -138,9 +138,7 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
app_block.is_sagemaker = False
|
app_block.is_sagemaker = False
|
||||||
|
|
||||||
gradio_app = App.create_app(app_block)
|
gradio_app = App.create_app(app_block)
|
||||||
for route in list(gradio_app.router.routes):
|
|
||||||
if route.path == "/proxy={url_path:path}":
|
|
||||||
gradio_app.router.routes.remove(route)
|
|
||||||
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
# --- --- replace gradio endpoint to forbid access to sensitive files --- ---
|
||||||
if len(AUTHENTICATION) > 0:
|
if len(AUTHENTICATION) > 0:
|
||||||
dependencies = []
|
dependencies = []
|
||||||
@@ -156,13 +154,9 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
||||||
async def file(path_or_url: str, request: fastapi.Request):
|
async def file(path_or_url: str, request: fastapi.Request):
|
||||||
if not _authorize_user(path_or_url, request, gradio_app):
|
if len(AUTHENTICATION) > 0:
|
||||||
return "越权访问!"
|
if not _authorize_user(path_or_url, request, gradio_app):
|
||||||
stripped = path_or_url.lstrip().lower()
|
return "越权访问!"
|
||||||
if stripped.startswith("https://") or stripped.startswith("http://"):
|
|
||||||
return "账户密码授权模式下, 禁止链接!"
|
|
||||||
if '../' in stripped:
|
|
||||||
return "非法路径!"
|
|
||||||
return await endpoint(path_or_url, request)
|
return await endpoint(path_or_url, request)
|
||||||
|
|
||||||
from fastapi import Request, status
|
from fastapi import Request, status
|
||||||
@@ -173,26 +167,6 @@ def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SS
|
|||||||
response.delete_cookie('access-token')
|
response.delete_cookie('access-token')
|
||||||
response.delete_cookie('access-token-unsecure')
|
response.delete_cookie('access-token-unsecure')
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
dependencies = []
|
|
||||||
endpoint = None
|
|
||||||
for route in list(gradio_app.router.routes):
|
|
||||||
if route.path == "/file/{path:path}":
|
|
||||||
gradio_app.router.routes.remove(route)
|
|
||||||
if route.path == "/file={path_or_url:path}":
|
|
||||||
dependencies = route.dependencies
|
|
||||||
endpoint = route.endpoint
|
|
||||||
gradio_app.router.routes.remove(route)
|
|
||||||
@gradio_app.get("/file/{path:path}", dependencies=dependencies)
|
|
||||||
@gradio_app.head("/file={path_or_url:path}", dependencies=dependencies)
|
|
||||||
@gradio_app.get("/file={path_or_url:path}", dependencies=dependencies)
|
|
||||||
async def file(path_or_url: str, request: fastapi.Request):
|
|
||||||
stripped = path_or_url.lstrip().lower()
|
|
||||||
if stripped.startswith("https://") or stripped.startswith("http://"):
|
|
||||||
return "账户密码授权模式下, 禁止链接!"
|
|
||||||
if '../' in stripped:
|
|
||||||
return "非法路径!"
|
|
||||||
return await endpoint(path_or_url, request)
|
|
||||||
|
|
||||||
# --- --- enable TTS (text-to-speech) functionality --- ---
|
# --- --- enable TTS (text-to-speech) functionality --- ---
|
||||||
TTS_TYPE = get_conf("TTS_TYPE")
|
TTS_TYPE = get_conf("TTS_TYPE")
|
||||||
|
|||||||
@@ -104,27 +104,17 @@ def extract_archive(file_path, dest_dir):
|
|||||||
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
|
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
|
||||||
|
|
||||||
elif file_extension in [".tar", ".gz", ".bz2"]:
|
elif file_extension in [".tar", ".gz", ".bz2"]:
|
||||||
try:
|
with tarfile.open(file_path, "r:*") as tarobj:
|
||||||
with tarfile.open(file_path, "r:*") as tarobj:
|
# 清理提取路径,移除任何不安全的元素
|
||||||
# 清理提取路径,移除任何不安全的元素
|
for member in tarobj.getmembers():
|
||||||
for member in tarobj.getmembers():
|
member_path = os.path.normpath(member.name)
|
||||||
member_path = os.path.normpath(member.name)
|
full_path = os.path.join(dest_dir, member_path)
|
||||||
full_path = os.path.join(dest_dir, member_path)
|
full_path = os.path.abspath(full_path)
|
||||||
full_path = os.path.abspath(full_path)
|
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
|
||||||
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
|
raise Exception(f"Attempted Path Traversal in {member.name}")
|
||||||
raise Exception(f"Attempted Path Traversal in {member.name}")
|
|
||||||
|
|
||||||
tarobj.extractall(path=dest_dir)
|
tarobj.extractall(path=dest_dir)
|
||||||
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
|
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
|
||||||
except tarfile.ReadError as e:
|
|
||||||
if file_extension == ".gz":
|
|
||||||
# 一些特别奇葩的项目,是一个gz文件,里面不是tar,只有一个tex文件
|
|
||||||
import gzip
|
|
||||||
with gzip.open(file_path, 'rb') as f_in:
|
|
||||||
with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
|
|
||||||
f_out.write(f_in.read())
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# 第三方库,需要预先pip install rarfile
|
# 第三方库,需要预先pip install rarfile
|
||||||
# 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
|
# 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ openai_regex = re.compile(
|
|||||||
r"sk-[a-zA-Z0-9_-]{92}$|" +
|
r"sk-[a-zA-Z0-9_-]{92}$|" +
|
||||||
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
|
r"sk-proj-[a-zA-Z0-9_-]{48}$|"+
|
||||||
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
|
r"sk-proj-[a-zA-Z0-9_-]{124}$|"+
|
||||||
r"sk-proj-[a-zA-Z0-9_-]{156}$|"+ #新版apikey位数不匹配故修改此正则表达式
|
|
||||||
r"sess-[a-zA-Z0-9]{40}$"
|
r"sess-[a-zA-Z0-9]{40}$"
|
||||||
)
|
)
|
||||||
def is_openai_api_key(key):
|
def is_openai_api_key(key):
|
||||||
|
|||||||
@@ -1029,7 +1029,7 @@ def check_repeat_upload(new_pdf_path, pdf_hash):
|
|||||||
# 如果所有页的内容都相同,返回 True
|
# 如果所有页的内容都相同,返回 True
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
def log_chat(llm_model: str, input_str: str, output_str: str, user_name: str=default_user_name):
|
def log_chat(llm_model: str, input_str: str, output_str: str):
|
||||||
try:
|
try:
|
||||||
if output_str and input_str and llm_model:
|
if output_str and input_str and llm_model:
|
||||||
uid = str(uuid.uuid4().hex)
|
uid = str(uuid.uuid4().hex)
|
||||||
@@ -1038,8 +1038,8 @@ def log_chat(llm_model: str, input_str: str, output_str: str, user_name: str=def
|
|||||||
logger.bind(chat_msg=True).info(dedent(
|
logger.bind(chat_msg=True).info(dedent(
|
||||||
"""
|
"""
|
||||||
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
|
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||||
[UID/USER]
|
[UID]
|
||||||
{uid}/{user_name}
|
{uid}
|
||||||
[Model]
|
[Model]
|
||||||
{llm_model}
|
{llm_model}
|
||||||
[Query]
|
[Query]
|
||||||
@@ -1047,6 +1047,6 @@ def log_chat(llm_model: str, input_str: str, output_str: str, user_name: str=def
|
|||||||
[Response]
|
[Response]
|
||||||
{output_str}
|
{output_str}
|
||||||
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||||
""").format(uid=uid, user_name=user_name, llm_model=llm_model, input_str=input_str, output_str=output_str))
|
""").format(uid=uid, llm_model=llm_model, input_str=input_str, output_str=output_str))
|
||||||
except:
|
except:
|
||||||
logger.error(trimmed_format_exc())
|
logger.error(trimmed_format_exc())
|
||||||
|
|||||||
Reference in New Issue
Block a user