Files
gpt_academic/shared_utils/handle_upload.py
2025-03-02 01:58:45 +08:00

182 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import importlib
import time
import inspect
import re
import os
import base64
import gradio
import shutil
import glob
from shared_utils.config_loader import get_conf
from loguru import logger
def html_local_file(file):
base_path = os.path.dirname(__file__) # 项目目录
if os.path.exists(str(file)):
file = f'file={file.replace(base_path, ".")}'
return file
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
style = ""
if max_width is not None:
style += f"max-width: {max_width};"
if max_height is not None:
style += f"max-height: {max_height};"
__file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md:
a = f"![{__file}]({__file})"
return a
def file_manifest_filter_type(file_list, filter_: list = None):
new_list = []
if not filter_:
filter_ = ["png", "jpg", "jpeg"]
for file in file_list:
if str(os.path.basename(file)).split(".")[-1] in filter_:
new_list.append(html_local_img(file, md=False))
else:
new_list.append(file)
return new_list
def zip_extract_member_new(self, member, targetpath, pwd):
# 修复中文乱码的问题
"""Extract the ZipInfo object 'member' to a physical
file on the path targetpath.
"""
import zipfile
if not isinstance(member, zipfile.ZipInfo):
member = self.getinfo(member)
# build the destination pathname, replacing
# forward slashes to platform specific separators.
arcname = member.filename.replace('/', os.path.sep)
arcname = arcname.encode('cp437', errors='replace').decode('gbk', errors='replace')
if os.path.altsep:
arcname = arcname.replace(os.path.altsep, os.path.sep)
# interpret absolute pathname as relative, remove drive letter or
# UNC path, redundant separators, "." and ".." components.
arcname = os.path.splitdrive(arcname)[1]
invalid_path_parts = ('', os.path.curdir, os.path.pardir)
arcname = os.path.sep.join(x for x in arcname.split(os.path.sep)
if x not in invalid_path_parts)
if os.path.sep == '\\':
# filter illegal characters on Windows
arcname = self._sanitize_windows_name(arcname, os.path.sep)
targetpath = os.path.join(targetpath, arcname)
targetpath = os.path.normpath(targetpath)
# Create all upper directories if necessary.
upperdirs = os.path.dirname(targetpath)
if upperdirs and not os.path.exists(upperdirs):
os.makedirs(upperdirs)
if member.is_dir():
if not os.path.isdir(targetpath):
os.mkdir(targetpath)
return targetpath
with self.open(member, pwd=pwd) as source, \
open(targetpath, "wb") as target:
shutil.copyfileobj(source, target)
return targetpath
def safe_extract_rar(file_path, dest_dir):
import rarfile
import posixpath
with rarfile.RarFile(file_path) as rf:
os.makedirs(dest_dir, exist_ok=True)
base_path = os.path.abspath(dest_dir)
for file_info in rf.infolist():
orig_filename = file_info.filename
filename = posixpath.normpath(orig_filename).lstrip('/')
# 路径遍历防护
if '..' in filename or filename.startswith('../'):
raise Exception(f"Attempted Path Traversal in {orig_filename}")
# 符号链接防护
if hasattr(file_info, 'is_symlink') and file_info.is_symlink():
raise Exception(f"Attempted Symlink in {orig_filename}")
# 构造完整目标路径
target_path = os.path.join(base_path, filename)
final_path = os.path.normpath(target_path)
# 最终路径校验
if not final_path.startswith(base_path):
raise Exception(f"Attempted Path Traversal in {orig_filename}")
rf.extractall(dest_dir)
def extract_archive(file_path, dest_dir):
import zipfile
import tarfile
import os
# Get the file extension of the input file
file_extension = os.path.splitext(file_path)[1]
# Extract the archive based on its extension
if file_extension == ".zip":
with zipfile.ZipFile(file_path, "r") as zipobj:
zipobj._extract_member = lambda a,b,c: zip_extract_member_new(zipobj, a,b,c) # 修复中文乱码的问题
zipobj.extractall(path=dest_dir)
logger.info("Successfully extracted zip archive to {}".format(dest_dir))
elif file_extension in [".tar", ".gz", ".bz2"]:
try:
with tarfile.open(file_path, "r:*") as tarobj:
# 清理提取路径,移除任何不安全的元素
for member in tarobj.getmembers():
member_path = os.path.normpath(member.name)
full_path = os.path.join(dest_dir, member_path)
full_path = os.path.abspath(full_path)
if member.islnk() or member.issym():
raise Exception(f"Attempted Symlink in {member.name}")
if not full_path.startswith(os.path.abspath(dest_dir) + os.sep):
raise Exception(f"Attempted Path Traversal in {member.name}")
tarobj.extractall(path=dest_dir)
logger.info("Successfully extracted tar archive to {}".format(dest_dir))
except tarfile.ReadError as e:
if file_extension == ".gz":
# 一些特别奇葩的项目是一个gz文件里面不是tar只有一个tex文件
import gzip
with gzip.open(file_path, 'rb') as f_in:
with open(os.path.join(dest_dir, 'main.tex'), 'wb') as f_out:
f_out.write(f_in.read())
else:
raise e
# 第三方库需要预先pip install rarfile
# 此外Windows上还需要安装winrar软件配置其Path环境变量如"C:\Program Files\WinRAR"才可以
elif file_extension == ".rar":
try:
import rarfile # 用来检查rarfile是否安装不要删除
safe_extract_rar(file_path, dest_dir)
except:
logger.info("Rar format requires additional dependencies to install")
return "<br/><br/>解压失败! 需要安装pip install rarfile来解压rar文件。建议使用zip压缩格式。"
# 第三方库需要预先pip install py7zr
elif file_extension == ".7z":
try:
import py7zr
with py7zr.SevenZipFile(file_path, mode="r") as f:
f.extractall(path=dest_dir)
logger.info("Successfully extracted 7z archive to {}".format(dest_dir))
except:
logger.info("7z format requires additional dependencies to install")
return "<br/><br/>解压失败! 需要安装pip install py7zr来解压7z文件"
else:
return ""
return ""