diff --git a/.gitignore b/.gitignore index 01cc7985..f433f9b4 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,5 @@ objdump* TODO experimental_mods search_results +gg.docx +unstructured_reader.py diff --git a/Dockerfile b/Dockerfile index 57646eaa..c54dcc79 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ RUN echo '[global]' > /etc/pip.conf && \ echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \ echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf -# 语音输出功能(以下1,2行更换阿里源,第3,4行安装ffmpeg,都可以删除) +# 语音输出功能(以下1,2行更换阿里源,第3,4行安装ffmpeg,都可以删除) RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \ sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources && \ apt-get update @@ -34,5 +34,7 @@ RUN uv venv --python=3.12 && uv pip install -r requirements.txt -i https://mirro # # 非必要步骤,用于预热模块(可以删除) RUN python -c 'from check_proxy import warm_up_modules; warm_up_modules()' +ENV CGO_ENABLED=0 + # 启动(必要) CMD ["bash", "-c", "python main.py"] diff --git a/check_proxy.py b/check_proxy.py index 6124a6ef..0acebbbc 100644 --- a/check_proxy.py +++ b/check_proxy.py @@ -230,6 +230,48 @@ def warm_up_modules(): enc.encode("模块预热", disallowed_special=()) enc = model_info["gpt-4"]['tokenizer'] enc.encode("模块预热", disallowed_special=()) + try_warm_up_vectordb() + + +# def try_warm_up_vectordb(): +# try: +# import os +# import nltk +# target = os.path.expanduser('~/nltk_data') +# logger.info(f'模块预热: nltk punkt (从Github下载部分文件到 {target})') +# nltk.data.path.append(target) +# nltk.download('punkt', download_dir=target) +# logger.info('模块预热完成: nltk punkt') +# except: +# logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt') +# logger.error('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt') + + +def try_warm_up_vectordb(): + import os + import nltk + target = os.path.expanduser('~/nltk_data') + nltk.data.path.append(target) + try: + # 尝试加载 punkt + logger.info(f'nltk模块预热') + nltk.data.find('tokenizers/punkt') + nltk.data.find('tokenizers/punkt_tab') + nltk.data.find('taggers/averaged_perceptron_tagger_eng') + logger.info('nltk模块预热完成(读取本地缓存)') + except: + # 如果找不到,则尝试下载 + try: + logger.info(f'模块预热: nltk punkt (从 Github 下载部分文件到 {target})') + from shared_utils.nltk_downloader import Downloader + _downloader = Downloader() + _downloader.download('punkt', download_dir=target) + _downloader.download('punkt_tab', download_dir=target) + _downloader.download('averaged_perceptron_tagger_eng', download_dir=target) + logger.info('nltk模块预热完成') + except Exception: + logger.exception('模块预热: nltk punkt 失败,可能需要手动安装 nltk punkt') + def warm_up_vectordb(): """ diff --git a/config.py b/config.py index 4af8ffd3..acc1e94a 100644 --- a/config.py +++ b/config.py @@ -43,7 +43,7 @@ AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024- "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo", "gemini-1.5-pro", "chatglm3", "chatglm4", - "deepseek-chat", "deepseek-coder", "deepseek-reasoner", + "deepseek-chat", "deepseek-coder", "deepseek-reasoner", "volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226", "dashscope-deepseek-r1", "dashscope-deepseek-v3", "dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b", @@ -94,19 +94,19 @@ AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gsta FONT = "Theme-Default-Font" AVAIL_FONTS = [ - "默认值(Theme-Default-Font)", - "宋体(SimSun)", - "黑体(SimHei)", - "楷体(KaiTi)", - "仿宋(FangSong)", + "默认值(Theme-Default-Font)", + "宋体(SimSun)", + "黑体(SimHei)", + "楷体(KaiTi)", + "仿宋(FangSong)", "华文细黑(STHeiti Light)", - "华文楷体(STKaiti)", - "华文仿宋(STFangsong)", - "华文宋体(STSong)", - "华文中宋(STZhongsong)", - "华文新魏(STXinwei)", - "华文隶书(STLiti)", - # 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)" + "华文楷体(STKaiti)", + "华文仿宋(STFangsong)", + "华文宋体(STSong)", + "华文中宋(STZhongsong)", + "华文新魏(STXinwei)", + "华文隶书(STLiti)", + # 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)" "思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)", "月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)", "珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)", @@ -355,6 +355,10 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran JINA_API_KEY = "" +# SEMANTIC SCHOLAR API KEY +SEMANTIC_SCHOLAR_KEY = "" + + # 是否自动裁剪上下文长度(是否启动,默认不启动) AUTO_CONTEXT_CLIP_ENABLE = False # 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪) diff --git a/crazy_functional.py b/crazy_functional.py index dde48786..0f0fbaf0 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -50,6 +50,9 @@ def get_crazy_functions(): from crazy_functions.SourceCode_Comment import 注释Python项目 from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap from crazy_functions.VideoResource_GPT import 多媒体任务 + from crazy_functions.Document_Conversation import 批量文件询问 + from crazy_functions.Document_Conversation_Wrap import Document_Conversation_Wrap + function_plugins = { "多媒体智能体": { @@ -378,7 +381,16 @@ def get_crazy_functions(): "Info": "PDF翻译中文,并重新编译PDF | 输入参数为路径", "Function": HotReload(PDF翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用 "Class": PDF_Localize # 新一代插件需要注册Class - } + }, + "批量文件询问 (支持自定义总结各种文件)": { + "Group": "学术", + "Color": "stop", + "AsButton": False, + "AdvancedArgs": False, + "Info": "先上传文件,点击此按钮,进行提问", + "Function": HotReload(批量文件询问), + "Class": Document_Conversation_Wrap, + }, } function_plugins.update( @@ -414,8 +426,6 @@ def get_crazy_functions(): - - # -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=- try: from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要 @@ -644,22 +654,21 @@ def get_crazy_functions(): logger.error(trimmed_format_exc()) logger.error("Load function plugin failed") - try: - from crazy_functions.多智能体 import 多智能体终端 - - function_plugins.update( - { - "AutoGen多智能体终端(仅供测试)": { - "Group": "智能体", - "Color": "stop", - "AsButton": False, - "Function": HotReload(多智能体终端), - } - } - ) - except: - logger.error(trimmed_format_exc()) - logger.error("Load function plugin failed") + # try: + # from crazy_functions.多智能体 import 多智能体终端 + # function_plugins.update( + # { + # "AutoGen多智能体终端(仅供测试)": { + # "Group": "智能体", + # "Color": "stop", + # "AsButton": False, + # "Function": HotReload(多智能体终端), + # } + # } + # ) + # except: + # logger.error(trimmed_format_exc()) + # logger.error("Load function plugin failed") try: from crazy_functions.互动小游戏 import 随机小游戏 @@ -696,6 +705,44 @@ def get_crazy_functions(): logger.error(trimmed_format_exc()) logger.error("Load function plugin failed") + # try: + # from crazy_functions.Document_Optimize import 自定义智能文档处理 + # function_plugins.update( + # { + # "一键处理文档(支持自定义全文润色、降重等)": { + # "Group": "学术", + # "Color": "stop", + # "AsButton": False, + # "AdvancedArgs": True, + # "ArgsReminder": "请输入处理指令和要求(可以详细描述),如:请帮我润色文本,要求幽默点。默认调用润色指令。", + # "Info": "保留文档结构,智能处理文档内容 | 输入参数为文件路径", + # "Function": HotReload(自定义智能文档处理) + # }, + # } + # ) + # except: + # logger.error(trimmed_format_exc()) + # logger.error("Load function plugin failed") + + + + # try: + # from crazy_functions.Paper_Reading import 快速论文解读 + # function_plugins.update( + # { + # "速读论文": { + # "Group": "学术", + # "Color": "stop", + # "AsButton": False, + # "Info": "上传一篇论文进行快速分析和解读 | 输入参数为论文路径或DOI/arXiv ID", + # "Function": HotReload(快速论文解读), + # }, + # } + # ) + # except: + # logger.error(trimmed_format_exc()) + # logger.error("Load function plugin failed") + # try: # from crazy_functions.高级功能函数模板 import 测试图表渲染 @@ -744,12 +791,12 @@ def get_multiplex_button_functions(): "查互联网后回答": "查互联网后回答", - "多模型对话": + "多模型对话": "询问多个GPT模型", # 映射到上面的 `询问多个GPT模型` 插件 - "智能召回 RAG": + "智能召回 RAG": "Rag智能召回", # 映射到上面的 `Rag智能召回` 插件 - "多媒体查询": + "多媒体查询": "多媒体智能体", # 映射到上面的 `多媒体智能体` 插件 } diff --git a/crazy_functions/Academic_Conversation.py b/crazy_functions/Academic_Conversation.py new file mode 100644 index 00000000..99be6a6a --- /dev/null +++ b/crazy_functions/Academic_Conversation.py @@ -0,0 +1,290 @@ +import re +import os +import asyncio +from typing import List, Dict, Tuple +from dataclasses import dataclass +from textwrap import dedent +from toolbox import CatchException, get_conf, update_ui, promote_file_to_downloadzone, get_log_folder, get_user +from toolbox import update_ui, CatchException, report_exception, write_history_to_file +from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource +from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource +from crazy_functions.review_fns.query_analyzer import QueryAnalyzer +from crazy_functions.review_fns.handlers.review_handler import 文献综述功能 +from crazy_functions.review_fns.handlers.recommend_handler import 论文推荐功能 +from crazy_functions.review_fns.handlers.qa_handler import 学术问答功能 +from crazy_functions.review_fns.handlers.paper_handler import 单篇论文分析功能 +from crazy_functions.Conversation_To_File import write_chat_to_file +from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive +from crazy_functions.review_fns.handlers.latest_handler import Arxiv最新论文推荐功能 +from datetime import datetime + +@CatchException +def 学术对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, user_request: str): + """主函数""" + + # 初始化数据源 + arxiv_source = ArxivSource() + semantic_source = SemanticScholarSource( + api_key=get_conf("SEMANTIC_SCHOLAR_KEY") + ) + + # 初始化处理器 + handlers = { + "review": 文献综述功能(arxiv_source, semantic_source, llm_kwargs), + "recommend": 论文推荐功能(arxiv_source, semantic_source, llm_kwargs), + "qa": 学术问答功能(arxiv_source, semantic_source, llm_kwargs), + "paper": 单篇论文分析功能(arxiv_source, semantic_source, llm_kwargs), + "latest": Arxiv最新论文推荐功能(arxiv_source, semantic_source, llm_kwargs), + } + + # 分析查询意图 + chatbot.append([None, "正在分析研究主题和查询要求..."]) + yield from update_ui(chatbot=chatbot, history=history) + + query_analyzer = QueryAnalyzer() + search_criteria = yield from query_analyzer.analyze_query(txt, chatbot, llm_kwargs) + handler = handlers.get(search_criteria.query_type) + if not handler: + handler = handlers["qa"] # 默认使用QA处理器 + + # 处理查询 + chatbot.append([None, f"使用{handler.__class__.__name__}处理...,可能需要您耐心等待3~5分钟..."]) + yield from update_ui(chatbot=chatbot, history=history) + + final_prompt = asyncio.run(handler.handle( + criteria=search_criteria, + chatbot=chatbot, + history=history, + system_prompt=system_prompt, + llm_kwargs=llm_kwargs, + plugin_kwargs=plugin_kwargs + )) + + if final_prompt: + # 检查是否是道歉提示 + if "很抱歉,我们未能找到" in final_prompt: + chatbot.append([txt, final_prompt]) + yield from update_ui(chatbot=chatbot, history=history) + return + # 在 final_prompt 末尾添加用户原始查询要求 + final_prompt += dedent(f""" + Original user query: "{txt}" + + IMPORTANT NOTE : + - Your response must directly address the user's original user query above + - While following the previous guidelines, prioritize answering what the user specifically asked + - Make sure your response format and content align with the user's expectations + - Do not translate paper titles, keep them in their original language + - Do not generate a reference list in your response - references will be handled separately + """) + + # 使用最终的prompt生成回答 + response = yield from request_gpt_model_in_new_thread_with_ui_alive( + inputs=final_prompt, + inputs_show_user=txt, + llm_kwargs=llm_kwargs, + chatbot=chatbot, + history=[], + sys_prompt=f"You are a helpful academic assistant. Response in Chinese by default unless specified language is required in the user's query." + ) + + # 1. 获取文献列表 + papers_list = handler.ranked_papers # 直接使用原始论文数据 + + # 在新的对话中添加格式化的参考文献列表 + if papers_list: + references = "" + for idx, paper in enumerate(papers_list, 1): + # 构建作者列表 + authors = paper.authors[:3] + if len(paper.authors) > 3: + authors.append("et al.") + authors_str = ", ".join(authors) + + # 构建期刊指标信息 + metrics = [] + if hasattr(paper, 'if_factor') and paper.if_factor: + metrics.append(f"IF: {paper.if_factor}") + if hasattr(paper, 'jcr_division') and paper.jcr_division: + metrics.append(f"JCR: {paper.jcr_division}") + if hasattr(paper, 'cas_division') and paper.cas_division: + metrics.append(f"中科院分区: {paper.cas_division}") + metrics_str = f" [{', '.join(metrics)}]" if metrics else "" + + # 构建DOI链接 + doi_link = "" + if paper.doi: + if "arxiv.org" in str(paper.doi): + doi_url = paper.doi + else: + doi_url = f"https://doi.org/{paper.doi}" + doi_link = f" DOI: {paper.doi}" + + # 构建完整的引用 + reference = f"[{idx}] {authors_str}. *{paper.title}*" + if paper.venue_name: + reference += f". {paper.venue_name}" + if paper.year: + reference += f", {paper.year}" + reference += metrics_str + if doi_link: + reference += f".{doi_link}" + reference += " \n" + + references += reference + + # 添加新的对话显示参考文献 + chatbot.append(["参考文献如下:", references]) + yield from update_ui(chatbot=chatbot, history=history) + + + # 2. 保存为不同格式 + from .review_fns.conversation_doc.word_doc import WordFormatter + from .review_fns.conversation_doc.word2pdf import WordToPdfConverter + from .review_fns.conversation_doc.markdown_doc import MarkdownFormatter + from .review_fns.conversation_doc.html_doc import HtmlFormatter + + # 创建保存目录 + save_dir = get_log_folder(get_user(chatbot), plugin_name='chatscholar') + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # 生成文件名 + def get_safe_filename(txt, max_length=10): + # 获取文本前max_length个字符作为文件名 + filename = txt[:max_length].strip() + # 移除不安全的文件名字符 + filename = re.sub(r'[\\/:*?"<>|]', '', filename) + # 如果文件名为空,使用时间戳 + if not filename: + filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + return filename + + base_filename = get_safe_filename(txt) + + result_files = [] # 收集所有生成的文件 + pdf_path = None # 用于跟踪PDF是否成功生成 + + # 保存为Markdown + try: + md_formatter = MarkdownFormatter() + md_content = md_formatter.create_document(txt, response, papers_list) + result_file_md = write_history_to_file( + history=[md_content], + file_basename=f"markdown_{base_filename}.md" + ) + result_files.append(result_file_md) + except Exception as e: + print(f"Markdown保存失败: {str(e)}") + + # 保存为HTML + try: + html_formatter = HtmlFormatter() + html_content = html_formatter.create_document(txt, response, papers_list) + result_file_html = write_history_to_file( + history=[html_content], + file_basename=f"html_{base_filename}.html" + ) + result_files.append(result_file_html) + except Exception as e: + print(f"HTML保存失败: {str(e)}") + + # 保存为Word + try: + word_formatter = WordFormatter() + try: + doc = word_formatter.create_document(txt, response, papers_list) + except Exception as e: + print(f"Word文档内容生成失败: {str(e)}") + raise e + + try: + result_file_docx = os.path.join( + os.path.dirname(result_file_md) if result_file_md else save_dir, + f"docx_{base_filename}.docx" + ) + doc.save(result_file_docx) + result_files.append(result_file_docx) + print(f"Word文档已保存到: {result_file_docx}") + + # 转换为PDF + try: + pdf_path = WordToPdfConverter.convert_to_pdf(result_file_docx) + if pdf_path: + result_files.append(pdf_path) + print(f"PDF文档已生成: {pdf_path}") + except Exception as e: + print(f"PDF转换失败: {str(e)}") + + except Exception as e: + print(f"Word文档保存失败: {str(e)}") + raise e + + except Exception as e: + print(f"Word格式化失败: {str(e)}") + import traceback + print(f"详细错误信息: {traceback.format_exc()}") + + # 保存为BibTeX格式 + try: + from .review_fns.conversation_doc.reference_formatter import ReferenceFormatter + ref_formatter = ReferenceFormatter() + bibtex_content = ref_formatter.create_document(papers_list) + + # 在与其他文件相同目录下创建BibTeX文件 + result_file_bib = os.path.join( + os.path.dirname(result_file_md) if result_file_md else save_dir, + f"references_{base_filename}.bib" + ) + + # 直接写入文件 + with open(result_file_bib, 'w', encoding='utf-8') as f: + f.write(bibtex_content) + + result_files.append(result_file_bib) + print(f"BibTeX文件已保存到: {result_file_bib}") + except Exception as e: + print(f"BibTeX格式保存失败: {str(e)}") + + # 保存为EndNote格式 + try: + from .review_fns.conversation_doc.endnote_doc import EndNoteFormatter + endnote_formatter = EndNoteFormatter() + endnote_content = endnote_formatter.create_document(papers_list) + + # 在与其他文件相同目录下创建EndNote文件 + result_file_enw = os.path.join( + os.path.dirname(result_file_md) if result_file_md else save_dir, + f"references_{base_filename}.enw" + ) + + # 直接写入文件 + with open(result_file_enw, 'w', encoding='utf-8') as f: + f.write(endnote_content) + + result_files.append(result_file_enw) + print(f"EndNote文件已保存到: {result_file_enw}") + except Exception as e: + print(f"EndNote格式保存失败: {str(e)}") + + # 添加所有文件到下载区 + success_files = [] + for file in result_files: + try: + promote_file_to_downloadzone(file, chatbot=chatbot) + success_files.append(os.path.basename(file)) + except Exception as e: + print(f"文件添加到下载区失败: {str(e)}") + + # 更新成功提示消息 + if success_files: + chatbot.append(["保存对话记录成功,bib和enw文件支持导入到EndNote、Zotero、JabRef、Mendeley等文献管理软件,HTML文件支持在浏览器中打开,里面包含详细论文源信息", "对话已保存并添加到下载区,可以在下载区找到相关文件"]) + else: + chatbot.append(["保存对话记录", "所有格式的保存都失败了,请检查错误日志。"]) + + yield from update_ui(chatbot=chatbot, history=history) + else: + report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询") + yield from update_ui(chatbot=chatbot, history=history) \ No newline at end of file diff --git a/crazy_functions/Document_Conversation.py b/crazy_functions/Document_Conversation.py new file mode 100644 index 00000000..abc800e9 --- /dev/null +++ b/crazy_functions/Document_Conversation.py @@ -0,0 +1,537 @@ +import os +import threading +import time +from dataclasses import dataclass +from typing import List, Tuple, Dict, Generator +from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency +from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit +from crazy_functions.rag_fns.rag_file_support import extract_text +from request_llms.bridge_all import model_info +from toolbox import update_ui, CatchException, report_exception +from shared_utils.fastapi_server import validate_path_safety + + +@dataclass +class FileFragment: + """文件片段数据类,用于组织处理单元""" + file_path: str + content: str + rel_path: str + fragment_index: int + total_fragments: int + + +class BatchDocumentSummarizer: + """优化的文档总结器 - 批处理版本""" + + def __init__(self, llm_kwargs: Dict, query: str, chatbot: List, history: List, system_prompt: str): + """初始化总结器""" + self.llm_kwargs = llm_kwargs + self.query = query + self.chatbot = chatbot + self.history = history + self.system_prompt = system_prompt + self.failed_files = [] + self.file_summaries_map = {} + + def _get_token_limit(self) -> int: + """获取模型token限制""" + max_token = model_info[self.llm_kwargs['llm_model']]['max_token'] + return max_token * 3 // 4 + + def _create_batch_inputs(self, fragments: List[FileFragment]) -> Tuple[List, List, List]: + """创建批处理输入""" + inputs_array = [] + inputs_show_user_array = [] + history_array = [] + + for frag in fragments: + if self.query: + i_say = (f'请按照用户要求对文件内容进行处理,文件名为{os.path.basename(frag.file_path)},' + f'用户要求为:{self.query}:' + f'文件内容是 ```{frag.content}```') + i_say_show_user = (f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})') + else: + i_say = (f'请对下面的内容用中文做总结,不超过500字,文件名是{os.path.basename(frag.file_path)},' + f'内容是 ```{frag.content}```') + i_say_show_user = f'正在处理 {frag.rel_path} (片段 {frag.fragment_index + 1}/{frag.total_fragments})' + + inputs_array.append(i_say) + inputs_show_user_array.append(i_say_show_user) + history_array.append([]) + + return inputs_array, inputs_show_user_array, history_array + + def _process_single_file_with_timeout(self, file_info: Tuple[str, str], mutable_status: List) -> List[FileFragment]: + """包装了超时控制的文件处理函数""" + + def timeout_handler(): + thread = threading.current_thread() + if hasattr(thread, '_timeout_occurred'): + thread._timeout_occurred = True + + # 设置超时标记 + thread = threading.current_thread() + thread._timeout_occurred = False + + # 设置超时时间为30秒,给予更多处理时间 + TIMEOUT_SECONDS = 30 + timer = threading.Timer(TIMEOUT_SECONDS, timeout_handler) + timer.start() + + try: + fp, project_folder = file_info + fragments = [] + + # 定期检查是否超时 + def check_timeout(): + if hasattr(thread, '_timeout_occurred') and thread._timeout_occurred: + raise TimeoutError(f"处理文件 {os.path.basename(fp)} 超时({TIMEOUT_SECONDS}秒)") + + # 更新状态 + mutable_status[0] = "检查文件大小" + mutable_status[1] = time.time() + check_timeout() + + # 文件大小检查 + if os.path.getsize(fp) > self.max_file_size: + self.failed_files.append((fp, f"文件过大:超过{self.max_file_size / 1024 / 1024}MB")) + mutable_status[2] = "文件过大" + return fragments + + # 更新状态 + mutable_status[0] = "提取文件内容" + mutable_status[1] = time.time() + + # 提取内容 - 使用单独的超时控制 + content = None + extract_start_time = time.time() + try: + while True: + check_timeout() # 检查全局超时 + + # 检查提取过程是否超时(10秒) + if time.time() - extract_start_time > 10: + raise TimeoutError("文件内容提取超时(10秒)") + + try: + content = extract_text(fp) + break + except Exception as e: + if "timeout" in str(e).lower(): + continue # 如果是临时超时,重试 + raise # 其他错误直接抛出 + + except Exception as e: + self.failed_files.append((fp, f"文件读取失败:{str(e)}")) + mutable_status[2] = "读取失败" + return fragments + + if content is None: + self.failed_files.append((fp, "文件解析失败:不支持的格式或文件损坏")) + mutable_status[2] = "格式不支持" + return fragments + elif not content.strip(): + self.failed_files.append((fp, "文件内容为空")) + mutable_status[2] = "内容为空" + return fragments + + check_timeout() + + # 更新状态 + mutable_status[0] = "分割文本" + mutable_status[1] = time.time() + + # 分割文本 - 添加超时检查 + split_start_time = time.time() + try: + while True: + check_timeout() # 检查全局超时 + + # 检查分割过程是否超时(5秒) + if time.time() - split_start_time > 5: + raise TimeoutError("文本分割超时(5秒)") + + paper_fragments = breakdown_text_to_satisfy_token_limit( + txt=content, + limit=self._get_token_limit(), + llm_model=self.llm_kwargs['llm_model'] + ) + break + + except Exception as e: + self.failed_files.append((fp, f"文本分割失败:{str(e)}")) + mutable_status[2] = "分割失败" + return fragments + + # 处理片段 + rel_path = os.path.relpath(fp, project_folder) + for i, frag in enumerate(paper_fragments): + check_timeout() # 每处理一个片段检查一次超时 + if frag.strip(): + fragments.append(FileFragment( + file_path=fp, + content=frag, + rel_path=rel_path, + fragment_index=i, + total_fragments=len(paper_fragments) + )) + + mutable_status[2] = "处理完成" + return fragments + + except TimeoutError as e: + self.failed_files.append((fp, str(e))) + mutable_status[2] = "处理超时" + return [] + except Exception as e: + self.failed_files.append((fp, f"处理失败:{str(e)}")) + mutable_status[2] = "处理异常" + return [] + finally: + timer.cancel() + + def prepare_fragments(self, project_folder: str, file_paths: List[str]) -> Generator: + import concurrent.futures + + from concurrent.futures import ThreadPoolExecutor + from typing import Generator, List + """并行准备所有文件的处理片段""" + all_fragments = [] + total_files = len(file_paths) + + # 配置参数 + self.refresh_interval = 0.2 # UI刷新间隔 + self.watch_dog_patience = 5 # 看门狗超时时间 + self.max_file_size = 10 * 1024 * 1024 # 10MB限制 + self.max_workers = min(32, len(file_paths)) # 最多32个线程 + + # 创建有超时控制的线程池 + executor = ThreadPoolExecutor(max_workers=self.max_workers) + + # 用于跨线程状态传递的可变列表 - 增加文件名信息 + mutable_status_array = [["等待中", time.time(), "pending", file_path] for file_path in file_paths] + + # 创建文件处理任务 + file_infos = [(fp, project_folder) for fp in file_paths] + + # 提交所有任务,使用带超时控制的处理函数 + futures = [ + executor.submit( + self._process_single_file_with_timeout, + file_info, + mutable_status_array[i] + ) for i, file_info in enumerate(file_infos) + ] + + # 更新UI的计数器 + cnt = 0 + + try: + # 监控任务执行 + while True: + time.sleep(self.refresh_interval) + cnt += 1 + + # 检查任务完成状态 + worker_done = [f.done() for f in futures] + + # 更新状态显示 + status_str = "" + for i, (status, timestamp, desc, file_path) in enumerate(mutable_status_array): + # 获取文件名(去掉路径) + file_name = os.path.basename(file_path) + if worker_done[i]: + status_str += f"文件 {file_name}: {desc}\n\n" + else: + status_str += f"文件 {file_name}: {status} {desc}\n\n" + + # 更新UI + self.chatbot[-1] = [ + "处理进度", + f"正在处理文件...\n\n{status_str}" + "." * (cnt % 10 + 1) + ] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 检查是否所有任务完成 + if all(worker_done): + break + + finally: + # 确保线程池正确关闭 + executor.shutdown(wait=False) + + # 收集结果 + processed_files = 0 + for future in futures: + try: + fragments = future.result(timeout=0.1) # 给予一个短暂的超时时间来获取结果 + all_fragments.extend(fragments) + processed_files += 1 + except concurrent.futures.TimeoutError: + # 处理获取结果超时 + file_index = futures.index(future) + self.failed_files.append((file_paths[file_index], "结果获取超时")) + continue + except Exception as e: + # 处理其他异常 + file_index = futures.index(future) + self.failed_files.append((file_paths[file_index], f"未知错误:{str(e)}")) + continue + + # 最终进度更新 + self.chatbot.append([ + "文件处理完成", + f"成功处理 {len(all_fragments)} 个片段,失败 {len(self.failed_files)} 个文件" + ]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + return all_fragments + + def _process_fragments_batch(self, fragments: List[FileFragment]) -> Generator: + """批量处理文件片段""" + from collections import defaultdict + batch_size = 64 # 每批处理的片段数 + max_retries = 3 # 最大重试次数 + retry_delay = 5 # 重试延迟(秒) + results = defaultdict(list) + + # 按批次处理 + for i in range(0, len(fragments), batch_size): + batch = fragments[i:i + batch_size] + + inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch) + sys_prompt_array = ["请总结以下内容:"] * len(batch) + + # 添加重试机制 + for retry in range(max_retries): + try: + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=inputs_array, + inputs_show_user_array=inputs_show_user_array, + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=history_array, + sys_prompt_array=sys_prompt_array, + ) + + # 处理响应 + for j, frag in enumerate(batch): + summary = response_collection[j * 2 + 1] + if summary and summary.strip(): + results[frag.rel_path].append({ + 'index': frag.fragment_index, + 'summary': summary, + 'total': frag.total_fragments + }) + break # 成功处理,跳出重试循环 + + except Exception as e: + if retry == max_retries - 1: # 最后一次重试失败 + for frag in batch: + self.failed_files.append((frag.file_path, f"处理失败:{str(e)}")) + else: + yield from update_ui(self.chatbot.append([f"批次处理失败,{retry_delay}秒后重试...", str(e)])) + time.sleep(retry_delay) + + return results + + def _generate_final_summary_request(self) -> Tuple[List, List, List]: + """准备最终总结请求""" + if not self.file_summaries_map: + return (["无可用的文件总结"], ["生成最终总结"], [[]]) + + summaries = list(self.file_summaries_map.values()) + if all(not summary for summary in summaries): + return (["所有文件处理均失败"], ["生成最终总结"], [[]]) + + if self.plugin_kwargs.get("advanced_arg"): + i_say = "根据以上所有文件的处理结果,按要求进行综合处理:" + self.plugin_kwargs['advanced_arg'] + else: + i_say = "请根据以上所有文件的处理结果,生成最终的总结,不超过1000字。" + + return ([i_say], [i_say], [summaries]) + + def process_files(self, project_folder: str, file_paths: List[str]) -> Generator: + """处理所有文件""" + total_files = len(file_paths) + self.chatbot.append([f"开始处理", f"总计 {total_files} 个文件"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 1. 准备所有文件片段 + # 在 process_files 函数中: + fragments = yield from self.prepare_fragments(project_folder, file_paths) + if not fragments: + self.chatbot.append(["处理失败", "没有可处理的文件内容"]) + return "没有可处理的文件内容" + + # 2. 批量处理所有文件片段 + self.chatbot.append([f"文件分析", f"共计 {len(fragments)} 个处理单元"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + try: + file_summaries = yield from self._process_fragments_batch(fragments) + except Exception as e: + self.chatbot.append(["处理错误", f"批处理过程失败:{str(e)}"]) + return "处理过程发生错误" + + # 3. 为每个文件生成整体总结 + self.chatbot.append(["生成总结", "正在汇总文件内容..."]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 处理每个文件的总结 + for rel_path, summaries in file_summaries.items(): + if len(summaries) > 1: # 多片段文件需要生成整体总结 + sorted_summaries = sorted(summaries, key=lambda x: x['index']) + if self.plugin_kwargs.get("advanced_arg"): + + i_say = f'请按照用户要求对文件内容进行处理,用户要求为:{self.plugin_kwargs["advanced_arg"]}:' + else: + i_say = f"请总结文件 {os.path.basename(rel_path)} 的主要内容,不超过500字。" + + try: + summary_texts = [s['summary'] for s in sorted_summaries] + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=[i_say], + inputs_show_user_array=[f"生成 {rel_path} 的处理结果"], + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=[summary_texts], + sys_prompt_array=["你是一个优秀的助手,"], + ) + self.file_summaries_map[rel_path] = response_collection[1] + except Exception as e: + self.chatbot.append(["警告", f"文件 {rel_path} 总结生成失败:{str(e)}"]) + self.file_summaries_map[rel_path] = "总结生成失败" + else: # 单片段文件直接使用其唯一的总结 + self.file_summaries_map[rel_path] = summaries[0]['summary'] + + # 4. 生成最终总结 + if total_files == 1: + return "文件数为1,此时不调用总结模块" + else: + try: + # 收集所有文件的总结用于生成最终总结 + file_summaries_for_final = [] + for rel_path, summary in self.file_summaries_map.items(): + file_summaries_for_final.append(f"文件 {rel_path} 的总结:\n{summary}") + + if self.plugin_kwargs.get("advanced_arg"): + final_summary_prompt = ("根据以下所有文件的总结内容,按要求进行综合处理:" + + self.plugin_kwargs['advanced_arg']) + else: + final_summary_prompt = "请根据以下所有文件的总结内容,生成最终的总结报告。" + + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=[final_summary_prompt], + inputs_show_user_array=["生成最终总结报告"], + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=[file_summaries_for_final], + sys_prompt_array=["总结所有文件内容。"], + max_workers=1 + ) + + return response_collection[1] if len(response_collection) > 1 else "生成总结失败" + except Exception as e: + self.chatbot.append(["错误", f"最终总结生成失败:{str(e)}"]) + return "生成总结失败" + + def save_results(self, final_summary: str): + """保存结果到文件""" + from toolbox import promote_file_to_downloadzone, write_history_to_file + from crazy_functions.doc_fns.batch_file_query_doc import MarkdownFormatter, HtmlFormatter, WordFormatter + import os + timestamp = time.strftime("%Y%m%d_%H%M%S") + + # 创建各种格式化器 + md_formatter = MarkdownFormatter(final_summary, self.file_summaries_map, self.failed_files) + html_formatter = HtmlFormatter(final_summary, self.file_summaries_map, self.failed_files) + word_formatter = WordFormatter(final_summary, self.file_summaries_map, self.failed_files) + + result_files = [] + + # 保存 Markdown + try: + md_content = md_formatter.create_document() + result_file_md = write_history_to_file( + history=[md_content], # 直接传入内容列表 + file_basename=f"文档总结_{timestamp}.md" + ) + result_files.append(result_file_md) + except: + pass + + # 保存 HTML + try: + html_content = html_formatter.create_document() + result_file_html = write_history_to_file( + history=[html_content], + file_basename=f"文档总结_{timestamp}.html" + ) + result_files.append(result_file_html) + except: + pass + + # 保存 Word + try: + doc = word_formatter.create_document() + # 由于 Word 文档需要用 doc.save(),我们使用与 md 文件相同的目录 + result_file_docx = os.path.join( + os.path.dirname(result_file_md), + f"文档总结_{timestamp}.docx" + ) + doc.save(result_file_docx) + result_files.append(result_file_docx) + except: + pass + + # 添加到下载区 + for file in result_files: + promote_file_to_downloadzone(file, chatbot=self.chatbot) + + self.chatbot.append(["处理完成", f"结果已保存至: {', '.join(result_files)}"]) + + +@CatchException +def 批量文件询问(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, user_request: str): + """主函数 - 优化版本""" + # 初始化 + import glob + import re + from crazy_functions.rag_fns.rag_file_support import supports_format + from toolbox import report_exception + query = plugin_kwargs.get("advanced_arg") + summarizer = BatchDocumentSummarizer(llm_kwargs, query, chatbot, history, system_prompt) + chatbot.append(["函数插件功能", f"作者:lbykkkk,批量总结文件。支持格式: {', '.join(supports_format)}等其他文本格式文件,如果长时间卡在文件处理过程,请查看处理进度,然后删除所有处于“pending”状态的文件,然后重新上传处理。"]) + yield from update_ui(chatbot=chatbot, history=history) + + # 验证输入路径 + if not os.path.exists(txt): + report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到项目或无权访问: {txt}") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 获取文件列表 + project_folder = txt + user_name = chatbot.get_user() + validate_path_safety(project_folder, user_name) + extract_folder = next((d for d in glob.glob(f'{project_folder}/*') + if os.path.isdir(d) and d.endswith('.extract')), project_folder) + exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$' + file_manifest = [f for f in glob.glob(f'{extract_folder}/**', recursive=True) + if os.path.isfile(f) and not re.search(exclude_patterns, f)] + + if not file_manifest: + report_exception(chatbot, history, a=f"解析项目: {txt}", b="未找到支持的文件类型") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 处理所有文件并生成总结 + final_summary = yield from summarizer.process_files(project_folder, file_manifest) + yield from update_ui(chatbot=chatbot, history=history) + + # 保存结果 + summarizer.save_results(final_summary) + yield from update_ui(chatbot=chatbot, history=history) \ No newline at end of file diff --git a/crazy_functions/Document_Conversation_Wrap.py b/crazy_functions/Document_Conversation_Wrap.py new file mode 100644 index 00000000..6555bc19 --- /dev/null +++ b/crazy_functions/Document_Conversation_Wrap.py @@ -0,0 +1,36 @@ +import random +from toolbox import get_conf +from crazy_functions.Document_Conversation import 批量文件询问 +from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty + + +class Document_Conversation_Wrap(GptAcademicPluginTemplate): + def __init__(self): + """ + 请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎! + """ + pass + + def define_arg_selection_menu(self): + """ + 定义插件的二级选项菜单 + + 第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值; + 第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值; + 第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options`,`default_value`为下拉菜单默认值; + + """ + gui_definition = { + "main_input": + ArgProperty(title="已上传的文件", description="上传文件后自动填充", default_value="", type="string").model_dump_json(), + "searxng_url": + ArgProperty(title="对材料提问", description="提问", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步 + } + return gui_definition + + def execute(txt, llm_kwargs, plugin_kwargs:dict, chatbot, history, system_prompt, user_request): + """ + 执行插件 + """ + yield from 批量文件询问(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request) + diff --git a/crazy_functions/Document_Optimize.py b/crazy_functions/Document_Optimize.py new file mode 100644 index 00000000..c7728afe --- /dev/null +++ b/crazy_functions/Document_Optimize.py @@ -0,0 +1,673 @@ +import os +import time +import glob +import re +import threading +from typing import Dict, List, Generator, Tuple +from dataclasses import dataclass + +from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency +from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit +from crazy_functions.rag_fns.rag_file_support import extract_text, supports_format, convert_to_markdown +from request_llms.bridge_all import model_info +from toolbox import update_ui, CatchException, report_exception, promote_file_to_downloadzone, write_history_to_file +from shared_utils.fastapi_server import validate_path_safety + +# 新增:导入结构化论文提取器 +from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import PaperStructureExtractor, ExtractorConfig, StructuredPaper + +# 导入格式化器 +from crazy_functions.paper_fns.file2file_doc import ( + TxtFormatter, + MarkdownFormatter, + HtmlFormatter, + WordFormatter +) + +@dataclass +class TextFragment: + """文本片段数据类,用于组织处理单元""" + content: str + fragment_index: int + total_fragments: int + + +class DocumentProcessor: + """文档处理器 - 处理单个文档并输出结果""" + + def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str): + """初始化处理器""" + self.llm_kwargs = llm_kwargs + self.plugin_kwargs = plugin_kwargs + self.chatbot = chatbot + self.history = history + self.system_prompt = system_prompt + self.processed_results = [] + self.failed_fragments = [] + # 新增:初始化论文结构提取器 + self.paper_extractor = PaperStructureExtractor() + + def _get_token_limit(self) -> int: + """获取模型token限制,返回更小的值以确保更细粒度的分割""" + max_token = model_info[self.llm_kwargs['llm_model']]['max_token'] + # 降低token限制,使每个片段更小 + return max_token // 4 # 从3/4降低到1/4 + + def _create_batch_inputs(self, fragments: List[TextFragment]) -> Tuple[List, List, List]: + """创建批处理输入""" + inputs_array = [] + inputs_show_user_array = [] + history_array = [] + + user_instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点") + + for frag in fragments: + i_say = (f'请按照以下要求处理文本内容:{user_instruction}\n\n' + f'请将对文本的处理结果放在标签之间。\n\n' + f'文本内容:\n```\n{frag.content}\n```') + + i_say_show_user = f'正在处理文本片段 {frag.fragment_index + 1}/{frag.total_fragments}' + + inputs_array.append(i_say) + inputs_show_user_array.append(i_say_show_user) + history_array.append([]) + + return inputs_array, inputs_show_user_array, history_array + + def _extract_decision(self, text: str) -> str: + """从LLM响应中提取标签内的内容""" + import re + pattern = r'(.*?)' + matches = re.findall(pattern, text, re.DOTALL) + + if matches: + return matches[0].strip() + else: + # 如果没有找到标签,返回原始文本 + return text.strip() + + def process_file(self, file_path: str) -> Generator: + """处理单个文件""" + self.chatbot.append(["开始处理文件", f"文件路径: {file_path}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + try: + # 首先尝试转换为Markdown + from crazy_functions.rag_fns.rag_file_support import convert_to_markdown + file_path = convert_to_markdown(file_path) + + # 1. 检查文件是否为支持的论文格式 + is_paper_format = any(file_path.lower().endswith(ext) for ext in self.paper_extractor.SUPPORTED_EXTENSIONS) + + if is_paper_format: + # 使用结构化提取器处理论文 + return (yield from self._process_structured_paper(file_path)) + else: + # 使用原有方式处理普通文档 + return (yield from self._process_regular_file(file_path)) + + except Exception as e: + self.chatbot.append(["处理错误", f"文件处理失败: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return None + + def _process_structured_paper(self, file_path: str) -> Generator: + """处理结构化论文文件""" + # 1. 提取论文结构 + self.chatbot[-1] = ["正在分析论文结构", f"文件路径: {file_path}"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + try: + paper = self.paper_extractor.extract_paper_structure(file_path) + + if not paper or not paper.sections: + self.chatbot.append(["无法提取论文结构", "将使用全文内容进行处理"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 使用全文内容进行段落切分 + if paper and paper.full_text: + # 使用增强的分割函数进行更细致的分割 + fragments = self._breakdown_section_content(paper.full_text) + + # 创建文本片段对象 + text_fragments = [] + for i, frag in enumerate(fragments): + if frag.strip(): + text_fragments.append(TextFragment( + content=frag, + fragment_index=i, + total_fragments=len(fragments) + )) + + # 批量处理片段 + if text_fragments: + self.chatbot[-1] = ["开始处理文本", f"共 {len(text_fragments)} 个片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 一次性准备所有输入 + inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(text_fragments) + + # 使用系统提示 + instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点") + sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(text_fragments) + + # 调用LLM一次性处理所有片段 + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=inputs_array, + inputs_show_user_array=inputs_show_user_array, + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=history_array, + sys_prompt_array=sys_prompt_array, + ) + + # 处理响应 + for j, frag in enumerate(text_fragments): + try: + llm_response = response_collection[j * 2 + 1] + processed_text = self._extract_decision(llm_response) + + if processed_text and processed_text.strip(): + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': processed_text + }) + else: + self.failed_fragments.append(frag) + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': frag.content + }) + except Exception as e: + self.failed_fragments.append(frag) + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': frag.content + }) + + # 按原始顺序合并结果 + self.processed_results.sort(key=lambda x: x['index']) + final_content = "\n".join([item['content'] for item in self.processed_results]) + + # 更新UI + success_count = len(text_fragments) - len(self.failed_fragments) + self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + return final_content + else: + self.chatbot.append(["处理失败", "未能提取到有效的文本内容"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return None + else: + self.chatbot.append(["处理失败", "未能提取到论文内容"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return None + + # 2. 准备处理章节内容(不处理标题) + self.chatbot[-1] = ["已提取论文结构", f"共 {len(paper.sections)} 个主要章节"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 3. 收集所有需要处理的章节内容并分割为合适大小 + sections_to_process = [] + section_map = {} # 用于映射处理前后的内容 + + def collect_section_contents(sections, parent_path=""): + """递归收集章节内容,跳过参考文献部分""" + for i, section in enumerate(sections): + current_path = f"{parent_path}/{i}" if parent_path else f"{i}" + + # 检查是否为参考文献部分,如果是则跳过 + if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']: + continue # 跳过参考文献部分 + + # 只处理内容非空的章节 + if section.content and section.content.strip(): + # 使用增强的分割函数进行更细致的分割 + fragments = self._breakdown_section_content(section.content) + + for fragment_idx, fragment_content in enumerate(fragments): + if fragment_content.strip(): + fragment_index = len(sections_to_process) + sections_to_process.append(TextFragment( + content=fragment_content, + fragment_index=fragment_index, + total_fragments=0 # 临时值,稍后更新 + )) + + # 保存映射关系,用于稍后更新章节内容 + # 为每个片段存储原始章节和片段索引信息 + section_map[fragment_index] = (current_path, section, fragment_idx, len(fragments)) + + # 递归处理子章节 + if section.subsections: + collect_section_contents(section.subsections, current_path) + + # 收集所有章节内容 + collect_section_contents(paper.sections) + + # 更新总片段数 + total_fragments = len(sections_to_process) + for frag in sections_to_process: + frag.total_fragments = total_fragments + + # 4. 如果没有内容需要处理,直接返回 + if not sections_to_process: + self.chatbot.append(["处理完成", "未找到需要处理的内容"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return None + + # 5. 批量处理章节内容 + self.chatbot[-1] = ["开始处理论文内容", f"共 {len(sections_to_process)} 个内容片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 一次性准备所有输入 + inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(sections_to_process) + + # 使用系统提示 + instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下学术文本,提高其语言表达的准确性、专业性和流畅度,保持学术风格,确保逻辑连贯,但不改变原文的科学内容和核心观点") + sys_prompt_array = [f"你是一个专业的学术文献编辑助手。请按照用户的要求:'{instruction}'处理文本。保持学术风格,增强表达的准确性和专业性。"] * len(sections_to_process) + + # 调用LLM一次性处理所有片段 + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=inputs_array, + inputs_show_user_array=inputs_show_user_array, + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=history_array, + sys_prompt_array=sys_prompt_array, + ) + + # 处理响应,重组章节内容 + section_contents = {} # 用于重组各章节的处理后内容 + + for j, frag in enumerate(sections_to_process): + try: + llm_response = response_collection[j * 2 + 1] + processed_text = self._extract_decision(llm_response) + + if processed_text and processed_text.strip(): + # 保存处理结果 + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': processed_text + }) + + # 存储处理后的文本片段,用于后续重组 + fragment_index = frag.fragment_index + if fragment_index in section_map: + path, section, fragment_idx, total_fragments = section_map[fragment_index] + + # 初始化此章节的内容容器(如果尚未创建) + if path not in section_contents: + section_contents[path] = [""] * total_fragments + + # 将处理后的片段放入正确位置 + section_contents[path][fragment_idx] = processed_text + else: + self.failed_fragments.append(frag) + except Exception as e: + self.failed_fragments.append(frag) + + # 重组每个章节的内容 + for path, fragments in section_contents.items(): + section = None + for idx in section_map: + if section_map[idx][0] == path: + section = section_map[idx][1] + break + + if section: + # 合并该章节的所有处理后片段 + section.content = "\n".join(fragments) + + # 6. 更新UI + success_count = total_fragments - len(self.failed_fragments) + self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 收集参考文献部分(不进行处理) + references_sections = [] + def collect_references(sections, parent_path=""): + """递归收集参考文献部分""" + for i, section in enumerate(sections): + current_path = f"{parent_path}/{i}" if parent_path else f"{i}" + + # 检查是否为参考文献部分 + if section.section_type == 'references' or section.title.lower() in ['references', '参考文献', 'bibliography', '文献']: + references_sections.append((current_path, section)) + + # 递归检查子章节 + if section.subsections: + collect_references(section.subsections, current_path) + + # 收集参考文献 + collect_references(paper.sections) + + # 7. 将处理后的结构化论文转换为Markdown + markdown_content = self.paper_extractor.generate_markdown(paper) + + # 8. 返回处理后的内容 + self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{total_fragments} 个内容片段,参考文献部分未处理"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + return markdown_content + + except Exception as e: + self.chatbot.append(["结构化处理失败", f"错误: {str(e)},将尝试作为普通文件处理"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return (yield from self._process_regular_file(file_path)) + + def _process_regular_file(self, file_path: str) -> Generator: + """使用原有方式处理普通文件""" + # 原有的文件处理逻辑 + self.chatbot[-1] = ["正在读取文件", f"文件路径: {file_path}"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + content = extract_text(file_path) + if not content or not content.strip(): + self.chatbot.append(["处理失败", "文件内容为空或无法提取内容"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return None + + # 2. 分割文本 + self.chatbot[-1] = ["正在分析文件", "将文件内容分割为适当大小的片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 使用增强的分割函数 + fragments = self._breakdown_section_content(content) + + # 3. 创建文本片段对象 + text_fragments = [] + for i, frag in enumerate(fragments): + if frag.strip(): + text_fragments.append(TextFragment( + content=frag, + fragment_index=i, + total_fragments=len(fragments) + )) + + # 4. 处理所有片段 + self.chatbot[-1] = ["开始处理文本", f"共 {len(text_fragments)} 个片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 批量处理片段 + batch_size = 8 # 每批处理的片段数 + for i in range(0, len(text_fragments), batch_size): + batch = text_fragments[i:i + batch_size] + + inputs_array, inputs_show_user_array, history_array = self._create_batch_inputs(batch) + + # 使用系统提示 + instruction = self.plugin_kwargs.get("advanced_arg", "请润色以下文本") + sys_prompt_array = [f"你是一个专业的文本处理助手。请按照用户的要求:'{instruction}'处理文本。"] * len(batch) + + # 调用LLM处理 + response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency( + inputs_array=inputs_array, + inputs_show_user_array=inputs_show_user_array, + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history_array=history_array, + sys_prompt_array=sys_prompt_array, + ) + + # 处理响应 + for j, frag in enumerate(batch): + try: + llm_response = response_collection[j * 2 + 1] + processed_text = self._extract_decision(llm_response) + + if processed_text and processed_text.strip(): + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': processed_text + }) + else: + self.failed_fragments.append(frag) + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': frag.content # 如果处理失败,使用原始内容 + }) + except Exception as e: + self.failed_fragments.append(frag) + self.processed_results.append({ + 'index': frag.fragment_index, + 'content': frag.content # 如果处理失败,使用原始内容 + }) + + # 5. 按原始顺序合并结果 + self.processed_results.sort(key=lambda x: x['index']) + final_content = "\n".join([item['content'] for item in self.processed_results]) + + # 6. 更新UI + success_count = len(text_fragments) - len(self.failed_fragments) + self.chatbot[-1] = ["处理完成", f"成功处理 {success_count}/{len(text_fragments)} 个片段"] + yield from update_ui(chatbot=self.chatbot, history=self.history) + + return final_content + + def save_results(self, content: str, original_file_path: str) -> List[str]: + """保存处理结果为多种格式""" + if not content: + return [] + + timestamp = time.strftime("%Y%m%d_%H%M%S") + original_filename = os.path.basename(original_file_path) + filename_without_ext = os.path.splitext(original_filename)[0] + base_filename = f"{filename_without_ext}_processed_{timestamp}" + + result_files = [] + + # 获取用户指定的处理类型 + processing_type = self.plugin_kwargs.get("advanced_arg", "文本处理") + + # 1. 保存为TXT + try: + txt_formatter = TxtFormatter() + txt_content = txt_formatter.create_document(content) + txt_file = write_history_to_file( + history=[txt_content], + file_basename=f"{base_filename}.txt" + ) + result_files.append(txt_file) + except Exception as e: + self.chatbot.append(["警告", f"TXT格式保存失败: {str(e)}"]) + + # 2. 保存为Markdown + try: + md_formatter = MarkdownFormatter() + md_content = md_formatter.create_document(content, processing_type) + md_file = write_history_to_file( + history=[md_content], + file_basename=f"{base_filename}.md" + ) + result_files.append(md_file) + except Exception as e: + self.chatbot.append(["警告", f"Markdown格式保存失败: {str(e)}"]) + + # 3. 保存为HTML + try: + html_formatter = HtmlFormatter(processing_type=processing_type) + html_content = html_formatter.create_document(content) + html_file = write_history_to_file( + history=[html_content], + file_basename=f"{base_filename}.html" + ) + result_files.append(html_file) + except Exception as e: + self.chatbot.append(["警告", f"HTML格式保存失败: {str(e)}"]) + + # 4. 保存为Word + try: + word_formatter = WordFormatter() + doc = word_formatter.create_document(content, processing_type) + + # 获取保存路径 + from toolbox import get_log_folder + word_path = os.path.join(get_log_folder(), f"{base_filename}.docx") + doc.save(word_path) + + # 5. 保存为PDF(通过Word转换) + try: + from crazy_functions.paper_fns.file2file_doc.word2pdf import WordToPdfConverter + pdf_path = WordToPdfConverter.convert_to_pdf(word_path) + result_files.append(pdf_path) + except Exception as e: + self.chatbot.append(["警告", f"PDF格式保存失败: {str(e)}"]) + + except Exception as e: + self.chatbot.append(["警告", f"Word格式保存失败: {str(e)}"]) + + # 添加到下载区 + for file in result_files: + promote_file_to_downloadzone(file, chatbot=self.chatbot) + + return result_files + + def _breakdown_section_content(self, content: str) -> List[str]: + """对文本内容进行分割与合并 + + 主要按段落进行组织,只合并较小的段落以减少片段数量 + 保留原始段落结构,不对长段落进行强制分割 + 针对中英文设置不同的阈值,因为字符密度不同 + """ + # 先按段落分割文本 + paragraphs = content.split('\n\n') + + # 检测语言类型 + chinese_char_count = sum(1 for char in content if '\u4e00' <= char <= '\u9fff') + is_chinese_text = chinese_char_count / max(1, len(content)) > 0.3 + + # 根据语言类型设置不同的阈值(只用于合并小段落) + if is_chinese_text: + # 中文文本:一个汉字就是一个字符,信息密度高 + min_chunk_size = 300 # 段落合并的最小阈值 + target_size = 800 # 理想的段落大小 + else: + # 英文文本:一个单词由多个字符组成,信息密度低 + min_chunk_size = 600 # 段落合并的最小阈值 + target_size = 1600 # 理想的段落大小 + + # 1. 只合并小段落,不对长段落进行分割 + result_fragments = [] + current_chunk = [] + current_length = 0 + + for para in paragraphs: + # 如果段落太小且不会超过目标大小,则合并 + if len(para) < min_chunk_size and current_length + len(para) <= target_size: + current_chunk.append(para) + current_length += len(para) + # 否则,创建新段落 + else: + # 如果当前块非空且与当前段落无关,先保存它 + if current_chunk and current_length > 0: + result_fragments.append('\n\n'.join(current_chunk)) + + # 当前段落作为新块 + current_chunk = [para] + current_length = len(para) + + # 如果当前块大小已接近目标大小,保存并开始新块 + if current_length >= target_size: + result_fragments.append('\n\n'.join(current_chunk)) + current_chunk = [] + current_length = 0 + + # 保存最后一个块 + if current_chunk: + result_fragments.append('\n\n'.join(current_chunk)) + + # 2. 处理可能过大的片段(确保不超过token限制) + final_fragments = [] + max_token = self._get_token_limit() + + for fragment in result_fragments: + # 检查fragment是否可能超出token限制 + # 根据语言类型调整token估算 + if is_chinese_text: + estimated_tokens = len(fragment) / 1.5 # 中文每个token约1-2个字符 + else: + estimated_tokens = len(fragment) / 4 # 英文每个token约4个字符 + + if estimated_tokens > max_token: + # 即使可能超出限制,也尽量保持段落的完整性 + # 使用breakdown_text但设置更大的限制来减少分割 + larger_limit = max_token * 0.95 # 使用95%的限制 + sub_fragments = breakdown_text_to_satisfy_token_limit( + txt=fragment, + limit=larger_limit, + llm_model=self.llm_kwargs['llm_model'] + ) + final_fragments.extend(sub_fragments) + else: + final_fragments.append(fragment) + + return final_fragments + + +@CatchException +def 自定义智能文档处理(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, user_request: str): + """主函数 - 文件到文件处理""" + # 初始化 + processor = DocumentProcessor(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt) + chatbot.append(["函数插件功能", "文件内容处理:将文档内容按照指定要求处理后输出为多种格式"]) + yield from update_ui(chatbot=chatbot, history=history) + + # 验证输入路径 + if not os.path.exists(txt): + report_exception(chatbot, history, a=f"解析路径: {txt}", b=f"找不到路径或无权访问: {txt}") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 验证路径安全性 + user_name = chatbot.get_user() + validate_path_safety(txt, user_name) + + # 获取文件列表 + if os.path.isfile(txt): + # 单个文件处理 + file_paths = [txt] + else: + # 目录处理 - 类似批量文件询问插件 + project_folder = txt + extract_folder = next((d for d in glob.glob(f'{project_folder}/*') + if os.path.isdir(d) and d.endswith('.extract')), project_folder) + + # 排除压缩文件 + exclude_patterns = r'/[^/]+\.(zip|rar|7z|tar|gz)$' + file_paths = [f for f in glob.glob(f'{extract_folder}/**', recursive=True) + if os.path.isfile(f) and not re.search(exclude_patterns, f)] + + # 过滤支持的文件格式 + file_paths = [f for f in file_paths if any(f.lower().endswith(ext) for ext in + list(processor.paper_extractor.SUPPORTED_EXTENSIONS) + ['.json', '.csv', '.xlsx', '.xls'])] + + if not file_paths: + report_exception(chatbot, history, a=f"解析路径: {txt}", b="未找到支持的文件类型") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 处理文件 + if len(file_paths) > 1: + chatbot.append(["发现多个文件", f"共找到 {len(file_paths)} 个文件,将处理第一个文件"]) + yield from update_ui(chatbot=chatbot, history=history) + + # 只处理第一个文件 + file_to_process = file_paths[0] + processed_content = yield from processor.process_file(file_to_process) + + if processed_content: + # 保存结果 + result_files = processor.save_results(processed_content, file_to_process) + + if result_files: + chatbot.append(["处理完成", f"已生成 {len(result_files)} 个结果文件"]) + else: + chatbot.append(["处理完成", "但未能保存任何结果文件"]) + else: + chatbot.append(["处理失败", "未能生成有效的处理结果"]) + + yield from update_ui(chatbot=chatbot, history=history) diff --git a/crazy_functions/Paper_Reading.py b/crazy_functions/Paper_Reading.py new file mode 100644 index 00000000..93c6dd1b --- /dev/null +++ b/crazy_functions/Paper_Reading.py @@ -0,0 +1,360 @@ +import os +import time +import glob +from pathlib import Path +from datetime import datetime +from dataclasses import dataclass +from typing import Dict, List, Generator, Tuple +from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive +from toolbox import update_ui, promote_file_to_downloadzone, write_history_to_file, CatchException, report_exception +from shared_utils.fastapi_server import validate_path_safety +from crazy_functions.paper_fns.paper_download import extract_paper_id, extract_paper_ids, get_arxiv_paper, format_arxiv_id + + + +@dataclass +class PaperQuestion: + """论文分析问题类""" + id: str # 问题ID + question: str # 问题内容 + importance: int # 重要性 (1-5,5最高) + description: str # 问题描述 + + +class PaperAnalyzer: + """论文快速分析器""" + + def __init__(self, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, history: List, system_prompt: str): + """初始化分析器""" + self.llm_kwargs = llm_kwargs + self.plugin_kwargs = plugin_kwargs + self.chatbot = chatbot + self.history = history + self.system_prompt = system_prompt + self.paper_content = "" + self.results = {} + + # 定义论文分析问题库(已合并为4个核心问题) + self.questions = [ + PaperQuestion( + id="research_and_methods", + question="这篇论文的主要研究问题、目标和方法是什么?请分析:1)论文的核心研究问题和研究动机;2)论文提出的关键方法、模型或理论框架;3)这些方法如何解决研究问题。", + importance=5, + description="研究问题与方法" + ), + PaperQuestion( + id="findings_and_innovation", + question="论文的主要发现、结论及创新点是什么?请分析:1)论文的核心结果与主要发现;2)作者得出的关键结论;3)研究的创新点与对领域的贡献;4)与已有工作的区别。", + importance=4, + description="研究发现与创新" + ), + PaperQuestion( + id="methodology_and_data", + question="论文使用了什么研究方法和数据?请详细分析:1)研究设计与实验设置;2)数据收集方法与数据集特点;3)分析技术与评估方法;4)方法学上的合理性。", + importance=3, + description="研究方法与数据" + ), + PaperQuestion( + id="limitations_and_impact", + question="论文的局限性、未来方向及潜在影响是什么?请分析:1)研究的不足与限制因素;2)作者提出的未来研究方向;3)该研究对学术界和行业可能产生的影响;4)研究结果的适用范围与推广价值。", + importance=2, + description="局限性与影响" + ), + ] + + # 按重要性排序 + self.questions.sort(key=lambda q: q.importance, reverse=True) + + def _load_paper(self, paper_path: str) -> Generator: + from crazy_functions.doc_fns.text_content_loader import TextContentLoader + """加载论文内容""" + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 使用TextContentLoader读取文件 + loader = TextContentLoader(self.chatbot, self.history) + + yield from loader.execute_single_file(paper_path) + + # 获取加载的内容 + if len(self.history) >= 2 and self.history[-2]: + self.paper_content = self.history[-2] + yield from update_ui(chatbot=self.chatbot, history=self.history) + return True + else: + self.chatbot.append(["错误", "无法读取论文内容,请检查文件是否有效"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return False + + def _analyze_question(self, question: PaperQuestion) -> Generator: + """分析单个问题 - 直接显示问题和答案""" + try: + # 创建分析提示 + prompt = f"请基于以下论文内容回答问题:\n\n{self.paper_content}\n\n问题:{question.question}" + + # 使用单线程版本的请求函数 + response = yield from request_gpt_model_in_new_thread_with_ui_alive( + inputs=prompt, + inputs_show_user=question.question, # 显示问题本身 + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history=[], # 空历史,确保每个问题独立分析 + sys_prompt="你是一个专业的科研论文分析助手,需要仔细阅读论文内容并回答问题。请保持客观、准确,并基于论文内容提供深入分析。" + ) + + if response: + self.results[question.id] = response + return True + return False + + except Exception as e: + self.chatbot.append(["错误", f"分析问题时出错: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return False + + def _generate_summary(self) -> Generator: + """生成最终总结报告""" + self.chatbot.append(["生成报告", "正在整合分析结果,生成最终报告..."]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + summary_prompt = "请基于以下对论文的各个方面的分析,生成一份全面的论文解读报告。报告应该简明扼要地呈现论文的关键内容,并保持逻辑连贯性。" + + for q in self.questions: + if q.id in self.results: + summary_prompt += f"\n\n关于{q.description}的分析:\n{self.results[q.id]}" + + try: + # 使用单线程版本的请求函数,可以在前端实时显示生成结果 + response = yield from request_gpt_model_in_new_thread_with_ui_alive( + inputs=summary_prompt, + inputs_show_user="生成论文解读报告", + llm_kwargs=self.llm_kwargs, + chatbot=self.chatbot, + history=[], + sys_prompt="你是一个科研论文解读专家,请将多个方面的分析整合为一份完整、连贯、有条理的报告。报告应当重点突出,层次分明,并且保持学术性和客观性。" + ) + + if response: + return response + return "报告生成失败" + + except Exception as e: + self.chatbot.append(["错误", f"生成报告时出错: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return "报告生成失败: " + str(e) + + def save_report(self, report: str) -> Generator: + """保存分析报告""" + timestamp = time.strftime("%Y%m%d_%H%M%S") + + # 保存为Markdown文件 + try: + md_content = f"# 论文快速解读报告\n\n{report}" + for q in self.questions: + if q.id in self.results: + md_content += f"\n\n## {q.description}\n\n{self.results[q.id]}" + + result_file = write_history_to_file( + history=[md_content], + file_basename=f"论文解读_{timestamp}.md" + ) + + if result_file and os.path.exists(result_file): + promote_file_to_downloadzone(result_file, chatbot=self.chatbot) + self.chatbot.append(["保存成功", f"解读报告已保存至: {os.path.basename(result_file)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + else: + self.chatbot.append(["警告", "保存报告成功但找不到文件"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + except Exception as e: + self.chatbot.append(["警告", f"保存报告失败: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + def analyze_paper(self, paper_path: str) -> Generator: + """分析论文主流程""" + # 加载论文 + success = yield from self._load_paper(paper_path) + if not success: + return + + # 分析关键问题 - 直接询问每个问题,不显示进度信息 + for question in self.questions: + yield from self._analyze_question(question) + + # 生成总结报告 + final_report = yield from self._generate_summary() + + # 显示最终报告 + # self.chatbot.append(["论文解读报告", final_report]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + # 保存报告 + yield from self.save_report(final_report) + + +def _find_paper_file(path: str) -> str: + """查找路径中的论文文件(简化版)""" + if os.path.isfile(path): + return path + + # 支持的文件扩展名(按优先级排序) + extensions = ["pdf", "docx", "doc", "txt", "md", "tex"] + + # 简单地遍历目录 + if os.path.isdir(path): + try: + for ext in extensions: + # 手动检查每个可能的文件,而不使用glob + potential_file = os.path.join(path, f"paper.{ext}") + if os.path.exists(potential_file) and os.path.isfile(potential_file): + return potential_file + + # 如果没找到特定命名的文件,检查目录中的所有文件 + for file in os.listdir(path): + file_path = os.path.join(path, file) + if os.path.isfile(file_path): + file_ext = file.split('.')[-1].lower() if '.' in file else "" + if file_ext in extensions: + return file_path + except Exception: + pass # 忽略任何错误 + + return None + + +def download_paper_by_id(paper_info, chatbot, history) -> str: + """下载论文并返回保存路径 + + Args: + paper_info: 元组,包含论文ID类型(arxiv或doi)和ID值 + chatbot: 聊天机器人对象 + history: 历史记录 + + Returns: + str: 下载的论文路径或None + """ + from crazy_functions.review_fns.data_sources.scihub_source import SciHub + id_type, paper_id = paper_info + + # 创建保存目录 - 使用时间戳创建唯一文件夹 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + user_name = chatbot.get_user() if hasattr(chatbot, 'get_user') else "default" + from toolbox import get_log_folder, get_user + base_save_dir = get_log_folder(get_user(chatbot), plugin_name='paper_download') + save_dir = os.path.join(base_save_dir, f"papers_{timestamp}") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_path = Path(save_dir) + + chatbot.append([f"下载论文", f"正在下载{'arXiv' if id_type == 'arxiv' else 'DOI'} {paper_id} 的论文..."]) + update_ui(chatbot=chatbot, history=history) + + pdf_path = None + + try: + if id_type == 'arxiv': + # 使用改进的arxiv查询方法 + formatted_id = format_arxiv_id(paper_id) + paper_result = get_arxiv_paper(formatted_id) + + if not paper_result: + chatbot.append([f"下载失败", f"未找到arXiv论文: {paper_id}"]) + update_ui(chatbot=chatbot, history=history) + return None + + # 下载PDF + filename = f"arxiv_{paper_id.replace('/', '_')}.pdf" + pdf_path = str(save_path / filename) + paper_result.download_pdf(filename=pdf_path) + + else: # doi + # 下载DOI + sci_hub = SciHub( + doi=paper_id, + path=save_path + ) + pdf_path = sci_hub.fetch() + + # 检查下载结果 + if pdf_path and os.path.exists(pdf_path): + promote_file_to_downloadzone(pdf_path, chatbot=chatbot) + chatbot.append([f"下载成功", f"已成功下载论文: {os.path.basename(pdf_path)}"]) + update_ui(chatbot=chatbot, history=history) + return pdf_path + else: + chatbot.append([f"下载失败", f"论文下载失败: {paper_id}"]) + update_ui(chatbot=chatbot, history=history) + return None + + except Exception as e: + chatbot.append([f"下载错误", f"下载论文时出错: {str(e)}"]) + update_ui(chatbot=chatbot, history=history) + return None + + +@CatchException +def 快速论文解读(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List, + history: List, system_prompt: str, user_request: str): + """主函数 - 论文快速解读""" + # 初始化分析器 + chatbot.append(["函数插件功能及使用方式", "论文快速解读:通过分析论文的关键要素,帮助您迅速理解论文内容,适用于各学科领域的科研论文。

