150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
from datetime import datetime
|
|
from sqlalchemy import select, func, desc, and_
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.models.message import Message
|
|
from app.services.wechat_client import send_msg
|
|
from app.config import settings
|
|
|
|
|
|
async def save_messages(db: AsyncSession, msg_list: list[dict]) -> int:
|
|
"""批量保存消息,按 msg_id 去重,返回新增数量"""
|
|
saved = 0
|
|
for msg in msg_list:
|
|
origin = _normalize_origin(msg.get("origin", 3))
|
|
stmt = pg_insert(Message).values(
|
|
msg_id=str(msg.get("msgid") or msg.get("msg_id", "")),
|
|
open_kfid=str(msg.get("open_kfid", settings.open_kfid)),
|
|
external_userid=str(msg.get("external_userid", "")),
|
|
servicer_userid=msg.get("servicer_userid"),
|
|
send_time=datetime.fromtimestamp(msg.get("send_time", 0)),
|
|
msgtype=str(msg.get("msgtype", "unknown")),
|
|
origin=origin,
|
|
content=_extract_text_content(msg),
|
|
raw_data=msg,
|
|
direction="inbound" if origin != "servicer" else "outbound",
|
|
status="received",
|
|
).on_conflict_do_nothing(index_elements=["msg_id"])
|
|
result = await db.execute(stmt)
|
|
if result.rowcount:
|
|
saved += 1
|
|
await db.commit()
|
|
return saved
|
|
|
|
|
|
async def get_conversations(db: AsyncSession, open_kfid: str = "", limit: int = 50) -> list[dict]:
|
|
"""获取会话列表:按 external_userid 分组,显示最新消息"""
|
|
kfid = open_kfid or settings.open_kfid
|
|
# 子查询:每个客户的最新消息
|
|
subq = (
|
|
select(
|
|
Message.external_userid,
|
|
func.max(Message.send_time).label("latest_time")
|
|
)
|
|
.where(Message.open_kfid == kfid)
|
|
.group_by(Message.external_userid)
|
|
.order_by(desc("latest_time"))
|
|
.limit(limit)
|
|
.subquery()
|
|
)
|
|
q = (
|
|
select(Message)
|
|
.join(subq, and_(
|
|
Message.external_userid == subq.c.external_userid,
|
|
Message.send_time == subq.c.latest_time
|
|
))
|
|
.where(Message.open_kfid == kfid)
|
|
.order_by(desc(Message.send_time))
|
|
)
|
|
result = await db.execute(q)
|
|
rows = result.scalars().all()
|
|
|
|
return [
|
|
{
|
|
"external_userid": r.external_userid,
|
|
"latest_time": r.send_time.isoformat(),
|
|
"latest_content": r.content or "",
|
|
"latest_msgtype": r.msgtype,
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
async def get_messages(db: AsyncSession, external_userid: str,
|
|
open_kfid: str = "", page: int = 1,
|
|
page_size: int = 50) -> list[dict]:
|
|
"""获取某客户的消息列表(时间升序)"""
|
|
kfid = open_kfid or settings.open_kfid
|
|
offset = (page - 1) * page_size
|
|
q = (
|
|
select(Message)
|
|
.where(
|
|
Message.open_kfid == kfid,
|
|
Message.external_userid == external_userid,
|
|
)
|
|
.order_by(Message.send_time.asc())
|
|
.offset(offset)
|
|
.limit(page_size)
|
|
)
|
|
result = await db.execute(q)
|
|
rows = result.scalars().all()
|
|
|
|
return [
|
|
{
|
|
"msg_id": r.msg_id,
|
|
"content": r.content,
|
|
"msgtype": r.msgtype,
|
|
"send_time": r.send_time.isoformat(),
|
|
"origin": r.origin,
|
|
"direction": r.direction,
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
async def send_and_save(db: AsyncSession, external_userid: str, content: str,
|
|
open_kfid: str = "") -> dict:
|
|
"""发送消息并记录到数据库,返回发送结果"""
|
|
kfid = open_kfid or settings.open_kfid
|
|
result = await send_msg(
|
|
touser=external_userid,
|
|
open_kfid=kfid,
|
|
msgtype="text",
|
|
content=content,
|
|
)
|
|
errcode = result.get("errcode", -1)
|
|
status = "sent" if errcode == 0 else "failed"
|
|
|
|
# 记录到数据库
|
|
msg = Message(
|
|
msg_id=result.get("msgid", f"out_{int(datetime.now().timestamp())}"),
|
|
open_kfid=kfid,
|
|
external_userid=external_userid,
|
|
servicer_userid="",
|
|
send_time=datetime.now(),
|
|
msgtype="text",
|
|
origin="servicer",
|
|
content=content,
|
|
raw_data=result,
|
|
direction="outbound",
|
|
status=status,
|
|
)
|
|
db.add(msg)
|
|
await db.commit()
|
|
return result
|
|
|
|
|
|
def _normalize_origin(origin) -> str:
|
|
"""将微信 API 返回的 origin 整数转为字符串"""
|
|
if isinstance(origin, int):
|
|
return {3: "customer", 4: "system", 5: "servicer"}.get(origin, str(origin))
|
|
return str(origin)
|
|
|
|
|
|
def _extract_text_content(msg: dict) -> str:
|
|
"""从消息中提取文本内容"""
|
|
if msg.get("msgtype") == "text":
|
|
text_data = msg.get("text", {})
|
|
return text_data.get("content", "") if isinstance(text_data, dict) else str(text_data)
|
|
return ""
|