257 lines
8.8 KiB
Python
257 lines
8.8 KiB
Python
from __future__ import annotations
|
||
|
||
# 策略对纯 asyncio 有效;Uvicorn 在 win32 上会改用 Proactor 类作为 loop_factory,
|
||
# 与下文无关。请用 ``python -m tg_bridge`` 或 ``uvicorn ... --loop tg_bridge.uvicorn_loop:selector_loop_factory``。
|
||
from tg_bridge.winloop import apply_windows_selector_policy
|
||
|
||
apply_windows_selector_policy()
|
||
|
||
import asyncio
|
||
import logging
|
||
from contextlib import asynccontextmanager
|
||
from typing import Annotated, Any
|
||
|
||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
from pydantic import BaseModel, Field
|
||
from telethon import TelegramClient
|
||
from telethon.errors import RPCError
|
||
|
||
from tg_bridge.client_factory import create_telegram_client
|
||
from tg_bridge.config import Settings
|
||
from tg_bridge.http_logging import AccessLogMiddleware, log_preview_max_chars, preview_for_log
|
||
from tg_bridge.logging_setup import setup_logging
|
||
|
||
setup_logging()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_settings: Settings | None = None
|
||
_client: TelegramClient | None = None
|
||
# 与 Bot 一问一答必须串行,否则 Conversation 会串话或 AlreadyInConversationError
|
||
_telegram_bridge_lock = asyncio.Lock()
|
||
security = HTTPBearer(auto_error=False)
|
||
|
||
|
||
def _extract_bot_reply_text(message: Any) -> str:
|
||
if message is None:
|
||
return ""
|
||
raw = getattr(message, "text", None) or getattr(message, "message", None) or ""
|
||
return raw.strip() if isinstance(raw, str) else ""
|
||
|
||
|
||
def _get_settings() -> Settings:
|
||
global _settings
|
||
if _settings is None:
|
||
_settings = Settings.load()
|
||
return _settings
|
||
|
||
|
||
async def _ensure_client() -> TelegramClient:
|
||
if _client is None:
|
||
raise HTTPException(status_code=503, detail="Telegram 客户端未初始化")
|
||
return _client
|
||
|
||
|
||
def _verify_bridge_auth(
|
||
creds: HTTPAuthorizationCredentials | None,
|
||
x_bridge_token: str | None,
|
||
) -> None:
|
||
s = _get_settings()
|
||
expected = s.bridge_token
|
||
if not expected:
|
||
return
|
||
if creds and creds.scheme.lower() == "bearer" and creds.credentials == expected:
|
||
return
|
||
if x_bridge_token and x_bridge_token == expected:
|
||
return
|
||
raise HTTPException(status_code=401, detail="无效或未提供鉴权")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _lifespan(app: FastAPI):
|
||
global _client
|
||
s = _get_settings()
|
||
if s.bridge_token:
|
||
logger.info("已启用 BRIDGE_TOKEN,请求需携带 Bearer 或 X-Bridge-Token")
|
||
else:
|
||
logger.warning("未设置 BRIDGE_TOKEN,/v1/forward 对同网段可达者开放,生产环境请设置")
|
||
|
||
if s.proxy_type and s.proxy_host and s.proxy_port:
|
||
logger.info(
|
||
"Telegram 使用代理: type=%s host=%s port=%s rdns=%s connection=%s timeout=%ss",
|
||
s.proxy_type,
|
||
s.proxy_host,
|
||
s.proxy_port,
|
||
s.proxy_rdns,
|
||
s.connection_mode,
|
||
s.connect_timeout,
|
||
)
|
||
_client = create_telegram_client(s)
|
||
await _client.connect()
|
||
if not await _client.is_user_authorized():
|
||
await _client.disconnect()
|
||
_client = None
|
||
raise RuntimeError(
|
||
"Telegram 会话未授权。请在 wx_python 目录执行: python -m tg_bridge.login_cli"
|
||
)
|
||
me = await _client.get_me()
|
||
logger.info("Telegram 已连接: user_id=%s username=%s", me.id, me.username)
|
||
yield
|
||
if _client:
|
||
await _client.disconnect()
|
||
_client = None
|
||
|
||
|
||
app = FastAPI(title="tg_bridge", version="0.1.0", lifespan=_lifespan)
|
||
app.add_middleware(AccessLogMiddleware)
|
||
|
||
|
||
class ForwardBody(BaseModel):
|
||
"""转发到 Telegram Bot 的请求体。"""
|
||
|
||
text: str = Field(..., min_length=1, description="要发送的文本")
|
||
bot: str | None = Field(
|
||
None,
|
||
description="目标 Bot 用户名(与 .env 中配置的某一个一致);省略则发往列表中的第一个",
|
||
)
|
||
context: str | None = Field(
|
||
None,
|
||
description="可选,如企业微信 UserID,会加在正文前便于区分来源",
|
||
)
|
||
wait_reply: bool = Field(
|
||
True,
|
||
description="为 true 时等待 Bot 下一条文本回复,并放入 reply_text(企业微信桥接常用)",
|
||
)
|
||
reply_timeout_sec: int | None = Field(
|
||
None,
|
||
ge=5,
|
||
le=600,
|
||
description="等待回复超时秒数,默认使用服务端 TELEGRAM_BOT_REPLY_TIMEOUT",
|
||
)
|
||
reply_take_nth: int | None = Field(
|
||
None,
|
||
ge=1,
|
||
le=20,
|
||
description="取 Bot 连续回复的第几条(1=第一条,2=第二条)。省略则用服务端 BOT_REPLY_TAKE_NTH",
|
||
)
|
||
|
||
|
||
class ForwardResponse(BaseModel):
|
||
ok: bool = True
|
||
detail: str = "sent"
|
||
reply_text: str | None = Field(
|
||
None,
|
||
description="Bot 回复正文;仅 wait_reply 为 true 且收到消息时可能有值",
|
||
)
|
||
|
||
|
||
async def _auth_dep(
|
||
creds: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||
x_bridge_token: Annotated[str | None, Header(alias="X-Bridge-Token")] = None,
|
||
):
|
||
_verify_bridge_auth(creds, x_bridge_token)
|
||
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
s = _get_settings()
|
||
c = _client
|
||
authorized = False
|
||
if c:
|
||
try:
|
||
authorized = await c.is_user_authorized()
|
||
except Exception:
|
||
authorized = False
|
||
return {
|
||
"status": "ok",
|
||
"telegram_authorized": authorized,
|
||
"bot": s.default_bot_username,
|
||
"bots": list(s.bot_usernames),
|
||
"telegram_proxy": bool(s.proxy_type and s.proxy_host and s.proxy_port),
|
||
"telegram_proxy_rdns": s.proxy_rdns,
|
||
"telegram_connect_timeout_sec": s.connect_timeout,
|
||
"telegram_connection": s.connection_mode,
|
||
"bot_reply_timeout_sec": s.bot_reply_timeout,
|
||
"default_bot_reply_take_nth": s.bot_reply_take_nth,
|
||
}
|
||
|
||
|
||
@app.post("/v1/forward", response_model=ForwardResponse)
|
||
async def forward_message(
|
||
body: ForwardBody,
|
||
_: Annotated[None, Depends(_auth_dep)],
|
||
):
|
||
"""将 text 用当前登录的 Telegram 个人号发给配置的 Bot。
|
||
|
||
若设置 BRIDGE_TOKEN:请求头需 `Authorization: Bearer <token>` 或 `X-Bridge-Token: <token>`。
|
||
"""
|
||
s = _get_settings()
|
||
client = await _ensure_client()
|
||
try:
|
||
target_bot = s.resolve_bot_username(body.bot)
|
||
except ValueError as e:
|
||
logger.warning("forward 请求被拒绝: %s", e)
|
||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||
|
||
payload = body.text.strip()
|
||
if body.context and body.context.strip():
|
||
payload = f"[wx:{body.context.strip()}]\n{payload}"
|
||
|
||
prev_n = log_preview_max_chars()
|
||
preview_part = (
|
||
f" text_preview={preview_for_log(payload, prev_n)!r}" if prev_n else ""
|
||
)
|
||
logger.info(
|
||
"forward 请求: bot=%s wait_reply=%s text_len=%d context=%s timeout_sec=%s take_nth=%s%s",
|
||
target_bot,
|
||
body.wait_reply,
|
||
len(payload),
|
||
bool(body.context and body.context.strip()),
|
||
body.reply_timeout_sec,
|
||
body.reply_take_nth,
|
||
preview_part,
|
||
)
|
||
|
||
reply_text: str | None = None
|
||
try:
|
||
if body.wait_reply:
|
||
rt = float(body.reply_timeout_sec or s.bot_reply_timeout)
|
||
nth = body.reply_take_nth if body.reply_take_nth is not None else s.bot_reply_take_nth
|
||
# 多条回复时,每条独立算超时;总会话上限略放大避免 total_timeout 先触发
|
||
total_rt = rt * float(nth) + 15.0
|
||
async with _telegram_bridge_lock:
|
||
async with client.conversation(
|
||
target_bot,
|
||
exclusive=True,
|
||
timeout=rt,
|
||
total_timeout=total_rt,
|
||
) as conv:
|
||
await conv.send_message(payload)
|
||
response = None
|
||
for _ in range(nth):
|
||
response = await conv.get_response(timeout=rt)
|
||
reply_text = _extract_bot_reply_text(response) if response is not None else None
|
||
reply_text = reply_text or None
|
||
else:
|
||
await client.send_message(target_bot, payload)
|
||
except asyncio.TimeoutError:
|
||
logger.warning("forward: bot=%s 等待 Bot 回复超时", target_bot)
|
||
raise HTTPException(status_code=504, detail="等待 Bot 回复超时") from None
|
||
except RPCError as e:
|
||
logger.exception("Telegram RPC 失败")
|
||
raise HTTPException(status_code=502, detail=str(e)) from e
|
||
except Exception as e:
|
||
logger.exception("发送失败")
|
||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||
out = ForwardResponse(reply_text=reply_text)
|
||
rlen = len(out.reply_text or "")
|
||
if prev_n and out.reply_text:
|
||
logger.info(
|
||
"forward 响应: reply_len=%d reply_preview=%r",
|
||
rlen,
|
||
preview_for_log(out.reply_text, prev_n),
|
||
)
|
||
else:
|
||
logger.info("forward 响应: reply_len=%d", rlen)
|
||
return out
|