📋 使用方式:
1、直接上传PDF文件或者输入DOI号(仅针对SCI hub存在的论文)或arXiv ID(如2501.03916)
2、点击插件开始分析"]) + yield from update_ui(chatbot=chatbot, history=history) + + paper_file = None + + # 检查输入是否为论文ID(arxiv或DOI) + paper_info = extract_paper_id(txt) + + if paper_info: + # 如果是论文ID,下载论文 + chatbot.append(["检测到论文ID", f"检测到{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'} ID: {paper_info[1]},准备下载论文..."]) + yield from update_ui(chatbot=chatbot, history=history) + + # 下载论文 - 完全重新实现 + paper_file = download_paper_by_id(paper_info, chatbot, history) + + if not paper_file: + report_exception(chatbot, history, a=f"下载论文失败", b=f"无法下载{'arXiv' if paper_info[0] == 'arxiv' else 'DOI'}论文: {paper_info[1]}") + yield from update_ui(chatbot=chatbot, history=history) + return + else: + # 检查输入路径 + if not os.path.exists(txt): + report_exception(chatbot, history, a=f"解析论文: {txt}", b=f"找不到文件或无权访问: {txt}") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 验证路径安全性 + user_name = chatbot.get_user() + validate_path_safety(txt, user_name) + + # 查找论文文件 + paper_file = _find_paper_file(txt) + + if not paper_file: + report_exception(chatbot, history, a=f"解析论文", b=f"在路径 {txt} 中未找到支持的论文文件") + yield from update_ui(chatbot=chatbot, history=history) + return + + yield from update_ui(chatbot=chatbot, history=history) + + # 增加调试信息,检查paper_file的类型和值 + chatbot.append(["文件类型检查", f"paper_file类型: {type(paper_file)}, 值: {paper_file}"]) + yield from update_ui(chatbot=chatbot, history=history) + chatbot.pop() # 移除调试信息 + + # 确保paper_file是字符串 + if paper_file is not None and not isinstance(paper_file, str): + # 尝试转换为字符串 + try: + paper_file = str(paper_file) + except: + report_exception(chatbot, history, a=f"类型错误", b=f"论文路径不是有效的字符串: {type(paper_file)}") + yield from update_ui(chatbot=chatbot, history=history) + return + + # 分析论文 + chatbot.append(["开始分析", f"正在分析论文: {os.path.basename(paper_file)}"]) + yield from update_ui(chatbot=chatbot, history=history) + + analyzer = PaperAnalyzer(llm_kwargs, plugin_kwargs, chatbot, history, system_prompt) + yield from analyzer.analyze_paper(paper_file) \ No newline at end of file diff --git a/crazy_functions/doc_fns/read_fns/docx_reader.py b/crazy_functions/doc_fns/read_fns/docx_reader.py index 9308940b..531b9360 100644 --- a/crazy_functions/doc_fns/read_fns/docx_reader.py +++ b/crazy_functions/doc_fns/read_fns/docx_reader.py @@ -1,6 +1,4 @@ import nltk nltk.data.path.append('~/nltk_data') -nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data', - ) -nltk.download('punkt', download_dir='~/nltk_data', - ) \ No newline at end of file +nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data') +nltk.download('punkt', download_dir='~/nltk_data') \ No newline at end of file diff --git a/crazy_functions/doc_fns/text_content_loader.py b/crazy_functions/doc_fns/text_content_loader.py new file mode 100644 index 00000000..021e100c --- /dev/null +++ b/crazy_functions/doc_fns/text_content_loader.py @@ -0,0 +1,451 @@ +import os +import re +import glob +import time +import queue +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Generator, Tuple, Set, Optional, Dict +from dataclasses import dataclass +from loguru import logger +from toolbox import update_ui +from crazy_functions.rag_fns.rag_file_support import extract_text +from crazy_functions.doc_fns.content_folder import ContentFoldingManager, FileMetadata, FoldingOptions, FoldingStyle, FoldingError +from shared_utils.fastapi_server import validate_path_safety +from datetime import datetime +import mimetypes + +@dataclass +class FileInfo: + """文件信息数据类""" + path: str # 完整路径 + rel_path: str # 相对路径 + size: float # 文件大小(MB) + extension: str # 文件扩展名 + last_modified: str # 最后修改时间 + + +class TextContentLoader: + """优化版本的文本内容加载器 - 保持原有接口""" + + # 压缩文件扩展名 + COMPRESSED_EXTENSIONS: Set[str] = {'.zip', '.rar', '.7z', '.tar', '.gz', '.bz2', '.xz'} + + # 系统配置 + MAX_FILE_SIZE: int = 100 * 1024 * 1024 # 最大文件大小(100MB) + MAX_TOTAL_SIZE: int = 100 * 1024 * 1024 # 最大总大小(100MB) + MAX_FILES: int = 100 # 最大文件数量 + CHUNK_SIZE: int = 1024 * 1024 # 文件读取块大小(1MB) + MAX_WORKERS: int = min(32, (os.cpu_count() or 1) * 4) # 最大工作线程数 + BATCH_SIZE: int = 5 # 批处理大小 + + def __init__(self, chatbot: List, history: List): + """初始化加载器""" + self.chatbot = chatbot + self.history = history + self.failed_files: List[Tuple[str, str]] = [] + self.processed_size: int = 0 + self.start_time: float = 0 + self.file_cache: Dict[str, str] = {} + self._lock = threading.Lock() + self.executor = ThreadPoolExecutor(max_workers=self.MAX_WORKERS) + self.results_queue = queue.Queue() + self.folding_manager = ContentFoldingManager() + + def _create_file_info(self, entry: os.DirEntry, root_path: str) -> FileInfo: + """优化的文件信息创建 + + Args: + entry: 目录入口对象 + root_path: 根路径 + + Returns: + FileInfo: 文件信息对象 + """ + try: + stats = entry.stat() # 使用缓存的文件状态 + return FileInfo( + path=entry.path, + rel_path=os.path.relpath(entry.path, root_path), + size=stats.st_size / (1024 * 1024), + extension=os.path.splitext(entry.path)[1].lower(), + last_modified=time.strftime('%Y-%m-%d %H:%M:%S', + time.localtime(stats.st_mtime)) + ) + except (OSError, ValueError) as e: + return None + + def _process_file_batch(self, file_batch: List[FileInfo]) -> List[Tuple[FileInfo, Optional[str]]]: + """批量处理文件 + + Args: + file_batch: 要处理的文件信息列表 + + Returns: + List[Tuple[FileInfo, Optional[str]]]: 处理结果列表 + """ + results = [] + futures = {} + + for file_info in file_batch: + if file_info.path in self.file_cache: + results.append((file_info, self.file_cache[file_info.path])) + continue + + if file_info.size * 1024 * 1024 > self.MAX_FILE_SIZE: + with self._lock: + self.failed_files.append( + (file_info.rel_path, + f"文件过大({file_info.size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)") + ) + continue + + future = self.executor.submit(self._read_file_content, file_info) + futures[future] = file_info + + for future in as_completed(futures): + file_info = futures[future] + try: + content = future.result() + if content: + with self._lock: + self.file_cache[file_info.path] = content + self.processed_size += file_info.size * 1024 * 1024 + results.append((file_info, content)) + except Exception as e: + with self._lock: + self.failed_files.append((file_info.rel_path, f"读取失败: {str(e)}")) + + return results + + def _read_file_content(self, file_info: FileInfo) -> Optional[str]: + """读取单个文件内容 + + Args: + file_info: 文件信息对象 + + Returns: + Optional[str]: 文件内容 + """ + try: + content = extract_text(file_info.path) + if not content or not content.strip(): + return None + return content + except Exception as e: + logger.exception(f"读取文件失败: {str(e)}") + raise Exception(f"读取文件失败: {str(e)}") + + def _is_valid_file(self, file_path: str) -> bool: + """检查文件是否有效 + + Args: + file_path: 文件路径 + + Returns: + bool: 是否为有效文件 + """ + if not os.path.isfile(file_path): + return False + + extension = os.path.splitext(file_path)[1].lower() + if (extension in self.COMPRESSED_EXTENSIONS or + os.path.basename(file_path).startswith('.') or + not os.access(file_path, os.R_OK)): + return False + + # 只要文件可以访问且不在排除列表中就认为是有效的 + return True + + def _collect_files(self, path: str) -> List[FileInfo]: + """收集文件信息 + + Args: + path: 目标路径 + + Returns: + List[FileInfo]: 有效文件信息列表 + """ + files = [] + total_size = 0 + + # 处理单个文件的情况 + if os.path.isfile(path): + if self._is_valid_file(path): + file_info = self._create_file_info(os.DirEntry(os.path.dirname(path)), os.path.dirname(path)) + if file_info: + return [file_info] + return [] + + # 处理目录的情况 + try: + # 使用os.walk来递归遍历目录 + for root, _, filenames in os.walk(path): + for filename in filenames: + if len(files) >= self.MAX_FILES: + self.failed_files.append((filename, f"超出最大文件数限制({self.MAX_FILES})")) + continue + + file_path = os.path.join(root, filename) + + if not self._is_valid_file(file_path): + continue + + try: + stats = os.stat(file_path) + file_size = stats.st_size / (1024 * 1024) # 转换为MB + + if file_size * 1024 * 1024 > self.MAX_FILE_SIZE: + self.failed_files.append((file_path, + f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)")) + continue + + if total_size + file_size * 1024 * 1024 > self.MAX_TOTAL_SIZE: + self.failed_files.append((file_path, "超出总大小限制")) + continue + + file_info = FileInfo( + path=file_path, + rel_path=os.path.relpath(file_path, path), + size=file_size, + extension=os.path.splitext(file_path)[1].lower(), + last_modified=time.strftime('%Y-%m-%d %H:%M:%S', + time.localtime(stats.st_mtime)) + ) + + total_size += file_size * 1024 * 1024 + files.append(file_info) + + except Exception as e: + self.failed_files.append((file_path, f"处理文件失败: {str(e)}")) + continue + + except Exception as e: + self.failed_files.append(("目录扫描", f"扫描失败: {str(e)}")) + return [] + + return sorted(files, key=lambda x: x.rel_path) + + def _format_content_with_fold(self, file_info, content: str) -> str: + """使用折叠管理器格式化文件内容""" + try: + metadata = FileMetadata( + rel_path=file_info.rel_path, + size=file_info.size, + last_modified=datetime.fromtimestamp( + os.path.getmtime(file_info.path) + ), + mime_type=mimetypes.guess_type(file_info.path)[0] + ) + + options = FoldingOptions( + style=FoldingStyle.DETAILED, + code_language=self.folding_manager._guess_language( + os.path.splitext(file_info.path)[1] + ), + show_timestamp=True + ) + + return self.folding_manager.format_content( + content=content, + formatter_type='file', + metadata=metadata, + options=options + ) + + except Exception as e: + return f"Error formatting content: {str(e)}" + + def _format_content_for_llm(self, file_infos: List[FileInfo], contents: List[str]) -> str: + """格式化用于LLM的内容 + + Args: + file_infos: 文件信息列表 + contents: 内容列表 + + Returns: + str: 格式化后的内容 + """ + if len(file_infos) != len(contents): + raise ValueError("文件信息和内容数量不匹配") + + result = [ + "以下是多个文件的内容集合。每个文件的内容都以 '===== 文件 {序号}: {文件名} =====' 开始,", + "以 '===== 文件 {序号} 结束 =====' 结束。你可以根据这些分隔符来识别不同文件的内容。\n\n" + ] + + for idx, (file_info, content) in enumerate(zip(file_infos, contents), 1): + result.extend([ + f"===== 文件 {idx}: {file_info.rel_path} =====", + "文件内容:", + content.strip(), + f"===== 文件 {idx} 结束 =====\n" + ]) + + return "\n".join(result) + + def execute(self, txt: str) -> Generator: + """执行文本加载和显示 - 保持原有接口 + + Args: + txt: 目标路径 + + Yields: + Generator: UI更新生成器 + """ + try: + # 首先显示正在处理的提示信息 + self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + user_name = self.chatbot.get_user() + validate_path_safety(txt, user_name) + self.start_time = time.time() + self.processed_size = 0 + self.failed_files.clear() + successful_files = [] + successful_contents = [] + + # 收集文件 + files = self._collect_files(txt) + if not files: + # 移除之前的提示信息 + self.chatbot.pop() + self.chatbot.append(["提示", "未找到任何有效文件"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + # 批量处理文件 + content_blocks = [] + for i in range(0, len(files), self.BATCH_SIZE): + batch = files[i:i + self.BATCH_SIZE] + results = self._process_file_batch(batch) + + for file_info, content in results: + if content: + content_blocks.append(self._format_content_with_fold(file_info, content)) + successful_files.append(file_info) + successful_contents.append(content) + + # 显示文件内容,替换之前的提示信息 + if content_blocks: + # 移除之前的提示信息 + self.chatbot.pop() + self.chatbot.append(["文件内容", "\n".join(content_blocks)]) + self.history.extend([ + self._format_content_for_llm(successful_files, successful_contents), + "我已经接收到你上传的文件的内容,请提问" + ]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + yield from update_ui(chatbot=self.chatbot, history=self.history) + + except Exception as e: + # 发生错误时,移除之前的提示信息 + if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示": + self.chatbot.pop() + self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + finally: + self.executor.shutdown(wait=False) + self.file_cache.clear() + + def execute_single_file(self, file_path: str) -> Generator: + """执行单个文件的加载和显示 + + Args: + file_path: 文件路径 + + Yields: + Generator: UI更新生成器 + """ + try: + # 首先显示正在处理的提示信息 + self.chatbot.append(["提示", "正在提取文本内容,请稍作等待..."]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + user_name = self.chatbot.get_user() + validate_path_safety(file_path, user_name) + self.start_time = time.time() + self.processed_size = 0 + self.failed_files.clear() + + # 验证文件是否存在且可读 + if not os.path.isfile(file_path): + self.chatbot.pop() + self.chatbot.append(["错误", f"指定路径不是文件: {file_path}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + if not self._is_valid_file(file_path): + self.chatbot.pop() + self.chatbot.append(["错误", f"无效的文件类型或无法读取: {file_path}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + # 创建文件信息 + try: + stats = os.stat(file_path) + file_size = stats.st_size / (1024 * 1024) # 转换为MB + + if file_size * 1024 * 1024 > self.MAX_FILE_SIZE: + self.chatbot.pop() + self.chatbot.append(["错误", f"文件过大({file_size:.2f}MB > {self.MAX_FILE_SIZE / (1024 * 1024)}MB)"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + file_info = FileInfo( + path=file_path, + rel_path=os.path.basename(file_path), + size=file_size, + extension=os.path.splitext(file_path)[1].lower(), + last_modified=time.strftime('%Y-%m-%d %H:%M:%S', + time.localtime(stats.st_mtime)) + ) + except Exception as e: + self.chatbot.pop() + self.chatbot.append(["错误", f"处理文件失败: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + # 读取文件内容 + try: + content = self._read_file_content(file_info) + if not content: + self.chatbot.pop() + self.chatbot.append(["提示", f"文件内容为空或无法提取: {file_path}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + except Exception as e: + self.chatbot.pop() + self.chatbot.append(["错误", f"读取文件失败: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + return + + # 格式化内容并更新UI + formatted_content = self._format_content_with_fold(file_info, content) + + # 移除之前的提示信息 + self.chatbot.pop() + self.chatbot.append(["文件内容", formatted_content]) + + # 更新历史记录,便于LLM处理 + llm_content = self._format_content_for_llm([file_info], [content]) + self.history.extend([llm_content, "我已经接收到你上传的文件的内容,请提问"]) + + yield from update_ui(chatbot=self.chatbot, history=self.history) + + except Exception as e: + # 发生错误时,移除之前的提示信息 + if len(self.chatbot) > 0 and self.chatbot[-1][0] == "提示": + self.chatbot.pop() + self.chatbot.append(["错误", f"处理过程中出现错误: {str(e)}"]) + yield from update_ui(chatbot=self.chatbot, history=self.history) + + def __del__(self): + """析构函数 - 确保资源被正确释放""" + if hasattr(self, 'executor'): + self.executor.shutdown(wait=False) + if hasattr(self, 'file_cache'): + self.file_cache.clear() \ No newline at end of file diff --git a/crazy_functions/paper_fns/__init__.py b/crazy_functions/paper_fns/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crazy_functions/paper_fns/auto_git/handlers/base_handler.py b/crazy_functions/paper_fns/auto_git/handlers/base_handler.py new file mode 100644 index 00000000..9852a910 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/handlers/base_handler.py @@ -0,0 +1,386 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from ..query_analyzer import SearchCriteria +from ..sources.github_source import GitHubSource +import asyncio +import re +from datetime import datetime + +class BaseHandler(ABC): + """处理器基类""" + + def __init__(self, github: GitHubSource, llm_kwargs: Dict = None): + self.github = github + self.llm_kwargs = llm_kwargs or {} + self.ranked_repos = [] # 存储排序后的仓库列表 + + def _get_search_params(self, plugin_kwargs: Dict) -> Dict: + """获取搜索参数""" + return { + 'max_repos': plugin_kwargs.get('max_repos', 150), # 最大仓库数量,从30改为150 + 'max_details': plugin_kwargs.get('max_details', 80), # 最多展示详情的仓库数量,新增参数 + 'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数 + 'min_stars': plugin_kwargs.get('min_stars', 0), # 最少星标数 + } + + @abstractmethod + async def handle( + self, + criteria: SearchCriteria, + chatbot: List[List[str]], + history: List[List[str]], + system_prompt: str, + llm_kwargs: Dict[str, Any], + plugin_kwargs: Dict[str, Any], + ) -> str: + """处理查询""" + pass + + async def _search_repositories(self, query: str, language: str = None, min_stars: int = 0, + sort: str = "stars", per_page: int = 30) -> List[Dict]: + """搜索仓库""" + try: + # 构建查询字符串 + if min_stars > 0 and "stars:>" not in query: + query += f" stars:>{min_stars}" + + if language and "language:" not in query: + query += f" language:{language}" + + # 执行搜索 + result = await self.github.search_repositories( + query=query, + sort=sort, + per_page=per_page + ) + + if result and "items" in result: + return result["items"] + return [] + except Exception as e: + print(f"仓库搜索出错: {str(e)}") + return [] + + async def _search_bilingual_repositories(self, english_query: str, chinese_query: str, language: str = None, min_stars: int = 0, + sort: str = "stars", per_page: int = 30) -> List[Dict]: + """同时搜索中英文仓库并合并结果""" + try: + # 搜索英文仓库 + english_results = await self._search_repositories( + query=english_query, + language=language, + min_stars=min_stars, + sort=sort, + per_page=per_page + ) + + # 搜索中文仓库 + chinese_results = await self._search_repositories( + query=chinese_query, + language=language, + min_stars=min_stars, + sort=sort, + per_page=per_page + ) + + # 合并结果,去除重复项 + merged_results = [] + seen_repos = set() + + # 优先添加英文结果 + for repo in english_results: + repo_id = repo.get('id') + if repo_id and repo_id not in seen_repos: + seen_repos.add(repo_id) + merged_results.append(repo) + + # 添加中文结果(排除重复) + for repo in chinese_results: + repo_id = repo.get('id') + if repo_id and repo_id not in seen_repos: + seen_repos.add(repo_id) + merged_results.append(repo) + + # 按星标数重新排序 + merged_results.sort(key=lambda x: x.get('stargazers_count', 0), reverse=True) + + return merged_results[:per_page] # 返回合并后的前per_page个结果 + except Exception as e: + print(f"双语仓库搜索出错: {str(e)}") + return [] + + async def _search_code(self, query: str, language: str = None, per_page: int = 30) -> List[Dict]: + """搜索代码""" + try: + # 构建查询字符串 + if language and "language:" not in query: + query += f" language:{language}" + + # 执行搜索 + result = await self.github.search_code( + query=query, + per_page=per_page + ) + + if result and "items" in result: + return result["items"] + return [] + except Exception as e: + print(f"代码搜索出错: {str(e)}") + return [] + + async def _search_bilingual_code(self, english_query: str, chinese_query: str, language: str = None, per_page: int = 30) -> List[Dict]: + """同时搜索中英文代码并合并结果""" + try: + # 搜索英文代码 + english_results = await self._search_code( + query=english_query, + language=language, + per_page=per_page + ) + + # 搜索中文代码 + chinese_results = await self._search_code( + query=chinese_query, + language=language, + per_page=per_page + ) + + # 合并结果,去除重复项 + merged_results = [] + seen_files = set() + + # 优先添加英文结果 + for item in english_results: + # 使用文件URL作为唯一标识 + file_url = item.get('html_url', '') + if file_url and file_url not in seen_files: + seen_files.add(file_url) + merged_results.append(item) + + # 添加中文结果(排除重复) + for item in chinese_results: + file_url = item.get('html_url', '') + if file_url and file_url not in seen_files: + seen_files.add(file_url) + merged_results.append(item) + + # 对结果进行排序,优先显示匹配度高的结果 + # 由于无法直接获取匹配度,这里使用仓库的星标数作为替代指标 + merged_results.sort(key=lambda x: x.get('repository', {}).get('stargazers_count', 0), reverse=True) + + return merged_results[:per_page] # 返回合并后的前per_page个结果 + except Exception as e: + print(f"双语代码搜索出错: {str(e)}") + return [] + + async def _search_users(self, query: str, per_page: int = 30) -> List[Dict]: + """搜索用户""" + try: + result = await self.github.search_users( + query=query, + per_page=per_page + ) + + if result and "items" in result: + return result["items"] + return [] + except Exception as e: + print(f"用户搜索出错: {str(e)}") + return [] + + async def _search_bilingual_users(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]: + """同时搜索中英文用户并合并结果""" + try: + # 搜索英文用户 + english_results = await self._search_users( + query=english_query, + per_page=per_page + ) + + # 搜索中文用户 + chinese_results = await self._search_users( + query=chinese_query, + per_page=per_page + ) + + # 合并结果,去除重复项 + merged_results = [] + seen_users = set() + + # 优先添加英文结果 + for user in english_results: + user_id = user.get('id') + if user_id and user_id not in seen_users: + seen_users.add(user_id) + merged_results.append(user) + + # 添加中文结果(排除重复) + for user in chinese_results: + user_id = user.get('id') + if user_id and user_id not in seen_users: + seen_users.add(user_id) + merged_results.append(user) + + # 按关注者数量进行排序 + merged_results.sort(key=lambda x: x.get('followers', 0), reverse=True) + + return merged_results[:per_page] # 返回合并后的前per_page个结果 + except Exception as e: + print(f"双语用户搜索出错: {str(e)}") + return [] + + async def _search_topics(self, query: str, per_page: int = 30) -> List[Dict]: + """搜索主题""" + try: + result = await self.github.search_topics( + query=query, + per_page=per_page + ) + + if result and "items" in result: + return result["items"] + return [] + except Exception as e: + print(f"主题搜索出错: {str(e)}") + return [] + + async def _search_bilingual_topics(self, english_query: str, chinese_query: str, per_page: int = 30) -> List[Dict]: + """同时搜索中英文主题并合并结果""" + try: + # 搜索英文主题 + english_results = await self._search_topics( + query=english_query, + per_page=per_page + ) + + # 搜索中文主题 + chinese_results = await self._search_topics( + query=chinese_query, + per_page=per_page + ) + + # 合并结果,去除重复项 + merged_results = [] + seen_topics = set() + + # 优先添加英文结果 + for topic in english_results: + topic_name = topic.get('name') + if topic_name and topic_name not in seen_topics: + seen_topics.add(topic_name) + merged_results.append(topic) + + # 添加中文结果(排除重复) + for topic in chinese_results: + topic_name = topic.get('name') + if topic_name and topic_name not in seen_topics: + seen_topics.add(topic_name) + merged_results.append(topic) + + # 可以按流行度进行排序(如果有) + if merged_results and 'featured' in merged_results[0]: + merged_results.sort(key=lambda x: x.get('featured', False), reverse=True) + + return merged_results[:per_page] # 返回合并后的前per_page个结果 + except Exception as e: + print(f"双语主题搜索出错: {str(e)}") + return [] + + async def _get_repo_details(self, repos: List[Dict]) -> List[Dict]: + """获取仓库详细信息""" + enhanced_repos = [] + + for repo in repos: + try: + # 获取README信息 + owner = repo.get('owner', {}).get('login') if repo.get('owner') is not None else None + repo_name = repo.get('name') + + if owner and repo_name: + readme = await self.github.get_repo_readme(owner, repo_name) + if readme and "decoded_content" in readme: + # 提取README的前1000个字符作为摘要 + repo['readme_excerpt'] = readme["decoded_content"][:1000] + "..." + + # 获取语言使用情况 + languages = await self.github.get_repository_languages(owner, repo_name) + if languages: + repo['languages_detail'] = languages + + # 获取最新发布版本 + releases = await self.github.get_repo_releases(owner, repo_name, per_page=1) + if releases and len(releases) > 0: + repo['latest_release'] = releases[0] + + # 获取主题标签 + topics = await self.github.get_repo_topics(owner, repo_name) + if topics and "names" in topics: + repo['topics'] = topics["names"] + + enhanced_repos.append(repo) + except Exception as e: + print(f"获取仓库 {repo.get('full_name')} 详情时出错: {str(e)}") + enhanced_repos.append(repo) # 添加原始仓库信息 + + return enhanced_repos + + def _format_repos(self, repos: List[Dict]) -> str: + """格式化仓库列表""" + formatted = [] + + for i, repo in enumerate(repos, 1): + # 构建仓库URL + repo_url = repo.get('html_url', '') + + # 构建完整的引用 + reference = ( + f"{i}. **{repo.get('full_name', '')}**\n" + f" - 描述: {repo.get('description', 'N/A')}\n" + f" - 语言: {repo.get('language', 'N/A')}\n" + f" - 星标: {repo.get('stargazers_count', 0)}\n" + f" - Fork数: {repo.get('forks_count', 0)}\n" + f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n" + f" - 创建时间: {repo.get('created_at', 'N/A')[:10]}\n" + f" - URL: {repo_url}\n" + ) + + # 添加主题标签(如果有) + if repo.get('topics'): + topics_str = ", ".join(repo.get('topics')) + reference += f" - 主题标签: {topics_str}\n" + + # 添加最新发布版本(如果有) + if repo.get('latest_release'): + release = repo.get('latest_release') + reference += f" - 最新版本: {release.get('tag_name', 'N/A')} ({release.get('published_at', 'N/A')[:10]})\n" + + # 添加README摘要(如果有) + if repo.get('readme_excerpt'): + # 截断README,只取前300个字符 + readme_short = repo.get('readme_excerpt')[:300].replace('\n', ' ') + reference += f" - README摘要: {readme_short}...\n" + + formatted.append(reference) + + return "\n".join(formatted) + + def _generate_apology_prompt(self, criteria: SearchCriteria) -> str: + """生成道歉提示""" + return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的GitHub项目。 + +可能的原因: +1. 搜索词过于具体或冷门 +2. 星标数要求过高 +3. 编程语言限制过于严格 + +建议解决方案: + 1. 尝试使用更通用的关键词 + 2. 降低最低星标数要求 + 3. 移除或更改编程语言限制 +请根据以上建议调整后重试。""" + + def _get_current_time(self) -> str: + """获取当前时间信息""" + now = datetime.now() + return now.strftime("%Y年%m月%d日") \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/handlers/code_handler.py b/crazy_functions/paper_fns/auto_git/handlers/code_handler.py new file mode 100644 index 00000000..3f672a22 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/handlers/code_handler.py @@ -0,0 +1,156 @@ +from typing import List, Dict, Any +from .base_handler import BaseHandler +from ..query_analyzer import SearchCriteria +import asyncio + +class CodeSearchHandler(BaseHandler): + """代码搜索处理器""" + + def __init__(self, github, llm_kwargs=None): + super().__init__(github, llm_kwargs) + + async def handle( + self, + criteria: SearchCriteria, + chatbot: List[List[str]], + history: List[List[str]], + system_prompt: str, + llm_kwargs: Dict[str, Any], + plugin_kwargs: Dict[str, Any], + ) -> str: + """处理代码搜索请求,返回最终的prompt""" + + search_params = self._get_search_params(plugin_kwargs) + + # 搜索代码 + code_results = await self._search_bilingual_code( + english_query=criteria.github_params["query"], + chinese_query=criteria.github_params["chinese_query"], + language=criteria.language, + per_page=search_params['max_repos'] + ) + + if not code_results: + return self._generate_apology_prompt(criteria) + + # 获取代码文件内容 + enhanced_code_results = await self._get_code_details(code_results[:search_params['max_details']]) + self.ranked_repos = [item["repository"] for item in enhanced_code_results if "repository" in item] + + if not enhanced_code_results: + return self._generate_apology_prompt(criteria) + + # 构建最终的prompt + current_time = self._get_current_time() + final_prompt = f"""当前时间: {current_time} + +基于用户对{criteria.main_topic}的查询,我找到了以下代码示例。 + +代码搜索结果: +{self._format_code_results(enhanced_code_results)} + +请提供: + +1. 对于搜索的"{criteria.main_topic}"主题的综合解释: + - 概念和原理介绍 + - 常见实现方法和技术 + - 最佳实践和注意事项 + +2. 对每个代码示例: + - 解释代码的主要功能和实现方式 + - 分析代码质量、可读性和效率 + - 指出代码中的亮点和潜在改进空间 + - 说明代码的适用场景 + +3. 代码实现比较: + - 不同实现方法的优缺点 + - 性能和可维护性分析 + - 适用不同场景的实现建议 + +4. 学习建议: + - 理解和使用这些代码需要的背景知识 + - 如何扩展或改进所展示的代码 + - 进一步学习相关技术的资源 + +重要提示: +- 深入解释代码的核心逻辑和实现思路 +- 提供专业、技术性的分析 +- 优先关注代码的实现质量和技术价值 +- 当代码实现有问题时,指出并提供改进建议 +- 对于复杂代码,分解解释其组成部分 +- 根据用户查询的具体问题提供针对性答案 +- 所有链接请使用链接文本格式,确保链接在新窗口打开 + +使用markdown格式提供清晰的分节回复。 +""" + + return final_prompt + + async def _get_code_details(self, code_results: List[Dict]) -> List[Dict]: + """获取代码详情""" + enhanced_results = [] + + for item in code_results: + try: + repo = item.get('repository', {}) + file_path = item.get('path', '') + repo_name = repo.get('full_name', '') + + if repo_name and file_path: + owner, repo_name = repo_name.split('/') + + # 获取文件内容 + file_content = await self.github.get_file_content(owner, repo_name, file_path) + if file_content and "decoded_content" in file_content: + item['code_content'] = file_content["decoded_content"] + + # 获取仓库基本信息 + repo_details = await self.github.get_repo(owner, repo_name) + if repo_details: + item['repository'] = repo_details + + enhanced_results.append(item) + except Exception as e: + print(f"获取代码详情时出错: {str(e)}") + enhanced_results.append(item) # 添加原始信息 + + return enhanced_results + + def _format_code_results(self, code_results: List[Dict]) -> str: + """格式化代码搜索结果""" + formatted = [] + + for i, item in enumerate(code_results, 1): + # 构建仓库信息 + repo = item.get('repository', {}) + repo_name = repo.get('full_name', 'N/A') + repo_url = repo.get('html_url', '') + stars = repo.get('stargazers_count', 0) + language = repo.get('language', 'N/A') + + # 构建文件信息 + file_path = item.get('path', 'N/A') + file_url = item.get('html_url', '') + + # 构建代码内容 + code_content = item.get('code_content', '') + if code_content: + # 只显示前30行代码 + code_lines = code_content.split("\n") + if len(code_lines) > 30: + displayed_code = "\n".join(code_lines[:30]) + "\n... (代码太长已截断) ..." + else: + displayed_code = code_content + else: + displayed_code = "(代码内容获取失败)" + + reference = ( + f"### {i}. {file_path} (在 {repo_name} 中)\n\n" + f"- **仓库**: {repo_name} (⭐ {stars}, 语言: {language})\n" + f"- **文件路径**: {file_path}\n\n" + f"```{language.lower()}\n{displayed_code}\n```\n\n" + ) + + formatted.append(reference) + + return "\n".join(formatted) \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/handlers/repo_handler.py b/crazy_functions/paper_fns/auto_git/handlers/repo_handler.py new file mode 100644 index 00000000..2038c2eb --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/handlers/repo_handler.py @@ -0,0 +1,192 @@ +from typing import List, Dict, Any +from .base_handler import BaseHandler +from ..query_analyzer import SearchCriteria +import asyncio + +class RepositoryHandler(BaseHandler): + """仓库搜索处理器""" + + def __init__(self, github, llm_kwargs=None): + super().__init__(github, llm_kwargs) + + async def handle( + self, + criteria: SearchCriteria, + chatbot: List[List[str]], + history: List[List[str]], + system_prompt: str, + llm_kwargs: Dict[str, Any], + plugin_kwargs: Dict[str, Any], + ) -> str: + """处理仓库搜索请求,返回最终的prompt""" + + search_params = self._get_search_params(plugin_kwargs) + + # 如果是特定仓库查询 + if criteria.repo_id: + try: + owner, repo = criteria.repo_id.split('/') + repo_details = await self.github.get_repo(owner, repo) + if repo_details: + # 获取推荐的相似仓库 + similar_repos = await self.github.get_repo_recommendations(criteria.repo_id, limit=5) + + # 添加详细信息 + all_repos = [repo_details] + similar_repos + enhanced_repos = await self._get_repo_details(all_repos) + + self.ranked_repos = enhanced_repos + + # 构建最终的prompt + current_time = self._get_current_time() + final_prompt = self._build_repo_detail_prompt(enhanced_repos[0], enhanced_repos[1:], current_time) + return final_prompt + else: + return self._generate_apology_prompt(criteria) + except Exception as e: + print(f"处理特定仓库时出错: {str(e)}") + return self._generate_apology_prompt(criteria) + + # 一般仓库搜索 + repos = await self._search_bilingual_repositories( + english_query=criteria.github_params["query"], + chinese_query=criteria.github_params["chinese_query"], + language=criteria.language, + min_stars=criteria.min_stars, + per_page=search_params['max_repos'] + ) + + if not repos: + return self._generate_apology_prompt(criteria) + + # 获取仓库详情 + enhanced_repos = await self._get_repo_details(repos[:search_params['max_details']]) # 使用max_details参数 + self.ranked_repos = enhanced_repos + + if not enhanced_repos: + return self._generate_apology_prompt(criteria) + + # 构建最终的prompt + current_time = self._get_current_time() + final_prompt = f"""当前时间: {current_time} + +基于用户对{criteria.main_topic}的兴趣,以下是相关的GitHub仓库。 + +可供推荐的GitHub仓库: +{self._format_repos(enhanced_repos)} + +请提供: +1. 按功能、用途或成熟度对仓库进行分组 + +2. 对每个仓库: + - 简要描述其主要功能和用途 + - 分析其技术特点和优势 + - 说明其适用场景和使用难度 + - 指出其与同类产品相比的独特优势 + - 解释其星标数量和活跃度代表的意义 + +3. 使用建议: + - 新手最适合入门的仓库 + - 生产环境中最稳定可靠的选择 + - 最新技术栈或创新方案的代表 + - 学习特定技术的最佳资源 + +4. 相关资源: + - 学习这些项目需要的前置知识 + - 项目间的关联和技术栈兼容性 + - 可能的使用组合方案 + +重要提示: +- 重点解释为什么每个仓库值得关注 +- 突出项目间的关联性和差异性 +- 考虑用户不同水平的需求(初学者vs专业人士) +- 在介绍项目时,使用文本格式,确保链接在新窗口打开 +- 根据仓库的活跃度、更新频率、维护状态提供使用建议 +- 仅基于提供的信息,不要做无根据的猜测 +- 在信息缺失或不明确时,坦诚说明 + +使用markdown格式提供清晰的分节回复。 +""" + + return final_prompt + + def _build_repo_detail_prompt(self, main_repo: Dict, similar_repos: List[Dict], current_time: str) -> str: + """构建仓库详情prompt""" + + # 提取README摘要 + readme_content = "未提供" + if main_repo.get('readme_excerpt'): + readme_content = main_repo.get('readme_excerpt') + + # 构建语言分布 + languages = main_repo.get('languages_detail', {}) + lang_distribution = [] + if languages: + total = sum(languages.values()) + for lang, bytes_val in languages.items(): + percentage = (bytes_val / total) * 100 + lang_distribution.append(f"{lang}: {percentage:.1f}%") + + lang_str = "未知" + if lang_distribution: + lang_str = ", ".join(lang_distribution) + + # 构建最终prompt + prompt = f"""当前时间: {current_time} + +## 主要仓库信息 + +### {main_repo.get('full_name')} + +- **描述**: {main_repo.get('description', '未提供')} +- **星标数**: {main_repo.get('stargazers_count', 0)} +- **Fork数**: {main_repo.get('forks_count', 0)} +- **Watch数**: {main_repo.get('watchers_count', 0)} +- **Issues数**: {main_repo.get('open_issues_count', 0)} +- **语言分布**: {lang_str} +- **许可证**: {main_repo.get('license', {}).get('name', '未指定') if main_repo.get('license') is not None else '未指定'} +- **创建时间**: {main_repo.get('created_at', '')[:10]} +- **最近更新**: {main_repo.get('updated_at', '')[:10]} +- **主题标签**: {', '.join(main_repo.get('topics', ['无']))} +- **GitHub链接**: 链接 + +### README摘要: +{readme_content} + +## 类似仓库: +{self._format_repos(similar_repos)} + +请提供以下内容: + +1. **项目概述** + - 详细解释{main_repo.get('name', '')}项目的主要功能和用途 + - 分析其技术特点、架构和实现原理 + - 讨论其在所属领域的地位和影响力 + - 评估项目成熟度和稳定性 + +2. **优势与特点** + - 与同类项目相比的独特优势 + - 显著的技术创新或设计模式 + - 值得学习或借鉴的代码实践 + +3. **使用场景** + - 最适合的应用场景 + - 潜在的使用限制和注意事项 + - 入门门槛和学习曲线评估 + - 产品级应用的可行性分析 + +4. **资源与生态** + - 相关学习资源推荐 + - 配套工具和库的建议 + - 社区支持和活跃度评估 + +5. **类似项目对比** + - 与列出的类似项目的详细对比 + - 不同场景下的最佳选择建议 + - 潜在的互补使用方案 + +提示:所有链接请使用链接文本格式,确保链接在新窗口打开。 + +请以专业、客观的技术分析角度回答,使用markdown格式提供结构化信息。 +""" + return prompt \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/handlers/topic_handler.py b/crazy_functions/paper_fns/auto_git/handlers/topic_handler.py new file mode 100644 index 00000000..6d6b4637 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/handlers/topic_handler.py @@ -0,0 +1,217 @@ +from typing import List, Dict, Any +from .base_handler import BaseHandler +from ..query_analyzer import SearchCriteria +import asyncio + +class TopicHandler(BaseHandler): + """主题搜索处理器""" + + def __init__(self, github, llm_kwargs=None): + super().__init__(github, llm_kwargs) + + async def handle( + self, + criteria: SearchCriteria, + chatbot: List[List[str]], + history: List[List[str]], + system_prompt: str, + llm_kwargs: Dict[str, Any], + plugin_kwargs: Dict[str, Any], + ) -> str: + """处理主题搜索请求,返回最终的prompt""" + + search_params = self._get_search_params(plugin_kwargs) + + # 搜索主题 + topics = await self._search_bilingual_topics( + english_query=criteria.github_params["query"], + chinese_query=criteria.github_params["chinese_query"], + per_page=search_params['max_repos'] + ) + + if not topics: + # 尝试用主题搜索仓库 + search_query = criteria.github_params["query"] + chinese_search_query = criteria.github_params["chinese_query"] + if "topic:" not in search_query: + search_query += " topic:" + criteria.main_topic.replace(" ", "-") + if "topic:" not in chinese_search_query: + chinese_search_query += " topic:" + criteria.main_topic.replace(" ", "-") + + repos = await self._search_bilingual_repositories( + english_query=search_query, + chinese_query=chinese_search_query, + language=criteria.language, + min_stars=criteria.min_stars, + per_page=search_params['max_repos'] + ) + + if not repos: + return self._generate_apology_prompt(criteria) + + # 获取仓库详情 + enhanced_repos = await self._get_repo_details(repos[:10]) + self.ranked_repos = enhanced_repos + + if not enhanced_repos: + return self._generate_apology_prompt(criteria) + + # 构建基于主题的仓库列表prompt + current_time = self._get_current_time() + final_prompt = f"""当前时间: {current_time} + +基于用户对主题"{criteria.main_topic}"的查询,我找到了以下相关GitHub仓库。 + +主题相关仓库: +{self._format_repos(enhanced_repos)} + +请提供: + +1. 主题综述: + - "{criteria.main_topic}"主题的概述和重要性 + - 该主题在技术领域中的应用和发展趋势 + - 主题相关的主要技术栈和知识体系 + +2. 仓库分析: + - 按功能、技术栈或应用场景对仓库进行分类 + - 每个仓库在该主题领域的定位和贡献 + - 不同仓库间的技术路线对比 + +3. 学习路径建议: + - 初学者入门该主题的推荐仓库和学习顺序 + - 进阶学习的关键仓库和技术要点 + - 实际应用中的最佳实践选择 + +4. 技术生态分析: + - 该主题下的主流工具和库 + - 社区活跃度和维护状况 + - 与其他相关技术的集成方案 + +重要提示: +- 主题"{criteria.main_topic}"是用户查询的核心,请围绕此主题展开分析 +- 注重仓库质量评估和使用建议 +- 提供基于事实的客观技术分析 +- 在介绍仓库时使用链接文本格式,确保链接在新窗口打开 +- 考虑不同技术水平用户的需求 + +使用markdown格式提供清晰的分节回复。 +""" + return final_prompt + + # 如果找到了主题,则获取主题下的热门仓库 + topic_repos = [] + for topic in topics[:5]: # 增加到5个主题 + topic_name = topic.get('name', '') + if topic_name: + # 搜索该主题下的仓库 + repos = await self._search_repositories( + query=f"topic:{topic_name}", + language=criteria.language, + min_stars=criteria.min_stars, + per_page=20 # 每个主题最多20个仓库 + ) + + if repos: + for repo in repos: + repo['topic_source'] = topic_name + topic_repos.append(repo) + + if not topic_repos: + return self._generate_apology_prompt(criteria) + + # 获取前N个仓库的详情 + enhanced_repos = await self._get_repo_details(topic_repos[:search_params['max_details']]) + self.ranked_repos = enhanced_repos + + if not enhanced_repos: + return self._generate_apology_prompt(criteria) + + # 构建最终的prompt + current_time = self._get_current_time() + final_prompt = f"""当前时间: {current_time} + +基于用户对"{criteria.main_topic}"主题的查询,我找到了以下相关GitHub主题和仓库。 + +主题相关仓库: +{self._format_topic_repos(enhanced_repos)} + +请提供: + +1. 主题概述: + - 对"{criteria.main_topic}"相关主题的介绍和技术背景 + - 这些主题在软件开发中的重要性和应用范围 + - 主题间的关联性和技术演进路径 + +2. 精选仓库分析: + - 每个主题下最具代表性的仓库详解 + - 仓库的技术亮点和创新点 + - 使用场景和技术成熟度评估 + +3. 技术趋势分析: + - 基于主题和仓库活跃度的技术发展趋势 + - 新兴解决方案和传统方案的对比 + - 未来可能的技术方向预测 + +4. 实践建议: + - 不同应用场景下的最佳仓库选择 + - 学习路径和资源推荐 + - 实际项目中的应用策略 + +重要提示: +- 将分析重点放在主题的技术内涵和价值上 +- 突出主题间的关联性和技术演进脉络 +- 提供基于数据(星标数、更新频率等)的客观分析 +- 考虑不同技术背景用户的需求 +- 所有链接请使用链接文本格式,确保链接在新窗口打开 + +使用markdown格式提供清晰的分节回复。 +""" + + return final_prompt + + def _format_topic_repos(self, repos: List[Dict]) -> str: + """按主题格式化仓库列表""" + # 按主题分组 + topics_dict = {} + for repo in repos: + topic = repo.get('topic_source', '其他') + if topic not in topics_dict: + topics_dict[topic] = [] + topics_dict[topic].append(repo) + + # 格式化输出 + formatted = [] + for topic, topic_repos in topics_dict.items(): + formatted.append(f"## 主题: {topic}\n") + + for i, repo in enumerate(topic_repos, 1): + # 构建仓库URL + repo_url = repo.get('html_url', '') + + # 构建引用 + reference = ( + f"{i}. **{repo.get('full_name', '')}**\n" + f" - 描述: {repo.get('description', 'N/A')}\n" + f" - 语言: {repo.get('language', 'N/A')}\n" + f" - 星标: {repo.get('stargazers_count', 0)}\n" + f" - Fork数: {repo.get('forks_count', 0)}\n" + f" - 更新时间: {repo.get('updated_at', 'N/A')[:10]}\n" + f" - URL: {repo_url}\n" + ) + + # 添加主题标签(如果有) + if repo.get('topics'): + topics_str = ", ".join(repo.get('topics')) + reference += f" - 主题标签: {topics_str}\n" + + # 添加README摘要(如果有) + if repo.get('readme_excerpt'): + # 截断README,只取前200个字符 + readme_short = repo.get('readme_excerpt')[:200].replace('\n', ' ') + reference += f" - README摘要: {readme_short}...\n" + + formatted.append(reference) + + formatted.append("\n") # 主题之间添加空行 + + return "\n".join(formatted) \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/handlers/user_handler.py b/crazy_functions/paper_fns/auto_git/handlers/user_handler.py new file mode 100644 index 00000000..923d0e90 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/handlers/user_handler.py @@ -0,0 +1,164 @@ +from typing import List, Dict, Any +from .base_handler import BaseHandler +from ..query_analyzer import SearchCriteria +import asyncio + +class UserSearchHandler(BaseHandler): + """用户搜索处理器""" + + def __init__(self, github, llm_kwargs=None): + super().__init__(github, llm_kwargs) + + async def handle( + self, + criteria: SearchCriteria, + chatbot: List[List[str]], + history: List[List[str]], + system_prompt: str, + llm_kwargs: Dict[str, Any], + plugin_kwargs: Dict[str, Any], + ) -> str: + """处理用户搜索请求,返回最终的prompt""" + + search_params = self._get_search_params(plugin_kwargs) + + # 搜索用户 + users = await self._search_bilingual_users( + english_query=criteria.github_params["query"], + chinese_query=criteria.github_params["chinese_query"], + per_page=search_params['max_repos'] + ) + + if not users: + return self._generate_apology_prompt(criteria) + + # 获取用户详情和仓库 + enhanced_users = await self._get_user_details(users[:search_params['max_details']]) + self.ranked_repos = [] # 添加用户top仓库进行展示 + + for user in enhanced_users: + if user.get('top_repos'): + self.ranked_repos.extend(user.get('top_repos')) + + if not enhanced_users: + return self._generate_apology_prompt(criteria) + + # 构建最终的prompt + current_time = self._get_current_time() + final_prompt = f"""当前时间: {current_time} + +基于用户对{criteria.main_topic}的查询,我找到了以下GitHub用户。 + +GitHub用户搜索结果: +{self._format_users(enhanced_users)} + +请提供: + +1. 用户综合分析: + - 各开发者的专业领域和技术专长 + - 他们在GitHub开源社区的影响力 + - 技术实力和项目质量评估 + +2. 对每位开发者: + - 其主要贡献领域和技术栈 + - 代表性项目及其价值 + - 编程风格和技术特点 + - 在相关领域的影响力 + +3. 项目推荐: + - 针对用户查询的最有价值项目 + - 值得学习和借鉴的代码实践 + - 不同用户项目的相互补充关系 + +4. 如何学习和使用: + - 如何从这些开发者项目中学习 + - 最适合入门学习的项目 + - 进阶学习的路径建议 + +重要提示: +- 关注开发者的技术专长和核心贡献 +- 分析其开源项目的技术价值 +- 根据用户的原始查询提供相关建议 +- 避免过度赞美或主观评价 +- 基于事实数据(项目数、星标数等)进行客观分析 +- 所有链接请使用链接文本格式,确保链接在新窗口打开 + +使用markdown格式提供清晰的分节回复。 +""" + + return final_prompt + + async def _get_user_details(self, users: List[Dict]) -> List[Dict]: + """获取用户详情和仓库""" + enhanced_users = [] + + for user in users: + try: + username = user.get('login') + + if username: + # 获取用户详情 + user_details = await self.github.get_user(username) + if user_details: + user.update(user_details) + + # 获取用户仓库 + repos = await self.github.get_user_repos( + username, + sort="stars", + per_page=10 # 增加到10个仓库 + ) + if repos: + user['top_repos'] = repos + + enhanced_users.append(user) + except Exception as e: + print(f"获取用户 {user.get('login')} 详情时出错: {str(e)}") + enhanced_users.append(user) # 添加原始信息 + + return enhanced_users + + def _format_users(self, users: List[Dict]) -> str: + """格式化用户列表""" + formatted = [] + + for i, user in enumerate(users, 1): + # 构建用户信息 + username = user.get('login', 'N/A') + name = user.get('name', username) + profile_url = user.get('html_url', '') + bio = user.get('bio', '无简介') + followers = user.get('followers', 0) + public_repos = user.get('public_repos', 0) + company = user.get('company', '未指定') + location = user.get('location', '未指定') + blog = user.get('blog', '') + + user_info = ( + f"### {i}. {name} (@{username})\n\n" + f"- **简介**: {bio}\n" + f"- **关注者**: {followers} | **公开仓库**: {public_repos}\n" + f"- **公司**: {company} | **地点**: {location}\n" + f"- **个人网站**: {blog}\n" + f"- **GitHub**: {username}\n\n" + ) + + # 添加用户的热门仓库 + top_repos = user.get('top_repos', []) + if top_repos: + user_info += "**热门仓库**:\n\n" + for repo in top_repos: + repo_name = repo.get('name', '') + repo_url = repo.get('html_url', '') + repo_desc = repo.get('description', '无描述') + repo_stars = repo.get('stargazers_count', 0) + repo_language = repo.get('language', '未指定') + + user_info += ( + f"- {repo_name} - ⭐ {repo_stars}, {repo_language}\n" + f" {repo_desc}\n\n" + ) + + formatted.append(user_info) + + return "\n".join(formatted) \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/query_analyzer.py b/crazy_functions/paper_fns/auto_git/query_analyzer.py new file mode 100644 index 00000000..605de715 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/query_analyzer.py @@ -0,0 +1,356 @@ +from typing import Dict, List +from dataclasses import dataclass +import re + +@dataclass +class SearchCriteria: + """搜索条件""" + query_type: str # 查询类型: repo/code/user/topic + main_topic: str # 主题 + sub_topics: List[str] # 子主题列表 + language: str # 编程语言 + min_stars: int # 最少星标数 + github_params: Dict # GitHub搜索参数 + original_query: str = "" # 原始查询字符串 + repo_id: str = "" # 特定仓库ID或名称 + +class QueryAnalyzer: + """查询分析器""" + + # 响应索引常量 + BASIC_QUERY_INDEX = 0 + GITHUB_QUERY_INDEX = 1 + + def __init__(self): + self.valid_types = { + "repo": ["repository", "project", "library", "framework", "tool"], + "code": ["code", "snippet", "implementation", "function", "class", "algorithm"], + "user": ["user", "developer", "organization", "contributor", "maintainer"], + "topic": ["topic", "category", "tag", "field", "area", "domain"] + } + + def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict): + """分析查询意图""" + from crazy_functions.crazy_utils import \ + request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt + + # 1. 基本查询分析 + type_prompt = f"""请分析这个与GitHub相关的查询,并严格按照以下XML格式回答: + +查询: {query} + +说明: +1. 你的回答必须使用下面显示的XML标签,不要有任何标签外的文本 +2. 从以下选项中选择查询类型: repo/code/user/topic + - repo: 用于查找仓库、项目、框架或库 + - code: 用于查找代码片段、函数实现或算法 + - user: 用于查找用户、开发者或组织 + - topic: 用于查找主题、类别或领域相关项目 +3. 识别主题和子主题 +4. 识别首选编程语言(如果有) +5. 确定最低星标数(如果适用) + +必需格式: +此处回答 +此处回答 +子主题1, 子主题2, ... +此处回答 +此处回答 + +示例回答: + +1. 仓库查询: +查询: "查找有至少1000颗星的Python web框架" +repo +web框架 +后端开发, HTTP服务器, ORM +Python +1000 + +2. 代码查询: +查询: "如何用JavaScript实现防抖函数" +code +防抖函数 +事件处理, 性能优化, 函数节流 +JavaScript +0""" + + # 2. 生成英文搜索条件 + github_prompt = f"""Optimize the following GitHub search query: + +Query: {query} + +Task: Convert the natural language query into an optimized GitHub search query. +Please use English, regardless of the language of the input query. + +Available search fields and filters: +1. Basic fields: + - in:name - Search in repository names + - in:description - Search in repository descriptions + - in:readme - Search in README files + - in:topic - Search in topics + - language:X - Filter by programming language + - user:X - Repositories from a specific user + - org:X - Repositories from a specific organization + +2. Code search fields: + - extension:X - Filter by file extension + - path:X - Filter by path + - filename:X - Filter by filename + +3. Metric filters: + - stars:>X - Has more than X stars + - forks:>X - Has more than X forks + - size:>X - Size greater than X KB + - created:>YYYY-MM-DD - Created after a specific date + - pushed:>YYYY-MM-DD - Updated after a specific date + +4. Other filters: + - is:public/private - Public or private repositories + - archived:true/false - Archived or not archived + - license:X - Specific license + - topic:X - Contains specific topic tag + +Examples: + +1. Query: "Find Python machine learning libraries with at least 1000 stars" +machine learning in:description language:python stars:>1000 + +2. Query: "Recently updated React UI component libraries" +UI components library in:readme in:description language:javascript topic:react pushed:>2023-01-01 + +3. Query: "Open source projects developed by Facebook" +org:facebook is:public + +4. Query: "Depth-first search implementation in JavaScript" +depth first search in:file language:javascript + +Please analyze the query and answer using only the XML tag: +Provide the optimized GitHub search query, using appropriate fields and operators""" + + # 3. 生成中文搜索条件 + chinese_github_prompt = f"""优化以下GitHub搜索查询: + +查询: {query} + +任务: 将自然语言查询转换为优化的GitHub搜索查询语句。 +为了搜索中文内容,请提取原始查询的关键词并使用中文形式,同时保留GitHub特定的搜索语法为英文。 + +可用的搜索字段和过滤器: +1. 基本字段: + - in:name - 在仓库名称中搜索 + - in:description - 在仓库描述中搜索 + - in:readme - 在README文件中搜索 + - in:topic - 在主题中搜索 + - language:X - 按编程语言筛选 + - user:X - 特定用户的仓库 + - org:X - 特定组织的仓库 + +2. 代码搜索字段: + - extension:X - 按文件扩展名筛选 + - path:X - 按路径筛选 + - filename:X - 按文件名筛选 + +3. 指标过滤器: + - stars:>X - 有超过X颗星 + - forks:>X - 有超过X个分支 + - size:>X - 大小超过X KB + - created:>YYYY-MM-DD - 在特定日期后创建 + - pushed:>YYYY-MM-DD - 在特定日期后更新 + +4. 其他过滤器: + - is:public/private - 公开或私有仓库 + - archived:true/false - 已归档或未归档 + - license:X - 特定许可证 + - topic:X - 含特定主题标签 + +示例: + +1. 查询: "找有关机器学习的Python库,至少1000颗星" +机器学习 in:description language:python stars:>1000 + +2. 查询: "最近更新的React UI组件库" +UI 组件库 in:readme in:description language:javascript topic:react pushed:>2023-01-01 + +3. 查询: "微信小程序开发框架" +微信小程序 开发框架 in:name in:description in:readme + +请分析查询并仅使用XML标签回答: +提供优化的GitHub搜索查询,使用适当的字段和运算符,保留中文关键词""" + + try: + # 构建提示数组 + prompts = [ + type_prompt, + github_prompt, + chinese_github_prompt, + ] + + show_messages = [ + "分析查询类型...", + "优化英文GitHub搜索参数...", + "优化中文GitHub搜索参数...", + ] + + sys_prompts = [ + "你是一个精通GitHub生态系统的专家,擅长分析与GitHub相关的查询。", + "You are a GitHub search expert, specialized in converting natural language queries into optimized GitHub search queries in English.", + "你是一个GitHub搜索专家,擅长处理查询并保留中文关键词进行搜索。", + ] + + # 使用同步方式调用LLM + responses = yield from request_gpt( + inputs_array=prompts, + inputs_show_user_array=show_messages, + llm_kwargs=llm_kwargs, + chatbot=chatbot, + history_array=[[] for _ in prompts], + sys_prompt_array=sys_prompts, + max_workers=3 + ) + + # 从收集的响应中提取我们需要的内容 + extracted_responses = [] + for i in range(len(prompts)): + if (i * 2 + 1) < len(responses): + response = responses[i * 2 + 1] + if response is None: + raise Exception(f"Response {i} is None") + if not isinstance(response, str): + try: + response = str(response) + except: + raise Exception(f"Cannot convert response {i} to string") + extracted_responses.append(response) + else: + raise Exception(f"未收到第 {i + 1} 个响应") + + # 解析基本信息 + query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type") + if not query_type: + print( + f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}") + raise Exception("无法提取query_type标签内容") + query_type = query_type.lower() + + main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic") + if not main_topic: + print(f"Debug - Failed to extract main_topic. Using query as fallback.") + main_topic = query + + query_type = self._normalize_query_type(query_type, query) + + # 提取子主题 + sub_topics = [] + sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics") + if sub_topics_text: + sub_topics = [topic.strip() for topic in sub_topics_text.split(",")] + + # 提取语言 + language = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "language") + + # 提取最低星标数 + min_stars = 0 + min_stars_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "min_stars") + if min_stars_text and min_stars_text.isdigit(): + min_stars = int(min_stars_text) + + # 解析GitHub搜索参数 - 英文 + english_github_query = self._extract_tag(extracted_responses[self.GITHUB_QUERY_INDEX], "query") + + # 解析GitHub搜索参数 - 中文 + chinese_github_query = self._extract_tag(extracted_responses[2], "query") + + # 构建GitHub参数 + github_params = { + "query": english_github_query, + "chinese_query": chinese_github_query, + "sort": "stars", # 默认按星标排序 + "order": "desc", # 默认降序 + "per_page": 30, # 默认每页30条 + "page": 1 # 默认第1页 + } + + # 检查是否为特定仓库查询 + repo_id = "" + if "repo:" in english_github_query or "repository:" in english_github_query: + repo_match = re.search(r'(repo|repository):([a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+)', english_github_query) + if repo_match: + repo_id = repo_match.group(2) + + print(f"Debug - 提取的信息:") + print(f"查询类型: {query_type}") + print(f"主题: {main_topic}") + print(f"子主题: {sub_topics}") + print(f"语言: {language}") + print(f"最低星标数: {min_stars}") + print(f"英文GitHub参数: {english_github_query}") + print(f"中文GitHub参数: {chinese_github_query}") + print(f"特定仓库: {repo_id}") + + # 更新返回的 SearchCriteria,包含中英文查询 + return SearchCriteria( + query_type=query_type, + main_topic=main_topic, + sub_topics=sub_topics, + language=language, + min_stars=min_stars, + github_params=github_params, + original_query=query, + repo_id=repo_id + ) + + except Exception as e: + raise Exception(f"分析查询失败: {str(e)}") + + def _normalize_query_type(self, query_type: str, query: str) -> str: + """规范化查询类型""" + if query_type in ["repo", "code", "user", "topic"]: + return query_type + + query_lower = query.lower() + for type_name, keywords in self.valid_types.items(): + for keyword in keywords: + if keyword in query_lower: + return type_name + + query_type_lower = query_type.lower() + for type_name, keywords in self.valid_types.items(): + for keyword in keywords: + if keyword in query_type_lower: + return type_name + + return "repo" # 默认返回repo类型 + + def _extract_tag(self, text: str, tag: str) -> str: + """提取标记内容""" + if not text: + return "" + + # 标准XML格式(处理多行和特殊字符) + pattern = f"<{tag}>(.*?)" + match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) + if match: + content = match.group(1).strip() + if content: + return content + + # 备用模式 + patterns = [ + rf"<{tag}>\s*([\s\S]*?)\s*", # 标准XML格式 + rf"<{tag}>([\s\S]*?)(?:|$)", # 未闭合的标签 + rf"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式 + rf"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式 + rf"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合 + ] + + # 尝试所有模式 + for pattern in patterns: + match = re.search(pattern, text, re.IGNORECASE | re.DOTALL) + if match: + content = match.group(1).strip() + if content: # 确保提取的内容不为空 + return content + + # 如果所有模式都失败,返回空字符串 + return "" \ No newline at end of file diff --git a/crazy_functions/paper_fns/auto_git/sources/github_source.py b/crazy_functions/paper_fns/auto_git/sources/github_source.py new file mode 100644 index 00000000..28cd80a6 --- /dev/null +++ b/crazy_functions/paper_fns/auto_git/sources/github_source.py @@ -0,0 +1,701 @@ +import aiohttp +import asyncio +import base64 +import json +import random +from datetime import datetime +from typing import List, Dict, Optional, Union, Any + +class GitHubSource: + """GitHub API实现""" + + # 默认API密钥列表 - 可以放置多个GitHub令牌 + API_KEYS = [ + "github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + # "your_github_token_1", + # "your_github_token_2", + # "your_github_token_3" + ] + + def __init__(self, api_key: Optional[Union[str, List[str]]] = None): + """初始化GitHub API客户端 + + Args: + api_key: GitHub个人访问令牌或令牌列表 + """ + if api_key is None: + self.api_keys = self.API_KEYS + elif isinstance(api_key, str): + self.api_keys = [api_key] + else: + self.api_keys = api_key + + self._initialize() + + def _initialize(self) -> None: + """初始化客户端,设置默认参数""" + self.base_url = "https://api.github.com" + self.headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "User-Agent": "GitHub-API-Python-Client" + } + + # 如果有可用的API密钥,随机选择一个 + if self.api_keys: + selected_key = random.choice(self.api_keys) + self.headers["Authorization"] = f"Bearer {selected_key}" + print(f"已随机选择API密钥进行认证") + else: + print("警告: 未提供API密钥,将受到GitHub API请求限制") + + async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any: + """发送API请求 + + Args: + method: HTTP方法 (GET, POST, PUT, DELETE等) + endpoint: API端点 + params: URL参数 + data: 请求体数据 + + Returns: + 解析后的响应JSON + """ + async with aiohttp.ClientSession(headers=self.headers) as session: + url = f"{self.base_url}{endpoint}" + + # 为调试目的打印请求信息 + print(f"请求: {method} {url}") + if params: + print(f"参数: {params}") + + # 发送请求 + request_kwargs = {} + if params: + request_kwargs["params"] = params + if data: + request_kwargs["json"] = data + + async with session.request(method, url, **request_kwargs) as response: + response_text = await response.text() + + # 检查HTTP状态码 + if response.status >= 400: + print(f"API请求失败: HTTP {response.status}") + print(f"响应内容: {response_text}") + return None + + # 解析JSON响应 + try: + return json.loads(response_text) + except json.JSONDecodeError: + print(f"JSON解析错误: {response_text}") + return None + + # ===== 用户相关方法 ===== + + async def get_user(self, username: Optional[str] = None) -> Dict: + """获取用户信息 + + Args: + username: 指定用户名,不指定则获取当前授权用户 + + Returns: + 用户信息字典 + """ + endpoint = "/user" if username is None else f"/users/{username}" + return await self._request("GET", endpoint) + + async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated", + direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]: + """获取用户的仓库列表 + + Args: + username: 指定用户名,不指定则获取当前授权用户 + sort: 排序方式 (created, updated, pushed, full_name) + direction: 排序方向 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 仓库列表 + """ + endpoint = "/user/repos" if username is None else f"/users/{username}/repos" + params = { + "sort": sort, + "direction": direction, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def get_user_starred(self, username: Optional[str] = None, + per_page: int = 30, page: int = 1) -> List[Dict]: + """获取用户星标的仓库 + + Args: + username: 指定用户名,不指定则获取当前授权用户 + per_page: 每页结果数量 + page: 页码 + + Returns: + 星标仓库列表 + """ + endpoint = "/user/starred" if username is None else f"/users/{username}/starred" + params = { + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + # ===== 仓库相关方法 ===== + + async def get_repo(self, owner: str, repo: str) -> Dict: + """获取仓库信息 + + Args: + owner: 仓库所有者 + repo: 仓库名 + + Returns: + 仓库信息 + """ + endpoint = f"/repos/{owner}/{repo}" + return await self._request("GET", endpoint) + + async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]: + """获取仓库的分支列表 + + Args: + owner: 仓库所有者 + repo: 仓库名 + per_page: 每页结果数量 + page: 页码 + + Returns: + 分支列表 + """ + endpoint = f"/repos/{owner}/{repo}/branches" + params = { + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None, + path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]: + """获取仓库的提交历史 + + Args: + owner: 仓库所有者 + repo: 仓库名 + sha: 特定提交SHA或分支名 + path: 文件路径筛选 + per_page: 每页结果数量 + page: 页码 + + Returns: + 提交列表 + """ + endpoint = f"/repos/{owner}/{repo}/commits" + params = { + "per_page": per_page, + "page": page + } + if sha: + params["sha"] = sha + if path: + params["path"] = path + + return await self._request("GET", endpoint, params=params) + + async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict: + """获取特定提交的详情 + + Args: + owner: 仓库所有者 + repo: 仓库名 + commit_sha: 提交SHA + + Returns: + 提交详情 + """ + endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}" + return await self._request("GET", endpoint) + + # ===== 内容相关方法 ===== + + async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict: + """获取文件内容 + + Args: + owner: 仓库所有者 + repo: 仓库名 + path: 文件路径 + ref: 分支名、标签名或提交SHA + + Returns: + 文件内容信息 + """ + endpoint = f"/repos/{owner}/{repo}/contents/{path}" + params = {} + if ref: + params["ref"] = ref + + response = await self._request("GET", endpoint, params=params) + if response and isinstance(response, dict) and "content" in response: + try: + # 解码Base64编码的文件内容 + content = base64.b64decode(response["content"].encode()).decode() + response["decoded_content"] = content + except Exception as e: + print(f"解码文件内容时出错: {str(e)}") + + return response + + async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]: + """获取目录内容 + + Args: + owner: 仓库所有者 + repo: 仓库名 + path: 目录路径 + ref: 分支名、标签名或提交SHA + + Returns: + 目录内容列表 + """ + # 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表 + endpoint = f"/repos/{owner}/{repo}/contents/{path}" + params = {} + if ref: + params["ref"] = ref + + return await self._request("GET", endpoint, params=params) + + # ===== Issues相关方法 ===== + + async def get_issues(self, owner: str, repo: str, state: str = "open", + sort: str = "created", direction: str = "desc", + per_page: int = 30, page: int = 1) -> List[Dict]: + """获取仓库的Issues列表 + + Args: + owner: 仓库所有者 + repo: 仓库名 + state: Issue状态 (open, closed, all) + sort: 排序方式 (created, updated, comments) + direction: 排序方向 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + Issues列表 + """ + endpoint = f"/repos/{owner}/{repo}/issues" + params = { + "state": state, + "sort": sort, + "direction": direction, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict: + """获取特定Issue的详情 + + Args: + owner: 仓库所有者 + repo: 仓库名 + issue_number: Issue编号 + + Returns: + Issue详情 + """ + endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}" + return await self._request("GET", endpoint) + + async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]: + """获取Issue的评论 + + Args: + owner: 仓库所有者 + repo: 仓库名 + issue_number: Issue编号 + + Returns: + 评论列表 + """ + endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments" + return await self._request("GET", endpoint) + + # ===== Pull Requests相关方法 ===== + + async def get_pull_requests(self, owner: str, repo: str, state: str = "open", + sort: str = "created", direction: str = "desc", + per_page: int = 30, page: int = 1) -> List[Dict]: + """获取仓库的Pull Request列表 + + Args: + owner: 仓库所有者 + repo: 仓库名 + state: PR状态 (open, closed, all) + sort: 排序方式 (created, updated, popularity, long-running) + direction: 排序方向 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + Pull Request列表 + """ + endpoint = f"/repos/{owner}/{repo}/pulls" + params = { + "state": state, + "sort": sort, + "direction": direction, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict: + """获取特定Pull Request的详情 + + Args: + owner: 仓库所有者 + repo: 仓库名 + pr_number: Pull Request编号 + + Returns: + Pull Request详情 + """ + endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}" + return await self._request("GET", endpoint) + + async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]: + """获取Pull Request中修改的文件 + + Args: + owner: 仓库所有者 + repo: 仓库名 + pr_number: Pull Request编号 + + Returns: + 修改文件列表 + """ + endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files" + return await self._request("GET", endpoint) + + # ===== 搜索相关方法 ===== + + async def search_repositories(self, query: str, sort: str = "stars", + order: str = "desc", per_page: int = 30, page: int = 1) -> Dict: + """搜索仓库 + + Args: + query: 搜索关键词 + sort: 排序方式 (stars, forks, updated) + order: 排序顺序 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 搜索结果 + """ + endpoint = "/search/repositories" + params = { + "q": query, + "sort": sort, + "order": order, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def search_code(self, query: str, sort: str = "indexed", + order: str = "desc", per_page: int = 30, page: int = 1) -> Dict: + """搜索代码 + + Args: + query: 搜索关键词 + sort: 排序方式 (indexed) + order: 排序顺序 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 搜索结果 + """ + endpoint = "/search/code" + params = { + "q": query, + "sort": sort, + "order": order, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def search_issues(self, query: str, sort: str = "created", + order: str = "desc", per_page: int = 30, page: int = 1) -> Dict: + """搜索Issues和Pull Requests + + Args: + query: 搜索关键词 + sort: 排序方式 (created, updated, comments) + order: 排序顺序 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 搜索结果 + """ + endpoint = "/search/issues" + params = { + "q": query, + "sort": sort, + "order": order, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def search_users(self, query: str, sort: str = "followers", + order: str = "desc", per_page: int = 30, page: int = 1) -> Dict: + """搜索用户 + + Args: + query: 搜索关键词 + sort: 排序方式 (followers, repositories, joined) + order: 排序顺序 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 搜索结果 + """ + endpoint = "/search/users" + params = { + "q": query, + "sort": sort, + "order": order, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + # ===== 组织相关方法 ===== + + async def get_organization(self, org: str) -> Dict: + """获取组织信息 + + Args: + org: 组织名称 + + Returns: + 组织信息 + """ + endpoint = f"/orgs/{org}" + return await self._request("GET", endpoint) + + async def get_organization_repos(self, org: str, type: str = "all", + sort: str = "created", direction: str = "desc", + per_page: int = 30, page: int = 1) -> List[Dict]: + """获取组织的仓库列表 + + Args: + org: 组织名称 + type: 仓库类型 (all, public, private, forks, sources, member, internal) + sort: 排序方式 (created, updated, pushed, full_name) + direction: 排序方向 (asc, desc) + per_page: 每页结果数量 + page: 页码 + + Returns: + 仓库列表 + """ + endpoint = f"/orgs/{org}/repos" + params = { + "type": type, + "sort": sort, + "direction": direction, + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]: + """获取组织成员列表 + + Args: + org: 组织名称 + per_page: 每页结果数量 + page: 页码 + + Returns: + 成员列表 + """ + endpoint = f"/orgs/{org}/members" + params = { + "per_page": per_page, + "page": page + } + return await self._request("GET", endpoint, params=params) + + # ===== 更复杂的操作 ===== + + async def get_repository_languages(self, owner: str, repo: str) -> Dict: + """获取仓库使用的编程语言及其比例 + + Args: + owner: 仓库所有者 + repo: 仓库名 + + Returns: + 语言使用情况 + """ + endpoint = f"/repos/{owner}/{repo}/languages" + return await self._request("GET", endpoint) + + async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]: + """获取仓库的贡献者统计 + + Args: + owner: 仓库所有者 + repo: 仓库名 + + Returns: + 贡献者统计信息 + """ + endpoint = f"/repos/{owner}/{repo}/stats/contributors" + return await self._request("GET", endpoint) + + async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]: + """获取仓库的提交活动 + + Args: + owner: 仓库所有者 + repo: 仓库名 + + Returns: + 提交活动统计 + """ + endpoint = f"/repos/{owner}/{repo}/stats/commit_activity" + return await self._request("GET", endpoint) + +async def example_usage(): + """GitHubSource使用示例""" + # 创建客户端实例(可选传入API令牌) + # github = GitHubSource(api_key="your_github_token") + github = GitHubSource() + + try: + # 示例1:搜索热门Python仓库 + print("\n=== 示例1:搜索热门Python仓库 ===") + repos = await github.search_repositories( + query="language:python stars:>1000", + sort="stars", + order="desc", + per_page=5 + ) + + if repos and "items" in repos: + for i, repo in enumerate(repos["items"], 1): + print(f"\n--- 仓库 {i} ---") + print(f"名称: {repo['full_name']}") + print(f"描述: {repo['description']}") + print(f"星标数: {repo['stargazers_count']}") + print(f"Fork数: {repo['forks_count']}") + print(f"最近更新: {repo['updated_at']}") + print(f"URL: {repo['html_url']}") + + # 示例2:获取特定仓库的详情 + print("\n=== 示例2:获取特定仓库的详情 ===") + repo_details = await github.get_repo("microsoft", "vscode") + if repo_details: + print(f"名称: {repo_details['full_name']}") + print(f"描述: {repo_details['description']}") + print(f"星标数: {repo_details['stargazers_count']}") + print(f"Fork数: {repo_details['forks_count']}") + print(f"默认分支: {repo_details['default_branch']}") + print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}") + print(f"语言: {repo_details['language']}") + print(f"Open Issues数: {repo_details['open_issues_count']}") + + # 示例3:获取仓库的提交历史 + print("\n=== 示例3:获取仓库的最近提交 ===") + commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5) + if commits: + for i, commit in enumerate(commits, 1): + print(f"\n--- 提交 {i} ---") + print(f"SHA: {commit['sha'][:7]}") + print(f"作者: {commit['commit']['author']['name']}") + print(f"日期: {commit['commit']['author']['date']}") + print(f"消息: {commit['commit']['message'].splitlines()[0]}") + + # 示例4:搜索代码 + print("\n=== 示例4:搜索代码 ===") + code_results = await github.search_code( + query="filename:README.md language:markdown pytorch in:file", + per_page=3 + ) + if code_results and "items" in code_results: + print(f"共找到: {code_results['total_count']} 个结果") + for i, item in enumerate(code_results["items"], 1): + print(f"\n--- 代码 {i} ---") + print(f"仓库: {item['repository']['full_name']}") + print(f"文件: {item['path']}") + print(f"URL: {item['html_url']}") + + # 示例5:获取文件内容 + print("\n=== 示例5:获取文件内容 ===") + file_content = await github.get_file_content("python", "cpython", "README.rst") + if file_content and "decoded_content" in file_content: + content = file_content["decoded_content"] + print(f"文件名: {file_content['name']}") + print(f"大小: {file_content['size']} 字节") + print(f"内容预览: {content[:200]}...") + + # 示例6:获取仓库使用的编程语言 + print("\n=== 示例6:获取仓库使用的编程语言 ===") + languages = await github.get_repository_languages("facebook", "react") + if languages: + print(f"React仓库使用的编程语言:") + for lang, bytes_of_code in languages.items(): + print(f"- {lang}: {bytes_of_code} 字节") + + # 示例7:获取组织信息 + print("\n=== 示例7:获取组织信息 ===") + org_info = await github.get_organization("google") + if org_info: + print(f"名称: {org_info['name']}") + print(f"描述: {org_info.get('description', '无')}") + print(f"位置: {org_info.get('location', '未指定')}") + print(f"公共仓库数: {org_info['public_repos']}") + print(f"成员数: {org_info.get('public_members', 0)}") + print(f"URL: {org_info['html_url']}") + + # 示例8:获取用户信息 + print("\n=== 示例8:获取用户信息 ===") + user_info = await github.get_user("torvalds") + if user_info: + print(f"名称: {user_info['name']}") + print(f"公司: {user_info.get('company', '无')}") + print(f"博客: {user_info.get('blog', '无')}") + print(f"位置: {user_info.get('location', '未指定')}") + print(f"公共仓库数: {user_info['public_repos']}") + print(f"关注者数: {user_info['followers']}") + print(f"URL: {user_info['html_url']}") + + except Exception as e: + print(f"发生错误: {str(e)}") + import traceback + print(traceback.format_exc()) + +if __name__ == "__main__": + import asyncio + + # 运行示例 + asyncio.run(example_usage()) \ No newline at end of file diff --git a/crazy_functions/paper_fns/document_structure_extractor.py b/crazy_functions/paper_fns/document_structure_extractor.py new file mode 100644 index 00000000..18334106 --- /dev/null +++ b/crazy_functions/paper_fns/document_structure_extractor.py @@ -0,0 +1,593 @@ +from typing import List, Dict, Optional, Tuple, Union, Any +from dataclasses import dataclass, field +import os +import re +import logging + +from crazy_functions.doc_fns.read_fns.unstructured_all.paper_structure_extractor import ( + PaperStructureExtractor, PaperSection, StructuredPaper +) +from unstructured.partition.auto import partition +from unstructured.documents.elements import ( + Text, Title, NarrativeText, ListItem, Table, + Footer, Header, PageBreak, Image, Address +) + +@dataclass +class DocumentSection: + """通用文档章节数据类""" + title: str # 章节标题,如果没有标题则为空字符串 + content: str # 章节内容 + level: int = 0 # 标题级别,0为主标题,1为一级标题,以此类推 + section_type: str = "content" # 章节类型 + is_heading_only: bool = False # 是否仅包含标题 + subsections: List['DocumentSection'] = field(default_factory=list) # 子章节列表 + + +@dataclass +class StructuredDocument: + """结构化文档数据类""" + title: str = "" # 文档标题 + metadata: Dict[str, Any] = field(default_factory=dict) # 元数据 + sections: List[DocumentSection] = field(default_factory=list) # 章节列表 + full_text: str = "" # 完整文本 + is_paper: bool = False # 是否为学术论文 + + +class GenericDocumentStructureExtractor: + """通用文档结构提取器 + + 可以从各种文档格式中提取结构信息,包括标题和内容。 + 支持论文、报告、文章和一般文本文档。 + """ + + # 支持的文件扩展名 + SUPPORTED_EXTENSIONS = [ + '.pdf', '.docx', '.doc', '.pptx', '.ppt', + '.txt', '.md', '.html', '.htm', '.xml', + '.rtf', '.odt', '.epub', '.msg', '.eml' + ] + + # 常见的标题前缀模式 + HEADING_PATTERNS = [ + # 数字标题 (1., 1.1., etc.) + r'^\s*(\d+\.)+\s+', + # 中文数字标题 (一、, 二、, etc.) + r'^\s*[一二三四五六七八九十]+[、::]\s+', + # 带括号的数字标题 ((1), (2), etc.) + r'^\s*\(\s*\d+\s*\)\s+', + # 特定标记的标题 (Chapter 1, Section 1, etc.) + r'^\s*(chapter|section|part|附录|章|节)\s+\d+[\.::]\s+', + ] + + # 常见的文档分段标记词 + SECTION_MARKERS = { + 'introduction': ['简介', '导言', '引言', 'introduction', '概述', 'overview'], + 'background': ['背景', '现状', 'background', '理论基础', '相关工作'], + 'main_content': ['主要内容', '正文', 'main content', '分析', '讨论'], + 'conclusion': ['结论', '总结', 'conclusion', '结语', '小结', 'summary'], + 'reference': ['参考', '参考文献', 'references', '文献', 'bibliography'], + 'appendix': ['附录', 'appendix', '补充资料', 'supplementary'] + } + + def __init__(self): + """初始化提取器""" + self.paper_extractor = PaperStructureExtractor() # 论文专用提取器 + self._setup_logging() + + def _setup_logging(self): + """配置日志""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + self.logger = logging.getLogger(__name__) + + def extract_document_structure(self, file_path: str, strategy: str = "fast") -> StructuredDocument: + """提取文档结构 + + Args: + file_path: 文件路径 + strategy: 提取策略 ("fast" 或 "accurate") + + Returns: + StructuredDocument: 结构化文档对象 + """ + try: + self.logger.info(f"正在处理文档结构: {file_path}") + + # 1. 首先尝试使用论文提取器 + try: + paper_result = self.paper_extractor.extract_paper_structure(file_path) + if paper_result and len(paper_result.sections) > 2: # 如果成功识别为论文结构 + self.logger.info(f"成功识别为学术论文: {file_path}") + # 将论文结构转换为通用文档结构 + return self._convert_paper_to_document(paper_result) + except Exception as e: + self.logger.debug(f"论文结构提取失败,将尝试通用提取: {str(e)}") + + # 2. 使用通用方法提取文档结构 + elements = partition( + str(file_path), + strategy=strategy, + include_metadata=True, + nlp=False + ) + + # 3. 使用通用提取器处理 + doc = self._extract_generic_structure(elements) + return doc + + except Exception as e: + self.logger.error(f"文档结构提取失败: {str(e)}") + # 返回一个空的结构化文档 + return StructuredDocument( + title="未能提取文档标题", + sections=[DocumentSection( + title="", + content="", + level=0, + section_type="content" + )] + ) + + def _convert_paper_to_document(self, paper: StructuredPaper) -> StructuredDocument: + """将论文结构转换为通用文档结构 + + Args: + paper: 结构化论文对象 + + Returns: + StructuredDocument: 转换后的通用文档结构 + """ + doc = StructuredDocument( + title=paper.metadata.title, + is_paper=True, + full_text=paper.full_text + ) + + # 转换元数据 + doc.metadata = { + 'title': paper.metadata.title, + 'authors': paper.metadata.authors, + 'keywords': paper.keywords, + 'abstract': paper.metadata.abstract if hasattr(paper.metadata, 'abstract') else "", + 'is_paper': True + } + + # 转换章节结构 + doc.sections = self._convert_paper_sections(paper.sections) + + return doc + + def _convert_paper_sections(self, paper_sections: List[PaperSection], level: int = 0) -> List[DocumentSection]: + """递归转换论文章节为通用文档章节 + + Args: + paper_sections: 论文章节列表 + level: 当前章节级别 + + Returns: + List[DocumentSection]: 通用文档章节列表 + """ + doc_sections = [] + + for section in paper_sections: + doc_section = DocumentSection( + title=section.title, + content=section.content, + level=section.level, + section_type=section.section_type, + is_heading_only=False if section.content else True + ) + + # 递归处理子章节 + if section.subsections: + doc_section.subsections = self._convert_paper_sections( + section.subsections, level + 1 + ) + + doc_sections.append(doc_section) + + return doc_sections + + def _extract_generic_structure(self, elements) -> StructuredDocument: + """从元素列表中提取通用文档结构 + + Args: + elements: 文档元素列表 + + Returns: + StructuredDocument: 结构化文档对象 + """ + # 创建结构化文档对象 + doc = StructuredDocument(full_text="") + + # 1. 提取文档标题 + title_candidates = [] + for i, element in enumerate(elements[:5]): # 只检查前5个元素 + if isinstance(element, Title): + title_text = str(element).strip() + title_candidates.append((i, title_text)) + + if title_candidates: + # 使用第一个标题作为文档标题 + doc.title = title_candidates[0][1] + + # 2. 识别所有标题元素和内容 + title_elements = [] + + # 2.1 首先识别所有标题 + for i, element in enumerate(elements): + is_heading = False + title_text = "" + level = 0 + + # 检查元素类型 + if isinstance(element, Title): + is_heading = True + title_text = str(element).strip() + + # 进一步检查是否为真正的标题 + if self._is_likely_heading(title_text, element, i, elements): + level = self._estimate_heading_level(title_text, element) + else: + is_heading = False + + # 也检查格式像标题的普通文本 + elif isinstance(element, (Text, NarrativeText)) and i > 0: + text = str(element).strip() + # 检查是否匹配标题模式 + if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS): + # 检查长度和后续内容以确认是否为标题 + if len(text) < 100 and self._has_sufficient_following_content(i, elements): + is_heading = True + title_text = text + level = self._estimate_heading_level(title_text, element) + + if is_heading: + section_type = self._identify_section_type(title_text) + title_elements.append((i, title_text, level, section_type)) + + # 2.2 为每个标题提取内容 + sections = [] + + for i, (index, title_text, level, section_type) in enumerate(title_elements): + # 确定内容范围 + content_start = index + 1 + content_end = elements[-1] # 默认到文档结束 + + # 如果有下一个标题,内容到下一个标题开始 + if i < len(title_elements) - 1: + content_end = title_elements[i+1][0] + else: + content_end = len(elements) + + # 提取内容 + content = self._extract_content_between(elements, content_start, content_end) + + # 创建章节 + section = DocumentSection( + title=title_text, + content=content, + level=level, + section_type=section_type, + is_heading_only=False if content.strip() else True + ) + + sections.append(section) + + # 3. 如果没有识别到任何章节,创建一个默认章节 + if not sections: + all_content = self._extract_content_between(elements, 0, len(elements)) + + # 尝试从内容中提取标题 + first_line = all_content.split('\n')[0] if all_content else "" + if first_line and len(first_line) < 100: + doc.title = first_line + all_content = '\n'.join(all_content.split('\n')[1:]) + + default_section = DocumentSection( + title="", + content=all_content, + level=0, + section_type="content" + ) + sections.append(default_section) + + # 4. 构建层次结构 + doc.sections = self._build_section_hierarchy(sections) + + # 5. 提取完整文本 + doc.full_text = "\n\n".join([str(element) for element in elements if isinstance(element, (Text, NarrativeText, Title, ListItem))]) + + return doc + + def _build_section_hierarchy(self, sections: List[DocumentSection]) -> List[DocumentSection]: + """构建章节层次结构 + + Args: + sections: 章节列表 + + Returns: + List[DocumentSection]: 具有层次结构的章节列表 + """ + if not sections: + return [] + + # 按层级排序 + top_level_sections = [] + current_parents = {0: None} # 每个层级的当前父节点 + + for section in sections: + # 找到当前节点的父节点 + parent_level = None + for level in sorted([k for k in current_parents.keys() if k < section.level], reverse=True): + parent_level = level + break + + if parent_level is None: + # 顶级章节 + top_level_sections.append(section) + else: + # 子章节 + parent = current_parents[parent_level] + if parent: + parent.subsections.append(section) + else: + top_level_sections.append(section) + + # 更新当前层级的父节点 + current_parents[section.level] = section + + # 清除所有更深层级的父节点缓存 + deeper_levels = [k for k in current_parents.keys() if k > section.level] + for level in deeper_levels: + current_parents.pop(level, None) + + return top_level_sections + + def _is_likely_heading(self, text: str, element, index: int, elements) -> bool: + """判断文本是否可能是标题 + + Args: + text: 文本内容 + element: 元素对象 + index: 元素索引 + elements: 所有元素列表 + + Returns: + bool: 是否可能是标题 + """ + # 1. 检查文本长度 - 标题通常不会太长 + if len(text) > 150: # 标题通常不超过150个字符 + return False + + # 2. 检查是否匹配标题的数字编号模式 + if any(re.match(pattern, text) for pattern in self.HEADING_PATTERNS): + return True + + # 3. 检查是否包含常见章节标记词 + lower_text = text.lower() + for markers in self.SECTION_MARKERS.values(): + if any(marker.lower() in lower_text for marker in markers): + return True + + # 4. 检查后续内容数量 - 标题后通常有足够多的内容 + if not self._has_sufficient_following_content(index, elements, min_chars=100): + # 但如果文本很短且以特定格式开头,仍可能是标题 + if len(text) < 50 and (text.endswith(':') or text.endswith(':')): + return True + return False + + # 5. 检查格式特征 + # 标题通常是元素的开头,不在段落中间 + if len(text.split('\n')) > 1: + # 多行文本不太可能是标题 + return False + + # 如果有元数据,检查字体特征(字体大小等) + if hasattr(element, 'metadata') and element.metadata: + try: + font_size = getattr(element.metadata, 'font_size', None) + is_bold = getattr(element.metadata, 'is_bold', False) + + # 字体较大或加粗的文本更可能是标题 + if font_size and font_size > 12: + return True + if is_bold: + return True + except (AttributeError, TypeError): + pass + + # 默认返回True,因为元素已被识别为Title类型 + return True + + def _estimate_heading_level(self, text: str, element) -> int: + """估计标题的层级 + + Args: + text: 标题文本 + element: 元素对象 + + Returns: + int: 标题层级 (0为主标题,1为一级标题, 等等) + """ + # 1. 通过编号模式判断层级 + for pattern, level in [ + (r'^\s*\d+\.\s+', 1), # 1. 开头 (一级标题) + (r'^\s*\d+\.\d+\.\s+', 2), # 1.1. 开头 (二级标题) + (r'^\s*\d+\.\d+\.\d+\.\s+', 3), # 1.1.1. 开头 (三级标题) + (r'^\s*\d+\.\d+\.\d+\.\d+\.\s+', 4), # 1.1.1.1. 开头 (四级标题) + ]: + if re.match(pattern, text): + return level + + # 2. 检查是否是常见的主要章节标题 + lower_text = text.lower() + main_sections = [ + 'abstract', 'introduction', 'background', 'methodology', + 'results', 'discussion', 'conclusion', 'references' + ] + for section in main_sections: + if section in lower_text: + return 1 # 主要章节为一级标题 + + # 3. 根据文本特征判断 + if text.isupper(): # 全大写文本可能是章标题 + return 1 + + # 4. 通过元数据判断层级 + if hasattr(element, 'metadata') and element.metadata: + try: + # 根据字体大小判断层级 + font_size = getattr(element.metadata, 'font_size', None) + if font_size is not None: + if font_size > 18: # 假设主标题字体最大 + return 0 + elif font_size > 16: + return 1 + elif font_size > 14: + return 2 + else: + return 3 + except (AttributeError, TypeError): + pass + + # 默认为二级标题 + return 2 + + def _identify_section_type(self, title_text: str) -> str: + """识别章节类型,包括参考文献部分""" + lower_text = title_text.lower() + + # 特别检查是否为参考文献部分 + references_patterns = [ + r'references', r'参考文献', r'bibliography', r'引用文献', + r'literature cited', r'^cited\s+literature', r'^文献$', r'^引用$' + ] + + for pattern in references_patterns: + if re.search(pattern, lower_text, re.IGNORECASE): + return "references" + + # 检查是否匹配其他常见章节类型 + for section_type, markers in self.SECTION_MARKERS.items(): + if any(marker.lower() in lower_text for marker in markers): + return section_type + + # 检查带编号的章节 + if re.match(r'^\d+\.', lower_text): + return "content" + + # 默认为内容章节 + return "content" + + def _has_sufficient_following_content(self, index: int, elements, min_chars: int = 150) -> bool: + """检查元素后是否有足够的内容 + + Args: + index: 当前元素索引 + elements: 所有元素列表 + min_chars: 最小字符数要求 + + Returns: + bool: 是否有足够的内容 + """ + total_chars = 0 + for i in range(index + 1, min(index + 5, len(elements))): + if isinstance(elements[i], Title): + # 如果紧接着是标题,就停止检查 + break + if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)): + total_chars += len(str(elements[i])) + if total_chars >= min_chars: + return True + + return total_chars >= min_chars + + def _extract_content_between(self, elements, start_index: int, end_index: int) -> str: + """提取指定范围内的内容文本 + + Args: + elements: 元素列表 + start_index: 开始索引 + end_index: 结束索引 + + Returns: + str: 提取的内容文本 + """ + content_parts = [] + + for i in range(start_index, end_index): + if isinstance(elements[i], (Text, NarrativeText, ListItem, Table)): + content_parts.append(str(elements[i]).strip()) + + return "\n\n".join([part for part in content_parts if part]) + + def generate_markdown(self, doc: StructuredDocument) -> str: + """将结构化文档转换为Markdown格式 + + Args: + doc: 结构化文档对象 + + Returns: + str: Markdown格式文本 + """ + md_parts = [] + + # 添加标题 + if doc.title: + md_parts.append(f"# {doc.title}\n") + + # 添加元数据 + if doc.is_paper: + # 作者信息 + if 'authors' in doc.metadata and doc.metadata['authors']: + authors_str = ", ".join(doc.metadata['authors']) + md_parts.append(f"**作者:** {authors_str}\n") + + # 关键词 + if 'keywords' in doc.metadata and doc.metadata['keywords']: + keywords_str = ", ".join(doc.metadata['keywords']) + md_parts.append(f"**关键词:** {keywords_str}\n") + + # 摘要 + if 'abstract' in doc.metadata and doc.metadata['abstract']: + md_parts.append(f"## 摘要\n\n{doc.metadata['abstract']}\n") + + # 添加章节内容 + md_parts.append(self._format_sections_markdown(doc.sections)) + + return "\n".join(md_parts) + + def _format_sections_markdown(self, sections: List[DocumentSection], base_level: int = 0) -> str: + """递归格式化章节为Markdown + + Args: + sections: 章节列表 + base_level: 基础层级 + + Returns: + str: Markdown格式文本 + """ + md_parts = [] + + for section in sections: + # 计算标题级别 (确保不超过6级) + header_level = min(section.level + base_level + 1, 6) + + # 添加标题和内容 + if section.title: + md_parts.append(f"{'#' * header_level} {section.title}\n") + + if section.content: + md_parts.append(f"{section.content}\n") + + # 递归处理子章节 + if section.subsections: + md_parts.append(self._format_sections_markdown( + section.subsections, base_level + )) + + return "\n".join(md_parts) \ No newline at end of file diff --git a/crazy_functions/paper_fns/file2file_doc/__init__.py b/crazy_functions/paper_fns/file2file_doc/__init__.py new file mode 100644 index 00000000..7992185e --- /dev/null +++ b/crazy_functions/paper_fns/file2file_doc/__init__.py @@ -0,0 +1,4 @@ +from .txt_doc import TxtFormatter +from .markdown_doc import MarkdownFormatter +from .html_doc import HtmlFormatter +from .word_doc import WordFormatter \ No newline at end of file diff --git a/crazy_functions/paper_fns/file2file_doc/html_doc.py b/crazy_functions/paper_fns/file2file_doc/html_doc.py new file mode 100644 index 00000000..9ff14799 --- /dev/null +++ b/crazy_functions/paper_fns/file2file_doc/html_doc.py @@ -0,0 +1,300 @@ +class HtmlFormatter: + """HTML格式文档生成器 - 保留原始文档结构""" + + def __init__(self, processing_type="文本处理"): + self.processing_type = processing_type + self.css_styles = """ + :root { + --primary-color: #2563eb; + --primary-light: #eff6ff; + --secondary-color: #1e293b; + --background-color: #f8fafc; + --text-color: #334155; + --border-color: #e2e8f0; + --card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); + } + + body { + font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + line-height: 1.8; + margin: 0; + padding: 2rem; + color: var(--text-color); + background-color: var(--background-color); + } + + .container { + max-width: 1200px; + margin: 0 auto; + background: white; + padding: 2rem; + border-radius: 16px; + box-shadow: var(--card-shadow); + } + ::selection { + background: var(--primary-light); + color: var(--primary-color); + } + @keyframes fadeIn { + from { opacity: 0; transform: translateY(20px); } + to { opacity: 1; transform: translateY(0); } + } + + .container { + animation: fadeIn 0.6s ease-out; + } + + .document-title { + color: var(--primary-color); + font-size: 2em; + text-align: center; + margin: 1rem 0 2rem; + padding-bottom: 1rem; + border-bottom: 2px solid var(--primary-color); + } + + .document-body { + display: flex; + flex-direction: column; + gap: 1.5rem; + margin: 2rem 0; + } + + .document-header { + display: flex; + flex-direction: column; + align-items: center; + margin-bottom: 2rem; + } + + .processing-type { + color: var(--secondary-color); + font-size: 1.2em; + margin: 0.5rem 0; + } + + .processing-date { + color: var(--text-color); + font-size: 0.9em; + opacity: 0.8; + } + + .document-content { + background: white; + padding: 1.5rem; + border-radius: 8px; + border-left: 4px solid var(--primary-color); + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); + } + + /* 保留文档结构的样式 */ + h1, h2, h3, h4, h5, h6 { + color: var(--secondary-color); + margin-top: 1.5em; + margin-bottom: 0.5em; + } + + h1 { font-size: 1.8em; } + h2 { font-size: 1.5em; } + h3 { font-size: 1.3em; } + h4 { font-size: 1.1em; } + + p { + margin: 0.8em 0; + } + + ul, ol { + margin: 1em 0; + padding-left: 2em; + } + + li { + margin: 0.5em 0; + } + + blockquote { + margin: 1em 0; + padding: 0.5em 1em; + border-left: 4px solid var(--primary-light); + background: rgba(0,0,0,0.02); + } + + code { + font-family: monospace; + background: rgba(0,0,0,0.05); + padding: 0.2em 0.4em; + border-radius: 3px; + } + + pre { + background: rgba(0,0,0,0.05); + padding: 1em; + border-radius: 5px; + overflow-x: auto; + } + + pre code { + background: transparent; + padding: 0; + } + + @media (prefers-color-scheme: dark) { + :root { + --background-color: #0f172a; + --text-color: #e2e8f0; + --border-color: #1e293b; + } + + .container, .document-content { + background: #1e293b; + } + + blockquote { + background: rgba(255,255,255,0.05); + } + + code, pre { + background: rgba(255,255,255,0.05); + } + } + """ + + def _escape_html(self, text): + """转义HTML特殊字符""" + import html + return html.escape(text) + + def _markdown_to_html(self, text): + """将Markdown格式转换为HTML格式,保留文档结构""" + try: + import markdown + # 使用Python-Markdown库将markdown转换为HTML,启用更多扩展以支持嵌套列表 + return markdown.markdown(text, extensions=['tables', 'fenced_code', 'codehilite', 'nl2br', 'sane_lists', 'smarty', 'extra']) + except ImportError: + # 如果没有markdown库,使用更复杂的替换来处理嵌套列表 + import re + + # 替换标题 + text = re.sub(r'^# (.+)$', r'

\1

', text, flags=re.MULTILINE) + text = re.sub(r'^## (.+)$', r'

\1

', text, flags=re.MULTILINE) + text = re.sub(r'^### (.+)$', r'

\1

', text, flags=re.MULTILINE) + + # 预处理列表 - 在列表项之间添加空行以正确分隔 + # 处理编号列表 + text = re.sub(r'(\n\d+\.\s.+)(\n\d+\.\s)', r'\1\n\2', text) + # 处理项目符号列表 + text = re.sub(r'(\n•\s.+)(\n•\s)', r'\1\n\2', text) + text = re.sub(r'(\n\*\s.+)(\n\*\s)', r'\1\n\2', text) + text = re.sub(r'(\n-\s.+)(\n-\s)', r'\1\n\2', text) + + # 处理嵌套列表 - 确保正确的缩进和结构 + lines = text.split('\n') + in_list = False + list_type = None # 'ol' 或 'ul' + list_html = [] + normal_lines = [] + + i = 0 + while i < len(lines): + line = lines[i] + + # 匹配编号列表项 + numbered_match = re.match(r'^(\d+)\.\s+(.+)$', line) + # 匹配项目符号列表项 + bullet_match = re.match(r'^[•\*-]\s+(.+)$', line) + + if numbered_match: + if not in_list or list_type != 'ol': + # 开始新的编号列表 + if in_list: + # 关闭前一个列表 + list_html.append(f'') + list_html.append('
    ') + in_list = True + list_type = 'ol' + + num, content = numbered_match.groups() + list_html.append(f'
  1. {content}
  2. ') + + elif bullet_match: + if not in_list or list_type != 'ul': + # 开始新的项目符号列表 + if in_list: + # 关闭前一个列表 + list_html.append(f'') + list_html.append('