1.初始化项目 #
本项目基于 FastAPI 搭建,采用 SQLAlchemy 进行 ORM 数据库操作,适用于构建以「智能体」为中心的对话或多模型应用后端。以下为初始化项目与依赖配置的详细说明:
创建虚拟环境与依赖安装
推荐使用 uv 工具进行依赖和环境管理,兼容 npm/yarn 的开发体验,大幅加快 Python 依赖管理的效率。- 初始化项目目录(生成 pyproject.toml 等):
uv init - 安装所需依赖包:
uv add python-multipart fastapi uvicorn[standard] sqlalchemy ...
- 初始化项目目录(生成 pyproject.toml 等):
依赖说明
fastapi:主框架,支持异步 RESTful API 和自动生成 OpenAPI 文档,适合敏捷开发。uvicorn[standard]:主流高性能的 ASGI 服务器,用于运行 FastAPI 项目,[standard]字样表示包含优化依赖如 uvloop、httptools 等,提高吞吐量和响应速度。sqlalchemy:行业标准的 Python ORM,统一支持多种数据库,定义模型与数据结构。pymysql:MySQL 适配驱动,配合 SQLAlchemy 实现数据存取。cryptography:用于敏感信息的加解密保障安全。pydantic/pydantic-settings:用于数据模型的类型验证和自动配置环境变量,与 FastAPI 深度集成。httpx:现代化 HTTP 客户端,支持同步与异步请求,适配微服务或第三方 API 调用。python-dotenv:方便本地开发时读取.env环境变量文件。mcp[cli]:多模型协同平台的管理 CLI,可以集中便捷地管理各类 LLM 服务/插件/模型资源。langchain&langchain-deepseek:打造基于大模型(如 DeepSeek、大语言模型等)的复杂推理、记忆和工具调用能力。
目录结构建议
├── app/ │ ├── __init__.py │ ├── main.py # FastAPI 应用主入口 │ ├── models.py # ORM 数据模型定义 │ └── database.py # 数据库连接与Session管理 ├── main.py # 作为服务器启动入口(可选) ├── .env # 环境变量(开发/部署配置) ├── requirements.txt / pyproject.toml └── README.md开发与启动
- 初始化数据库表结构可通过 SQLAlchemy 自动生成。
- 使用如下命令启动开发服务器(默认热重载):
uvicorn app.main:app --reload - 自定义配置(端口、host、日志等)可在
uvicorn.run()或启动命令中指定。
uv init
uv add python-multipart fastapi uvicorn[standard] sqlalchemy pymysql cryptography pydantic pydantic-settings httpx python-dotenv mcp[cli] langchain langchain-deepseek| 模块名称 | 作用简介 |
|---|---|
| python-multipart | 处理 HTTP 的 multipart/form-data(常用于文件上传) |
| fastapi | 高性能 Web API 框架,支持异步、自动文档生成 |
| uvicorn[standard] | ASGI 服务器,运行 FastAPI 等异步 Python 应用,standard 包含性能增强依赖 |
| sqlalchemy | Python ORM 框架,简化数据库操作 |
| pymysql | MySQL 数据库的 Python 客户端 |
| cryptography | 常用的加密和安全功能库 |
| pydantic | 数据验证与解析库,基于 Python 类型注解 |
| pydantic-settings | 基于 pydantic 的配置管理工具 |
| httpx | 新一代 Python HTTP 客户端,支持异步和同步请求 |
| python-dotenv | 加载 .env 文件中的环境变量,便于配置管理 |
| mcp[cli] | 多模型协同平台的命令行工具,便于模型及资源管理 |
| langchain | 用于大语言模型(LLM)应用开发的框架 |
| langchain-deepseek | Langchain 的 deepseek 插件支持,便于集成 DeepSeek LLM |
2. 启动服务器 #
本项目推荐使用 FastAPI 作为 Web 服务主框架,并通过 Uvicorn 启动应用服务器,配合 SQLAlchemy 实现数据库管理。具体启动流程如下:
本地开发环境准备
确保 Python 环境和必需依赖已安装,可以参考前面提供的uv add ...安装命令。若有.env配置文件,建议同步环境变量。启动 API 服务方式
方法一:直接使用 Uvicorn 命令行
在项目根目录下运行:uvicorn app.main:app --reload --host 0.0.0.0 --port 8000其中:
app.main:app指定 FastAPI 实例的位置(app/main.py 里的app对象)。--reload让开发阶段的服务变动代码自动重启(建议开发环境开启,生产关闭)。- 可以通过
--port设置监听端口。
方法二:运行 main.py 脚本
也可直接运行根目录下的main.py,其内部会调用uvicorn.run启动服务(底层同样使用 Uvicorn)。python main.py
验证服务启动情况
启动后可访问 http://localhost:8000/docs 查看自动生成的 API 文档;
访问 http://localhost:8000/health 可验证健康检查接口是否正常响应:{ "status": "ok" }可选配置
- 可通过
.env文件自定义 FastAPI 的标题、描述、数据库地址等参数。 - 若需绑定公网 IP 或自定义域名,请确保服务器开放相关端口。
- 可通过
注意:
- 生产部署建议关闭
--reload,并可结合多进程服务(如 Gunicorn 搭配 Uvicorn worker)。- 若需对接前端/第三方系统,可提前规划跨域(CORS)、接口鉴权等中间件功能。
由此,可快速搭建起数据驱动、接口友好、可拓展的智能体后端服务。
2.1. init.py #
app/init.py
2.2. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 创建FastAPI应用实例,设置标题和版本号
app = FastAPI(title="智能体服务", version="0.1.0")
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}2.3. main.py #
main.py
# 导入uvicorn库,用于运行ASGI服务器
+import uvicorn
# 判断当前模块是否为主模块(直接运行该脚本时为True)
if __name__ == "__main__":
# 启动uvicorn服务器,加载app.main模块中的app对象
# host设置为0.0.0.0,允许外部访问
# port设置为8000
# reload=True表示代码变动时自动重启服务(开发环境常用)
+ uvicorn.run(
+ "app.main:app",
+ host="0.0.0.0",
+ port=8000,
+ reload=True,
+ )2.4 .env #
DATABASE_URL=mysql+pymysql://root:root@127.0.0.1:3306/ai_agent?charset=utf8mb4
CORS_ORIGINS=http://localhost:5173,http://127.0.0.1:51733. 连接数据库 #
在「连接数据库」这一小节,我们将详细介绍如何在 FastAPI 项目中集成 SQLAlchemy 来操作 MySQL 数据库。
一般流程分为以下几步:
配置数据库连接
在.env和app/config.py中定义数据库连接字符串,便于统一管理和灵活切换。.env文件通过DATABASE_URL配置连接参数,不建议把账号密码硬编码到代码仓库。- 在
app/config.py中通过 pydantic-settings 自动加载.env配置,提供便捷的全局访问。
创建 SQLAlchemy 数据库引擎和会话
- 在
app/database.py中定义engine(连接池)、SessionLocal(数据库会话工厂)和Base(ORM模型基类)。 - 所有 ORM 模型需继承自
Base,以便 SQLAlchemy 能自动发现和创建数据表。
- 在
依赖注入数据库会话
- 在业务路由中,需要为每个请求提供一个独立的数据库会话。通常可通过 FastAPI 的依赖注入机制(
Depends)实现。 - 推荐定义
get_db函数:用yield生成 session,自动实现请求结束时的关闭和资源回收,避免连接泄漏。
- 在业务路由中,需要为每个请求提供一个独立的数据库会话。通常可通过 FastAPI 的依赖注入机制(
数据库表模型定义与迁移(可选)
- 以 Python 类的方式定义数据表结构,继承自
Base。 - 若需要自动迁移/同步表结构,可借助 Alembic 等工具完成。
- 以 Python 类的方式定义数据表结构,继承自
安全与性能建议
- 生产环境下建议使用更复杂的数据库账号密码,并控制连接池数量。
- 谨慎处理事务,避免长事务与死锁。
通过上述步骤,即可为智能体后端服务建立可靠的数据库访问能力,为后续业务开发打下坚实基础。
3.1 创建数据库 #
CREATE DATABASE ai_agent CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;3.2. .env #
.env
DATABASE_URL=mysql+pymysql://root:root@127.0.0.1:3306/ai_agent?charset=utf8mb4
CORS_ORIGINS=http://localhost:5173,http://127.0.0.1:51733.3. config.py #
app/config.py
# 导入Path用于路径处理(虽然本文件未直接用到)
from pathlib import Path
# 导入pydantic_settings中的BaseSettings和SettingsConfigDict用于配置管理
from pydantic_settings import BaseSettings, SettingsConfigDict
# 定义Settings类,继承自BaseSettings,便于环境变量和配置文件的读取
class Settings(BaseSettings):
# 配置SettingsConfigDict,指定.env文件路径及编码格式,并设置额外字段忽略
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
# 数据库连接字符串,默认连接本地mysql数据库
database_url: str = "mysql+pymysql://root:password@127.0.0.1:3306/agent?charset=utf8mb4"
# 允许跨域的前端地址,多个地址以逗号分隔
cors_origins: str = "http://localhost:5173,http://127.0.0.1:5173"
# 创建Settings实例,供项目其他部分导入使用
settings = Settings()
3.4. database.py #
app/database.py
# 导入SQLAlchemy的create_engine用于创建数据库引擎
from sqlalchemy import create_engine
# 导入sessionmaker用于会话创建,DeclarativeBase用于基类声明
from sqlalchemy.orm import sessionmaker, DeclarativeBase
# 从app.config中导入settings对象,读取配置信息
from app.config import settings
# 定义ORM模型的基类,所有模型都将继承该类
class Base(DeclarativeBase):
pass
# 创建数据库引擎,连接方式使用settings中的database_url
# pool_pre_ping=True用于防止数据库连接断开
# pool_recycle=3600设置连接池中连接的最大存活时间为3600秒
engine = create_engine(
settings.database_url,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂,autocommit=False表示手动提交事务
# autoflush=False表示不自动刷新
# bind=engine用于绑定数据库引擎
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 定义数据库会话的生成器,在依赖中使用
def get_db():
# 创建一个数据库会话实例
db = SessionLocal()
try:
# 使用yield返回会话对象
yield db
finally:
# 关闭会话,释放资源
db.close()3.5. main.py #
main.py
# 导入uvicorn库,用于运行ASGI服务器
import uvicorn
# 从sqlalchemy导入text,用于执行原生SQL语句
+from sqlalchemy import text
# 从app.database模块导入get_db,用于获取数据库会话
+from app.database import get_db
# 尝试与数据库建立连接
+try:
# 获取数据库会话生成器
+ db_gen = get_db()
# 获取数据库会话实例
+ db = next(db_gen)
# 执行一条简单的SQL语句以测试数据库连接
+ db.execute(text("SELECT 1"))
# 打印数据库连接正常的信息
+ print("数据库连接正常")
# 如果发生异常
+except Exception as e:
# 打印数据库连接失败的错误信息
+ print(f"数据库连接失败: {e}")
# 无论是否发生异常都会执行
+finally:
+ try:
# 尝试关闭数据库会话生成器,释放资源
+ db_gen.close()
+ except Exception:
# 如果关闭时发生异常则忽略
+ pass
# 判断是否为主程序运行入口
if __name__ == "__main__":
# 启动uvicorn服务器,加载app.main模块下的app实例
# host设置为0.0.0.0以便外部主机访问
# port设置为8000
# reload设置为True用于开发时自动重启
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=True,
)4. 数据模型 #
以下是对数据模型的详细讲解:
本项目的数据模型主要使用SQLAlchemy的声明式ORM方式定义,所有模型都继承自 app.database 中的 Base 基类。
McpService 模型
用于存储第三方MCP服务的配置信息,主要字段说明如下:
- id:主键,自增,为每个服务分配唯一标识。
- name:服务名称,字符串类型,设置唯一约束(
unique=True)和索引(index=True),保证每个服务名称唯一并提升查询效率。 - description:服务描述,可为
NULL,用于详细说明服务作用等扩展信息。 - protocol:协议类型,如
HTTP,MQTT等,不能为空。 - config:用于保存服务的配置信息,采用 MySQL 的
JSON类型字段,可灵活扩展不同服务的配置参数。
该表设计可拓展性强,新增服务类型或配置信息时无需改动表结构。
LlmModel 模型
用于存储大模型 API 提供方的信息和其支持的大模型列表,每个字段解释如下:
- id:主键,自增。
- provider_name:提供方名称,唯一且有索引,如
OpenAI、Azure。 - provider_icon:提供方的图标地址,可以为
NULL,便于前端展示。 - api_base_url:API 的基础请求地址,为必填项。
- api_key:用于访问该 LLM 提供方 API 的密钥,为必填项。
- api_key_url:密钥申请或获取地址,为可选项,便于新用户配置。
- model_names:存储提供方所支持的具体模型列表,数据类型为
JSON,例如["gpt-3.5-turbo", "gpt-4"]。
这种设计便于后续动态添加更多模型或多个提供方,同时支持前端读取支持模型以动态渲染 UI。
数据类型说明
- 采用
Mapped[...]注解配合mapped_column明确字段类型,兼容 SQLAlchemy 2.0+。 - 字符串长度(如
String(255),String(1024))根据实际业务场景预留充足长度。 JSON类型便于存储动态配置信息和模型名称列表(仅在支持 JSON 的数据库后端如 MySQL 5.7+ 有效)。- 支持可空字段均加了
nullable=True标记。
示例数据
McpService示例:{ "id": 1, "name": "mqtt-service", "description": "用于IoT设备通讯的MQTT服务", "protocol": "MQTT", "config": { "host": "localhost", "port": 1883, "username": "user1", "password": "*****" } }LlmModel示例:{ "id": 1, "provider_name": "OpenAI", "provider_icon": "https://example.com/icon.png", "api_base_url": "https://api.openai.com/v1/", "api_key": "sk-****************", "api_key_url": "https://platform.openai.com/account/api-keys", "model_names": ["gpt-3.5-turbo", "gpt-4"] }
通过上述模型设计,可以方便扩展系统对接多种服务和大模型能力,并具备良好的灵活性与可维护性。
4.1. models.py #
app/models.py
# 导入datetime用于处理日期和时间
from datetime import datetime
# 导入SQLAlchemy中的字段类型和时间函数
from sqlalchemy import DateTime, String, Text, func
# 导入MySQL方言中的JSON类型
from sqlalchemy.dialects.mysql import JSON
# 导入ORM映射辅助工具
from sqlalchemy.orm import Mapped, mapped_column
# 导入数据库基类
from app.database import Base
# 定义McpService服务模型
class McpService(Base):
# 设置对应的表名
__tablename__ = "mcp_services"
# 主键ID,自增
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# 服务名,唯一且有索引
name: Mapped[str] = mapped_column(String(255), unique=True, index=True)
# 服务描述,允许为null
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# 协议类型,必填
protocol: Mapped[str] = mapped_column(String(32), nullable=False)
# 配置信息,JSON格式,必填
config: Mapped[dict] = mapped_column(JSON, nullable=False)
# 定义LlmModel大模型提供方模型
class LlmModel(Base):
# 设置表名
__tablename__ = "llm_models"
# 主键ID,自增
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# 提供方名称,唯一且有索引
provider_name: Mapped[str] = mapped_column(String(255), unique=True, index=True)
# 图标地址,允许为null
provider_icon: Mapped[str | None] = mapped_column(Text, nullable=True)
# API基础URL,必填
api_base_url: Mapped[str] = mapped_column(String(1024), nullable=False)
# API密钥,必填
api_key: Mapped[str] = mapped_column(String(1024), nullable=False)
# 密钥获取地址,可空
api_key_url: Mapped[str | None] = mapped_column(String(1024), nullable=True)
# 支持的模型名列表,JSON格式,必填
model_names: Mapped[list[str]] = mapped_column(JSON, nullable=False)
# 定义Agent智能体模型
class Agent(Base):
# 表名设置
__tablename__ = "agents"
# 主键、自增
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# 头像,可空
avatar: Mapped[str | None] = mapped_column(Text, nullable=True)
# 名称,有索引
name: Mapped[str] = mapped_column(String(255), index=True)
# 描述,可空
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# 开场消息,可空
opening_message: Mapped[str | None] = mapped_column(Text, nullable=True)
# 系统提示,必填
system_prompt: Mapped[str] = mapped_column(Text, nullable=False)
# LLM提供方名称,必填
llm_provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
# LLM模型名称,必填
llm_model_name: Mapped[str] = mapped_column(String(255), nullable=False)
# MCP服务ID列表,JSON格式,必填
mcp_service_ids: Mapped[list[int]] = mapped_column(JSON, nullable=False)
# 询问提示词(模板),可空
ask_prompt_template: Mapped[str | None] = mapped_column(Text, nullable=True)
# 询问变量,默认为空列表,JSON格式,不能为空
ask_variables: Mapped[list[dict]] = mapped_column(JSON, nullable=False, default=list)
# 定义智能体对话会话模型
class AgentChatSession(Base):
# 表名
__tablename__ = "agent_chat_sessions"
# 主键,自增
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# 智能体ID,有索引且必填
agent_id: Mapped[int] = mapped_column(nullable=False, index=True)
# 会话标题,必填,默认“新对话”
title: Mapped[str] = mapped_column(String(255), nullable=False, default="新对话")
# 创建时间,默认当前时间
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), server_default=func.now(), nullable=False
)
# 更新时间,默认当前时间,修改时更新
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), server_default=func.now(), onupdate=func.now(), nullable=False
)
# 定义智能体对话消息模型
class AgentChatMessage(Base):
# 表名
__tablename__ = "agent_chat_messages"
# 主键,自增
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# 会话ID,有索引且必填
session_id: Mapped[int] = mapped_column(nullable=False, index=True)
# 发送者角色(如user/agent),必填
role: Mapped[str] = mapped_column(String(32), nullable=False)
# 消息内容,必填
content: Mapped[str] = mapped_column(Text, nullable=False)
# 附加元信息,默认为空dict,JSON格式
meta: Mapped[dict] = mapped_column(JSON, nullable=False, default=dict)
# 消息创建时间,默认为当前时间
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), server_default=func.now(), nullable=False
)4.2. 表说明 #
| 表名 | 中文说明 |
|---|---|
mcp_services |
MCP 服务注册(名称唯一),存协议类型与连接配置 JSON |
llm_models |
大模型提供商配置(provider_name 唯一),含接口地址、密钥、可用模型列表 |
agents |
智能体配置:系统提示、绑定提供商/模型名、绑定的 MCP id 列表、询问变量与模板 |
agent_chat_sessions |
某智能体下的一次对话会话 |
agent_chat_messages |
会话内的单条消息(角色、正文、扩展元数据) |
4.2.1 mcp_services #
| 字段 | 类型 | 约束 | 说明 |
|---|---|---|---|
id |
int |
PK,自增 | 主键 |
name |
varchar(255) |
NOT NULL,唯一 | 服务展示/管理用名称 |
description |
text |
可空 | 服务说明 |
protocol |
varchar(32) |
NOT NULL | 传输协议:stdio / sse / streamable-http 等 |
config |
json |
NOT NULL | 连接参数(命令行、URL、headers、env 等) |
4.2.2 llm_models #
| 字段 | 类型 | 约束 | 说明 |
|---|---|---|---|
id |
int |
PK,自增 | 主键 |
provider_name |
varchar(255) |
NOT NULL,唯一 | 提供商/配置名称,与智能体里 llm_provider_name 对应 |
provider_icon |
text |
可空 | 图标 URL 等 |
api_base_url |
varchar(1024) |
NOT NULL | OpenAI 兼容 API 根地址 |
api_key |
varchar(1024) |
NOT NULL | 调用密钥(敏感) |
api_key_url |
varchar(1024) |
可空 | 申请密钥的页面等 |
model_names |
json |
NOT NULL | 该配置下模型 id 列表(JSON 数组) |
4.2.3 agents #
| 字段 | 类型 | 约束 | 说明 |
|---|---|---|---|
id |
int |
PK,自增 | 主键 |
avatar |
text |
可空 | 头像 URL 路径等 |
name |
varchar(255) |
NOT NULL,有索引 | 智能体名称 |
description |
text |
可空 | 描述 |
opening_message |
text |
可空 | 开场白文案 |
system_prompt |
text |
NOT NULL | 系统提示词 |
llm_provider_name |
varchar(255) |
NOT NULL | 对应 llm_models.provider_name(逻辑关联,非 FK) |
llm_model_name |
varchar(255) |
NOT NULL | 模型名,须在对应配置的 model_names 中可用 |
mcp_service_ids |
json |
NOT NULL | 绑定的 mcp_services.id 列表(JSON 数组) |
ask_prompt_template |
text |
可空 | 收集完变量后拼装用户侧提示的模板(可含 {{变量}}) |
ask_variables |
json |
NOT NULL | 变量定义列表(key、question、required 等),驱动多轮提问 |
4.2.4 agent_chat_sessions #
| 字段 | 类型 | 约束 | 说明 |
|---|---|---|---|
id |
int |
PK,自增 | 主键 |
agent_id |
int |
NOT NULL,有索引 | 所属智能体 agents.id(逻辑 FK) |
title |
varchar(255) |
NOT NULL | 会话标题,默认如「新对话」 |
created_at |
datetime |
NOT NULL,默认 now() |
创建时间 |
updated_at |
datetime |
NOT NULL,默认 now() |
更新时间(ORM 侧可能 onupdate,以库内实际为准) |
4.2.5 agent_chat_messages #
| 字段 | 类型 | 约束 | 说明 |
|---|---|---|---|
id |
int |
PK,自增 | 主键 |
session_id |
int |
NOT NULL,有索引 | 所属会话 agent_chat_sessions.id(逻辑 FK) |
role |
varchar(32) |
NOT NULL | 角色:user / assistant 等 |
content |
text |
NOT NULL | 消息正文(含模型输出、用户追问等) |
meta |
json |
NOT NULL | 扩展信息,如 kind:opening_message、ask_variable、ask_variable_answer 等 |
created_at |
datetime |
NOT NULL,默认 now() |
创建时间 |
4.2.6 逻辑关系 #
| 从 | 到 | 关联方式 |
|---|---|---|
agent_chat_sessions.agent_id |
agents.id |
整数引用 |
agent_chat_messages.session_id |
agent_chat_sessions.id |
整数引用 |
agents.llm_provider_name |
llm_models.provider_name |
字符串匹配 |
agents.mcp_service_ids[] |
mcp_services.id |
JSON 数组中的 id |
4.3. ER 图 #
agent_chat_sessions.agent_id→agents.idagent_chat_messages.session_id→agent_chat_sessions.idagents.llm_provider_name↔llm_models.provider_name(逻辑)agents.mcp_service_ids内含mcp_services.id(逻辑 N:M)
4.4. 高层关系 #
5. 建表 #
本节将详细说明如何通过 SQLAlchemy 数据模型管理上述数据表结构。
SQLAlchemy 数据模型定义
在 app/models.py 文件中,你需要为每个表定义一个对应的 SQLAlchemy ORM 类。例如:
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, ForeignKey
from sqlalchemy.orm import declarative_base, relationship
import datetime
Base = declarative_base()
class Agent(Base):
__tablename__ = "agents"
class AgentChatSession(Base):
__tablename__ = "agent_chat_sessions"
class AgentChatMessage(Base):
__tablename__ = "agent_chat_messages"自动建表
- 自动建表:如
main.py片段所示,Base.metadata.create_all(bind=engine)可在应用启动时自动创建表结构。
关联关系说明
agent_chat_sessions.agent_id通过外键指向agents.id,实现一个 Agent 关联多个 Session 的 1:N 关系。agent_chat_messages.session_id通过外键指向agent_chat_sessions.id,实现一个 Session 关联多条 Message 的 1:N 关系。agents.llm_provider_name与llm_models.provider_name通过名称字段逻辑绑定,无物理外键。agents.mcp_service_ids以 JSON 数组存储多个mcp_services.id,实现灵活关联(N:M,需应用层逻辑管理)。
5.1. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 导入logging库以便后续日志记录
+import logging
# 导入asynccontextmanager用于异步上下文管理器
+from contextlib import asynccontextmanager
# 从app.database模块导入Base和engine,用于数据库相关操作
+from app.database import Base, engine
# 导入app.models模块下的所有内容(类、函数等)
+from app.models import *
# 配置日志的基本设置,日志级别为INFO
+logging.basicConfig(level=logging.INFO)
# 定义一个异步上下文管理器,用于FastAPI生命周期
+@asynccontextmanager
+async def lifespan(app: FastAPI):
# 创建所有数据库表结构(如果未存在则自动创建)
+ Base.metadata.create_all(bind=engine)
# 通过yield挂起,等待应用关闭时进行清理
+ yield
# 创建FastAPI应用实例,设置API标题和版本号,并指定生命周期管理器
+app = FastAPI(title="智能体服务", version="0.1.0", lifespan=lifespan)
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}5.2. main.py #
main.py
# 导入uvicorn库,用于运行ASGI服务器
import uvicorn
# 判断是否为主程序运行入口
if __name__ == "__main__":
# 启动uvicorn服务器,加载app.main模块下的app实例
# host设置为0.0.0.0以便外部主机访问
# port设置为8000
# reload设置为True用于开发时自动重启
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=True,
)
6. 跨域 #
在FastAPI中实现跨域(CORS)支持,最常用的方法是引入CORSMiddleware中间件。这样可以确保你的API能够被浏览器中的前端应用安全地访问,尤其是在本地和线上环境存在不同域名或端口时。
- CORS(跨域资源共享):默认情况下,浏览器出于安全考虑会阻止网页访问不同源(协议、域名或端口不同)下的API。CORS是一种机制,允许服务端声明可被哪些源访问,从而实现安全的跨域请求。
- CORSMiddleware:FastAPI中集成的中间件,配置后自动为API响应添加适当的CORS头部信息。
- 导入中间件和配置
from fastapi.middleware.cors import CORSMiddlewarefrom app.config import settings
- 解析 CORS 允许的来源
- 通常会将允许的域名列表写在环境变量(示例:
settings.cors_origins),用英文逗号分隔。 - 通过列表推导式进行分割和清理空格、去除空字符串,得到最终的
origins列表。
- 通常会将允许的域名列表写在环境变量(示例:
- 注册中间件
- 使用
app.add_middleware(...)方法,把CORSMiddleware加到FastAPI应用上。 - 通常建议设置:
allow_origins: 可访问的源组成的列表(如开发阶段通常允许所有源,生产环境请精确配置)。allow_credentials: 是否允许cookie、认证等凭证。allow_methods与allow_headers均设为["*"],表示不限制方法和头部字段。
- 使用
示例场景
- 前端(如本地 http://localhost:3000)开发时访问本后端API
- 生产环境下只允许公司域名访问API
此配置提升了服务的灵活性和安全性。
常见问题
- 配置了 CORS 但仍报跨域错误?请检查:
- 前端请求地址(端口、协议等是否与允许列表对应)
allow_origins是否包含了请求源- nginx、网关等外层代理是否覆盖或删改了CORS相关头信息
推荐做法
- 开发环境:
cors_origins可设为*或http://localhost:3000等前端地址。 - 生产环境:
cors_origins应精确枚举允许的正式域名,防止被恶意第三方利用。
6.1. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 导入logging库以便后续日志记录
import logging
# 导入asynccontextmanager用于异步上下文管理器
from contextlib import asynccontextmanager
# 导入FastAPI的CORS中间件,用于跨域资源共享
+from fastapi.middleware.cors import CORSMiddleware
# 导入项目配置settings对象
+from app.config import settings
# 从app.database模块导入Base和engine,用于数据库相关操作
from app.database import Base, engine
# 导入app.models模块下的所有内容(类、函数等)
from app.models import *
# 配置日志的基本设置,日志级别为INFO
logging.basicConfig(level=logging.INFO)
# 定义一个异步上下文管理器,用于FastAPI生命周期
@asynccontextmanager
async def lifespan(app: FastAPI):
# 创建所有数据库表结构(如果未存在则自动创建)
Base.metadata.create_all(bind=engine)
# 通过yield挂起,等待应用关闭时进行清理
yield
# 创建FastAPI应用实例,设置API标题和版本号,并指定生命周期管理器
app = FastAPI(title="智能体服务", version="0.1.0", lifespan=lifespan)
# 解析配置中的CORS来源列表,去除空白项和空字符串
+origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
# 向FastAPI应用添加CORS中间件
+app.add_middleware(
+ CORSMiddleware,
# 允许访问的来源列表,如果为空则允许所有来源("*")
+ allow_origins=origins or ["*"],
# 允许携带cookie等凭证
+ allow_credentials=True,
# 允许所有HTTP方法
+ allow_methods=["*"],
# 允许所有HTTP头
+ allow_headers=["*"],
+)
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}7. 添加MCP服务 #
本节我们将为项目添加 MCP 服务 (即多通道处理服务,Multi-Channel Processing Service)。这一部分内容主要包括:MCP 服务的数据模型定义、接口(API)定义、数据库操作方法(Repository)实现、以及路由(Router)注册流程。
主要内容如下:
数据模型(Model)和序列化(Schema)
- 在
app/models.py中增加McpService模型,描述 MCP 服务的数据结构及其字段(如名称、描述、协议类型、配置信息等)。 - 在
app/schemas.py中定义与模型对应的输入和输出序列化类,以用于数据校验和文档自动生成。
- 在
数据库操作(Repository)
- 在
app/repositories/mcp_repository.py中实现 MCP 服务的增删查改方法。比如create_mcp_service函数,用于插入新的 MCP 服务记录,该函数会接收数据库会话对象和待插入的数据对象,处理后将数据持久化到数据库。
- 在
接口路由(Router)
- 路由文件(比如
app/routers/mcp.py)定义 MCP 服务的相关 API 接口,如创建服务、查询服务列表、获取详情、删除服务等。这些接口会调用 repository 层实现实际的数据操作。 - 路由注册通常会在 FastAPI 主实例中(如
app/main.py)统一挂载。
- 路由文件(比如
依赖注入与请求/响应模型
- 结合 FastAPI 的依赖注入特性,通过参数注入
Session数据库会话、以及参数校验自动对接相关pydantic数据模型。
- 结合 FastAPI 的依赖注入特性,通过参数注入
具体开发流程
- 先设计并创建数据表和模型;
- 再实现操作数据库的 repository 层方法;
- 编写与之匹配的 schema;
- 最后写 API 路由和接口逻辑,并将路由注册到应用主程序。
本节内容有助于你了解如何使用 FastAPI 构建结构清晰、解耦良好的 RESTful 服务,便于后续的功能扩展和维护。你可以根据需求灵活调整服务字段及接口设计。
7.1. init.py #
app/repositories/init.py
from . import mcp_repository
__all__ = ["mcp_repository"]
7.2. mcp_repository.py #
app/repositories/mcp_repository.py
# 导入SQLAlchemy的select
from sqlalchemy import select
# 导入SQLAlchemy的Session对象
from sqlalchemy.orm import Session
# 导入项目中的models模块
from app import models
# 导入项目中的schemas模块
from app import schemas
# 定义创建McpService的函数,接收数据库会话db和待创建数据data,返回新建的McpService对象
def create_mcp_service(db: Session, data: schemas.McpServiceCreate) -> models.McpService:
# 构造McpService模型对象,strip去除名字首尾空白
row = models.McpService(
name=data.name.strip(),#去除名字首尾空白
description=data.description,#服务描述
protocol=data.protocol.value,#协议类型
config=data.config,#配置信息
)
# 添加新对象到会话
db.add(row)
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含数据库自动生成的字段值
db.refresh(row)
# 返回新建的McpService对象
return row
7.3. init.py #
app/routers/init.py
# routers
7.4. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义创建服务的POST接口,响应模型为McpServiceOut
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
# 使用try-except来捕捉数据库插入冲突
try:
# 调用仓库方法创建服务
return mcp_repository.create_mcp_service(db, payload)
# 捕获唯一性约束异常(如名称重复)
except IntegrityError:
# 回滚数据库会话,撤消操作
db.rollback()
# 抛出409冲突异常,返回错误信息
raise HTTPException(status_code=409, detail="名称已存在")
7.5. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}7.6. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 导入logging库以便后续日志记录
import logging
# 导入asynccontextmanager用于异步上下文管理器
from contextlib import asynccontextmanager
# 导入FastAPI的CORS中间件,用于跨域资源共享
from fastapi.middleware.cors import CORSMiddleware
# 导入项目配置settings对象
from app.config import settings
# 从app.database模块导入Base和engine,用于数据库相关操作
from app.database import Base, engine
# 导入app.models模块下的所有内容(类、函数等)
from app.models import *
# 导入MCP服务路由
+from app.routers import mcp_services
# 配置日志的基本设置,日志级别为INFO
logging.basicConfig(level=logging.INFO)
# 定义一个异步上下文管理器,用于FastAPI生命周期
@asynccontextmanager
async def lifespan(app: FastAPI):
# 创建所有数据库表结构(如果未存在则自动创建)
Base.metadata.create_all(bind=engine)
# 通过yield挂起,等待应用关闭时进行清理
yield
# 创建FastAPI应用实例,设置API标题和版本号,并指定生命周期管理器
app = FastAPI(title="智能体服务", version="0.1.0", lifespan=lifespan)
# 解析配置中的CORS来源列表,去除空白项和空字符串
origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
# 向FastAPI应用添加CORS中间件
app.add_middleware(
CORSMiddleware,
# 允许访问的来源列表,如果为空则允许所有来源("*")
allow_origins=origins or ["*"],
# 允许携带cookie等凭证
allow_credentials=True,
# 允许所有HTTP方法
allow_methods=["*"],
# 允许所有HTTP头
allow_headers=["*"],
)
# 包含MCP服务路由
+app.include_router(mcp_services.router)
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}7.7 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services" --header "Content-Type: application/json" --data-raw "{\"name\": \"name\",\"description\": \"description\",\"protocol\": \"streamable-http\",\"config\": {\"url\": \"http://127.0.0.1:8002/mcp\",\"headers\": {\"BAIDU_MAP_AK\": \"xxx\",\"DEEPSEEK_API_KEY\": \"yyy\"}}}"
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services" --header "Content-Type: application/json" --data-binary "@payload.json"payload.json
{
"name": "路线规划服务",
"description": "路线规划服务",
"protocol": "streamable-http",
"config": {
"url": "http://127.0.0.1:8002/mcp",
"headers": {
"BAIDU_MAP_AK": "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX",
"DEEPSEEK_API_KEY": "sk-24088156e9ab48f3adddaf5a9c0c4ede"
}
}
}curl
- 命令行工具,用于发送 HTTP/HTTPS 等协议的请求。
--location
- 如果服务器返回 3xx 重定向响应,curl 会自动跟随重定向,向新的地址再次发起请求。
- 这里可能用不到,但加上可确保在服务端有重定向时请求仍能成功。
--request POST
- 显式指定 HTTP 方法为 POST。
- 如果省略,curl 会根据数据自动判断(例如有
--data时会默认为 POST),但显式声明更清晰。
"http://127.0.0.1:8000/api/mcp-services"
- 请求的目标 URL。
127.0.0.1是本机地址,端口8000,路径/api/mcp-services。- 通常这是一个 RESTful 接口,用于创建新的 MCP 服务。
`--header "Content-Type: application/json"
- 添加 HTTP 请求头,告知服务器请求体是 JSON 格式。
- 服务器一般会根据这个头来解析请求体内容。
`--data-raw "..."
- 指定请求体(body)的原始内容。
- 与
--data的区别:--data会将@开头的字符串解释为文件名,读取文件内容;--data-raw对@不作特殊处理,原样发送。
- 这里使用
--data-raw可以避免 JSON 中的@被误解析(虽然本例 JSON 中没有@)。 - 注意整个 JSON 被双引号包裹,内部的引号需要用反斜杠转义。
请求体(JSON 结构)
{
"name": "name",
"description": "description",
"protocol": "streamable-http",
"config": {
"url": "http://127.0.0.1:8002/mcp",
"headers": {
"BAIDU_MAP_AK": "xxx",
"DEEPSEEK_API_KEY": "yyy"
}
}
}name和description:服务的名称和描述。protocol:服务使用的协议,这里是streamable-http(一种可能用于流式传输的 HTTP 扩展)。config:服务的具体配置。url:服务实际监听的地址。headers:调用该服务时需要附加的 HTTP 头,包含百度地图 API Key 和 DeepSeek API Key(示例中占位为xxx和yyy)。
总结:该命令通过 POST 方式向本机 8000 端口的 MCP 服务管理接口发送一个 JSON 格式的服务定义,用于注册一个新的路线规划服务。
8. 查看MCP服务列表 #
在本节将讲解如何通过 API 查看已注册的所有 MCP 服务。
- 典型用法:获取所有已注册的路线规划服务列表。
示例 curl 命令
curl -X GET "http://127.0.0.1:8000/api/mcp-services"请求说明
-X GET:指定 HTTP 方法为 GET,获取资源。- URL:指向用于获取 MCP 服务列表的 API 路径。
响应格式
接口会返回 MCP 服务的 JSON 数组。每个元素包含服务的详细信息。例如:
[
{
"id": 1,
"name": "name",
"description": "description",
"protocol": "streamable-http",
"config": {
"url": "http://127.0.0.1:8002/mcp",
"headers": {
"BAIDU_MAP_AK": "xxx",
"DEEPSEEK_API_KEY": "yyy"
}
},
"created_at": "2024-05-01T12:00:00"
}
]字段说明
id:服务唯一标识。name:服务名称。description:服务描述。protocol:协议类型(如streamable-http)。config:配置信息(如服务调用地址和所需 header)。created_at:服务注册时间。
实现原理
- 后端路由
/api/mcp-services配置了 GET 方法的处理函数。 - 查询数据库,按
id倒序取出所有 MCP 服务。 - 使用 Pydantic Schema 自动转换数据库对象为 JSON 格式返回前端。
总结:通过该接口可以随时查看注册在 MCP 管理平台下的所有服务及其详情,便于统一运维和后续调用。
8.1. mcp_repository.py #
app/repositories/mcp_repository.py
# 导入SQLAlchemy的select
from sqlalchemy import select
# 导入SQLAlchemy的Session对象
from sqlalchemy.orm import Session
# 导入项目中的models模块
from app import models
# 导入项目中的schemas模块
from app import schemas
# 定义创建McpService的函数,接收数据库会话db和待创建数据data,返回新建的McpService对象
def create_mcp_service(db: Session, data: schemas.McpServiceCreate) -> models.McpService:
# 构造McpService模型对象,strip去除名字首尾空白
row = models.McpService(
name=data.name.strip(),#去除名字首尾空白
description=data.description,#服务描述
protocol=data.protocol.value,#协议类型
config=data.config,#配置信息
)
# 添加新对象到会话
db.add(row)
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含数据库自动生成的字段值
db.refresh(row)
# 返回新建的McpService对象
return row
# 定义获取所有McpService对象的函数,参数为数据库会话db,返回McpService对象列表
+def list_mcp_services(db: Session) -> list[models.McpService]:
# 构造按id倒序排序的查询,获取所有McpService记录
+ return list(db.scalars(select(models.McpService).order_by(models.McpService.id.desc())).all())
8.2. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义GET方法,用于列出所有已注册的MCP服务,返回值为McpServiceOut的列表
+@router.get("", response_model=list[schemas.McpServiceOut])
+def list_services(db: Session = Depends(get_db)):
# 调用mcp_repository中的list_mcp_services方法,查询数据库中的所有服务
+ return mcp_repository.list_mcp_services(db)
# 定义POST方法,用于创建一个新的MCP服务,接收McpServiceCreate模式对象作为请求体
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
try:
# 调用mcp_repository中的create_mcp_service方法,将新服务信息写入数据库并返回
return mcp_repository.create_mcp_service(db, payload)
except IntegrityError:
# 捕获唯一性约束异常,如服务名称已存在,进行回滚操作
db.rollback()
# 抛出HTTP异常,状态码409,表示名称冲突
+ raise HTTPException(status_code=409, detail="名称已存在")8.3. main.py #
main.py
# 导入uvicorn库,用于运行ASGI服务器
import uvicorn
# 判断是否为主程序运行入口
if __name__ == "__main__":
# 启动uvicorn服务器,加载app.main模块下的app实例
# host设置为0.0.0.0以便外部主机访问
# port设置为8000
# reload设置为True用于开发时自动重启
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
+ reload=False,
)
9. 查看MCP服务 #
实际开发中,通常我们需要按唯一主键(如ID)来获取某个具体的服务(即查询详情页)。相比“全部列表”,详情接口返回单条数据,通常会在前端页面的“详情弹窗”或“编辑表单”场景里用到。
在本项目中,你需要实现如下步骤:
Repository 层实现
在app/repositories/mcp_repository.py中,新增了get_mcp_service函数。它接收数据库会话db和服务主键mcp_id:- 核心语句
db.get(models.McpService, mcp_id),利用 SQLAlchemy 的get方法直接按主键查询数据,高效且简洁。 - 如果数据库存在该 ID 的记录,则返回模型实例;否则返回
None。
- 核心语句
Router 层接口实现
在app/routers/mcp_services.py中,新增了获取单个服务的GET路由:- 路径为
/api/mcp-services/{mcp_id},参数取自 URL 路径。 - 首先利用 repository 层的
get_mcp_service查询。 - 若结果不存在,使用 FastAPI 的
HTTPException抛出 404 状态码提示“记录不存在”;否则将记录返回,自动转换为输出 Schema。
- 路径为
接口使用示例
使用 curl 或 Postman 请求:GET /api/mcp-services/1返回值为指定 ID 的 MCP 服务详细信息,若找不到则返回 404 错误提示。
这种方式能更好地支持前端“点击查看详情”或“编辑实体”这种典型场景,编程实践中十分常见。
小提示: repository 层一般只负责数据库操作,不直接处理异常和 HTTP 相关逻辑,路由层负责捕获异常并给出联动响应。
9.1. mcp_repository.py #
app/repositories/mcp_repository.py
# 导入SQLAlchemy的select
from sqlalchemy import select
# 导入SQLAlchemy的Session对象
from sqlalchemy.orm import Session
# 导入项目中的models模块
from app import models
# 导入项目中的schemas模块
from app import schemas
# 定义创建McpService的函数,接收数据库会话db和待创建数据data,返回新建的McpService对象
def create_mcp_service(db: Session, data: schemas.McpServiceCreate) -> models.McpService:
# 构造McpService模型对象,strip去除名字首尾空白
row = models.McpService(
name=data.name.strip(),#去除名字首尾空白
description=data.description,#服务描述
protocol=data.protocol.value,#协议类型
config=data.config,#配置信息
)
# 添加新对象到会话
db.add(row)
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含数据库自动生成的字段值
db.refresh(row)
# 返回新建的McpService对象
return row
# 定义获取所有McpService对象的函数,参数为数据库会话db,返回McpService对象列表
def list_mcp_services(db: Session) -> list[models.McpService]:
# 构造按id倒序排序的查询,获取所有McpService记录
return list(db.scalars(select(models.McpService).order_by(models.McpService.id.desc())).all())
# 定义一个函数,根据传入的mcp_id从数据库中获取对应的McpService对象
# 参数db为数据库会话对象,mcp_id为服务的主键ID
# 如果找到则返回对应的McpService对象,否则返回None
+def get_mcp_service(db: Session, mcp_id: int) -> models.McpService | None:
# 调用SQLAlchemy的get方法根据主键查询McpService
+ return db.get(models.McpService, mcp_id)9.2. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义GET方法,用于列出所有已注册的MCP服务,返回值为McpServiceOut的列表
@router.get("", response_model=list[schemas.McpServiceOut])
def list_services(db: Session = Depends(get_db)):
# 调用mcp_repository中的list_mcp_services方法,查询数据库中的所有服务
return mcp_repository.list_mcp_services(db)
# 定义POST方法,用于创建一个新的MCP服务,接收McpServiceCreate模式对象作为请求体
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
try:
# 调用mcp_repository中的create_mcp_service方法,将新服务信息写入数据库并返回
return mcp_repository.create_mcp_service(db, payload)
except IntegrityError:
# 捕获唯一性约束异常,如服务名称已存在,进行回滚操作
db.rollback()
# 抛出HTTP异常,状态码409,表示名称冲突
raise HTTPException(status_code=409, detail="名称已存在")
# 定义一个GET类型的路由,用于根据mcp_id获取单个MCP服务,返回McpServiceOut模型
+@router.get("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数:根据传入的mcp_id和数据库会话读取指定的服务
+def get_service(mcp_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取对应id的MCP服务记录
+ row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 返回查找到的服务记录
+ return row curl --location --request GET "http://127.0.0.1:8000/api/mcp-services?mcp_id=1"10. 更新MCP服务 #
本节将介绍如何通过API接口实现对已有MCP服务信息的更新操作。通常我们使用HTTP的PUT或PATCH方法,通过指定服务的mcp_id并传递需要更新的数据,对已有的服务进行修改。此操作适用于修改服务的描述、配置、协议类型等字段。
接口说明
- 接口路径:
/api/mcp-services/{mcp_id} - 请求方法:
PUT(全部字段更新)或PATCH(部分字段更新,推荐) - 请求参数:
- 路径参数:
mcp_id(int) — MCP服务的唯一标识符 - 请求体:JSON格式,内容需符合
McpServiceUpdate或McpServicePatch模型
- 路径参数:
- 返回数据:更新后的MCP服务对象
典型请求示例
curl --location --request PATCH "http://127.0.0.1:8000/api/mcp-services/1" \
--header "Content-Type: application/json" \
--data-raw '{
"description": "新的服务描述",
"protocol": "http",
"config": {
"url": "https://example.com/updated"
}
}'典型响应
{
"id": 1,
"name": "MyMcp",
"description": "新的服务描述",
"protocol": "http",
"config": {
"url": "https://example.com/updated"
}
}注意事项
- 更新操作会根据传入的内容,按需更新MCP服务表中的相关字段。未提供的字段保持不变。
- 若更新的服务名称与数据库中已存在的其他服务名称冲突,会返回409冲突错误。
- 修改不存在的
mcp_id将返回404错误。
10.1. mcp_repository.py #
app/repositories/mcp_repository.py
# 导入SQLAlchemy的select
from sqlalchemy import select
# 导入SQLAlchemy的Session对象
from sqlalchemy.orm import Session
# 导入项目中的models模块
from app import models
# 导入项目中的schemas模块
from app import schemas
# 定义创建McpService的函数,接收数据库会话db和待创建数据data,返回新建的McpService对象
def create_mcp_service(db: Session, data: schemas.McpServiceCreate) -> models.McpService:
# 构造McpService模型对象,strip去除名字首尾空白
row = models.McpService(
name=data.name.strip(),#去除名字首尾空白
description=data.description,#服务描述
protocol=data.protocol.value,#协议类型
config=data.config,#配置信息
)
# 添加新对象到会话
db.add(row)
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含数据库自动生成的字段值
db.refresh(row)
# 返回新建的McpService对象
return row
# 定义获取所有McpService对象的函数,参数为数据库会话db,返回McpService对象列表
def list_mcp_services(db: Session) -> list[models.McpService]:
# 构造按id倒序排序的查询,获取所有McpService记录
return list(db.scalars(select(models.McpService).order_by(models.McpService.id.desc())).all())
# 定义一个函数,根据传入的mcp_id从数据库中获取对应的McpService对象
# 参数db为数据库会话对象,mcp_id为服务的主键ID
# 如果找到则返回对应的McpService对象,否则返回None
def get_mcp_service(db: Session, mcp_id: int) -> models.McpService | None:
# 调用SQLAlchemy的get方法根据主键查询McpService
return db.get(models.McpService, mcp_id)
# 定义一个函数,通过服务名称从数据库中获取对应的McpService对象
# 参数db为数据库会话对象,name为服务名称
# 如果找到则返回对应的McpService对象,否则返回None
+def get_by_name(db: Session, name: str) -> models.McpService | None:
# 构造查询,根据名称筛选McpService记录,并返回第一条结果
+ return db.scalar(select(models.McpService).where(models.McpService.name == name))
# 定义一个函数,更新指定的McpService对象内容
# 参数db为数据库会话对象,row为待更新的McpService对象,data为更新数据
# 返回更新后的McpService对象
+def update_mcp_service(db: Session, row: models.McpService, data: schemas.McpServiceUpdate) -> models.McpService:
# 如果更新数据中name字段不为空,则去除首尾空白后赋值
+ if data.name is not None:
+ row.name = data.name.strip()
# 如果更新数据中description字段不为空,则赋值
+ if data.description is not None:
+ row.description = data.description
# 如果更新数据中protocol字段不为空,则将其值转为字符串后赋值
+ if data.protocol is not None:
+ row.protocol = data.protocol.value
# 如果更新数据中config字段不为空,则赋值
+ if data.config is not None:
+ row.config = data.config
# 提交事务,将更改保存到数据库
+ db.commit()
# 刷新实例,确保row包含最新的字段值
+ db.refresh(row)
# 返回更新后的McpService对象
+ return row10.2. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义GET方法,用于列出所有已注册的MCP服务,返回值为McpServiceOut的列表
@router.get("", response_model=list[schemas.McpServiceOut])
def list_services(db: Session = Depends(get_db)):
# 调用mcp_repository中的list_mcp_services方法,查询数据库中的所有服务
return mcp_repository.list_mcp_services(db)
# 定义POST方法,用于创建一个新的MCP服务,接收McpServiceCreate模式对象作为请求体
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
try:
# 调用mcp_repository中的create_mcp_service方法,将新服务信息写入数据库并返回
return mcp_repository.create_mcp_service(db, payload)
except IntegrityError:
# 捕获唯一性约束异常,如服务名称已存在,进行回滚操作
db.rollback()
# 抛出HTTP异常,状态码409,表示名称冲突
raise HTTPException(status_code=409, detail="名称已存在")
# 定义一个GET类型的路由,用于根据mcp_id获取单个MCP服务,返回McpServiceOut模型
@router.get("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数:根据传入的mcp_id和数据库会话读取指定的服务
def get_service(mcp_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取对应id的MCP服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 返回查找到的服务记录
return row
# 定义PUT方法用于更新指定ID的MCP服务,返回更新后的服务对象
+@router.put("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数,参数包括服务ID、更新请求体、数据库会话
+def update_service(mcp_id: int, payload: schemas.McpServiceUpdate, db: Session = Depends(get_db)):
# 首先根据mcp_id从数据库查找对应的服务记录
+ row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果未找到记录,则抛出404异常提示记录不存在
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 如果更新请求中包含新的名称
+ if payload.name is not None:
# 通过名称查找是否存在其他服务
+ other = mcp_repository.get_by_name(db, payload.name.strip())
# 如果找到的其他服务ID与当前更新目标不同,说明名称已被占用
+ if other and other.id != mcp_id:
+ raise HTTPException(status_code=409, detail="名称已被其他记录使用")
+ try:
# 调用repository方法执行数据库更新并返回结果
+ return mcp_repository.update_mcp_service(db, row, payload)
+ except IntegrityError:
# 捕获唯一约束冲突,回滚事务
+ db.rollback()
# 抛出409异常提示名称冲突
+ raise HTTPException(status_code=409, detail="名称冲突") 10.3. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
+class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
+ name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
+ description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
+ protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
+ config: dict[str, Any] | None = None 10.4 测试 #
curl --location --request PUT "http://127.0.0.1:8000/api/mcp-services/1" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"name\": \"路线规划服务\", \"description\": \"路线规划服务\", \"protocol\": \"streamable-http\", \"config\": { \"url\": \"http://127.0.0.1:8002/mcp\", \"headers\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\" } }}"11. 删除MCP服务 #
本节将介绍如何通过API接口删除(移除)已有的MCP服务。删除操作会物理移除数据库中的对应记录,无法恢复,请谨慎调用。
接口说明
- 接口路径:
/api/mcp-services/{mcp_id} - 请求方法:
DELETE - 请求参数:
- 路径参数:
mcp_id(int) — 需要删除的MCP服务的唯一标识符
- 路径参数:
- 返回数据:
{"ok": true}表示删除成功
典型请求示例
curl --location --request DELETE "http://127.0.0.1:8000/api/mcp-services/1"典型响应
{
"ok": true
}注意事项
- 若
mcp_id对应的服务不存在,API将返回404异常,响应内容为记录不存在的信息。 - 删除操作会立即生效,无法恢复。
- 只有有权的用户/系统应调用该接口,以防误删服务数据。
实现思路简述
- 路由处理器根据
mcp_id查询数据库中的服务记录。 - 若找到,则调用repository的删除方法删除记录。
- 若未找到,则返回404错误。
- 删除成功后,响应
{"ok": true}。
11.1. mcp_repository.py #
app/repositories/mcp_repository.py
# 导入SQLAlchemy的select
from sqlalchemy import select
# 导入SQLAlchemy的Session对象
from sqlalchemy.orm import Session
# 导入项目中的models模块
from app import models
# 导入项目中的schemas模块
from app import schemas
# 定义创建McpService的函数,接收数据库会话db和待创建数据data,返回新建的McpService对象
def create_mcp_service(db: Session, data: schemas.McpServiceCreate) -> models.McpService:
# 构造McpService模型对象,strip去除名字首尾空白
row = models.McpService(
name=data.name.strip(),#去除名字首尾空白
description=data.description,#服务描述
protocol=data.protocol.value,#协议类型
config=data.config,#配置信息
)
# 添加新对象到会话
db.add(row)
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含数据库自动生成的字段值
db.refresh(row)
# 返回新建的McpService对象
return row
# 定义获取所有McpService对象的函数,参数为数据库会话db,返回McpService对象列表
def list_mcp_services(db: Session) -> list[models.McpService]:
# 构造按id倒序排序的查询,获取所有McpService记录
return list(db.scalars(select(models.McpService).order_by(models.McpService.id.desc())).all())
# 定义一个函数,根据传入的mcp_id从数据库中获取对应的McpService对象
# 参数db为数据库会话对象,mcp_id为服务的主键ID
# 如果找到则返回对应的McpService对象,否则返回None
def get_mcp_service(db: Session, mcp_id: int) -> models.McpService | None:
# 调用SQLAlchemy的get方法根据主键查询McpService
return db.get(models.McpService, mcp_id)
# 定义一个函数,通过服务名称从数据库中获取对应的McpService对象
# 参数db为数据库会话对象,name为服务名称
# 如果找到则返回对应的McpService对象,否则返回None
def get_by_name(db: Session, name: str) -> models.McpService | None:
# 构造查询,根据名称筛选McpService记录,并返回第一条结果
return db.scalar(select(models.McpService).where(models.McpService.name == name))
# 定义一个函数,更新指定的McpService对象内容
# 参数db为数据库会话对象,row为待更新的McpService对象,data为更新数据
# 返回更新后的McpService对象
def update_mcp_service(db: Session, row: models.McpService, data: schemas.McpServiceUpdate) -> models.McpService:
# 如果更新数据中name字段不为空,则去除首尾空白后赋值
if data.name is not None:
row.name = data.name.strip()
# 如果更新数据中description字段不为空,则赋值
if data.description is not None:
row.description = data.description
# 如果更新数据中protocol字段不为空,则将其值转为字符串后赋值
if data.protocol is not None:
row.protocol = data.protocol.value
# 如果更新数据中config字段不为空,则赋值
if data.config is not None:
row.config = data.config
# 提交事务,将更改保存到数据库
db.commit()
# 刷新实例,确保row包含最新的字段值
db.refresh(row)
# 返回更新后的McpService对象
return row
# 定义删除McpService服务的函数,接收数据库会话db和待删除的McpService对象row,无返回值
+def delete_mcp_service(db: Session, row: models.McpService) -> None:
# 从数据库会话中删除指定的row对象
+ db.delete(row)
# 提交删除操作,将更改保存到数据库
+ db.commit()
11.2. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义GET方法,用于列出所有已注册的MCP服务,返回值为McpServiceOut的列表
@router.get("", response_model=list[schemas.McpServiceOut])
def list_services(db: Session = Depends(get_db)):
# 调用mcp_repository中的list_mcp_services方法,查询数据库中的所有服务
return mcp_repository.list_mcp_services(db)
# 定义POST方法,用于创建一个新的MCP服务,接收McpServiceCreate模式对象作为请求体
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
try:
# 调用mcp_repository中的create_mcp_service方法,将新服务信息写入数据库并返回
return mcp_repository.create_mcp_service(db, payload)
except IntegrityError:
# 捕获唯一性约束异常,如服务名称已存在,进行回滚操作
db.rollback()
# 抛出HTTP异常,状态码409,表示名称冲突
raise HTTPException(status_code=409, detail="名称已存在")
# 定义一个GET类型的路由,用于根据mcp_id获取单个MCP服务,返回McpServiceOut模型
@router.get("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数:根据传入的mcp_id和数据库会话读取指定的服务
def get_service(mcp_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取对应id的MCP服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 返回查找到的服务记录
return row
# 定义PUT方法用于更新指定ID的MCP服务,返回更新后的服务对象
@router.put("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数,参数包括服务ID、更新请求体、数据库会话
def update_service(mcp_id: int, payload: schemas.McpServiceUpdate, db: Session = Depends(get_db)):
# 首先根据mcp_id从数据库查找对应的服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果未找到记录,则抛出404异常提示记录不存在
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 如果更新请求中包含新的名称
if payload.name is not None:
# 通过名称查找是否存在其他服务
other = mcp_repository.get_by_name(db, payload.name.strip())
# 如果找到的其他服务ID与当前更新目标不同,说明名称已被占用
if other and other.id != mcp_id:
raise HTTPException(status_code=409, detail="名称已被其他记录使用")
try:
# 调用repository方法执行数据库更新并返回结果
return mcp_repository.update_mcp_service(db, row, payload)
except IntegrityError:
# 捕获唯一约束冲突,回滚事务
db.rollback()
# 抛出409异常提示名称冲突
raise HTTPException(status_code=409, detail="名称冲突")
# 定义DELETE方法的路由,指定路径参数mcp_id
+@router.delete("/{mcp_id}")
# 删除指定ID的MCP服务,db为数据库会话依赖
+def delete_service(mcp_id: int, db: Session = Depends(get_db)):
# 根据mcp_id查询对应的MCP服务记录
+ row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库方法删除该服务记录
+ mcp_repository.delete_mcp_service(db, row)
# 返回成功标志
+ return {"ok": True} 11.3 测试 #
curl --location --request DELETE "http://127.0.0.1:8000/api/mcp-services/1"11. 测试MCP服务 #
本部分将介绍如何基于 百度地图路线规划API 搭建并测试自定义的 MCP 路线规划服务及客户端。
主要内容包括:
- 路线规划服务端(
mcp-services/direction-server.py)部署和接口说明 - 调试客户端(
mcp-services/direction-client.py)调用方式 - API 测试方法示例
- 关键模型与字段说明
路线规划服务端说明
mcp-services/direction-server.py 是基于 FastMCP 的 MCP 路线规划服务实现。主要功能:
- 支持通过 MCP 协议进行路线规划(如驾车、公共交通等),调用百度地图 API 实时查询并返回规划结果。
- 核心接口为
plan_route工具,入参包含自然语言描述、最大步数等,并可自动提取起点、终点及规划方式。 - 支持异步运行,面向流式 HTTP 交互,方便对接各类前端和中台应用。
启动方式:
python mcp-services/direction-server.py服务将默认在 http://127.0.0.1:8002/mcp 上通过 Streamable HTTP MCP 协议提供服务。
请求头要求
所有调用均需提供有效的以下请求头:
BAIDU_MAP_AK:你的百度地图开放平台 AKDEEPSEEK_API_KEY:DeepSeek AI Key(用于智能解析自然语言输入)
调试用客户端说明
mcp-services/direction-client.py 提供了与上述服务端配套的异步 HTTP 客户端样例,演示如何通过 Streamable HTTP 协议调用服务:
- 内置两种调用示例:普通驾车路线、长途公共交通路线
- 日志实时输出路线结果摘要
- 可根据实际需求修改
user_input或 headers 实现自定义调用
运行方式:
python mcp-services/direction-client.pyAPI 测试示例
你可以直接用 curl 对 MCP 服务进行测试,如下所示:
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" \
--header "Content-Type: application/json" \
--data-raw '{
"protocol": "streamable-http",
"config": {
"url": "http://127.0.0.1:8002/mcp",
"headers": {
"BAIDU_MAP_AK": "你的百度地图AK",
"DEEPSEEK_API_KEY": "你的DeepSeek Key"
}
}
}'其中 AK 和 Key 请替换为你自己申请的有效密钥。
返回结果示例
调用成功时,将以结构化文本或 JSON 返回,包括:
- 输入参数(如起终点、方式等)
- 规划总里程与预计时长
- 导航摘要(分步文字路线说明)
模型与关键参数说明
- 服务参数、响应模型定义详见主入口
plan_route user_input支持自然语言路线描述,自动提取目标(例如:“从北京市海淀区到天津市滨海新区驾车,尽量不走高速”)max_steps可控制分步详细程度,范围 1 ~ 20
注意事项
- 确保百度地图 AK、DeepSeek API Key 有效且配置正确
- 勿泄露密钥;生产环境请务必妥善保护敏感配置
- 一些错误和异常会以结构化方式返回,可根据响应 message 字段进行排查
11.1. mcp_tester.py #
app/services/mcp_tester.py
# 导入异步库asyncio
import asyncio
# 导入自定义的streamable_http_client客户端
from mcp.client.streamable_http import streamable_http_client
# 导入日志库
import logging
# 导入httpx用于网络请求
import httpx
# 从mcp模块导入ClientSession用于会话管理
from mcp import ClientSession
# 导入自定义的mcp_httpx_client_factory工厂函数
from app.services.mcp_httpx import mcp_httpx_client_factory
# 获取logger对象用于日志输出
logger = logging.getLogger(__name__)
# 合并headers辅助函数,将配置中的header全部转换为字符串格式
def _merge_headers(config):
# 从config中获取headers
raw = config.get("headers")
# 如果headers不存在或不是dict,则返回空字典
if not raw or not isinstance(raw, dict):
return {}
# 初始化输出字典
out = {}
# 遍历原始headers,将键值都转为字符串
for k, v in raw.items():
out[str(k)] = str(v)
# 返回处理后的headers
return out
# 提取tools辅助函数
def _extract_tools(result):
# 如果result为None,返回空列表
if result is None:
return []
# 如果result已经是列表,直接返回
if isinstance(result, list):
return result
# 尝试获取result对象上的tools属性
tools = getattr(result, "tools", None)
if isinstance(tools, list):
return tools
# 如果result是字典,则优先直接取tools字段
if isinstance(result, dict):
t = result.get("tools")
if isinstance(t, list):
return t
# 如果没有,尝试进入result字段嵌套查找
nested = result.get("result")
if isinstance(nested, dict):
t = nested.get("tools")
if isinstance(t, list):
return t
# 如果无法提取到tools,则返回空列表
return []
# 标准化tool对象,确保返回字典结构{name, description, input_schema}
def _normalize_tool(tool):
# 如果tool为None,返回默认空结构
if tool is None:
return {"name": "", "description": "", "input_schema": {}}
# 如果tool是dict类型,从中提取字段并构造返回
if isinstance(tool, dict):
return {
"name": str(tool.get("name") or ""),
"description": str(tool.get("description") or ""),
"input_schema": tool.get("inputSchema") or tool.get("input_schema") or {},
}
# 否则尝试通过对象属性获取name、description、input_schema
return {
"name": str(getattr(tool, "name", "") or ""),
"description": str(getattr(tool, "description", "") or ""),
"input_schema": getattr(tool, "inputSchema", None)
or getattr(tool, "input_schema", None)
or {},
}
# 通过给定的read/write对象与MCP服务初始化,并探测tools列表
async def _probe_with_session(read, write, label):
# 使用ClientSession创建异步上下文
async with ClientSession(read, write) as session:
# 执行session初始化,带超时时间
init_result = await asyncio.wait_for(session.initialize(), timeout=15)
tools_result = None
raw_tools = []
# 最多循环尝试3次获取工具列表
for i in range(3):
# 列出tools,带超时时间
tools_result = await asyncio.wait_for(session.list_tools(), timeout=15)
# 提取tools
raw_tools = _extract_tools(tools_result)
# 如果成功提取到tool就跳出循环
if raw_tools:
break
# 若还未成功则短暂等待后重试(最多重试2次)
if i < 2:
await asyncio.sleep(0.2)
# 获取初始化结果中的serverInfo属性
server_info = getattr(init_result, "serverInfo", None)
# 对工具列表进行标准化处理
tools = [_normalize_tool(t) for t in raw_tools]
# 日志输出工具数量
logger.info("%s probe list_tools count=%s", label, len(tools))
# 如果serverInfo存在且有name则写入返回消息
if server_info and getattr(server_info, "name", None):
return True, f"{label} MCP 初始化成功: {server_info.name}", tools
# 否则只是简单初始化成功
return True, f"{label} MCP 初始化成功", tools
# 测试 streamable-http 协议的 MCP 服务
def test_streamable_http(config):
# 从配置中获取 url
url = config.get("url")
# 若 url 无效或不是字符串,直接返回错误信息
if not url or not isinstance(url, str):
return False, "streamable-http 配置需要字符串 url", []
# 标准化合并 headers
headers = _merge_headers(config)
# 定义异步检测函数
async def _run_http_check():
try:
# 使用定制的 httpx async client 工厂函数创建客户端(避免本地代理干扰)
async with mcp_httpx_client_factory(headers=headers) as http_client:
# 以异步方式建立与 streamable-http MCP 服务的连接,并获取读写对象
async with streamable_http_client(url, http_client=http_client) as (read, write, _):
# 调用内部探测逻辑初始化 session 并获取工具列表
return await _probe_with_session(read, write, "streamable-http")
# 捕获超时异常(如连接或响应过慢),返回特定的提示
except asyncio.TimeoutError:
return False, "等待响应超时;请检查 streamable-http 服务地址与请求头", []
# 捕获 httpx 的请求异常,如网络不可达等
except httpx.RequestError as e:
return False, f"请求失败: {e}", []
# 捕获所有其他异常,写日志,返回简要错误说明
except Exception as e:
logger.exception("streamable-http 检测失败")
return False, f"streamable-http MCP 检测失败: {e}", []
# 在主线程执行异步检测函数并返回结果
return asyncio.run(_run_http_check())
# MCP通用检测入口,根据协议选择检测方法
def test_mcp(protocol, config):
# 将协议字符串归一化为小写去除空格
p = (protocol or "").lower().strip()
# 如果协议是streamable-http或http则用http检测方法
if p in {"streamable-http", "http"}:
return test_streamable_http(config)
# 否则返回不支持的协议
return False, f"不支持的协议: {protocol}", []
11.2. mcp_httpx.py #
app\services\mcp_httpx.py
"""MCP SSE / Streamable HTTP 使用的 httpx 客户端:不读取环境变量里的代理,避免本机 127.0.0.1 被错误走代理。"""
# 导入 httpx 库,用于实现 HTTP 客户端功能
import httpx
# 从共享模块中导入默认的 SSE 读取超时时间和通用超时时间常量
from mcp.shared._httpx_utils import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT
# 定义生成 httpx.AsyncClient 的工厂函数
def mcp_httpx_client_factory(
headers=None, # 可选的 HTTP 请求头参数
timeout=None, # 可选的超时时间
auth=None, # 可选的认证参数
):
# 初始化 httpx.AsyncClient 构造参数
kwargs = {
"follow_redirects": True, # 启用自动跟随重定向
"trust_env": False, # 不从环境变量读取代理配置
}
# 如果未指定超时时间,则使用自定义的默认超时
if timeout is None:
# 设置连接和读取操作的超时时间
kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT)
else:
# 使用外部传入的超时配置
kwargs["timeout"] = timeout
# 如果有自定义 headers,则加入到参数中
if headers is not None:
kwargs["headers"] = headers
# 如果有认证参数,则加入到参数中
if auth is not None:
kwargs["auth"] = auth
# 返回配置好的异步 httpx 客户端对象
return httpx.AsyncClient(**kwargs)
11.3. direction-client.py #
mcp-services/direction-client.py
# 调试用:连接 direction Streamable HTTP MCP 并调用 plan_route。
"""调试用:连接 direction Streamable HTTP MCP 并调用 plan_route。"""
# 导入异步IO库
import asyncio
# 导入日志模块
import logging
# 导入httpx库,用于发起HTTP请求
import httpx
# 导入MCP客户端会话
from mcp import ClientSession
# 导入Streamable HTTP客户端连接方法
from mcp.client.streamable_http import streamable_http_client
# 配置日志输出格式和日志级别
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# 获取一个名为 direction-client 的logger对象
logger = logging.getLogger("direction-client")
# MCP服务HTTP地址
HTTP_URL = "http://127.0.0.1:8002/mcp"
# 百度地图AK
BAIDU_MAP_AK = "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX"
# DeepSeek API Key
DEEPSEEK_API_KEY = "sk-24088156e9ab48f3adddaf5a9c0c4ede"
# 定义异步主运行函数
async def run():
# 输出连接流式HTTP服务日志
logger.info("连接 Streamable HTTP 服务:%s", HTTP_URL)
# 构造请求头,包含百度地图AK和DeepSeek API Key
headers = {"BAIDU_MAP_AK": BAIDU_MAP_AK, "DEEPSEEK_API_KEY": DEEPSEEK_API_KEY}
# 使用httpx异步客户端设置请求头和超时时间
async with httpx.AsyncClient(headers=headers, timeout=30.0) as http_client:
# 使用streamable HTTP客户端连接到MCP服务,获取读写接口
async with streamable_http_client(HTTP_URL, http_client=http_client) as (read, write, _):
# 新建一个MCP客户端会话
async with ClientSession(read, write) as session:
# 初始化MCP会话
await session.initialize()
# 日志记录会话初始化成功
logger.info("MCP 会话初始化成功")
# 日志记录准备调用plan_route(驾车)
logger.info("调用 plan_route(驾车)")
# 调用plan_route工具,参数为驾车路线规划
driving = await session.call_tool(
"plan_route",
{"user_input": "从北京市海淀区到天津市滨海新区驾车,尽量不走高速", "max_steps": 6},
)
# 遍历返回的内容
for item in driving.content:
# 提取每个返回内容的text字段
text = getattr(item, "text", "")
# 如果text内容不为空,输出日志
if text:
logger.info("驾车结果:\n%s", text)
# 日志记录准备调用plan_route(公共交通)
logger.info("调用 plan_route(公共交通)")
# 调用plan_route工具,参数为公交路线规划
transit = await session.call_tool(
"plan_route",
{"user_input": "从北京市朝阳区到上海市浦东新区公共交通,优先火车", "max_steps": 6},
)
# 遍历返回的内容
for item in transit.content:
# 提取每个返回内容的text字段
text = getattr(item, "text", "")
# 如果text内容不为空,输出日志
if text:
logger.info("公共交通结果:\n%s", text)
# 定义主入口函数
def main():
# 使用asyncio运行主异步任务
asyncio.run(run())
# 如果作为主程序运行
if __name__ == "__main__":
# 调用主函数
main()
11.4. direction-server.py #
mcp-services/direction-server.py
# 路线规划 MCP 服务器(Streamable HTTP)
"""路线规划 MCP 服务器(Streamable HTTP)。"""
# 导入相关标准库
import json
import logging
from typing import Annotated
# 导入第三方库
import httpx
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_deepseek import ChatDeepSeek
from mcp.server.fastmcp import Context, FastMCP
from pydantic import Field
# 配置日志输出格式和等级
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# 获取logger对象
logger = logging.getLogger("direction-service")
# 设置服务监听地址和端口
HOST = "127.0.0.1"
PORT = 8002
# 设置HTTP接口路径
HTTP_PATH = "/mcp"
# 创建FastMCP实例
mcp = FastMCP("路线规划服务", host=HOST, port=PORT, streamable_http_path=HTTP_PATH)
# 百度地理编码API地址
GEOCODE_URL = "https://api.map.baidu.com/geocoding/v3/"
# 百度路线规划API基础地址
ROUTE_V2_BASE = "https://api.map.baidu.com/direction/v2"
# DeepSeek 服务基础地址
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
# DeepSeek 使用的模型
DEEPSEEK_MODEL = "deepseek-chat"
# 获取请求头中的指定key的值
def _header_value(ctx, key):
# 初始化request为None
request = None
# 如果上下文和request_context存在
if ctx and ctx.request_context:
# 获取request对象
request = getattr(ctx.request_context, "request", None)
# 如果request不存在则抛异常
if request is None:
raise RuntimeError(f"缺少请求上下文,无法读取请求头 {key}")
# 读取header中的key或小写key
raw = request.headers.get(key) or request.headers.get(key.lower()) or ""
# 去除首尾空白
v = str(raw).strip()
# 如果仍为空则抛出异常
if not v:
raise RuntimeError(f"缺少请求头 {key}")
# 返回header值
return v
# 获取百度地图AK
def _ak(ctx):
return _header_value(ctx, "BAIDU_MAP_AK")
# 获取DeepSeek会话对象
def _deepseek(ctx):
# 获取DeepSeek API KEY
key = _header_value(ctx, "DEEPSEEK_API_KEY")
# 返回ChatDeepSeek实例
return ChatDeepSeek(
api_key=key,
base_url=DEEPSEEK_BASE_URL,
model=DEEPSEEK_MODEL,
temperature=0,
)
# 抽取用户输入的路线参数
async def _extract_route_args(user_input, ctx):
# 记录日志,开始参数抽取
logger.info("开始抽取路线参数,input=%s", user_input)
# 创建json输出解析器
parser = JsonOutputParser()
# 构建语言模型提示模板
prompt = ChatPromptTemplate.from_template(
"""
从输入中抽取百度路线规划参数,只输出 JSON:
{{
"origin": "",
"destination": "",
"mode": "driving",
"tactics": ""
}}
约束:
1) mode 只能是 driving / transit。
2) 如果用户提到“公共交通/地铁/公交/高铁/动车/火车/飞机”,mode 设为 transit。
3) tactics 可选(仅 driving 生效,如 “avoid_highway”),否则空字符串。
输入:{user_input}
输出格式要求:{format_instructions}
""".strip()
)
# 组装prompt->模型->json解析链
chain = prompt | _deepseek(ctx) | parser
# 用链式方式抽取参数
data = await chain.ainvoke(
{"user_input": user_input, "format_instructions": parser.get_format_instructions()}
)
# 获取mode,如果不正确则置为driving
mode = str(data.get("mode") or "driving").strip().lower()
if mode not in {"driving", "transit"}:
mode = "driving"
# 组装抽取后的参数
extracted = {
"origin": str(data.get("origin") or "").strip(),# 起点
"destination": str(data.get("destination") or "").strip(),# 终点
"mode": mode,# 模式
"tactics": str(data.get("tactics") or "").strip(),# 策略
}
# 打印参数抽取完成日志
logger.info("参数抽取完成:%s", json.dumps(extracted, ensure_ascii=False))# 打印参数抽取完成日志
return extracted# 返回参数字典
# 地理编码(地址->经纬度)
async def _geocode(address, ctx):
# 构造地理编码请求参数
req = {"address": address, "output": "json", "ak": _ak(ctx)}
# 创建异步HTTP客户端
async with httpx.AsyncClient(timeout=20.0) as client:
# 调用百度地理编码接口
r = await client.get(GEOCODE_URL, params=req)
# 检查响应状态码
r.raise_for_status()
# 解析返回json
data = r.json()
# 检查百度状态码
if data.get("status") != 0:
raise RuntimeError(f"地理编码失败: {address}, status={data.get('status')}")
# 获取返回的位置信息
result = data.get("result") or {}
loc = result.get("location") or {}
lng, lat = loc.get("lng"), loc.get("lat")
# 无经纬度则抛异常
if lng is None or lat is None:
raise RuntimeError(f"地理编码失败: {address}, 无坐标")
# 返回经纬度字符串和格式化地址
return f"{lat},{lng}", result.get("formatted_address") or address
# 秒转为小时分钟
def _seconds_to_hhmm(sec):
# 转换为整数
sec = int(sec or 0)
# 小时数
h = sec // 3600
# 分钟数
m = (sec % 3600) // 60
# 若有小时则返回带小时文本
if h > 0:
return f"{h}小时{m}分钟"
# 否则只返回分钟
return f"{m}分钟"
# 构建每一步路线的文本描述
def _step_text(step):
# 优先从多个字段尝试提取文本说明
text = str(
step.get("instruction")
or step.get("instructions")
or step.get("html_instructions")
or ""
).strip()
# 有文本则去除标签后返回
if text:
return text.replace("<b>", "").replace("</b>", "")
# 若没有instruction再拼凑道路名和距离等
road = str(step.get("road_name") or "道路").strip()
dist = int(step.get("distance") or 0)
dur = int(step.get("duration") or 0)
# 距离大于0则打印完整描述
if dist > 0:
return f"沿{road}行驶约{dist}米,预计{_seconds_to_hhmm(dur)}"
# 否则只描述沿路行驶
return f"沿{road}行驶"
# 查询路线主流程
async def _query_route(route_args, ctx):
# 地理编码原始起点
origin_ll, origin_fmt = await _geocode(route_args["origin"], ctx)
# 地理编码终点
dest_ll, dest_fmt = await _geocode(route_args["destination"], ctx)
# 路线方式
mode = route_args["mode"]
# 组装百度路线接口地址
endpoint = f"{ROUTE_V2_BASE}/{mode}"
# 组装请求参数
req = {
"origin": origin_ll,# 起点
"destination": dest_ll,# 终点
"ak": _ak(ctx),# 百度地图AK
}
# 若为驾车方式,添加tactics参数
if mode == "driving":
req["tactics"] = "11" if route_args["tactics"] == "avoid_highway" else "0"# 策略
# 记录调用路线API的日志
logger.info("调用路线接口,mode=%s params=%s", mode, json.dumps(req, ensure_ascii=False))# 打印调用路线接口日志
# 调用百度路线接口
async with httpx.AsyncClient(timeout=30.0) as client:
r = await client.get(endpoint, params=req)
r.raise_for_status()# 检查响应状态码
data = r.json()# 解析返回json
# 校验返回状态
if data.get("status") != 0:
raise RuntimeError(f"路线规划失败: status={data.get('status')} message={data.get('message')}")# 抛出异常
# 提取并处理返回的数据
result = data.get("result") or {}
routes = list(result.get("routes") or [])# 提取路线列表
# 没有可用路线则报错
if not routes:
raise RuntimeError("路线规划失败: 无可用路线")# 抛出异常
# 选取最佳路线
best = routes[0]# 最佳路线
# 返回路线主要信息
return {
"origin": origin_fmt,# 起点
"destination": dest_fmt,# 终点
"mode": mode,# 模式
"distance": int(best.get("distance") or 0),# 距离
"duration": int(best.get("duration") or 0),# 时长
"steps": best.get("steps") or [],# 步骤
}
# 注册为MCP工具
@mcp.tool()
# 路线规划主接口
async def plan_route(
# 用户输入的需求
user_input: Annotated[str, Field(description="用户输入的路线规划需求,例如:从北京市海淀区到天津市滨海新区驾车,尽量不走高速")],
# 最大步数(可选,默认6,范围1~20)
max_steps: Annotated[int, Field(description="最大步数,范围 1~20,默认 6")] = 6,
# 上下文参数
ctx: Annotated[Context, Field(description="上下文")] = None,
):
# 路线规划服务文档字符串
"""路线规划服务。"""
# 抽取路线参数
route_args = await _extract_route_args(user_input, ctx)
# 获取路线结果
data = await _query_route(route_args, ctx)
# 限制最大步数
max_steps = max(1, min(int(max_steps), 20))
# 构建返回描述文本
lines = [
f"输入:{user_input}",
f"参数:{json.dumps(route_args, ensure_ascii=False)}",
f"出发地:{data['origin']}",
f"目的地:{data['destination']}",
f"方式:{data['mode']}",
f"总里程:{round(data['distance'] / 1000, 2)} 公里",
f"预计时长:{_seconds_to_hhmm(data['duration'])}",
"",
"导航摘要:",
]
# 步数已展示计数器
shown = 0
# 遍历每个路线步骤
for step in data["steps"]:
# 达到最多步数则结束
if shown >= max_steps:
break
# 某些方案一步为嵌套列表
if isinstance(step, list):
# 遍历内部每一个
for s in step:
if shown >= max_steps:
break
lines.append(f"- {_step_text(s)}")
shown += 1
else:
lines.append(f"- {_step_text(step)}")
shown += 1
# 组合所有文本并返回
return "\n".join(lines)
# 启动入口
def main():
# 打印启动日志
logger.info("启动 Direction MCP Streamable HTTP 服务: http://%s:%s%s", HOST, PORT, HTTP_PATH)
# 启动MCP服务
mcp.run(transport="streamable-http")
# 如果作为主程序执行
if __name__ == "__main__":
main()
11.5. mcp_services.py #
app/routers/mcp_services.py
# 引入日志模块
import logging
# 从FastAPI导入路由器、依赖项和HTTP异常类
from fastapi import APIRouter, Depends, HTTPException
# 从SQLAlchemy导入唯一性错误异常
from sqlalchemy.exc import IntegrityError
# 从SQLAlchemy导入ORM会话对象
from sqlalchemy.orm import Session
# 导入应用程序的数据模型
from app import schemas
# 导入mcp_repository模块,包含数据库操作方法
from app.repositories import mcp_repository
# 导入用于获取数据库会话的依赖函数
from app.database import get_db
# 导入mcp_tester模块,包含测试MCP服务的方法
+from app.services.mcp_tester import test_mcp
# 创建API路由器,设置前缀和标签
router = APIRouter(prefix="/api/mcp-services", tags=["mcp-services"])
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义GET方法,用于列出所有已注册的MCP服务,返回值为McpServiceOut的列表
@router.get("", response_model=list[schemas.McpServiceOut])
def list_services(db: Session = Depends(get_db)):
# 调用mcp_repository中的list_mcp_services方法,查询数据库中的所有服务
return mcp_repository.list_mcp_services(db)
# 定义POST方法,用于创建一个新的MCP服务,接收McpServiceCreate模式对象作为请求体
@router.post("", response_model=schemas.McpServiceOut)
def create_service(payload: schemas.McpServiceCreate, db: Session = Depends(get_db)):
try:
# 调用mcp_repository中的create_mcp_service方法,将新服务信息写入数据库并返回
return mcp_repository.create_mcp_service(db, payload)
except IntegrityError:
# 捕获唯一性约束异常,如服务名称已存在,进行回滚操作
db.rollback()
# 抛出HTTP异常,状态码409,表示名称冲突
raise HTTPException(status_code=409, detail="名称已存在")
# 定义一个GET类型的路由,用于根据mcp_id获取单个MCP服务,返回McpServiceOut模型
@router.get("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数:根据传入的mcp_id和数据库会话读取指定的服务
def get_service(mcp_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取对应id的MCP服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 返回查找到的服务记录
return row
# 定义PUT方法用于更新指定ID的MCP服务,返回更新后的服务对象
@router.put("/{mcp_id}", response_model=schemas.McpServiceOut)
# 处理函数,参数包括服务ID、更新请求体、数据库会话
def update_service(mcp_id: int, payload: schemas.McpServiceUpdate, db: Session = Depends(get_db)):
# 首先根据mcp_id从数据库查找对应的服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果未找到记录,则抛出404异常提示记录不存在
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 如果更新请求中包含新的名称
if payload.name is not None:
# 通过名称查找是否存在其他服务
other = mcp_repository.get_by_name(db, payload.name.strip())
# 如果找到的其他服务ID与当前更新目标不同,说明名称已被占用
if other and other.id != mcp_id:
raise HTTPException(status_code=409, detail="名称已被其他记录使用")
try:
# 调用repository方法执行数据库更新并返回结果
return mcp_repository.update_mcp_service(db, row, payload)
except IntegrityError:
# 捕获唯一约束冲突,回滚事务
db.rollback()
# 抛出409异常提示名称冲突
raise HTTPException(status_code=409, detail="名称冲突")
# 定义DELETE方法的路由,指定路径参数mcp_id
@router.delete("/{mcp_id}")
# 删除指定ID的MCP服务,db为数据库会话依赖
def delete_service(mcp_id: int, db: Session = Depends(get_db)):
# 根据mcp_id查询对应的MCP服务记录
row = mcp_repository.get_mcp_service(db, mcp_id)
# 如果记录不存在,则抛出404异常
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库方法删除该服务记录
mcp_repository.delete_mcp_service(db, row)
# 返回成功标志
return {"ok": True}
# 定义POST接口,路径为/test,响应模型为McpTestResult
+@router.post("/test", response_model=schemas.McpTestResult)
# 定义处理函数,接收McpTestRequest为请求体
+def test_service(payload: schemas.McpTestRequest):
# 调用test_mcp函数,传入协议类型和配置,获取测试结果、消息和工具列表
+ ok, msg, tools = test_mcp(payload.protocol.value, payload.config)
# 记录日志,包括协议类型、测试是否成功和工具数量
+ logger.info("mcp test protocol=%s ok=%s tools=%s", payload.protocol.value, ok, len(tools))
# 返回McpTestResult对象,包含测试结果、消息和工具列表
+ return schemas.McpTestResult(ok=ok, message=msg, tools=tools) 11.6. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
+class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
+ protocol: McpProtocol
# 配置信息,要求为字典类型
+ config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
+ @field_validator("config", mode="before")
+ @classmethod
+ def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
+ if not isinstance(v, dict):
+ raise ValueError("config 必须为 JSON 对象")
# 返回config原值
+ return v
# 定义McpTestResult模型,继承自BaseModel
+class McpTestResult(BaseModel):
# 测试是否成功的标志
+ ok: bool
# 返回的信息或说明
+ message: str
# 工具列表,默认为空列表
+ tools: list[dict[str, Any]] = Field(default_factory=list)
11.7 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"protocol\": \"streamable-http\", \"config\": { \"url\": \"http://127.0.0.1:8002/mcp\", \"headers\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\" } }}"11.8 路线规划服务 #
11.8.1 描述 #
基于 MCP 的智能路线规划服务,采用 Streamable HTTP 协议对外提供能力。服务可将自然语言出行需求自动解析为结构化参数,调用百度地图路线规划接口生成结果,当前支持自驾(driving)与公共交通(transit,含地铁/公交/高铁/动车/飞机场景)。返回内容包含出发地、目的地、总里程、预计时长与导航摘要步骤,适合在 AI 助手中用于行程建议与路线查询。
11.8.2 协议 #
streamable-http11.8.3 URL #
http://127.0.0.1:8002/mcp11.8.4 请求头 #
{
"BAIDU_MAP_AK": "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX",
"DEEPSEEK_API_KEY": "sk-24088156e9ab48f3adddaf5a9c0c4ede"
}12. 地点检索服务 #
本节详细介绍“地点检索服务”的功能特性、适用场景、接口协议,以及如何集成和调用。
服务简介
地点检索服务基于 MCP 的工具协议实现,能够将自然语言表达的地点需求转为结构化检索参数,并调用百度地图的 Place API 获取相关的地点列表,如景点、酒店等。该服务特别面向中文出行与旅游等场景:
- 用户可输入如“北京景点”“上海酒店”这类自然语句;
- 服务会自动抽取所需的检索参数(如地区、类型、关键词),并返回包含名称、地址、经纬度等信息的地点列表;
- 支持分页、补充详情(如电话、人均、营业时间等)和无结果时的智能建议。
适用于 AI 助手、智能对话、旅游产品推荐等类型的应用程序。
服务接入方式
- 协议:使用 Server-Sent Events (
sse) - 服务地址:
http://127.0.0.1:8001/sse - 认证参数(放在 headers):
BAIDU_MAP_AK:百度地图开放平台申请的密钥DEEPSEEK_API_KEY:DeepSeek API Key
示例 headers:
{
"BAIDU_MAP_AK": "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX",
"DEEPSEEK_API_KEY": "sk-24088156e9ab48f3adddaf5a9c0c4ede"
}调用流程示例
用户调用流程如下:
- 客户端通过 SSE 协议建立长连接,携带有效 Headers。
- 通过 MCP 协议的
initialize进行会话初始化。 - 使用
call_tool,调用工具search_place,并传入参数:- 地区 (
region) - 关键词 (
query) - 标签(可选)(如“景点”, “酒店”)
- 是否仅限本地(可选)
- 分页参数(可选)
- 地区 (
调用参数示例:
{
"tool": "search_place",
"input": {
"region": "北京",
"query": "景点",
"tags": ["5A级", "博物馆"],
"strict_local": false,
"page_num": 1,
"page_size": 10
}
}返回结果说明
- 返回为地点列表,每一项包含名称、简要描述、地址、经纬度等关键信息。
- 可根据实际需要为指定地点补充详情字段,如电话、营业时间、评分等。
- 若本轮输入无结果,将自动返回智能联想词建议,辅助用户调整查询。
快速测试方法
可使用如下命令在本地完成 HTTP 接口测试(注意更换为真实的服务地址和有效的 headers):
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"protocol\": \"sse\", \"config\": { \"url\": \"http://127.0.0.1:8001/sse\", \"headers\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\" } }}"服务将返回测试结果、初始化说明及支持的工具描述,便于开发时验证部署环境和接口配置。
12.5 深入集成示例 #
你可以参考 mcp-services/place-client.py,使用 Python 代码通过 sse_client 和 ClientSession 接入 MCP 服务并发起检索请求,实现自动化集成与演示。详细代码见下方 12.1 节。
12.1. place-client.py #
mcp-services/place-client.py
# 调试用:连接 place SSE MCP 并调用 search_place。
"""调试用:连接 place SSE MCP 并调用 search_place。"""
# 导入异步IO模块
import asyncio
# 导入日志模块
import logging
# 从mcp模块导入ClientSession
from mcp import ClientSession
# 从mcp.client.sse模块导入sse_client
from mcp.client.sse import sse_client
# 配置日志格式及级别
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# 创建名为"place-client"的日志记录器
logger = logging.getLogger("place-client")
# SSE服务url
SSE_URL = "http://127.0.0.1:8001/sse"
# 百度地图AK
BAIDU_MAP_AK = "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX"
# DeepSeek API KEY
DEEPSEEK_API_KEY = "sk-24088156e9ab48f3adddaf5a9c0c4ede"
# 定义异步运行主流程的函数
async def run():
# 输出正在连接SSE服务的信息
logger.info("连接 SSE 服务:%s", SSE_URL)
# 构造请求头,包含AK和API KEY
headers = {"BAIDU_MAP_AK": BAIDU_MAP_AK, "DEEPSEEK_API_KEY": DEEPSEEK_API_KEY}
# 异步连接至SSE服务,获取读写流
async with sse_client(SSE_URL, headers=headers) as (read, write):
# 使用ClientSession管理MCP会话
async with ClientSession(read, write) as session:
# 初始化MCP会话
await session.initialize()
# 输出会话初始化成功的信息
logger.info("MCP 会话初始化成功")
# 输出将调用search_place工具(查询景点)的信息
logger.info("调用工具 search_place(景点)")
# 调用search_place工具,查询北京景点,分页参数设定
scenic = await session.call_tool(
"search_place",
{"user_input": "北京市 景点", "page_size": 5, "page_num": 0, "with_detail": False},
)
# 如果响应包含错误,抛出异常
if scenic.isError:
raise RuntimeError(str(scenic.content))
# 输出景点查询成功的信息
logger.info("景点查询成功")
# 遍历返回的景点结果,输出描述文本
for item in scenic.content:
text = getattr(item, "text", "")
if text:
logger.info("景点结果:\n%s", text)
# 输出将调用search_place工具(查询酒店)的信息
logger.info("调用工具 search_place(酒店)")
# 调用search_place工具,查询北京酒店,要求带详细信息
hotel = await session.call_tool(
"search_place",
{"user_input": "北京市 酒店", "page_size": 5, "page_num": 0, "with_detail": True},
)
# 如果响应包含错误,抛出异常
if hotel.isError:
raise RuntimeError(str(hotel.content))
# 输出酒店查询成功的信息
logger.info("酒店查询成功")
# 遍历返回的酒店结果,输出描述文本
for item in hotel.content:
text = getattr(item, "text", "")
if text:
logger.info("酒店结果:\n%s", text)
# 定义程序入口函数
def main():
# 运行异步事件循环,启动run函数
asyncio.run(run())
# 判断当前脚本是否为主模块运行
if __name__ == "__main__":
# 调用主函数
main()
12.2. place-server.py #
mcp-services/place-server.py
"""地点检索 MCP 服务器(SSE)。"""
# 导入json用于处理JSON格式数据
import json
# 导入logging用于日志记录
import logging
# 导入os用于获取环境变量
import os
# 导入类型标注工具
from typing import Annotated
# 导入httpx异步HTTP客户端
import httpx
# 导入LangChain的JSON输出解析器
from langchain_core.output_parsers import JsonOutputParser
# 导入LangChain聊天提示模板
from langchain_core.prompts import ChatPromptTemplate
# 导入DeepSeek聊天模型
from langchain_deepseek import ChatDeepSeek
# 导入FastMCP相关
from mcp.server.fastmcp import Context, FastMCP
# 导入pydantic字段定义
from pydantic import Field
# 配置日志输出格式和级别
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# 获取logger实例
logger = logging.getLogger("place-service")
# 设置服务主机地址
HOST = "127.0.0.1"
# 设置服务端口号
PORT = 8001
# 设置SSE路径
SSE_PATH = "/sse"
# 实例化FastMCP服务
mcp = FastMCP("地点检索服务", host=HOST, port=PORT, sse_path=SSE_PATH)
# 百度地图地点检索API地址
PLACE_SEARCH_URL = "https://api.map.baidu.com/place/v2/search"
# 百度地图地点详情API地址
PLACE_DETAIL_URL = "https://api.map.baidu.com/place/v2/detail"
# 百度地图建议词/补全API地址
PLACE_SUGGESTION_URL = "https://api.map.baidu.com/place/v2/suggestion"
# 获取请求头指定key的值(用于提取AK和API KEY)
def _header_value(ctx, key):
# 初始化request对象
request = None
# 检查上下文中是否有request_context,并获取request
if ctx and ctx.request_context:
request = getattr(ctx.request_context, "request", None)
# 如无request,抛异常
if request is None:
raise RuntimeError(f"缺少请求上下文,无法读取请求头 {key}")
# 获取请求头中的key值,区分大小写
raw = request.headers.get(key) or request.headers.get(key.lower()) or ""
# 转成字符串并去除首尾空白
v = str(raw).strip()
# 如果值为空,抛出异常
if not v:
raise RuntimeError(f"缺少请求头 {key}")
# 返回请求头值
return v
# 获取百度地图AK(从请求头)
def _ak(ctx):
return _header_value(ctx, "BAIDU_MAP_AK")
# 获取DeepSeek模型实例(自动从请求头和环境变量读取base_url/model)
def _deepseek(ctx):
key = _header_value(ctx, "DEEPSEEK_API_KEY")
return ChatDeepSeek(
api_key=key,
base_url=(os.environ.get("DEEPSEEK_BASE_URL") or "https://api.deepseek.com").strip(),
model=(os.environ.get("DEEPSEEK_MODEL") or "deepseek-chat").strip(),
temperature=0,
)
# 异步抽取用户输入中的地点检索参数
async def _extract_place_args(user_input, ctx):
# 记录日志:开始参数抽取
logger.info("开始抽取地点检索参数,input=%s", user_input)
# 实例化JSON解析器
parser = JsonOutputParser()
# 构造聊天提示模板
prompt = ChatPromptTemplate.from_template(
"""
从用户输入中抽取百度地点检索参数,只输出 JSON:
{{
"region": "",
"query": "",
"tag": "",
"city_limit": false
}}
规则:
1) region 必须是地区名(市/州/区县),如“延边朝鲜族自治州”。
2) query 填检索词,如“景点”“酒店”“火锅”等。
3) tag 可选(如“旅游景点”“酒店”),无法确定可空字符串。
4) city_limit 仅当用户强调“仅本地区”时设为 true。
输入:{user_input}
输出格式要求:{format_instructions}
""".strip()
)
# 组装prompt、模型和解析器
chain = prompt | _deepseek(ctx) | parser
# 异步执行链,获得抽取结果
data = await chain.ainvoke(
{"user_input": user_input, "format_instructions": parser.get_format_instructions()}
)
# 整合并规范化抽取结果
extracted = {
"region": str(data.get("region") or "").strip(),
"query": str(data.get("query") or "").strip(),
"tag": str(data.get("tag") or "").strip(),
"city_limit": bool(data.get("city_limit") or False),
}
# 记录日志:抽取结果
logger.info("参数抽取完成:%s", json.dumps(extracted, ensure_ascii=False))
# 返回参数字典
return extracted
# 异步调用百度地点检索API
async def _search_places(params_from_llm, page_size, page_num, ctx):
# 构造API请求参数
req = {
"query": params_from_llm["query"] or "景点",
"region": params_from_llm["region"],
"tag": params_from_llm["tag"],
"city_limit": "true" if params_from_llm["city_limit"] else "false",
"output": "json",
"scope": "2",
"page_size": str(page_size),
"page_num": str(page_num),
"ak": _ak(ctx),
}
# 如果无region,抛异常
if not req["region"]:
raise RuntimeError("LLM 未提取出 region,无法检索地点")
# 记录日志:将要调用百度API
logger.info("调用地点检索接口,params=%s", json.dumps(req, ensure_ascii=False))
# 使用httpx进行异步HTTP请求
async with httpx.AsyncClient(timeout=30.0) as client:
r = await client.get(PLACE_SEARCH_URL, params=req)
r.raise_for_status()
data = r.json()
# 如果 API 返回非0状态,表示错误
if data.get("status") != 0:
raise RuntimeError(f"地点检索失败: status={data.get('status')} message={data.get('message')}")
# 返回API响应
return data
# 根据uid查询地点详细信息
async def _detail_by_uid(uid, ctx):
# 构造请求参数
req = {"uid": uid, "scope": "2", "output": "json", "ak": _ak(ctx)}
# 发送Http请求,获取详情
async with httpx.AsyncClient(timeout=30.0) as client:
r = await client.get(PLACE_DETAIL_URL, params=req)
r.raise_for_status()
data = r.json()
# 如请求失败,返回空
if data.get("status") != 0:
return {}
# 返回结果字段
return data.get("result") or {}
# 根据region和query给出建议词/模糊补全
async def _suggest_region_keyword(region, query, ctx):
# 构造请求参数
req = {
"query": query or region,
"region": region or "全国",
"city_limit": "false",
"output": "json",
"ak": _ak(ctx),
}
# 调用建议词接口
async with httpx.AsyncClient(timeout=30.0) as client:
r = await client.get(PLACE_SUGGESTION_URL, params=req)
r.raise_for_status()
data = r.json()
# 如失败,返回空列表
if data.get("status") != 0:
return []
# 返回建议结果列表
return list(data.get("results") or [])
# 注册为MCP工具
@mcp.tool()
# 定义search_place主函数
async def search_place(
user_input: Annotated[
str,
Field(description="用户输入的地点检索需求,例如:北京 景点"),
],
page_size: Annotated[
int,
Field(description="每页数量,范围 1~20,默认 5"),
] = 5,
page_num: Annotated[
int,
Field(description="页码,从 0 开始,默认 0"),
] = 0,
with_detail: Annotated[
bool,
Field(description="是否补充地点详情信息(电话、人均、营业时间)"),
] = False,
ctx: Context = None,
):
# 文档注释:地点检索 MCP 服务
"""地点检索 MCP 服务。"""
# 限定page_size范围
page_size = max(1, min(int(page_size), 20))
# 限定page_num不小于0
page_num = max(0, int(page_num))
# 转换with_detail为布尔型
with_detail = bool(with_detail)
# 调用LLM参数抽取
params_from_llm = await _extract_place_args(user_input, ctx)
# 调用检索API获取数据
data = await _search_places(params_from_llm, page_size, page_num, ctx)
# 获取检索结果
results = list(data.get("results") or [])
# 获取总条数
total = int(data.get("total") or 0)
# 组装输出行
lines = [
f"输入:{user_input}",
f"参数:{json.dumps(params_from_llm, ensure_ascii=False)}",
f"总数:{total},当前页数量:{len(results)}",
"",
]
# 如果没有检索到结果
if not results:
# 获取建议词(如无结果)
suggestions = await _suggest_region_keyword(params_from_llm["region"], params_from_llm["query"], ctx)
# 输出未检索到提示
lines.append("未检索到结果。可参考候选词:")
# 前5条建议逐条加入
for i, it in enumerate(suggestions[:5], 1):
lines.append(f"{i}. {it.get('name', '')} {it.get('city', '')}{it.get('district', '')}")
# 返回拼接后的文本
return "\n".join(lines)
# 对每个结果进行格式化输出
for i, item in enumerate(results, 1):
lines.append(f"{i}. {item.get('name', '')}")
lines.append(f" 地址:{item.get('address', '未知')}")
lines.append(f" 区域:{item.get('province', '')}{item.get('city', '')}{item.get('area', '')}")
lines.append(f" 评分:{item.get('detail_info', {}).get('overall_rating', '无')}")
lines.append(f" 类型:{item.get('detail_info', {}).get('tag', '无')}")
# 如需补充详情,调用详情接口
if with_detail:
uid = str(item.get("uid") or "").strip()
if uid:
detail = await _detail_by_uid(uid, ctx)
d = detail.get("detail_info") or {}
lines.append(f" 电话:{detail.get('telephone', '无')}")
lines.append(f" 人均:{d.get('price', '无')}")
lines.append(f" 营业时间:{d.get('shop_hours', '无')}")
# 加空行分隔
lines.append("")
# 返回最终整合文本
return "\n".join(lines)
# 主函数,启动MCP SSE服务
def main():
# 输出服务启动日志
logger.info("启动 Place MCP SSE 服务: http://%s:%s%s", HOST, PORT, SSE_PATH)
# 启动MCP SSE服务
mcp.run(transport="sse")
# 作为主程序入口时执行main方法
if __name__ == "__main__":
main()
12.3. mcp_tester.py #
app/services/mcp_tester.py
# 导入异步库asyncio
import asyncio
# 导入自定义的streamable_http_client客户端
from mcp.client.streamable_http import streamable_http_client
# 导入日志库
import logging
# 导入httpx用于网络请求
import httpx
# 从mcp模块导入ClientSession用于会话管理
from mcp import ClientSession
# 导入自定义的mcp_httpx_client_factory工厂函数
from app.services.mcp_httpx import mcp_httpx_client_factory
# 导入sse_client客户端
+from mcp.client.sse import sse_client
# 获取logger对象用于日志输出
logger = logging.getLogger(__name__)
# 合并headers辅助函数,将配置中的header全部转换为字符串格式
def _merge_headers(config):
# 从config中获取headers
raw = config.get("headers")
# 如果headers不存在或不是dict,则返回空字典
if not raw or not isinstance(raw, dict):
return {}
# 初始化输出字典
out = {}
# 遍历原始headers,将键值都转为字符串
for k, v in raw.items():
out[str(k)] = str(v)
# 返回处理后的headers
return out
# 提取tools辅助函数
def _extract_tools(result):
# 如果result为None,返回空列表
if result is None:
return []
# 如果result已经是列表,直接返回
if isinstance(result, list):
return result
# 尝试获取result对象上的tools属性
tools = getattr(result, "tools", None)
if isinstance(tools, list):
return tools
# 如果result是字典,则优先直接取tools字段
if isinstance(result, dict):
t = result.get("tools")
if isinstance(t, list):
return t
# 如果没有,尝试进入result字段嵌套查找
nested = result.get("result")
if isinstance(nested, dict):
t = nested.get("tools")
if isinstance(t, list):
return t
# 如果无法提取到tools,则返回空列表
return []
# 标准化tool对象,确保返回字典结构{name, description, input_schema}
def _normalize_tool(tool):
# 如果tool为None,返回默认空结构
if tool is None:
return {"name": "", "description": "", "input_schema": {}}
# 如果tool是dict类型,从中提取字段并构造返回
if isinstance(tool, dict):
return {
"name": str(tool.get("name") or ""),
"description": str(tool.get("description") or ""),
"input_schema": tool.get("inputSchema") or tool.get("input_schema") or {},
}
# 否则尝试通过对象属性获取name、description、input_schema
return {
"name": str(getattr(tool, "name", "") or ""),
"description": str(getattr(tool, "description", "") or ""),
"input_schema": getattr(tool, "inputSchema", None)
or getattr(tool, "input_schema", None)
or {},
}
# 通过给定的read/write对象与MCP服务初始化,并探测tools列表
async def _probe_with_session(read, write, label):
# 使用ClientSession创建异步上下文
async with ClientSession(read, write) as session:
# 执行session初始化,带超时时间
init_result = await asyncio.wait_for(session.initialize(), timeout=15)
tools_result = None
raw_tools = []
# 最多循环尝试3次获取工具列表
for i in range(3):
# 列出tools,带超时时间
tools_result = await asyncio.wait_for(session.list_tools(), timeout=15)
# 提取tools
raw_tools = _extract_tools(tools_result)
# 如果成功提取到tool就跳出循环
if raw_tools:
break
# 若还未成功则短暂等待后重试(最多重试2次)
if i < 2:
await asyncio.sleep(0.2)
# 获取初始化结果中的serverInfo属性
server_info = getattr(init_result, "serverInfo", None)
# 对工具列表进行标准化处理
tools = [_normalize_tool(t) for t in raw_tools]
# 日志输出工具数量
logger.info("%s probe list_tools count=%s", label, len(tools))
# 如果serverInfo存在且有name则写入返回消息
if server_info and getattr(server_info, "name", None):
return True, f"{label} MCP 初始化成功: {server_info.name}", tools
# 否则只是简单初始化成功
return True, f"{label} MCP 初始化成功", tools
# 测试 streamable-http 协议的 MCP 服务
def test_streamable_http(config):
# 从配置中获取 url
url = config.get("url")
# 若 url 无效或不是字符串,直接返回错误信息
if not url or not isinstance(url, str):
return False, "streamable-http 配置需要字符串 url", []
# 标准化合并 headers
headers = _merge_headers(config)
# 定义异步检测函数
async def _run_http_check():
try:
# 使用定制的 httpx async client 工厂函数创建客户端(避免本地代理干扰)
async with mcp_httpx_client_factory(headers=headers) as http_client:
# 以异步方式建立与 streamable-http MCP 服务的连接,并获取读写对象
async with streamable_http_client(url, http_client=http_client) as (read, write, _):
# 调用内部探测逻辑初始化 session 并获取工具列表
return await _probe_with_session(read, write, "streamable-http")
# 捕获超时异常(如连接或响应过慢),返回特定的提示
except asyncio.TimeoutError:
return False, "等待响应超时;请检查 streamable-http 服务地址与请求头", []
# 捕获 httpx 的请求异常,如网络不可达等
except httpx.RequestError as e:
return False, f"请求失败: {e}", []
# 捕获所有其他异常,写日志,返回简要错误说明
except Exception as e:
logger.exception("streamable-http 检测失败")
return False, f"streamable-http MCP 检测失败: {e}", []
# 在主线程执行异步检测函数并返回结果
return asyncio.run(_run_http_check())
# 定义测试sse协议MCP服务的函数
+def test_sse(config):
# 从配置中获取url
+ url = config.get("url")
# 如果url不存在或不是字符串,返回错误信息
+ if not url or not isinstance(url, str):
+ return False, "sse 配置需要字符串 url", []
# 合并请求头
+ headers = _merge_headers(config)
# 定义异步检测函数
+ async def _run_sse_check():
+ try:
# 建立与sse服务的异步连接,获取读写对象
+ async with sse_client(
+ url, headers=headers, httpx_client_factory=mcp_httpx_client_factory
+ ) as (read, write):
# 调用探测工具初始化session并获取工具列表
+ return await _probe_with_session(read, write, "sse")
# 捕获超时异常,返回特定提示
+ except asyncio.TimeoutError:
+ return False, "等待响应超时;请检查 sse 服务地址与请求头", []
# 捕获httpx请求异常,如网络不可达等
+ except httpx.RequestError as e:
+ return False, f"请求失败: {e}", []
# 捕获其它异常,写日志,返回简要错误说明
+ except Exception as e: # noqa: BLE001
+ logger.exception("sse 检测失败")
+ return False, f"sse MCP 检测失败: {e}", []
# 在主线程运行异步检测函数并返回结果
+ return asyncio.run(_run_sse_check())
# MCP通用检测入口,根据协议选择检测方法
def test_mcp(protocol, config):
# 将协议字符串归一化为小写去除空格
p = (protocol or "").lower().strip()
# 如果协议是streamable-http或http则用http检测方法
if p in {"streamable-http", "http"}:
return test_streamable_http(config)
+ if p == "sse":
+ return test_sse(config)
# 否则返回不支持的协议
return False, f"不支持的协议: {protocol}", []12.4 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"protocol\": \"sse\", \"config\": { \"url\": \"http://127.0.0.1:8001/sse\", \"headers\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\" } }}"12.5 地点检索服务 #
12.5.1 描述 #
面向中文出行场景的地点检索服务。用户输入自然语言需求(如“北京景点”“上海酒店”)后,服务会先用 DeepSeek 抽取检索参数(地区、关键词、标签、是否限制本地),再调用百度 Place API 返回地点列表;可按需补充详情信息(电话、人均、营业时间),并在无结果时给出联想词建议。适用于 AI 助手的景点/酒店查询与行程决策场景。
12.5.2 协议 #
sse12.5.3 URL #
http://127.0.0.1:8001/sse12.5.4 请求头 #
{
"BAIDU_MAP_AK": "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX",
"DEEPSEEK_API_KEY": "sk-24088156e9ab48f3adddaf5a9c0c4ede"
}13. 天气查询服务 #
服务综述
天气查询服务(weather MCP 服务)基于 MCP 通信协议,设计用于获取中国大陆地区未来天气预报。与传统 API 接口不同,本服务采用 stdio 协议作为输入输出通道,适用于被各类 AI 助手客户端(如 Cursor、Claude Desktop 等支持 MCP 的程序)直接集成。
本服务支持通过自然语言输入目的地及天数。它自动通过 DeepSeek 大模型抽取并规整请求参数,然后调用百度天气开放接口获取对应的天气数据,最终以易读中文回复。
主要功能亮点:
- 支持自然语言输入:用户只需输入「去成都3天天气」等描述,无需提前查找地区编码。
- 精准参数抽取:结合 DeepSeek 大模型能力解析自然语言,自动匹配百度天气 API 所需参数(如
district_id)。 - 国内天气覆盖:基于百度开放平台,准确查询中国境内绝大多数省市区县的7日预报。
- 标准 MCP 工具协议:适配 AI 智能体工具链,并支持多种客户端直连。
协议与调用方法
本服务基于 MCP stdio 协议实现,需要由外部进程(如 AI 客户端或命令行用户)以标准输入/输出的方式启动和交互。启动服务的推荐命令如下:
uv run --directory D:\aprepare\mcp_agent\mcp-services weather-server.py其中:
uv是快速启动 Python 脚本的 runner,具备「热重载」和更低延迟的优点。run与--directory及脚本路径需要根据实际部署目录调整。weather-server.py为本服务主程序文件。
环境变量请务必正确设置,包括:
BAIDU_MAP_AK:百度地图开放平台的 API Key(用于天气及地理相关接口调用)DEEPSEEK_MODEL:DeepSeek 所用模型(如"deepseek-chat",如无可省略)DEEPSEEK_API_KEY:DeepSeek 大模型的 API KeyDEEPSEEK_BASE_URL:DeepSeek 后端地址
可参照如下 shell 伪代码:
export BAIDU_MAP_AK=你的百度API_KEY
export DEEPSEEK_API_KEY=你的DeepSeek_KEY
# ...其余变量...
uv run --directory 项目路径 weather-server.py配置参数说明
| 字段名 | 类型 | 说明 | 示例 |
|---|---|---|---|
| command | str | 可执行命令,推荐 uv |
"uv" |
| args | list | 启动参数,依次填写子命令及路径等 | ["run", "--directory", "D:/aprepare/mcp_agent/mcp-services", "weather-server.py"] |
| env | object | 环境变量字典,参见上文 | {...} |
请求体示例:
{
"protocol": "stdio",
"config": {
"command": "uv",
"args": [
"run",
"--directory",
"D:/aprepare/mcp_agent/mcp-services",
"weather-server.py"
],
"env": {
"BAIDU_MAP_AK": "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX",
"DEEPSEEK_MODEL": "deepseek-chat",
"DEEPSEEK_API_KEY": "sk-24088156e9ab48f3adddaf5a9c0c4ede",
"DEEPSEEK_BASE_URL": "https://api.deepseek.com"
}
}
}检测与测试方法
可用 HTTP POST 方式检测服务可用性,只需调用 /api/mcp-services/test,请求体举例:
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"protocol\": \"stdio\", \"config\": { \"env\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_MODEL\": \"deepseek-chat\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\", \"DEEPSEEK_BASE_URL\": \"https://api.deepseek.com\" }, \"args\": [ \"run\", \"--directory\", \"D:/aprepare/mcp-backend/mcp-services\", \"weather-server.py\" ], \"command\": \"uv\" }}"场景举例与用途
- 用户:直接询问「我去杭州玩4天,这几天天气如何?」
- AI 助手:自动调用本服务,返回易于直接引用的天气描述、未来多日温度与降水预报,并结合天气辅助旅行建议。
与客户端集成建议
- 推荐使用
mcp.client.stdio标准库或你所用智能体的 MCP Stdio 客户端模块。 - 若需 CLI 级调试可参考
weather-client.py示例代码。 - 建议在生产环境下合理配置环境变量,确保 DeepSeek 及百度相关密钥安全。
13.1. weather-client.py #
mcp-services/weather-client.py
"""调试用:启动 weather stdio MCP 并调用 get_travel_forecast。"""
# 导入异步IO模块
import asyncio
# 导入系统模块(用于获取解释器路径)
import sys
# 导入Path对象,用于操作文件路径
from pathlib import Path
# 从mcp库导入ClientSession和StdioServerParameters工具
from mcp import ClientSession, StdioServerParameters
# 从mcp.client.stdio导入stdio_client,用于与stdio协议MCP服务通讯
from mcp.client.stdio import stdio_client
# 获取当前文件同目录下的 weather-server.py 路径
SERVER_FILE = Path(__file__).with_name("weather-server.py")
# 设定要查询的目的地
DESTINATION = "北京"
# 设定要查询的天数
DAYS = 3
# 设置百度地图AK
BAIDU_MAP_AK = "e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX"
# 设置DeepSeek API Key
DEEPSEEK_API_KEY = "sk-24088156e9ab48f3adddaf5a9c0c4ede"
# 设置DeepSeek的基础URL
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
# 设置DeepSeek所用的模型名
DEEPSEEK_MODEL = "deepseek-chat"
# 定义异步主流程
async def run():
# 初始化StdioServerParameters,配置待启动的天气服务所需命令和环境变量
server = StdioServerParameters(
command=sys.executable,
args=[str(SERVER_FILE)],
env={
"BAIDU_MAP_AK": BAIDU_MAP_AK,
"DEEPSEEK_API_KEY": DEEPSEEK_API_KEY,
"DEEPSEEK_BASE_URL": DEEPSEEK_BASE_URL,
"DEEPSEEK_MODEL": DEEPSEEK_MODEL,
},
)
# 启动并连接weather-server进程,获取读写流
async with stdio_client(server) as (read, write):
# 创建MCP会话
async with ClientSession(read, write) as session:
# 初始化会话
await session.initialize()
# 调用天气查询工具
result = await session.call_tool(
"get_travel_forecast",
{"destination": DESTINATION, "days": DAYS},
)
# 检查是否有错误,如有,抛出异常
if result.isError:
raise RuntimeError(str(result.content))
# 遍历返回的内容并输出其中的“text”字段
for item in result.content:
text = getattr(item, "text", "")
if text:
print(text)
# 定义主入口,运行异步主流程
def main():
asyncio.run(run())
# 判断是否直接运行此文件,若是则执行main
if __name__ == "__main__":
main()
13.2. weather-server.py #
mcp-services/weather-server.py
# 天气 MCP 服务器(stdio):目的地 + N 天 -> DeepSeek 参数抽取 -> 百度天气预报。
"""
天气 MCP 服务器(stdio):目的地 + N 天 -> DeepSeek 参数抽取 -> 百度天气预报。
"""
# 导入标准库模块
import json
import logging
import os
from typing import Annotated
# 导入第三方库
import httpx
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_deepseek import ChatDeepSeek
from mcp.server.fastmcp import FastMCP
from pydantic import Field
# 配置日志格式与日志级别
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# 创建名为 weather-service 的日志记录器
logger = logging.getLogger("weather-service")
# 创建 FastMCP 实例,服务名为“天气查询服务”
mcp = FastMCP("天气查询服务")
# 百度天气 API 的 URL
WEATHER_V1_URL = "https://api.map.baidu.com/weather/v1/"
# 获取百度地图 AK,若未设置则抛出异常
def _ak():
v = (os.environ.get("BAIDU_MAP_AK") or "").strip()
if not v:
raise RuntimeError("缺少环境变量 BAIDU_MAP_AK")
return v
# 获取 DeepSeek 对象,检查 API KEY 环境变量是否存在
def _deepseek():
key = (os.environ.get("DEEPSEEK_API_KEY") or "").strip()
if not key:
raise RuntimeError("缺少环境变量 DEEPSEEK_API_KEY")
# 返回 DeepSeek 对象,支持自定义 base_url 和模型,默认为 deepseek-chat
return ChatDeepSeek(
api_key=key,
base_url=(os.environ.get("DEEPSEEK_BASE_URL") or "https://api.deepseek.com").strip(),
model=(os.environ.get("DEEPSEEK_MODEL") or "deepseek-chat").strip(),
temperature=0,
)
# 异步方法:调用大模型抽取天气查询参数
async def _extract_params(destination):
# 日志记录参数抽取开始
logger.info("开始参数抽取,destination=%s", destination)
# 创建 JSON 输出解析器
parser = JsonOutputParser()
# 构建用于参数抽取的 PromptTemplate
prompt = ChatPromptTemplate.from_template(
"""
从输入中抽取百度天气接口参数,只输出 JSON:
{{
"district_id": "",
"province": "",
"city": "",
"district": ""
}}
输入:{destination}
输出格式要求:{format_instructions}
""".strip()
)
# 串联 prompt、deepseek 和 JSON parser
chain = prompt | _deepseek() | parser
# 调用链式执行,抽取参数
data = await chain.ainvoke(
{
"destination": destination,
"format_instructions": parser.get_format_instructions(),
}
)
# 整理抽取结果,保证字段存在且类型为字符串
extracted = {
"district_id": str(data.get("district_id") or "").strip(),
"province": str(data.get("province") or "").strip(),
"city": str(data.get("city") or "").strip(),
"district": str(data.get("district") or "").strip(),
}
# 记录抽取结果日志
logger.info("参数抽取完成:%s", json.dumps(extracted, ensure_ascii=False))
# 返回抽取的参数
return extracted
# 异步方法:调用百度天气接口获取天气预报
async def _fetch_forecast(params_from_llm):
# 构建基础查询参数,data_type=fc 表示未来天气预报
params = {"data_type": "fc", "output": "json", "ak": _ak()}
# 如果抽取结果存在 district_id,优先使用
if params_from_llm["district_id"]:
params["district_id"] = params_from_llm["district_id"]
# 否则按 district -> province/city 顺序补全
elif params_from_llm["district"]:
params["district"] = params_from_llm["district"]
if params_from_llm["province"]:
params["province"] = params_from_llm["province"]
if params_from_llm["city"]:
params["city"] = params_from_llm["city"]
# 如果有 city,用 city 作为 district
elif params_from_llm["city"]:
params["district"] = params_from_llm["city"]
if params_from_llm["province"]:
params["province"] = params_from_llm["province"]
# 如果只剩 province,用 province 作为 district
elif params_from_llm["province"]:
params["district"] = params_from_llm["province"]
# 若以上都没有,抛出异常
else:
raise RuntimeError("LLM 未提取出有效地点(district_id/district/city/province)")
# 日志记录即将调用百度天气接口
logger.info("调用百度天气接口,params=%s", json.dumps(params, ensure_ascii=False))
# 创建 httpx 异步客户端,设置超时时间为 30 秒
async with httpx.AsyncClient(timeout=30.0) as client:
# 向百度天气接口发起 GET 请求
r = await client.get(WEATHER_V1_URL, params=params)
# 请求异常时抛出异常
r.raise_for_status()
# 解析返回的 json 数据
data = r.json()
# 检查百度天气 API 返回状态
if data.get("status") != 0:
raise RuntimeError(f"百度天气接口失败: status={data.get('status')} message={data.get('message')}")
# 日志记录接口调用成功
logger.info("百度天气接口调用成功")
# 返回“result”字段内容
return data.get("result") or {}
# 注册 MCP 工具
@mcp.tool()
# 异步方法:获取指定目的地与天数的天气预报
async def get_travel_forecast(
destination: Annotated[
str,
Field(description="旅游目的地,例如:吉林省延边朝鲜族自治州龙井市"),
],
days: Annotated[
int,
Field(description="展示未来天气天数,范围 1~15,默认 3"),
] = 3,
):
# 限定展示天数在 1~15 范围
days = max(1, min(int(days), 15))
# 调用参数抽取逻辑
params_from_llm = await _extract_params(destination)
# 调用接口查询天气预报结果
result = await _fetch_forecast(params_from_llm)
# 获取 forecasts 列表(未来各天的天气预报)
forecasts = list(result.get("forecasts") or [])
# 日志记录生成天气预报结果
logger.info("生成天气预报结果,destination=%s days=%s", destination, days)
# 生成结果的文本行列表
lines = [
f"输入目的地:{destination}",
f"参数:{json.dumps(params_from_llm, ensure_ascii=False)}",
f"展示天数:{min(days, len(forecasts))}",
"",
]
# 遍历前 days 条天气信息,依次生成每一天的信息
for day in forecasts[:days]:
lines.append(f"{day.get('date', '')} {day.get('week', '')}")
lines.append(f"白天:{day.get('text_day', '')},夜间:{day.get('text_night', '')}")
lines.append(f"温度:{day.get('low', '')}~{day.get('high', '')}℃")
lines.append("")
# 返回多行文本,用于 AI 助手等调用
return "\n".join(lines)
# 程序主入口,启动 MCP 服务器(stdio 模式)
def main():
mcp.run(transport="stdio")
# 判断是否直接运行此文件,若是则执行 main()
if __name__ == "__main__":
main()
13.3. mcp_tester.py #
app/services/mcp_tester.py
# 导入异步库asyncio
import asyncio
# 导入自定义的streamable_http_client客户端
from mcp.client.streamable_http import streamable_http_client
# 导入日志库
import logging
# 导入httpx用于网络请求
import httpx
# 导入os模块用于操作系统
+import os
# 导入shutil模块用于文件操作
+import shutil
# 导入Any类型用于类型注解
+from typing import Any
# 导入stdio_client客户端
+from mcp.client.stdio import stdio_client
# 从mcp模块导入ClientSession用于会话管理
+from mcp import ClientSession,StdioServerParameters
# 导入自定义的mcp_httpx_client_factory工厂函数
from app.services.mcp_httpx import mcp_httpx_client_factory
# 导入sse_client客户端
from mcp.client.sse import sse_client
# 获取logger对象用于日志输出
logger = logging.getLogger(__name__)
# 合并headers辅助函数,将配置中的header全部转换为字符串格式
def _merge_headers(config):
# 从config中获取headers
raw = config.get("headers")
# 如果headers不存在或不是dict,则返回空字典
if not raw or not isinstance(raw, dict):
return {}
# 初始化输出字典
out = {}
# 遍历原始headers,将键值都转为字符串
for k, v in raw.items():
out[str(k)] = str(v)
# 返回处理后的headers
return out
# 提取tools辅助函数
def _extract_tools(result):
# 如果result为None,返回空列表
if result is None:
return []
# 如果result已经是列表,直接返回
if isinstance(result, list):
return result
# 尝试获取result对象上的tools属性
tools = getattr(result, "tools", None)
if isinstance(tools, list):
return tools
# 如果result是字典,则优先直接取tools字段
if isinstance(result, dict):
t = result.get("tools")
if isinstance(t, list):
return t
# 如果没有,尝试进入result字段嵌套查找
nested = result.get("result")
if isinstance(nested, dict):
t = nested.get("tools")
if isinstance(t, list):
return t
# 如果无法提取到tools,则返回空列表
return []
# 标准化tool对象,确保返回字典结构{name, description, input_schema}
def _normalize_tool(tool):
# 如果tool为None,返回默认空结构
if tool is None:
return {"name": "", "description": "", "input_schema": {}}
# 如果tool是dict类型,从中提取字段并构造返回
if isinstance(tool, dict):
return {
"name": str(tool.get("name") or ""),
"description": str(tool.get("description") or ""),
"input_schema": tool.get("inputSchema") or tool.get("input_schema") or {},
}
# 否则尝试通过对象属性获取name、description、input_schema
return {
"name": str(getattr(tool, "name", "") or ""),
"description": str(getattr(tool, "description", "") or ""),
"input_schema": getattr(tool, "inputSchema", None)
or getattr(tool, "input_schema", None)
or {},
}
# 通过给定的read/write对象与MCP服务初始化,并探测tools列表
async def _probe_with_session(read, write, label):
# 使用ClientSession创建异步上下文
async with ClientSession(read, write) as session:
# 执行session初始化,带超时时间
init_result = await asyncio.wait_for(session.initialize(), timeout=15)
tools_result = None
raw_tools = []
# 最多循环尝试3次获取工具列表
for i in range(3):
# 列出tools,带超时时间
tools_result = await asyncio.wait_for(session.list_tools(), timeout=15)
# 提取tools
raw_tools = _extract_tools(tools_result)
# 如果成功提取到tool就跳出循环
if raw_tools:
break
# 若还未成功则短暂等待后重试(最多重试2次)
if i < 2:
await asyncio.sleep(0.2)
# 获取初始化结果中的serverInfo属性
server_info = getattr(init_result, "serverInfo", None)
# 对工具列表进行标准化处理
tools = [_normalize_tool(t) for t in raw_tools]
# 日志输出工具数量
logger.info("%s probe list_tools count=%s", label, len(tools))
# 如果serverInfo存在且有name则写入返回消息
if server_info and getattr(server_info, "name", None):
return True, f"{label} MCP 初始化成功: {server_info.name}", tools
# 否则只是简单初始化成功
return True, f"{label} MCP 初始化成功", tools
# 测试 streamable-http 协议的 MCP 服务
def test_streamable_http(config):
# 从配置中获取 url
url = config.get("url")
# 若 url 无效或不是字符串,直接返回错误信息
if not url or not isinstance(url, str):
return False, "streamable-http 配置需要字符串 url", []
# 标准化合并 headers
headers = _merge_headers(config)
# 定义异步检测函数
async def _run_http_check():
try:
# 使用定制的 httpx async client 工厂函数创建客户端(避免本地代理干扰)
async with mcp_httpx_client_factory(headers=headers) as http_client:
# 以异步方式建立与 streamable-http MCP 服务的连接,并获取读写对象
async with streamable_http_client(url, http_client=http_client) as (read, write, _):
# 调用内部探测逻辑初始化 session 并获取工具列表
return await _probe_with_session(read, write, "streamable-http")
# 捕获超时异常(如连接或响应过慢),返回特定的提示
except asyncio.TimeoutError:
return False, "等待响应超时;请检查 streamable-http 服务地址与请求头", []
# 捕获 httpx 的请求异常,如网络不可达等
except httpx.RequestError as e:
return False, f"请求失败: {e}", []
# 捕获所有其他异常,写日志,返回简要错误说明
except Exception as e:
logger.exception("streamable-http 检测失败")
return False, f"streamable-http MCP 检测失败: {e}", []
# 在主线程执行异步检测函数并返回结果
return asyncio.run(_run_http_check())
# 定义测试sse协议MCP服务的函数
def test_sse(config):
# 从配置中获取url
url = config.get("url")
# 如果url不存在或不是字符串,返回错误信息
if not url or not isinstance(url, str):
return False, "sse 配置需要字符串 url", []
# 合并请求头
headers = _merge_headers(config)
# 定义异步检测函数
async def _run_sse_check():
try:
# 建立与sse服务的异步连接,获取读写对象
async with sse_client(
url, headers=headers, httpx_client_factory=mcp_httpx_client_factory
) as (read, write):
# 调用探测工具初始化session并获取工具列表
return await _probe_with_session(read, write, "sse")
# 捕获超时异常,返回特定提示
except asyncio.TimeoutError:
return False, "等待响应超时;请检查 sse 服务地址与请求头", []
# 捕获httpx请求异常,如网络不可达等
except httpx.RequestError as e:
return False, f"请求失败: {e}", []
# 捕获其它异常,写日志,返回简要错误说明
except Exception as e: # noqa: BLE001
logger.exception("sse 检测失败")
return False, f"sse MCP 检测失败: {e}", []
# 在主线程运行异步检测函数并返回结果
return asyncio.run(_run_sse_check())
# 定义用于测试 stdio 协议 MCP 服务的函数
+def test_stdio(config):
# 从配置中获取 command 字段
+ command = config.get("command")
# 检查 command 是否存在且为字符串
+ if not command or not isinstance(command, str):
+ return False, "stdio 配置需要字符串字段 command", []
# 获取 args 参数,为空时默认为空列表
+ args = config.get("args") or []
# 检查 args 是否为列表类型
+ if not isinstance(args, list):
+ return False, "stdio 配置的 args 须为数组", []
# 获取 env(环境变量)字段
+ env_vars = config.get("env")
# 复制当前环境变量
+ merged_env = os.environ.copy()
# 如果配置提供了额外环境变量且为字典,将其合并进当前环境变量
+ if env_vars and isinstance(env_vars, dict):
+ for k, v in env_vars.items():
+ merged_env[str(k)] = str(v)
# 取出 command 作为可执行文件
+ exe = command
# 如果 exe 不是绝对路径且不含路径分隔符,则在 PATH 里查找实际路径
+ if exe and not os.path.isabs(exe) and os.sep not in exe:
+ found = shutil.which(exe)
# 如果找不到可执行文件则返回错误
+ if not found:
+ return False, f"找不到可执行文件: {command}", []
# 找到则更新 exe 为实际路径
+ exe = found
# 定义内部异步检测函数
+ async def _run_stdio_check():
+ try:
# 生成 StdioServerParameters 实例,包装启动参数
+ server = StdioServerParameters(
+ command=exe,
+ args=[str(a) for a in args],
+ env=merged_env,
+ )
# 以异步方式创建 stdio 客户端并建立连接
+ async with stdio_client(server) as (read, write):
# 调用探测函数检查 stdio MCP 服务
+ return await _probe_with_session(read, write, "stdio")
# 捕获超时异常,返回特定的错误提示
+ except asyncio.TimeoutError:
+ return False, "等待响应超时;请检查命令与参数是否为 MCP stdio 服务", []
# 捕获找不到文件异常,返回具体的错误信息
+ except FileNotFoundError:
+ return False, f"找不到可执行文件: {exe}", []
# 捕获无法启动进程的异常
+ except OSError as e:
+ return False, f"无法启动进程: {e}", []
# 捕获所有其他异常,写入日志并返回通用错误说明
+ except Exception as e: # noqa: BLE001
+ logger.exception("stdio 检测失败")
+ return False, f"stdio MCP 检测失败: {e}", []
# 在主线程运行异步检测,并返回检测结果
+ return asyncio.run(_run_stdio_check())
# MCP通用检测入口,根据协议选择检测方法
def test_mcp(protocol, config):
# 将协议字符串归一化为小写去除空格
p = (protocol or "").lower().strip()
# 如果协议是streamable-http或http则用http检测方法
if p in {"streamable-http", "http"}:
return test_streamable_http(config)
if p == "sse":
return test_sse(config)
+ if p == "stdio":
+ return test_stdio(config)
# 否则返回不支持的协议
return False, f"不支持的协议: {protocol}", []13.4 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/mcp-services/test" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"protocol\": \"stdio\", \"config\": { \"env\": { \"BAIDU_MAP_AK\": \"e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX\", \"DEEPSEEK_MODEL\": \"deepseek-chat\", \"DEEPSEEK_API_KEY\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\", \"DEEPSEEK_BASE_URL\": \"https://api.deepseek.com\" }, \"args\": [ \"run\", \"--directory\", \"D:/aprepare/mcp-backend/mcp-services\", \"weather-server.py\" ], \"command\": \"uv\" }}"13.5 数据 #
13.5.1 描述 #
一个面向中文出行场景的天气查询 MCP 服务。用户只需输入旅游目的地和天数,服务会先通过 DeepSeek 从自然语言中提取百度天气接口所需的区域参数(如 district_id、省市区),再调用百度天气预报 API 获取未来天气,并以易读文本返回给 AI 助手用于行程建议和天气提醒。服务采用 stdio 传输协议,适合被 Cursor、Claude Desktop 等支持 MCP 的客户端直接接入。
13.5.2 协议 #
stdio13.5.3 命令 #
uv13.5.4 参数 #
run
--directory
D:\aprepare\mcp_agent\mcp-services
weather-server.py13.5.5 环境变量 #
BAIDU_MAP_AK=e0QsxCTdlt6qPNoQQNJwa89qoJ4OXieX
DEEPSEEK_MODEL=deepseek-chat
DEEPSEEK_API_KEY=sk-24088156e9ab48f3adddaf5a9c0c4ede
DEEPSEEK_BASE_URL=https://api.deepseek.com14. 添加大模型 #
本章节将介绍如何在系统中添加和管理大模型(如 DeepSeek、ChatGLM 等)。
支持存储和管理以下信息:模型提供商、图标、API 基础地址、API 密钥、密钥申请地址、模型名称列表等,支持通过管理后台或 API 快速添加第三方大模型,便于灵活扩展后续的 LLM 底座能力。
主要能力说明
- 模型信息注册:可在系统中注册不同厂商的大模型,包括 API 地址、密钥、模型名称等关键信息。
- 多模型管理:支持一套系统中接入多个大模型,便于按需扩展和切换。
- 字段校验与标准化:注册数据通过 Pydantic 严格校验及去重、去空格处理,确保数据合规和规范。
- 接口调用安全:敏感信息(如密钥)支持表单加密传输,系统仅在实际调用时访问,并可配置密钥申请帮助地址。
- 使用场景示例:通过 API、管理后台页面等方式,管理员可动态添加/维护大模型配置,从而灵活适配新模型。
适用场景举例
- 统一运维各类 LLM 能力(DeepSeek、Azure OpenAI、聊聊GLM等),按需切换模型自动适配。
- 企业/个人研发者根据自身授权情况,导入属于自己的 API Key 并指定模型列表,系统内用户即可调用。
- 快速接入新兴大模型(如 DeepSeek Reasoner),只需在管理后台新增条目,无需变更代码。
14.1. llm_repository.py #
app/repositories/llm_repository.py
# 导入SQLAlchemy的Session类,用于数据库会话
from sqlalchemy.orm import Session
# 从app包导入models模块,包含数据库模型
from app import models
# 从app包导入schemas模块,包含数据校验模型
from app import schemas
# 定义创建LLM模型记录的函数,接收数据库会话和入参数据,返回新建的LlmModel实例
def create_llm_model(db: Session, data: schemas.LlmModelCreate) -> models.LlmModel:
# 创建LlmModel数据库对象,剔除多余空白并赋值各字段
row = models.LlmModel(
provider_name=data.provider_name.strip(), # 提供商名称,去除首尾空白
provider_icon=data.provider_icon, # 提供商图标
api_base_url=data.api_base_url.strip(), # API基础地址,去除首尾空白
api_key=data.api_key.strip(), # API密钥,去除首尾空白
api_key_url=data.api_key_url.strip() if data.api_key_url else None, # API密钥申请地址,有则去除空白,无则为None
model_names=data.model_names, # LLM模型名称列表
)
# 添加新纪录到数据库session
db.add(row)
# 提交事务保存数据
db.commit()
# 刷新session获取新数据
db.refresh(row)
# 返回新建的LlmModel对象
return row14.2. llm_models.py #
app/routers/llm_models.py
# 导入urljoin用于拼接URL
from urllib.parse import urljoin
# 导入httpx库(可用于发送HTTP请求,本文件未用到)
import httpx
# 从fastapi库导入APIRouter用于创建路由,Depends用于依赖注入,HTTPException用于异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从sqlalchemy导入IntegrityError,用于捕获唯一性冲突异常
from sqlalchemy.exc import IntegrityError
# 导入Session会话,用于数据库操作
from sqlalchemy.orm import Session
# 导入schemas定义的Pydantic模型
from app import schemas
# 导入llm_repository,用于大模型数据操作
from app.repositories import llm_repository
# 导入get_db获取数据库依赖
from app.database import get_db
# 从返回的payload中提取模型名称列表
def _extract_model_names(payload):
# 获取payload中的"data"字段
data = payload.get("data")
# 如果data不是列表,则直接返回空列表
if not isinstance(data, list):
return []
# 准备输出列表
out = []
# 用于去重的集合
seen = set()
# 遍历data中的每一项
for item in data:
# 如果当前项不是字典,跳过
if not isinstance(item, dict):
continue
# 取得"id"字段,转为字符串并去两端空白
name = str(item.get("id") or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经出现过,跳过
if key in seen:
continue
# 添加到已见集合
seen.add(key)
# 添加到输出列表
out.append(name)
# 返回整理后的模型名称列表
return out
# 创建APIRouter实例,设置路由前缀与标签
router = APIRouter(prefix="/api/llm-models", tags=["llm-models"])
# 定义POST类型接口,路径为"",响应模型为LlmModelOut
@router.post("", response_model=schemas.LlmModelOut)
def create_model(payload: schemas.LlmModelCreate, db: Session = Depends(get_db)):
# 尝试创建大模型记录
try:
return llm_repository.create_llm_model(db, payload)
# 捕获唯一性冲突异常(例如Provider的名字已存在)
except IntegrityError:
db.rollback()
# 抛出HTTP 409异常,提示"提供商名称已存在"
raise HTTPException(status_code=409, detail="提供商名称已存在")
# 定义路由POST接口 /probe,返回值为LlmModelTestResult
@router.post("/probe", response_model=schemas.LlmModelTestResult)
# 定义测试大模型服务的函数
def test_model_service(payload: schemas.LlmModelTestRequest):
# 去除api_base_url首尾空白字符
base_url = payload.api_base_url.strip()
# 去除api_key首尾空白字符
api_key = payload.api_key.strip()
# 构造模型列表接口的URL
models_url = urljoin(base_url.rstrip("/") + "/", "models")
# 准备请求头,添加Authorization字段
headers = {
"Authorization": f"Bearer {api_key}",
}
try:
# 创建HTTP客户端,设置超时时间和请求头
with httpx.Client(timeout=20.0, headers=headers) as client:
# 发送GET请求获取模型列表
resp = client.get(models_url)
# 检查HTTP响应状态码,若有异常则抛出
resp.raise_for_status()
# 解析响应体为JSON
body = resp.json()
# 调用工具函数提取模型名称列表
names = _extract_model_names(body if isinstance(body, dict) else {})
# 如果检测到模型名称
if names:
# 返回检测通过且包含模型数量和名称的结果
return schemas.LlmModelTestResult(
ok=True,
message=f"模型服务检测通过,可用模型 {len(names)} 个",
models=names,
)
# 如果没检测到模型列表,返回通过但无模型名提示
return schemas.LlmModelTestResult(
ok=True,
message="模型服务检测通过,但未识别到模型列表",
models=[],
)
# 捕获请求超时异常,返回相应错误信息
except httpx.TimeoutException:
return schemas.LlmModelTestResult(ok=False, message="请求超时,请检查 API 地址")
# 捕获HTTP状态异常,返回状态码和响应内容前300字符
except httpx.HTTPStatusError as e:
return schemas.LlmModelTestResult(ok=False, message=f"HTTP {e.response.status_code}: {e.response.text[:300]}")
# 捕获所有其他异常,返回异常信息
except Exception as e: # noqa: BLE001
return schemas.LlmModelTestResult(ok=False, message=f"模型服务检测失败: {e}") 14.3. uploads.py #
app/routers/uploads.py
# 导入 uuid 库用于生成唯一文件名
import uuid
# 从 typing 导入 ClassVar,用于类型标注
from typing import ClassVar
# 从 fastapi 导入相关函数和类
from fastapi import APIRouter, File, HTTPException, UploadFile
# 从 pydantic 导入 BaseModel,用于定义响应模型
from pydantic import BaseModel
# 导入自定义的配置 settings
from app.config import settings
# 创建带有前缀和 tags 的 FastAPI 路由对象
router = APIRouter(prefix="/api/uploads", tags=["uploads"])
# 定义最大上传文件大小为 5MB
_MAX_BYTES = 5 * 1024 * 1024
# 支持的图片 MIME 类型与文件扩展名映射
_CONTENT_TYPES: ClassVar[dict[str, str]] = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"image/webp": ".webp",
}
# 定义图片上传响应模型
class UploadImageResult(BaseModel):
# 图片访问的 URL 路径
url: str
# 定义图片上传接口,响应模型为 UploadImageResult
@router.post("/image", response_model=UploadImageResult)
async def upload_image(file: UploadFile = File(...)) -> UploadImageResult:
# 获取 Content-Type,分号前部分,转换小写,去除空白
raw_ct = (file.content_type or "").split(";")[0].strip().lower()
# 检查 Content-Type 是否为支持的图片类型
if raw_ct not in _CONTENT_TYPES:
# 若不支持则抛出 400 错误
raise HTTPException(status_code=400, detail="只支持 JPEG、PNG、GIF、WebP 图片")
# 根据 MIME 类型决定文件扩展名
ext = _CONTENT_TYPES[raw_ct]
# 读取上传的文件内容为二进制
body = await file.read()
# 检查文件大小是否超过最大限制
if len(body) > _MAX_BYTES:
# 超过限制抛出 400 错误
raise HTTPException(status_code=400, detail="图片不能超过 5MB")
# 检查文件内容是否为空
if not body:
# 空文件抛出 400 错误
raise HTTPException(status_code=400, detail="空文件")
# 生成唯一的文件名,附上文件扩展名
name = f"{uuid.uuid4().hex}{ext}"
# 计算目标文件保存路径
dest = settings.upload_path() / name
# 将图片内容写入目标路径
dest.write_bytes(body)
# 返回图片访问 URL
return UploadImageResult(url=f"/uploads/{name}")
14.4. 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/llm-models" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"provider_name\": \"深度探索\", \"provider_icon\": \"/uploads/08b305a11707447696aecc2040f96ecc.png\", \"api_base_url\": \"https://api.deepseek.com\", \"api_key\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\", \"api_key_url\": \"https://platform.deepseek.com/api_keys\", \"model_names\": [ \"deepseek-chat\", \"deepseek-reasoner\" ]}"14.5. 数据 #
14.5.1 API 地址 #
https://api.deepseek.com14.5.2 LOGO #

14.5.3 API 密钥 #
sk-24088156e9ab48f3adddaf5a9c0c4ede14.5.4 获取密钥地址 #
https://platform.deepseek.com/api_keys14.5.5 模型名称 #
deepseek-chat
deepseek-reasoner15. 大模型列表 #
本节介绍如何通过 API 快速获取和管理系统内已注册的大模型(LLM)列表,便于查看和维护多模型接入情况。
能力说明
- 支持通过 RESTful API 查询全部已注册大模型配置,返回包含提供商、API 地址、密钥信息、支持的模型名称等关键信息的列表。
- 可在管理后台或集成环境中动态展示与维护已添加的大模型条目。
- 响应按新近添加顺序(id 倒序)排列,便于查看最近接入的模型。
API 说明
路由
GET /api/llm-models请求参数
无
返回结果
返回一个包含所有大模型信息的数组,每个模型字段说明如下:
| 字段 | 说明 |
|---|---|
| id | 记录自增ID |
| provider_name | 模型提供商名称 |
| provider_icon | 提供商 LOGO 图片 URL |
| api_base_url | API 基础地址 |
| api_key | 当前存储的 API 密钥 |
| api_key_url | 密钥申请网址 |
| model_names | 支持的模型名称数组 |
示例响应:
[
{
"id": 1,
"provider_name": "深度探索",
"provider_icon": "http://127.0.0.1:8000/uploads/08b305a11707447696aecc2040f96ecc.png",
"api_base_url": "https://api.deepseek.com",
"api_key": "sk-xxxxxxx",
"api_key_url": "https://platform.deepseek.com/api_keys",
"model_names": [
"deepseek-chat",
"deepseek-reasoner"
]
}
]数据获取范例
可通过如下 curl 命令请求大模型列表:
curl --location --request GET "http://127.0.0.1:8000/api/llm-models"返回内容即为模型数组,可直接用于前端界面展示或接口调用。
15.1. llm_repository.py #
app/repositories/llm_repository.py
# 导入SQLAlchemy的Session类,用于数据库会话
from sqlalchemy.orm import Session
+from sqlalchemy import select
# 从app包导入models模块,包含数据库模型
from app import models
# 从app包导入schemas模块,包含数据校验模型
from app import schemas
# 定义创建LLM模型记录的函数,接收数据库会话和入参数据,返回新建的LlmModel实例
def create_llm_model(db: Session, data: schemas.LlmModelCreate) -> models.LlmModel:
# 创建LlmModel数据库对象,剔除多余空白并赋值各字段
row = models.LlmModel(
provider_name=data.provider_name.strip(), # 提供商名称,去除首尾空白
provider_icon=data.provider_icon, # 提供商图标
api_base_url=data.api_base_url.strip(), # API基础地址,去除首尾空白
api_key=data.api_key.strip(), # API密钥,去除首尾空白
api_key_url=data.api_key_url.strip() if data.api_key_url else None, # API密钥申请地址,有则去除空白,无则为None
model_names=data.model_names, # LLM模型名称列表
)
# 添加新纪录到数据库session
db.add(row)
# 提交事务保存数据
db.commit()
# 刷新session获取新数据
db.refresh(row)
# 返回新建的LlmModel对象
return row
# 定义获取所有大模型记录的函数,传入数据库会话,返回LlmModel对象的列表
+def list_llm_models(db: Session) -> list[models.LlmModel]:
# 按id倒序查询所有LlmModel记录,转换为列表返回
+ return list(db.scalars(select(models.LlmModel).order_by(models.LlmModel.id.desc())).all())
15.2. llm_models.py #
app/routers/llm_models.py
# 导入urljoin用于拼接URL
from urllib.parse import urljoin
# 导入httpx库(可用于发送HTTP请求,本文件未用到)
import httpx
# 从fastapi库导入APIRouter用于创建路由,Depends用于依赖注入,HTTPException用于异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从sqlalchemy导入IntegrityError,用于捕获唯一性冲突异常
from sqlalchemy.exc import IntegrityError
# 导入Session会话,用于数据库操作
from sqlalchemy.orm import Session
# 导入schemas定义的Pydantic模型
from app import schemas
# 导入llm_repository,用于大模型数据操作
from app.repositories import llm_repository
# 导入get_db获取数据库依赖
from app.database import get_db
# 从返回的payload中提取模型名称列表
def _extract_model_names(payload):
# 获取payload中的"data"字段
data = payload.get("data")
# 如果data不是列表,则直接返回空列表
if not isinstance(data, list):
return []
# 准备输出列表
out = []
# 用于去重的集合
seen = set()
# 遍历data中的每一项
for item in data:
# 如果当前项不是字典,跳过
if not isinstance(item, dict):
continue
# 取得"id"字段,转为字符串并去两端空白
name = str(item.get("id") or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经出现过,跳过
if key in seen:
continue
# 添加到已见集合
seen.add(key)
# 添加到输出列表
out.append(name)
# 返回整理后的模型名称列表
return out
# 创建APIRouter实例,设置路由前缀与标签
router = APIRouter(prefix="/api/llm-models", tags=["llm-models"])
# 定义POST类型接口,路径为"",响应模型为LlmModelOut
@router.post("", response_model=schemas.LlmModelOut)
def create_model(payload: schemas.LlmModelCreate, db: Session = Depends(get_db)):
# 尝试创建大模型记录
try:
return llm_repository.create_llm_model(db, payload)
# 捕获唯一性冲突异常(例如Provider的名字已存在)
except IntegrityError:
db.rollback()
# 抛出HTTP 409异常,提示"提供商名称已存在"
raise HTTPException(status_code=409, detail="提供商名称已存在")
# 定义路由POST接口 /probe,返回值为LlmModelTestResult
@router.post("/probe", response_model=schemas.LlmModelTestResult)
# 定义测试大模型服务的函数
def test_model_service(payload: schemas.LlmModelTestRequest):
# 去除api_base_url首尾空白字符
base_url = payload.api_base_url.strip()
# 去除api_key首尾空白字符
api_key = payload.api_key.strip()
# 构造模型列表接口的URL
models_url = urljoin(base_url.rstrip("/") + "/", "models")
# 准备请求头,添加Authorization字段
headers = {
"Authorization": f"Bearer {api_key}",
}
try:
# 创建HTTP客户端,设置超时时间和请求头
with httpx.Client(timeout=20.0, headers=headers) as client:
# 发送GET请求获取模型列表
resp = client.get(models_url)
# 检查HTTP响应状态码,若有异常则抛出
resp.raise_for_status()
# 解析响应体为JSON
body = resp.json()
# 调用工具函数提取模型名称列表
names = _extract_model_names(body if isinstance(body, dict) else {})
# 如果检测到模型名称
if names:
# 返回检测通过且包含模型数量和名称的结果
return schemas.LlmModelTestResult(
ok=True,
message=f"模型服务检测通过,可用模型 {len(names)} 个",
models=names,
)
# 如果没检测到模型列表,返回通过但无模型名提示
return schemas.LlmModelTestResult(
ok=True,
message="模型服务检测通过,但未识别到模型列表",
models=[],
)
# 捕获请求超时异常,返回相应错误信息
except httpx.TimeoutException:
return schemas.LlmModelTestResult(ok=False, message="请求超时,请检查 API 地址")
# 捕获HTTP状态异常,返回状态码和响应内容前300字符
except httpx.HTTPStatusError as e:
return schemas.LlmModelTestResult(ok=False, message=f"HTTP {e.response.status_code}: {e.response.text[:300]}")
# 捕获所有其他异常,返回异常信息
except Exception as e: # noqa: BLE001
return schemas.LlmModelTestResult(ok=False, message=f"模型服务检测失败: {e}")
# 定义GET接口用于获取所有大模型列表,响应为LlmModelOut对象的列表
+@router.get("", response_model=list[schemas.LlmModelOut])
# 声明依赖注入数据库会话db
+def list_models(db: Session = Depends(get_db)):
# 调用llm_repository中的方法获取所有大模型数据
+ return llm_repository.list_llm_models(db) 15.3 测试 #
curl --location --request GET "http://127.0.0.1:8000/api/llm-models" ^
--header "Content-Type: application/json"16. 更新大语言模型 #
在本节中,我们将讲解如何通过接口更新已有大语言模型(LLM)的信息,实现供应商名称、图标、API 基础地址、API 密钥、密钥申请地址及支持的模型名称列表等字段的变更。
接口说明
接口路径:
PUT /api/llm-models/{llm_id}请求头:
Content-Type: application/json请求体参数(部分或全部可选):
| 字段名 | 类型 | 说明 |
|---|---|---|
| provider_name | string | 供应商名称,长度1-255,唯一。 |
| provider_icon | string | 供应商图标URL,可为空。 |
| api_base_url | string | API基础地址,长度1-1024 |
| api_key | string | API密钥,长度1-1024 |
| api_key_url | string | API密钥申请地址,最长1024,可为空。 |
| model_names | array of string | 支持的模型名称列表,可为空或不传,元素自动去重与去空白。 |
所有字段均可选,只有提交的字段才会被更新,未传递的字段保持原值不变。
响应结果
成功调用后返回更新后的大模型信息对象,结构与创建接口一致。
若出现以下情况,将返回对应错误码和消息:
- 404:记录不存在
- 409:供应商名称已被其他记录占用
示例
请求示例:
curl --location --request PUT "http://127.0.0.1:8000/api/llm-models/1" ^
--header "Content-Type: application/json" ^
--data-raw "{
\"provider_name\": \"深度探索4\",
\"provider_icon\": \"/uploads/08b305a11707447696aecc2040f96ecc.png\",
\"api_base_url\": \"https://api.deepseek.com\",
\"api_key\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\",
\"api_key_url\": \"https://platform.deepseek.com/api_keys2\",
\"model_names\": [
\"deepseek-chat\",
\"deepseek-reasoner\"
]
}"成功响应示例:
{
"id": 1,
"provider_name": "深度探索4",
"provider_icon": "/uploads/08b305a11707447696aecc2040f96ecc.png",
"api_base_url": "https://api.deepseek.com",
"api_key": "sk-24088156e9ab48f3adddaf5a9c0c4ede",
"api_key_url": "https://platform.deepseek.com/api_keys2",
"model_names": [
"deepseek-chat",
"deepseek-reasoner"
]
}常见错误:
{"detail":"记录不存在"}(修改不存在的ID){"detail":"提供商名称已被其他记录使用"}(名称冲突)
16.1. llm_repository.py #
app/repositories/llm_repository.py
# 导入SQLAlchemy的Session类,用于数据库会话
from sqlalchemy.orm import Session
from sqlalchemy import select
# 从app包导入models模块,包含数据库模型
from app import models
# 从app包导入schemas模块,包含数据校验模型
from app import schemas
# 定义创建LLM模型记录的函数,接收数据库会话和入参数据,返回新建的LlmModel实例
def create_llm_model(db: Session, data: schemas.LlmModelCreate) -> models.LlmModel:
# 创建LlmModel数据库对象,剔除多余空白并赋值各字段
row = models.LlmModel(
provider_name=data.provider_name.strip(), # 提供商名称,去除首尾空白
provider_icon=data.provider_icon, # 提供商图标
api_base_url=data.api_base_url.strip(), # API基础地址,去除首尾空白
api_key=data.api_key.strip(), # API密钥,去除首尾空白
api_key_url=data.api_key_url.strip() if data.api_key_url else None, # API密钥申请地址,有则去除空白,无则为None
model_names=data.model_names, # LLM模型名称列表
)
# 添加新纪录到数据库session
db.add(row)
# 提交事务保存数据
db.commit()
# 刷新session获取新数据
db.refresh(row)
# 返回新建的LlmModel对象
return row
# 定义获取所有大模型记录的函数,传入数据库会话,返回LlmModel对象的列表
def list_llm_models(db: Session) -> list[models.LlmModel]:
# 按id倒序查询所有LlmModel记录,转换为列表返回
return list(db.scalars(select(models.LlmModel).order_by(models.LlmModel.id.desc())).all())
# 根据指定的llm_id主键,从数据库中获取对应的LlmModel对象,无则返回None
+def get_llm_model(db: Session, llm_id: int) -> models.LlmModel | None:
+ return db.get(models.LlmModel, llm_id)
# 定义根据提供商名称获取对应 LlmModel 记录的函数
+def get_llm_by_provider_name(db: Session, provider_name: str) -> models.LlmModel | None:
# 构造查询,通过 provider_name 精确查找对应的 LlmModel,返回首个结果或 None
+ return db.scalar(select(models.LlmModel).where(models.LlmModel.provider_name == provider_name))
# 定义更新 LLM 模型记录的函数,接收数据库会话、待更新行和更新数据
+def update_llm_model(db: Session, row: models.LlmModel, data: schemas.LlmModelUpdate) -> models.LlmModel:
# 如果 provider_name 不为 None,则更新并去除首尾空白
+ if data.provider_name is not None:
+ row.provider_name = data.provider_name.strip()
# 如果 provider_icon 不为 None,则直接赋值
+ if data.provider_icon is not None:
+ row.provider_icon = data.provider_icon
# 如果 api_base_url 不为 None,则更新并去除首尾空白
+ if data.api_base_url is not None:
+ row.api_base_url = data.api_base_url.strip()
# 如果 api_key 不为 None,则更新并去除首尾空白
+ if data.api_key is not None:
+ row.api_key = data.api_key.strip()
# 如果 api_key_url 不为 None,则去除首尾空白,否则设为 None
+ if data.api_key_url is not None:
+ row.api_key_url = data.api_key_url.strip() if data.api_key_url else None
# 如果 model_names 不为 None,则进行赋值
+ if data.model_names is not None:
+ row.model_names = data.model_names
# 提交数据库事务
+ db.commit()
# 刷新 session 以获取最新数据
+ db.refresh(row)
# 返回更新后的模型对象
+ return row 16.2. llm_models.py #
app/routers/llm_models.py
# 导入urljoin用于拼接URL
from urllib.parse import urljoin
# 导入httpx库(可用于发送HTTP请求,本文件未用到)
import httpx
# 从fastapi库导入APIRouter用于创建路由,Depends用于依赖注入,HTTPException用于异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从sqlalchemy导入IntegrityError,用于捕获唯一性冲突异常
from sqlalchemy.exc import IntegrityError
# 导入Session会话,用于数据库操作
from sqlalchemy.orm import Session
# 导入schemas定义的Pydantic模型
from app import schemas
# 导入llm_repository,用于大模型数据操作
from app.repositories import llm_repository
# 导入get_db获取数据库依赖
from app.database import get_db
# 从返回的payload中提取模型名称列表
def _extract_model_names(payload):
# 获取payload中的"data"字段
data = payload.get("data")
# 如果data不是列表,则直接返回空列表
if not isinstance(data, list):
return []
# 准备输出列表
out = []
# 用于去重的集合
seen = set()
# 遍历data中的每一项
for item in data:
# 如果当前项不是字典,跳过
if not isinstance(item, dict):
continue
# 取得"id"字段,转为字符串并去两端空白
name = str(item.get("id") or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经出现过,跳过
if key in seen:
continue
# 添加到已见集合
seen.add(key)
# 添加到输出列表
out.append(name)
# 返回整理后的模型名称列表
return out
# 创建APIRouter实例,设置路由前缀与标签
router = APIRouter(prefix="/api/llm-models", tags=["llm-models"])
# 定义POST类型接口,路径为"",响应模型为LlmModelOut
@router.post("", response_model=schemas.LlmModelOut)
def create_model(payload: schemas.LlmModelCreate, db: Session = Depends(get_db)):
# 尝试创建大模型记录
try:
return llm_repository.create_llm_model(db, payload)
# 捕获唯一性冲突异常(例如Provider的名字已存在)
except IntegrityError:
db.rollback()
# 抛出HTTP 409异常,提示"提供商名称已存在"
raise HTTPException(status_code=409, detail="提供商名称已存在")
# 定义路由POST接口 /probe,返回值为LlmModelTestResult
@router.post("/probe", response_model=schemas.LlmModelTestResult)
# 定义测试大模型服务的函数
def test_model_service(payload: schemas.LlmModelTestRequest):
# 去除api_base_url首尾空白字符
base_url = payload.api_base_url.strip()
# 去除api_key首尾空白字符
api_key = payload.api_key.strip()
# 构造模型列表接口的URL
models_url = urljoin(base_url.rstrip("/") + "/", "models")
# 准备请求头,添加Authorization字段
headers = {
"Authorization": f"Bearer {api_key}",
}
try:
# 创建HTTP客户端,设置超时时间和请求头
with httpx.Client(timeout=20.0, headers=headers) as client:
# 发送GET请求获取模型列表
resp = client.get(models_url)
# 检查HTTP响应状态码,若有异常则抛出
resp.raise_for_status()
# 解析响应体为JSON
body = resp.json()
# 调用工具函数提取模型名称列表
names = _extract_model_names(body if isinstance(body, dict) else {})
# 如果检测到模型名称
if names:
# 返回检测通过且包含模型数量和名称的结果
return schemas.LlmModelTestResult(
ok=True,
message=f"模型服务检测通过,可用模型 {len(names)} 个",
models=names,
)
# 如果没检测到模型列表,返回通过但无模型名提示
return schemas.LlmModelTestResult(
ok=True,
message="模型服务检测通过,但未识别到模型列表",
models=[],
)
# 捕获请求超时异常,返回相应错误信息
except httpx.TimeoutException:
return schemas.LlmModelTestResult(ok=False, message="请求超时,请检查 API 地址")
# 捕获HTTP状态异常,返回状态码和响应内容前300字符
except httpx.HTTPStatusError as e:
return schemas.LlmModelTestResult(ok=False, message=f"HTTP {e.response.status_code}: {e.response.text[:300]}")
# 捕获所有其他异常,返回异常信息
except Exception as e: # noqa: BLE001
return schemas.LlmModelTestResult(ok=False, message=f"模型服务检测失败: {e}")
# 定义GET接口用于获取所有大模型列表,响应为LlmModelOut对象的列表
@router.get("", response_model=list[schemas.LlmModelOut])
# 声明依赖注入数据库会话db
def list_models(db: Session = Depends(get_db)):
# 调用llm_repository中的方法获取所有大模型数据
return llm_repository.list_llm_models(db)
# 定义PUT接口用于更新指定ID的大模型记录,响应为LlmModelOut对象
+@router.put("/{llm_id}", response_model=schemas.LlmModelOut)
# 定义视图函数,llm_id为要更新的模型ID,payload为更新内容,db为数据库会话(依赖注入)
+def update_model(llm_id: int, payload: schemas.LlmModelUpdate, db: Session = Depends(get_db)):
# 根据llm_id查询对应的模型记录
+ row = llm_repository.get_llm_model(db, llm_id)
# 如果没有查到该记录,则返回404异常
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 如果提交的provider_name不为None,则检查唯一性
+ if payload.provider_name is not None:
# 根据去空格后的provider_name查找数据库中是否存在其他重名记录
+ other = llm_repository.get_llm_by_provider_name(db, payload.provider_name.strip())
# 如果找到的记录不是当前要更新的这条,则冲突
+ if other and other.id != llm_id:
+ raise HTTPException(status_code=409, detail="提供商名称已被其他记录使用")
+ try:
# 调用仓库层方法执行数据库更新,返回更新后的对象
+ return llm_repository.update_llm_model(db, row, payload)
+ except IntegrityError:
# 捕获唯一性约束冲突,回滚事务并返回409错误
+ db.rollback()
+ raise HTTPException(status_code=409, detail="提供商名称冲突") 16.3. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
@field_validator("config", mode="before")
@classmethod
def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
# 返回config原值
return v
# 定义McpTestResult模型,继承自BaseModel
class McpTestResult(BaseModel):
# 测试是否成功的标志
ok: bool
# 返回的信息或说明
message: str
# 工具列表,默认为空列表
tools: list[dict[str, Any]] = Field(default_factory=list)
# 定义LlmModelBase基类,用于存储大模型相关信息
class LlmModelBase(BaseModel):
# 提供商名称,必填,最小长度1,最大长度255
provider_name: str = Field(..., min_length=1, max_length=255)
# 提供商图标,可选
provider_icon: str | None = None
# API基础地址,必填,最小长度1,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必填,最小长度1,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# API密钥申请地址,可选,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,默认为空列表
model_names: list[str] = Field(default_factory=list)
# 对model_names字段进行校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names(cls, v: list[str]) -> list[str]:
# 定义输出列表
out: list[str] = []
# 用于存储已经出现过的模型名称(小写形式)以去重
seen: set[str] = set()
# 遍历输入值,如果为空则用空列表
for item in v or []:
# 将每个名称转为字符串并去除首尾空格
name = str(item or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经见过该名称,则跳过
if key in seen:
continue
# 添加到已见集合和输出列表
seen.add(key)
out.append(name)
# 返回归一化后的模型名称列表
return out
# 定义LlmModelCreate模型,继承自LlmModelBase,没有扩展字段
class LlmModelCreate(LlmModelBase):
pass
# 定义 LlmModelOut 类,继承自 BaseModel,用于大模型的返回数据结构
class LlmModelOut(BaseModel):
# 唯一ID,类型为整数
id: int
# 提供商名称,字符串类型
provider_name: str
# 提供商图标,可选,字符串或None
provider_icon: str | None
# API 基础地址,字符串类型
api_base_url: str
# API 密钥,字符串类型
api_key: str
# API 密钥申请地址,可选,字符串或None
api_key_url: str | None
# 模型名称列表,字符串列表类型
model_names: list[str]
# Pydantic 配置,启用从属性赋值(用于ORM模式)
model_config = {"from_attributes": True}
# 定义用于大模型服务测试请求体的Pydantic模型
class LlmModelTestRequest(BaseModel):
# API基础地址,必须为非空字符串,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必须为非空字符串,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# 定义 LlmModelTestResult 类,继承自 BaseModel,用于返回大模型服务检测的结果
class LlmModelTestResult(BaseModel):
# 检测是否通过,布尔类型
ok: bool
# 检测的信息提示,字符串类型
message: str
# 检测到的模型名称列表,默认为空列表
models: list[str] = Field(default_factory=list)
# 定义用于更新大模型信息的Pydantic模型
+class LlmModelUpdate(BaseModel):
# 提供商名称,可选字段,限定最小长度1,最大长度255
+ provider_name: str | None = Field(None, min_length=1, max_length=255)
# 提供商图标,可选字段
+ provider_icon: str | None = None
# API 基础地址,可选字段,限定最小长度1,最大长度1024
+ api_base_url: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥,可选字段,限定最小长度1,最大长度1024
+ api_key: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥申请地址,可选字段,最大长度1024
+ api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,可选字段
+ model_names: list[str] | None = None
# 对model_names字段做校验和归一化处理
+ @field_validator("model_names")
+ @classmethod
+ def normalize_model_names_optional(cls, v: list[str] | None) -> list[str] | None:
# 如果为None,直接返回None
+ if v is None:
+ return None
# 用于存储归一化后的模型名称
+ out: list[str] = []
# 用于去重
+ seen: set[str] = set()
# 遍历输入列表
+ for item in v:
# 转为字符串并去除首尾空白
+ name = str(item or "").strip()
# 如果名称为空则跳过
+ if not name:
+ continue
# 使用小写做去重key
+ key = name.lower()
# 如果已经见过则跳过
+ if key in seen:
+ continue
# 添加进已见集合
+ seen.add(key)
# 添加进输出列表
+ out.append(name)
# 返回归一化去重后的列表
+ return out 16.4.测试 #
curl --location --request PUT "http://127.0.0.1:8000/api/llm-models/1" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"provider_name\": \"深度探索4\", \"provider_icon\": \"/uploads/08b305a11707447696aecc2040f96ecc.png\", \"api_base_url\": \"https://api.deepseek.com\", \"api_key\": \"sk-24088156e9ab48f3adddaf5a9c0c4ede\", \"api_key_url\": \"https://platform.deepseek.com/api_keys2\", \"model_names\": [ \"deepseek-chat\", \"deepseek-reasoner\" ]}"17. 删除大语言模型 #
本节将介绍如何通过 API 删除已有的大语言模型(LLM)记录,包括相关接口说明、使用方法和注意事项。
后端实现说明
删除大语言模型的实现涉及到两部分代码改动:
仓库层(app/repositories/llm_repository.py)
新增delete_llm_model方法,传入数据库会话和要删除的模型对象,完成模型删除和事务提交。路由层(app/routers/llm_models.py)
新增DELETE /api/llm-models/{llm_id}接口:- 首先根据
llm_id查询待删除对象; - 如果对象不存在,返回 404;
- 调用仓库删除方法成功后返回
{ "ok": true }。
- 首先根据
接口使用方法
请求说明
- 接口路径:
DELETE /api/llm-models/{llm_id} - 路径参数:
llm_id需要删除的模型ID - 请求体:无
- 响应结果:
- 删除成功:
{ "ok": true } - 若ID不存在:
HTTP 404 Not Found
- 删除成功:
示例命令
curl --location --request DELETE "http://127.0.0.1:8000/api/llm-models/2" ^
--header "Content-Type: application/json"可能返回值
成功
{ "ok": true }记录不存在
{ "detail": "记录不存在" }
17.1. llm_repository.py #
app/repositories/llm_repository.py
# 导入SQLAlchemy的Session类,用于数据库会话
from sqlalchemy.orm import Session
from sqlalchemy import select
# 从app包导入models模块,包含数据库模型
from app import models
# 从app包导入schemas模块,包含数据校验模型
from app import schemas
# 定义创建LLM模型记录的函数,接收数据库会话和入参数据,返回新建的LlmModel实例
def create_llm_model(db: Session, data: schemas.LlmModelCreate) -> models.LlmModel:
# 创建LlmModel数据库对象,剔除多余空白并赋值各字段
row = models.LlmModel(
provider_name=data.provider_name.strip(), # 提供商名称,去除首尾空白
provider_icon=data.provider_icon, # 提供商图标
api_base_url=data.api_base_url.strip(), # API基础地址,去除首尾空白
api_key=data.api_key.strip(), # API密钥,去除首尾空白
api_key_url=data.api_key_url.strip() if data.api_key_url else None, # API密钥申请地址,有则去除空白,无则为None
model_names=data.model_names, # LLM模型名称列表
)
# 添加新纪录到数据库session
db.add(row)
# 提交事务保存数据
db.commit()
# 刷新session获取新数据
db.refresh(row)
# 返回新建的LlmModel对象
return row
# 定义获取所有大模型记录的函数,传入数据库会话,返回LlmModel对象的列表
def list_llm_models(db: Session) -> list[models.LlmModel]:
# 按id倒序查询所有LlmModel记录,转换为列表返回
return list(db.scalars(select(models.LlmModel).order_by(models.LlmModel.id.desc())).all())
# 根据指定的llm_id主键,从数据库中获取对应的LlmModel对象,无则返回None
def get_llm_model(db: Session, llm_id: int) -> models.LlmModel | None:
return db.get(models.LlmModel, llm_id)
# 定义根据提供商名称获取对应 LlmModel 记录的函数
def get_llm_by_provider_name(db: Session, provider_name: str) -> models.LlmModel | None:
# 构造查询,通过 provider_name 精确查找对应的 LlmModel,返回首个结果或 None
return db.scalar(select(models.LlmModel).where(models.LlmModel.provider_name == provider_name))
# 定义更新 LLM 模型记录的函数,接收数据库会话、待更新行和更新数据
def update_llm_model(db: Session, row: models.LlmModel, data: schemas.LlmModelUpdate) -> models.LlmModel:
# 如果 provider_name 不为 None,则更新并去除首尾空白
if data.provider_name is not None:
row.provider_name = data.provider_name.strip()
# 如果 provider_icon 不为 None,则直接赋值
if data.provider_icon is not None:
row.provider_icon = data.provider_icon
# 如果 api_base_url 不为 None,则更新并去除首尾空白
if data.api_base_url is not None:
row.api_base_url = data.api_base_url.strip()
# 如果 api_key 不为 None,则更新并去除首尾空白
if data.api_key is not None:
row.api_key = data.api_key.strip()
# 如果 api_key_url 不为 None,则去除首尾空白,否则设为 None
if data.api_key_url is not None:
row.api_key_url = data.api_key_url.strip() if data.api_key_url else None
# 如果 model_names 不为 None,则进行赋值
if data.model_names is not None:
row.model_names = data.model_names
# 提交数据库事务
db.commit()
# 刷新 session 以获取最新数据
db.refresh(row)
# 返回更新后的模型对象
return row
# 定义删除 LLM 模型记录的函数,接收数据库会话和要删除的模型行
+def delete_llm_model(db: Session, row: models.LlmModel) -> None:
# 从数据库会话中删除指定的模型对象
+ db.delete(row)
# 提交事务,执行删除操作
+ db.commit()
17.2. llm_models.py #
app/routers/llm_models.py
# 导入urljoin用于拼接URL
from urllib.parse import urljoin
# 导入httpx库(可用于发送HTTP请求,本文件未用到)
import httpx
# 从fastapi库导入APIRouter用于创建路由,Depends用于依赖注入,HTTPException用于异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从sqlalchemy导入IntegrityError,用于捕获唯一性冲突异常
from sqlalchemy.exc import IntegrityError
# 导入Session会话,用于数据库操作
from sqlalchemy.orm import Session
# 导入schemas定义的Pydantic模型
from app import schemas
# 导入llm_repository,用于大模型数据操作
from app.repositories import llm_repository
# 导入get_db获取数据库依赖
from app.database import get_db
# 从返回的payload中提取模型名称列表
def _extract_model_names(payload):
# 获取payload中的"data"字段
data = payload.get("data")
# 如果data不是列表,则直接返回空列表
if not isinstance(data, list):
return []
# 准备输出列表
out = []
# 用于去重的集合
seen = set()
# 遍历data中的每一项
for item in data:
# 如果当前项不是字典,跳过
if not isinstance(item, dict):
continue
# 取得"id"字段,转为字符串并去两端空白
name = str(item.get("id") or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经出现过,跳过
if key in seen:
continue
# 添加到已见集合
seen.add(key)
# 添加到输出列表
out.append(name)
# 返回整理后的模型名称列表
return out
# 创建APIRouter实例,设置路由前缀与标签
router = APIRouter(prefix="/api/llm-models", tags=["llm-models"])
# 定义POST类型接口,路径为"",响应模型为LlmModelOut
@router.post("", response_model=schemas.LlmModelOut)
def create_model(payload: schemas.LlmModelCreate, db: Session = Depends(get_db)):
# 尝试创建大模型记录
try:
return llm_repository.create_llm_model(db, payload)
# 捕获唯一性冲突异常(例如Provider的名字已存在)
except IntegrityError:
db.rollback()
# 抛出HTTP 409异常,提示"提供商名称已存在"
raise HTTPException(status_code=409, detail="提供商名称已存在")
# 定义路由POST接口 /probe,返回值为LlmModelTestResult
@router.post("/probe", response_model=schemas.LlmModelTestResult)
# 定义测试大模型服务的函数
def test_model_service(payload: schemas.LlmModelTestRequest):
# 去除api_base_url首尾空白字符
base_url = payload.api_base_url.strip()
# 去除api_key首尾空白字符
api_key = payload.api_key.strip()
# 构造模型列表接口的URL
models_url = urljoin(base_url.rstrip("/") + "/", "models")
# 准备请求头,添加Authorization字段
headers = {
"Authorization": f"Bearer {api_key}",
}
try:
# 创建HTTP客户端,设置超时时间和请求头
with httpx.Client(timeout=20.0, headers=headers) as client:
# 发送GET请求获取模型列表
resp = client.get(models_url)
# 检查HTTP响应状态码,若有异常则抛出
resp.raise_for_status()
# 解析响应体为JSON
body = resp.json()
# 调用工具函数提取模型名称列表
names = _extract_model_names(body if isinstance(body, dict) else {})
# 如果检测到模型名称
if names:
# 返回检测通过且包含模型数量和名称的结果
return schemas.LlmModelTestResult(
ok=True,
message=f"模型服务检测通过,可用模型 {len(names)} 个",
models=names,
)
# 如果没检测到模型列表,返回通过但无模型名提示
return schemas.LlmModelTestResult(
ok=True,
message="模型服务检测通过,但未识别到模型列表",
models=[],
)
# 捕获请求超时异常,返回相应错误信息
except httpx.TimeoutException:
return schemas.LlmModelTestResult(ok=False, message="请求超时,请检查 API 地址")
# 捕获HTTP状态异常,返回状态码和响应内容前300字符
except httpx.HTTPStatusError as e:
return schemas.LlmModelTestResult(ok=False, message=f"HTTP {e.response.status_code}: {e.response.text[:300]}")
# 捕获所有其他异常,返回异常信息
except Exception as e: # noqa: BLE001
return schemas.LlmModelTestResult(ok=False, message=f"模型服务检测失败: {e}")
# 定义GET接口用于获取所有大模型列表,响应为LlmModelOut对象的列表
@router.get("", response_model=list[schemas.LlmModelOut])
# 声明依赖注入数据库会话db
def list_models(db: Session = Depends(get_db)):
# 调用llm_repository中的方法获取所有大模型数据
return llm_repository.list_llm_models(db)
# 定义PUT接口用于更新指定ID的大模型记录,响应为LlmModelOut对象
@router.put("/{llm_id}", response_model=schemas.LlmModelOut)
# 定义视图函数,llm_id为要更新的模型ID,payload为更新内容,db为数据库会话(依赖注入)
def update_model(llm_id: int, payload: schemas.LlmModelUpdate, db: Session = Depends(get_db)):
# 根据llm_id查询对应的模型记录
row = llm_repository.get_llm_model(db, llm_id)
# 如果没有查到该记录,则返回404异常
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 如果提交的provider_name不为None,则检查唯一性
if payload.provider_name is not None:
# 根据去空格后的provider_name查找数据库中是否存在其他重名记录
other = llm_repository.get_llm_by_provider_name(db, payload.provider_name.strip())
# 如果找到的记录不是当前要更新的这条,则冲突
if other and other.id != llm_id:
raise HTTPException(status_code=409, detail="提供商名称已被其他记录使用")
try:
# 调用仓库层方法执行数据库更新,返回更新后的对象
return llm_repository.update_llm_model(db, row, payload)
except IntegrityError:
# 捕获唯一性约束冲突,回滚事务并返回409错误
db.rollback()
raise HTTPException(status_code=409, detail="提供商名称冲突")
# 定义DELETE接口用于删除指定ID的大模型记录
+@router.delete("/{llm_id}")
# 定义删除大模型的函数,接收llm_id参数和数据库会话(依赖注入)
+def delete_model(llm_id: int, db: Session = Depends(get_db)):
# 根据传入的llm_id从数据库查询对应的模型记录
+ row = llm_repository.get_llm_model(db, llm_id)
# 如果未查到该记录,则抛出404 Not Found异常
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库层方法删除该模型记录
+ llm_repository.delete_llm_model(db, row)
# 删除成功后返回ok为True的响应
+ return {"ok": True} 17.3 测试 #
curl --location --request DELETE "http://127.0.0.1:8000/api/llm-models/2" ^
--header "Content-Type: application/json"18. 添加查看智能体 #
本节介绍如何在平台中查看已创建的所有智能体(Agent)信息。
功能说明
“查看智能体”功能允许你通过接口一次性获取所有已注册的智能体(如旅游规划助手等)的详细配置信息,便于管理、展示或调试。返回结果包括每个智能体的基础属性(如ID、头像、名称、描述、引擎型号、提示词、参数配置等),适用于管理后台展示或前端下拉选择。
主要请求接口
接口路径:GET /api/agents
主要用途:
- 获取所有智能体(Agent)配置信息列表
- 可用于智能体管理页面、表单智能体选择、参数预览等场景
返回示例
每个智能体返回如下示例字段(字段类型见 AgentOut 响应模型,可根据实际数据库表结构自动序列化):
[
{
"id": 1,
"avatar": "/uploads/b20afbe9728742579b53d003ab7a7008.png",
"name": "旅游规划智能体",
"description": "面向中文出行场景的一站式旅游规划助手:可先问清出发地、目的地、行程与偏好...",
"opening_message": "你好,我是你的旅游规划智能体 👋 ...",
"system_prompt": "你是“旅游规划智能体”,目标是帮助用户完成从“去哪玩”到“怎么去、天气如何、是否适合出行”的一站式决策...",
"llm_provider_name": "深度探索",
"llm_model_name": "deepseek-chat",
"mcp_service_ids": [5,4,3],
"ask_prompt_template": "你是一名资深中文旅行规划师。请基于以下用户信息,输出一份可执行、具体、务实的旅行方案...",
"ask_variables": [
{"key": "出发地", "label": "出发地", "question": "你从哪个城市出发?", "required": true, "default": ""},
{"key": "目的地", "label": "目的地", "question": "你要去哪个城市?", "required": true, "default": ""},
...
]
},
...
]使用方法
- 前端或测试工具(如 curl、Postman)直接向该接口发送 GET 请求。
- 服务端返回所有已注册智能体的完整配置信息。
- 可结合分页、筛选等业务需求进行扩展。
示例:
curl --location --request GET "http://127.0.0.1:8000/api/agents" ^
--header "Content-Type: application/json"预期响应
- 返回 200,数据为 AgentOut 类型的列表
- 每个智能体均为一条完整的 JSON 记录
相关代码说明
- 智能体管理相关的数据访问(如 list_agents)已在
app/repositories/agent_repository.py实现。 - 路由接口部分见
app/routers/agents.py文件。 - 响应序列化结构见
schemas.AgentOut,可根据业务需求灵活调整。
18.1. agent_repository.py #
app/repositories/agent_repository.py
# 导入Session类用于数据库会话管理
from sqlalchemy.orm import Session
from sqlalchemy import select
# 从app模块导入models和schemas
from app import models
from app import schemas
# 定义创建Agent的函数,参数为数据库会话和Agent创建数据,返回创建后的Agent模型
def create_agent(db: Session, data: schemas.AgentCreate) -> models.Agent:
# 创建Agent模型实例,将前端传入的数据赋值到各个字段
row = models.Agent(
avatar=data.avatar, # 头像
name=data.name.strip(), # 名称,去除首尾空格
description=data.description, # 描述
opening_message=(data.opening_message or "").strip() or None, # 开场白,去除空格,允许为None
system_prompt=data.system_prompt.strip(), # 系统提示,去除空格
llm_provider_name=data.llm_provider_name.strip(), # LLM提供商名称,去除空格
llm_model_name=data.llm_model_name.strip(), # LLM模型名称,去除空格
mcp_service_ids=data.mcp_service_ids, # 关联MCP服务ID列表
ask_prompt_template=(data.ask_prompt_template or "").strip() or None, # 提问提示词模板,去除空格,允许为None
ask_variables=data.ask_variables or [], # 提问变量,默认为空列表
)
# 将Agent实例添加到数据库会话
db.add(row)
# 提交事务,保存到数据库
db.commit()
# 刷新实例,获取数据库生成的最新字段(如自增ID)
db.refresh(row)
# 返回创建后的Agent对象
return row
# 定义list_agents函数,接收一个数据库会话db作为参数,返回Agent对象列表
def list_agents(db: Session) -> list[models.Agent]:
# 使用select语句查询Agent表,并按id倒序排序,获取所有Agent对象
return list(db.scalars(select(models.Agent).order_by(models.Agent.id.desc())).all()) 18.2. agents.py #
app/routers/agents.py
# 从 fastapi 导入 APIRouter, Depends 以及 HTTPException 异常
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session 会话对象
from sqlalchemy.orm import Session
# 从 app 包分别导入 schemas 模块
from app import schemas
# 导入数据库依赖获取函数
from app.database import get_db
# 导入 agent_repository、llm_repository 和 mcp_repository,分别处理不同的数据操作
from app.repositories import agent_repository, llm_repository, mcp_repository
# 创建一个 APIRouter 实例,设置路由的前缀和标签
router = APIRouter(prefix="/api/agents", tags=["agents"])
# 校验相关引用有效性(大模型提供商、模型、MCP 服务)
def _validate_refs(db: Session, provider_name: str, model_name: str, mcp_service_ids: list[int]) -> None:
# 根据提供商名称查询 LLM 提供商数据
llm_row = llm_repository.get_llm_by_provider_name(db, provider_name.strip())
# 如果没有找到对应的 LLM 提供商则抛出 HTTP 400 异常
if not llm_row:
raise HTTPException(status_code=400, detail=f"大语言模型提供商不存在: {provider_name}")
# 获取该提供商下的所有模型名,并做字符串清洗
llm_models = [str(m or "").strip() for m in (llm_row.model_names or [])]
# 如果传入的模型名称不属于当前提供商的模型,则抛出 HTTP 400 异常
if model_name.strip() not in llm_models:
raise HTTPException(status_code=400, detail=f"模型名称不属于提供商 {provider_name}: {model_name}")
# 遍历所有 mcp_service_ids
for sid in mcp_service_ids:
# 校验每一个 mcp_service 是否存在,如果不存在则抛出 400 异常
if not mcp_repository.get_mcp_service(db, int(sid)):
raise HTTPException(status_code=400, detail=f"MCP 服务不存在: {sid}")
# 定义创建 agent 的接口,POST 请求,响应体为 AgentOut 模型
@router.post("", response_model=schemas.AgentOut)
def create_agent(payload: schemas.AgentCreate, db: Session = Depends(get_db)):
# 调用 _validate_refs 校验 Agent 创建时关联的 LLM 提供商、模型、MCP 服务是否有效
_validate_refs(db, payload.llm_provider_name, payload.llm_model_name, payload.mcp_service_ids)
# 校验通过后,调用 agent_repository 创建新的 Agent,并返回创建结果
return agent_repository.create_agent(db, payload)
# 定义 GET 接口用于获取所有 Agent 列表,响应为 AgentOut 对象列表
@router.get("", response_model=list[schemas.AgentOut])
# 定义 list_agents 视图函数,依赖注入数据库会话 db
def list_agents(db: Session = Depends(get_db)):
# 调用 agent_repository 的 list_agents 方法,获取所有 Agent 数据
return agent_repository.list_agents(db) 18.3. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 导入logging库以便后续日志记录
import logging
# 导入asynccontextmanager用于异步上下文管理器
from contextlib import asynccontextmanager
# 导入FastAPI的静态文件中间件,用于静态文件服务
from fastapi.staticfiles import StaticFiles
# 导入FastAPI的CORS中间件,用于跨域资源共享
from fastapi.middleware.cors import CORSMiddleware
# 导入项目配置settings对象
from app.config import settings
# 从app.database模块导入Base和engine,用于数据库相关操作
from app.database import Base, engine
# 导入app.models模块下的所有内容(类、函数等)
from app.models import *
# 导入MCP服务路由 和 大模型路由
+from app.routers import mcp_services,llm_models,uploads,agents
# 配置日志的基本设置,日志级别为INFO
logging.basicConfig(level=logging.INFO)
# 定义一个异步上下文管理器,用于FastAPI生命周期
@asynccontextmanager
async def lifespan(app: FastAPI):
# 创建所有数据库表结构(如果未存在则自动创建)
Base.metadata.create_all(bind=engine)
# 通过yield挂起,等待应用关闭时进行清理
yield
# 创建FastAPI应用实例,设置API标题和版本号,并指定生命周期管理器
app = FastAPI(title="智能体服务", version="0.1.0", lifespan=lifespan)
# 解析配置中的CORS来源列表,去除空白项和空字符串
origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
# 向FastAPI应用添加CORS中间件
app.add_middleware(
CORSMiddleware,
# 允许访问的来源列表,如果为空则允许所有来源("*")
allow_origins=origins or ["*"],
# 允许携带cookie等凭证
allow_credentials=True,
# 允许所有HTTP方法
allow_methods=["*"],
# 允许所有HTTP头
allow_headers=["*"],
)
# 包含MCP服务路由
app.include_router(mcp_services.router)
app.include_router(llm_models.router)
app.include_router(uploads.router)
+app.include_router(agents.router)
upload_root = settings.upload_path()
upload_root.mkdir(parents=True, exist_ok=True)
app.mount("/uploads", StaticFiles(directory=str(upload_root)), name="uploads")
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}18.4. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
@field_validator("config", mode="before")
@classmethod
def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
# 返回config原值
return v
# 定义McpTestResult模型,继承自BaseModel
class McpTestResult(BaseModel):
# 测试是否成功的标志
ok: bool
# 返回的信息或说明
message: str
# 工具列表,默认为空列表
tools: list[dict[str, Any]] = Field(default_factory=list)
# 定义LlmModelBase基类,用于存储大模型相关信息
class LlmModelBase(BaseModel):
# 提供商名称,必填,最小长度1,最大长度255
provider_name: str = Field(..., min_length=1, max_length=255)
# 提供商图标,可选
provider_icon: str | None = None
# API基础地址,必填,最小长度1,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必填,最小长度1,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# API密钥申请地址,可选,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,默认为空列表
model_names: list[str] = Field(default_factory=list)
# 对model_names字段进行校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names(cls, v: list[str]) -> list[str]:
# 定义输出列表
out: list[str] = []
# 用于存储已经出现过的模型名称(小写形式)以去重
seen: set[str] = set()
# 遍历输入值,如果为空则用空列表
for item in v or []:
# 将每个名称转为字符串并去除首尾空格
name = str(item or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经见过该名称,则跳过
if key in seen:
continue
# 添加到已见集合和输出列表
seen.add(key)
out.append(name)
# 返回归一化后的模型名称列表
return out
# 定义LlmModelCreate模型,继承自LlmModelBase,没有扩展字段
class LlmModelCreate(LlmModelBase):
pass
# 定义 LlmModelOut 类,继承自 BaseModel,用于大模型的返回数据结构
class LlmModelOut(BaseModel):
# 唯一ID,类型为整数
id: int
# 提供商名称,字符串类型
provider_name: str
# 提供商图标,可选,字符串或None
provider_icon: str | None
# API 基础地址,字符串类型
api_base_url: str
# API 密钥,字符串类型
api_key: str
# API 密钥申请地址,可选,字符串或None
api_key_url: str | None
# 模型名称列表,字符串列表类型
model_names: list[str]
# Pydantic 配置,启用从属性赋值(用于ORM模式)
model_config = {"from_attributes": True}
# 定义用于大模型服务测试请求体的Pydantic模型
class LlmModelTestRequest(BaseModel):
# API基础地址,必须为非空字符串,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必须为非空字符串,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# 定义 LlmModelTestResult 类,继承自 BaseModel,用于返回大模型服务检测的结果
class LlmModelTestResult(BaseModel):
# 检测是否通过,布尔类型
ok: bool
# 检测的信息提示,字符串类型
message: str
# 检测到的模型名称列表,默认为空列表
models: list[str] = Field(default_factory=list)
# 定义用于更新大模型信息的Pydantic模型
class LlmModelUpdate(BaseModel):
# 提供商名称,可选字段,限定最小长度1,最大长度255
provider_name: str | None = Field(None, min_length=1, max_length=255)
# 提供商图标,可选字段
provider_icon: str | None = None
# API 基础地址,可选字段,限定最小长度1,最大长度1024
api_base_url: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥,可选字段,限定最小长度1,最大长度1024
api_key: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥申请地址,可选字段,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,可选字段
model_names: list[str] | None = None
# 对model_names字段做校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names_optional(cls, v: list[str] | None) -> list[str] | None:
# 如果为None,直接返回None
if v is None:
return None
# 用于存储归一化后的模型名称
out: list[str] = []
# 用于去重
seen: set[str] = set()
# 遍历输入列表
for item in v:
# 转为字符串并去除首尾空白
name = str(item or "").strip()
# 如果名称为空则跳过
if not name:
continue
# 使用小写做去重key
key = name.lower()
# 如果已经见过则跳过
if key in seen:
continue
# 添加进已见集合
seen.add(key)
# 添加进输出列表
out.append(name)
# 返回归一化去重后的列表
return out
# 定义函数,用于规范化 ask_variables 字段,输入为可选的字典列表,返回规范化后的字典列表
+def _normalize_ask_variables(v):
# 用于保存处理后的变量字典
+ out = []
# 用于记录已出现过的变量key,实现去重
+ seen = set()
# 遍历输入列表(如果为None则转为空列表)
+ for item in v or []:
# 如果当前元素不是字典类型,则跳过
+ if not isinstance(item, dict):
+ continue
# 从字典中获取key字段,去除首尾空白转字符串
+ key = str(item.get("key") or "").strip()
# 如果key为空,则跳过本轮
+ if not key:
+ continue
# 如果key已出现,则跳过实现去重
+ if key in seen:
+ continue
# 将当前key加入去重集合
+ seen.add(key)
# 获取question字段、去空白
+ question = str(item.get("question") or "").strip()
# 获取label字段、去空白
+ label = str(item.get("label") or "").strip()
# 获取default字段、去空白
+ default_value = str(item.get("default") or "").strip()
# 获取required字段,默认为True
+ required = bool(item.get("required", True))
# 将变量信息归一化后放入输出列表
+ out.append(
+ {
+ "key": key,
+ "label": label,
+ "question": question or f"请提供 {label or key}",
+ "required": required,
+ "default": default_value,
+ }
+ )
# 返回归一化和去重后的变量列表
+ return out
# 定义 AgentBase 基础模型,继承自 Pydantic 的 BaseModel
+class AgentBase(BaseModel):
# 头像字段,可为空
+ avatar: str | None = None
# 智能体名称,必填,长度1~255
+ name: str = Field(..., min_length=1, max_length=255)
# 描述信息,可为空
+ description: str | None = None
# 开场白内容,可为空
+ opening_message: str | None = None
# 智能体系统提示,必填,最小长度1
+ system_prompt: str = Field(..., min_length=1)
# LLM 提供商名称,必填,长度1~255
+ llm_provider_name: str = Field(..., min_length=1, max_length=255)
# LLM 模型名称,必填,长度1~255
+ llm_model_name: str = Field(..., min_length=1, max_length=255)
# 关联的 MCP 服务ID列表,默认为空列表
+ mcp_service_ids: list[int] = Field(default_factory=list)
# 询问提示词模板,可为空
+ ask_prompt_template: str | None = None
# 询问变量列表,默认为空列表
+ ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# 对 mcp_service_ids 字段做归一化校验
+ @field_validator("mcp_service_ids")
+ @classmethod
+ def normalize_mcp_service_ids(cls, v):
# 存储去重后的有效 id
+ out = []
# 已经出现过的 id 集合
+ seen = set()
# 遍历 id 列表(防止为 None)
+ for item in v or []:
# 转成整数类型
+ num = int(item)
# 跳过小于等于0的无效 id
+ if num <= 0:
+ continue
# 跳过重复 id
+ if num in seen:
+ continue
# 添加到去重集合
+ seen.add(num)
# 添加到输出集合
+ out.append(num)
# 返回整理后的 id 列表
+ return out
# 对 ask_variables 字段做归一化校验
+ @field_validator("ask_variables")
+ @classmethod
+ def normalize_ask_variables(cls, v):
# 使用 _normalize_ask_variables 函数处理
+ return _normalize_ask_variables(v)
# 定义 AgentCreate 创建模型,继承自 AgentBase,无额外字段
+class AgentCreate(AgentBase):
+ pass
# 定义AgentOut响应模型,继承自BaseModel
+class AgentOut(BaseModel):
# 主键ID
+ id: int
# 头像,允许为None
+ avatar: str | None
# 智能体名称
+ name: str
# 智能体描述,允许为None
+ description: str | None
# 开场白,允许为None
+ opening_message: str | None
# 智能体系统提示,不可为None
+ system_prompt: str
# LLM提供商名称
+ llm_provider_name: str
# LLM模型名称
+ llm_model_name: str
# 关联的MCP服务ID列表
+ mcp_service_ids: list[int]
# 询问提示词模板,允许为None
+ ask_prompt_template: str | None
# 询问变量列表,默认为空列表
+ ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# Pydantic配置:允许根据对象属性创建模型(ORM模式)
+ model_config = {"from_attributes": True}18.5 测试 #
curl --location --request POST "http://127.0.0.1:8000/api/agents" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"avatar\": \"/uploads/b20afbe9728742579b53d003ab7a7008.png\", \"name\": \"旅游规划智能体\", \"description\": \"面向中文出行场景的一站式旅游规划助手:可先问清出发地、目的地、行程与偏好,再结合地点检索、路线规划与天气预报等能力,输出可执行的行程摘要、逐日安排、交通与住宿建议及注意事项。\", \"opening_message\": \"你好,我是你的旅游规划智能体 👋 \\n我可以帮你一站式完成:**去哪玩、怎么去、天气如何、是否适合出行**。\\n\\n为了给你生成一份可执行的旅行方案,我会先快速确认几个关键信息(如出发地、目的地、天数、预算、人数和偏好),然后再结合工具结果给你:\\n1. 结论摘要 \\n2. 逐日行程 \\n3. 路线与交通建议 \\n4. 天气与出行提醒 \\n5. 预算与注意事项\\n\\n我们现在开始吧,我先问你第一个问题。\", \"system_prompt\": \"你是“旅游规划智能体”,目标是帮助用户完成从“去哪玩”到“怎么去、天气如何、是否适合出行”的一站式决策。\\n\\n【核心职责】\\n1. 根据用户需求推荐地点(景点/酒店/餐饮等)。\\n2. 提供路线规划(自驾或公共交通),并给出关键出行建议。\\n3. 提供目的地未来天气预报,结合天气给出出行提醒。\\n4. 回答必须真实、可执行、结构清晰,不夸大、不编造。\\n\\n【工具使用规则】\\n- 与地点检索相关:调用 search_place。\\n- 与路线规划相关:调用 plan_route。\\n- 与天气相关:调用 get_travel_forecast。\\n- 只要问题涉及“地点推荐、路线、天气”中的任意一项,应优先调用工具,不要凭空臆测。\\n- 若用户信息不足(如缺少出发地、目的地、天数、偏好),先用 1-2 个关键问题补齐后再调用工具。\\n- 当工具调用失败或结果为空时,说明原因并给出下一步可操作建议(例如改关键词、改地区、改出行方式)。\\n\\n【对话策略】\\n- 用户目标优先:先判断用户是“选目的地”“查路线”“看天气”还是“完整行程规划”。\\n- 组合调用:\\n - 完整规划时,默认按“地点 -> 路线 -> 天气”顺序调用。\\n - 若用户已指定地点,可直接“路线 + 天气”。\\n- 对用户提到“高铁/动车/地铁/公交/飞机”等,路线优先使用公共交通思路。\\n- 对用户提到“不走高速”等偏好,路线建议中明确体现。\\n\\n【输出风格】\\n- 默认中文,简洁专业,先结论后细节。\\n- 输出结构建议:\\n 1) 结论摘要(1-3 行)\\n 2) 推荐方案(地点/路线/天气)\\n 3) 注意事项(预算、天气、时段、备选)\\n- 涉及时间、温度、里程、时长时,尽量保留工具结果中的关键数据。\\n- 不输出工具内部实现细节,不暴露密钥、请求头等敏感信息。\\n\\n【安全与边界】\\n- 不编造未获取到的数据;不确定时明确说“不确定”并建议重新检索。\\n- 医疗、法律、政策等高风险问题仅给一般性建议并提示咨询专业机构。\\n- 用户要求与旅游无关的内容时,礼貌回应并引导回到旅游场景。\\n\\n现在开始:优先理解用户旅行意图,并按需调用工具给出可执行建议。\", \"ask_prompt_template\": \"你是一名资深中文旅行规划师。请基于以下用户信息,输出一份可执行、具体、务实的旅行方案。\\n\\n【用户信息】\\n- 出发地:{{出发地}}\\n- 目的地:{{目的地}}\\n- 游玩天数:{{游玩天数}} 天\\n- 出发日期:{{出发日期}}\\n- 预算档位:{{预算档位}}\\n- 出行人数:{{出行人数}}\\n- 出行偏好:{{出行偏好}}\\n- 交通偏好:{{交通偏好}}\\n- 住宿偏好:{{住宿偏好}}\\n- 必去点:{{必去点}}\\n- 避开项:{{避开项}}\\n\\n【输出要求】\\n1) 先给“总览结论”(3-5条)\\n2) 再给逐日行程(Day1 ~ DayN),每一天包含:\\n - 上午/下午/晚上安排\\n - 核心景点与停留时长建议\\n - 餐饮建议(当地特色)\\n3) 给交通方案对比(至少2种:时间、成本、优缺点)\\n4) 给住宿区域建议(2-3个片区,适合人群、预算)\\n5) 给预算拆分(交通/住宿/餐饮/门票)\\n6) 给注意事项(天气、穿衣、预约、避坑)\\n7) 如果信息不足,先明确列出“仍需补充的信息”,再给“可先执行的临时方案”。\\n\\n请使用清晰的 Markdown 格式输出,标题层级明确,内容尽量具体到可直接执行。\", \"ask_variables\": [ { \"key\": \"出发地\", \"label\": \"出发地\", \"question\": \"你从哪个城市出发?\", \"required\": true, \"default\": \"\" }, { \"key\": \"目的地\", \"label\": \"目的地\", \"question\": \"你要去哪个城市?\", \"required\": true, \"default\": \"\" }, { \"key\": \"游玩天数\", \"label\": \"游玩天数\", \"question\": \"你计划游玩几天?\", \"required\": true, \"default\": \"\" }, { \"key\": \"出发日期\", \"label\": \"出发日期\", \"question\": \"你的出发日期是?\", \"required\": true, \"default\": \"\" }, { \"key\": \"预算档位\", \"label\": \"预算档位\", \"question\": \"你的预算档位是?(低/中/高)\", \"required\": true, \"default\": \"\" }, { \"key\": \"出行人数\", \"label\": \"出行人数\", \"question\": \"有几人同行?\", \"required\": true, \"default\": \"\" }, { \"key\": \"出行偏好\", \"label\": \"出行偏好\", \"question\": \"你偏好哪种旅行风格?(可选自然风光/人文历史/亲子/美食/轻松等,可多选)\", \"required\": true, \"default\": \"\" }, { \"key\": \"交通偏好\", \"label\": \"交通偏好\", \"question\": \"你偏好哪种交通方式?(高铁/飞机/自驾/无偏好)\", \"required\": true, \"default\": \"\" }, { \"key\": \"住宿偏好\", \"label\": \"住宿偏好\", \"question\": \"你的住宿偏好是?(经济/舒适/高档)\", \"required\": true, \"default\": \"\" }, { \"key\": \"必去点\", \"label\": \"必去点\", \"question\": \"有哪些必去的景点或地点?\", \"required\": true, \"default\": \"\" }, { \"key\": \"避开项\", \"label\": \"避开项\", \"question\": \"有哪些不想去的地方或忌口/限制?\", \"required\": true, \"default\": \"\" } ], \"llm_provider_name\": \"深度探索\", \"llm_model_name\": \"deepseek-chat\", \"mcp_service_ids\": [ 5, 4, 3 ]}"18.6 旅游规划智能体 #
18.6.1 描述 #
面向中文出行场景的一站式旅游规划助手:可先问清出发地、目的地、行程与偏好,再结合地点检索、路线规划与天气预报等能力,输出可执行的行程摘要、逐日安排、交通与住宿建议及注意事项。18.6.2 开场白 #
你好,我是你的旅游规划智能体 👋
我可以帮你一站式完成:**去哪玩、怎么去、天气如何、是否适合出行**。
为了给你生成一份可执行的旅行方案,我会先快速确认几个关键信息(如出发地、目的地、天数、预算、人数和偏好),然后再结合工具结果给你:
1. 结论摘要
2. 逐日行程
3. 路线与交通建议
4. 天气与出行提醒
5. 预算与注意事项
我们现在开始吧,我先问你第一个问题。18.6.3 系统提示词 #
你是“旅游规划智能体”,目标是帮助用户完成从“去哪玩”到“怎么去、天气如何、是否适合出行”的一站式决策。
【核心职责】
1. 根据用户需求推荐地点(景点/酒店/餐饮等)。
2. 提供路线规划(自驾或公共交通),并给出关键出行建议。
3. 提供目的地未来天气预报,结合天气给出出行提醒。
4. 回答必须真实、可执行、结构清晰,不夸大、不编造。
【工具使用规则】
- 与地点检索相关:调用 search_place。
- 与路线规划相关:调用 plan_route。
- 与天气相关:调用 get_travel_forecast。
- 只要问题涉及“地点推荐、路线、天气”中的任意一项,应优先调用工具,不要凭空臆测。
- 若用户信息不足(如缺少出发地、目的地、天数、偏好),先用 1-2 个关键问题补齐后再调用工具。
- 当工具调用失败或结果为空时,说明原因并给出下一步可操作建议(例如改关键词、改地区、改出行方式)。
【对话策略】
- 用户目标优先:先判断用户是“选目的地”“查路线”“看天气”还是“完整行程规划”。
- 组合调用:
- 完整规划时,默认按“地点 -> 路线 -> 天气”顺序调用。
- 若用户已指定地点,可直接“路线 + 天气”。
- 对用户提到“高铁/动车/地铁/公交/飞机”等,路线优先使用公共交通思路。
- 对用户提到“不走高速”等偏好,路线建议中明确体现。
【输出风格】
- 默认中文,简洁专业,先结论后细节。
- 输出结构建议:
1) 结论摘要(1-3 行)
2) 推荐方案(地点/路线/天气)
3) 注意事项(预算、天气、时段、备选)
- 涉及时间、温度、里程、时长时,尽量保留工具结果中的关键数据。
- 不输出工具内部实现细节,不暴露密钥、请求头等敏感信息。
【安全与边界】
- 不编造未获取到的数据;不确定时明确说“不确定”并建议重新检索。
- 医疗、法律、政策等高风险问题仅给一般性建议并提示咨询专业机构。
- 用户要求与旅游无关的内容时,礼貌回应并引导回到旅游场景。
现在开始:优先理解用户旅行意图,并按需调用工具给出可执行建议。18.6.4 提问提示词模板 #
你是一名资深中文旅行规划师。请基于以下用户信息,输出一份可执行、具体、务实的旅行方案。
【用户信息】
- 出发地:{{出发地}}
- 目的地:{{目的地}}
- 游玩天数:{{游玩天数}} 天
- 出发日期:{{出发日期}}
- 预算档位:{{预算档位}}
- 出行人数:{{出行人数}}
- 出行偏好:{{出行偏好}}
- 交通偏好:{{交通偏好}}
- 住宿偏好:{{住宿偏好}}
- 必去点:{{必去点}}
- 避开项:{{避开项}}
【输出要求】
1) 先给“总览结论”(3-5条)
2) 再给逐日行程(Day1 ~ DayN),每一天包含:
- 上午/下午/晚上安排
- 核心景点与停留时长建议
- 餐饮建议(当地特色)
3) 给交通方案对比(至少2种:时间、成本、优缺点)
4) 给住宿区域建议(2-3个片区,适合人群、预算)
5) 给预算拆分(交通/住宿/餐饮/门票)
6) 给注意事项(天气、穿衣、预约、避坑)
7) 如果信息不足,先明确列出“仍需补充的信息”,再给“可先执行的临时方案”。
请使用清晰的 Markdown 格式输出,标题层级明确,内容尽量具体到可直接执行。18.6.5 提问变量配置 #
[
{"key": "departure_city", "label": "出发地", "default": "", "question": "你从哪个城市出发?", "required": true},
{"key": "destination_city", "label": "目的地", "default": "", "question": "你要去哪个城市?", "required": true},
{"key": "days", "label": "游玩天数", "default": "", "question": "你计划游玩几天?", "required": true},
{"key": "start_date", "label": "出发日期", "default": "", "question": "你的出发日期是?", "required": true},
{"key": "budget_level", "label": "预算档位", "default": "", "question": "你的预算档位是?(低/中/高)", "required": true},
{"key": "travelers", "label": "出行人数", "default": "", "question": "有几人同行?", "required": true},
{"key": "travel_style", "label": "旅行偏好", "default": "", "question": "你偏好哪种旅行风格?(可选自然风光/人文历史/亲子/美食/轻松等,可多选)", "required": true},
{"key": "transport_preference", "label": "交通偏好", "default": "", "question": "你偏好哪种交通方式?(高铁/飞机/自驾/无偏好)", "required": true},
{"key": "hotel_preference", "label": "住宿偏好", "default": "", "question": "你的住宿偏好是?(经济/舒适/高档)", "required": true},
{"key": "must_visit", "label": "必去景点", "default": "", "question": "有哪些必去的景点或地点?", "required": true},
{"key": "avoid", "label": "避开事项", "default": "", "question": "有哪些不想去的地方或忌口/限制?", "required": true}
]19. 修改智能体 #
本章节介绍如何通过 API 修改(更新)现有的智能体(Agent)信息。
功能说明
更新智能体接口允许你根据智能体 ID,批量修改其名称、头像、描述、系统提示词、大模型配置、MCP 服务列表、提问模板和变量等属性。未提交的字段保持原值,提交为 null(None)的字段将被设为 null。接口会验证相关引用数据的有效性(如大模型提供商、模型、MCP 服务是否存在与匹配)。
请求路径
PUT /api/agents/{agent_id}agent_id:要修改的智能体的主键 ID。
请求参数
请求体需为 JSON,字段参考 AgentUpdate 模型,常用字段包括:
| 字段名 | 类型 | 是否可选 | 含义 |
|---|---|---|---|
| avatar | str | 是 | 智能体头像(URL) |
| name | str | 是 | 智能体名称 |
| description | str | 是 | 智能体描述 |
| opening_message | str | 是 | 智能体开场白 |
| system_prompt | str | 是 | 智能体系统提示词 |
| llm_provider_name | str | 是 | 大语言模型提供商(如"深度探索") |
| llm_model_name | str | 是 | LLM模型名称(如"deepseek-chat") |
| mcp_service_ids | list[int] | 是 | 关联的MCP服务ID列表 |
| ask_prompt_template | str | 是 | 提问题模板 |
| ask_variables | list[dict[str, Any]] | 是 | 提问变量配置(每项结构见下方) |
注意:
- 所有字段均为可选,未包含的字段自动保持原值。
- 如要将某字段“清空”,可设为
null或类型的空值。- 只有传递的值才会被更新。
提问变量配置说明
ask_variables 字段为一个列表,每项为一个字典,包含如下常用键:
key: 变量字段名(建议英文/拼音)label: 变量中文名question: 用户交互时的问题required: 是否必填default: 默认值
例如:
[
{ "key": "departure_city", "label": "出发地", "question": "你从哪个城市出发?", "required": true, "default": "" },
{ "key": "destination_city", "label": "目的地", "question": "你要去哪个城市?", "required": true, "default": "" }
]示例
下面是一个完整的 PUT 修改请求示例:
curl --location --request PUT "http://127.0.0.1:8000/api/agents/1" ^
--header "Content-Type: application/json" ^
--data-raw '{
"avatar": "/uploads/your_avatar.png",
"name": "智能体新名称",
"description": "更新后的描述内容",
"opening_message": "新开场白",
"system_prompt": "你是新的智能体系统提示词",
"llm_provider_name": "深度探索",
"llm_model_name": "deepseek-chat",
"mcp_service_ids": [1,2,3],
"ask_prompt_template": "请基于以下信息输出……",
"ask_variables": [
{ "key": "a", "label": "A", "question": "问题A?", "required": true, "default": "" },
{ "key": "b", "label": "B", "question": "问题B?", "required": false, "default": "B1" }
]
}'响应结果
成功时,返回更新后的 Agent 详细信息(结构同创建)。
出现错误时返回标准的 API 错误对象,比如:
- 404:指定 ID 的 Agent 不存在
- 400:大模型提供商、模型、或 MCP 服务引用有误
注意事项
- 更新时会做基础校验和引用合法性校验。
- 提交的 mcp_service_ids 若非整数列表或含无效 ID,会导致 400 错误。
- 字符串字段会自动去除首尾空格;可为 null 的字段可设为 null。
- 部分字段如 ask_variables 支持嵌套结构,注意 JSON 格式和内容完整性。
若需仅更新部分字段,仅提交对应字段即可,其他字段将保持当前值不变。
19.1. agent_repository.py #
app/repositories/agent_repository.py
# 导入Session类用于数据库会话管理
from sqlalchemy.orm import Session
from sqlalchemy import select
# 从app模块导入models和schemas
from app import models
from app import schemas
# 定义创建Agent的函数,参数为数据库会话和Agent创建数据,返回创建后的Agent模型
def create_agent(db: Session, data: schemas.AgentCreate) -> models.Agent:
# 创建Agent模型实例,将前端传入的数据赋值到各个字段
row = models.Agent(
avatar=data.avatar, # 头像
name=data.name.strip(), # 名称,去除首尾空格
description=data.description, # 描述
opening_message=(data.opening_message or "").strip() or None, # 开场白,去除空格,允许为None
system_prompt=data.system_prompt.strip(), # 系统提示,去除空格
llm_provider_name=data.llm_provider_name.strip(), # LLM提供商名称,去除空格
llm_model_name=data.llm_model_name.strip(), # LLM模型名称,去除空格
mcp_service_ids=data.mcp_service_ids, # 关联MCP服务ID列表
ask_prompt_template=(data.ask_prompt_template or "").strip() or None, # 提问提示词模板,去除空格,允许为None
ask_variables=data.ask_variables or [], # 提问变量,默认为空列表
)
# 将Agent实例添加到数据库会话
db.add(row)
# 提交事务,保存到数据库
db.commit()
# 刷新实例,获取数据库生成的最新字段(如自增ID)
db.refresh(row)
# 返回创建后的Agent对象
return row
# 定义list_agents函数,接收一个数据库会话db作为参数,返回Agent对象列表
def list_agents(db: Session) -> list[models.Agent]:
# 使用select语句查询Agent表,并按id倒序排序,获取所有Agent对象
return list(db.scalars(select(models.Agent).order_by(models.Agent.id.desc())).all())
# 定义一个用于更新智能体(Agent)的函数
+def update_agent(db: Session, row: models.Agent, data: schemas.AgentUpdate) -> models.Agent:
# 如果数据中的avatar不为None,则更新头像字段
+ if data.avatar is not None:
+ row.avatar = data.avatar
# 如果数据中的name不为None,则去除空格后更新名称字段
+ if data.name is not None:
+ row.name = data.name.strip()
# 如果数据中的description不为None,则更新描述信息
+ if data.description is not None:
+ row.description = data.description
# 如果数据中的opening_message不为None,则去除空格后更新开场白,允许为None
+ if data.opening_message is not None:
+ row.opening_message = (data.opening_message or "").strip() or None
# 如果数据中的system_prompt不为None,则去除空格后更新系统提示词
+ if data.system_prompt is not None:
+ row.system_prompt = data.system_prompt.strip()
# 如果数据中的llm_provider_name不为None,则去除空格后更新LLM提供商名称
+ if data.llm_provider_name is not None:
+ row.llm_provider_name = data.llm_provider_name.strip()
# 如果数据中的llm_model_name不为None,则去除空格后更新LLM模型名称
+ if data.llm_model_name is not None:
+ row.llm_model_name = data.llm_model_name.strip()
# 如果数据中的mcp_service_ids不为None,则更新关联的MCP服务ID列表
+ if data.mcp_service_ids is not None:
+ row.mcp_service_ids = data.mcp_service_ids
# 如果数据中的ask_prompt_template不为None,则去除空格后更新提问提示词模板,允许为None
+ if data.ask_prompt_template is not None:
+ row.ask_prompt_template = (data.ask_prompt_template or "").strip() or None
# 如果数据中的ask_variables不为None,则更新提问变量配置
+ if data.ask_variables is not None:
+ row.ask_variables = data.ask_variables
# 提交事务保存更改
+ db.commit()
# 刷新实例获取数据库中的最新数据
+ db.refresh(row)
# 返回更新后的Agent对象
+ return row
# 定义get_agent函数,根据给定的agent_id,从数据库中获取对应的Agent对象
# 参数db为数据库会话,agent_id为智能体主键ID
# 若找到对应的Agent,返回Agent对象,否则返回None
+def get_agent(db: Session, agent_id: int) -> models.Agent | None:
+ return db.get(models.Agent, agent_id)19.2. agents.py #
app/routers/agents.py
# 从 fastapi 导入 APIRouter, Depends 以及 HTTPException 异常
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session 会话对象
from sqlalchemy.orm import Session
# 从 app 包分别导入 schemas 模块
from app import schemas
# 导入数据库依赖获取函数
from app.database import get_db
# 导入 agent_repository、llm_repository 和 mcp_repository,分别处理不同的数据操作
from app.repositories import agent_repository, llm_repository, mcp_repository
# 创建一个 APIRouter 实例,设置路由的前缀和标签
router = APIRouter(prefix="/api/agents", tags=["agents"])
# 校验相关引用有效性(大模型提供商、模型、MCP 服务)
def _validate_refs(db: Session, provider_name: str, model_name: str, mcp_service_ids: list[int]) -> None:
# 根据提供商名称查询 LLM 提供商数据
llm_row = llm_repository.get_llm_by_provider_name(db, provider_name.strip())
# 如果没有找到对应的 LLM 提供商则抛出 HTTP 400 异常
if not llm_row:
raise HTTPException(status_code=400, detail=f"大语言模型提供商不存在: {provider_name}")
# 获取该提供商下的所有模型名,并做字符串清洗
llm_models = [str(m or "").strip() for m in (llm_row.model_names or [])]
# 如果传入的模型名称不属于当前提供商的模型,则抛出 HTTP 400 异常
if model_name.strip() not in llm_models:
raise HTTPException(status_code=400, detail=f"模型名称不属于提供商 {provider_name}: {model_name}")
# 遍历所有 mcp_service_ids
for sid in mcp_service_ids:
# 校验每一个 mcp_service 是否存在,如果不存在则抛出 400 异常
if not mcp_repository.get_mcp_service(db, int(sid)):
raise HTTPException(status_code=400, detail=f"MCP 服务不存在: {sid}")
# 定义创建 agent 的接口,POST 请求,响应体为 AgentOut 模型
@router.post("", response_model=schemas.AgentOut)
def create_agent(payload: schemas.AgentCreate, db: Session = Depends(get_db)):
# 调用 _validate_refs 校验 Agent 创建时关联的 LLM 提供商、模型、MCP 服务是否有效
_validate_refs(db, payload.llm_provider_name, payload.llm_model_name, payload.mcp_service_ids)
# 校验通过后,调用 agent_repository 创建新的 Agent,并返回创建结果
return agent_repository.create_agent(db, payload)
# 定义 GET 接口用于获取所有 Agent 列表,响应为 AgentOut 对象列表
@router.get("", response_model=list[schemas.AgentOut])
# 定义 list_agents 视图函数,依赖注入数据库会话 db
def list_agents(db: Session = Depends(get_db)):
# 调用 agent_repository 的 list_agents 方法,获取所有 Agent 数据
return agent_repository.list_agents(db)
# 定义更新指定 agent_id 智能体信息的接口,返回更新后的 AgentOut 响应模型
+@router.put("/{agent_id}", response_model=schemas.AgentOut)
# update_agent 视图函数,接收 agent_id、更新数据 payload,和数据库会话 db
+def update_agent(agent_id: int, payload: schemas.AgentUpdate, db: Session = Depends(get_db)):
# 根据 agent_id 从数据库查询原有的 agent 记录
+ row = agent_repository.get_agent(db, agent_id)
# 如果未查到该记录,则抛出 404 错误,提示“记录不存在”
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 优先使用提交的数据,否则使用原有值,确定最终的 provider 名称
+ provider_name = payload.llm_provider_name if payload.llm_provider_name is not None else row.llm_provider_name
# 优先使用提交的数据,否则使用原有值,确定最终的 model 名称
+ model_name = payload.llm_model_name if payload.llm_model_name is not None else row.llm_model_name
# 优先使用提交的数据,否则使用原有值,确定 mcp_service_ids
+ mcp_service_ids = payload.mcp_service_ids if payload.mcp_service_ids is not None else row.mcp_service_ids
# 校验 provider/model/mcp_service 引用是否合法
+ _validate_refs(db, provider_name, model_name, mcp_service_ids)
# 调用仓库方法更新 agent 记录并返回更新结果
+ return agent_repository.update_agent(db, row, payload)
19.3. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
@field_validator("config", mode="before")
@classmethod
def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
# 返回config原值
return v
# 定义McpTestResult模型,继承自BaseModel
class McpTestResult(BaseModel):
# 测试是否成功的标志
ok: bool
# 返回的信息或说明
message: str
# 工具列表,默认为空列表
tools: list[dict[str, Any]] = Field(default_factory=list)
# 定义LlmModelBase基类,用于存储大模型相关信息
class LlmModelBase(BaseModel):
# 提供商名称,必填,最小长度1,最大长度255
provider_name: str = Field(..., min_length=1, max_length=255)
# 提供商图标,可选
provider_icon: str | None = None
# API基础地址,必填,最小长度1,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必填,最小长度1,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# API密钥申请地址,可选,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,默认为空列表
model_names: list[str] = Field(default_factory=list)
# 对model_names字段进行校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names(cls, v: list[str]) -> list[str]:
# 定义输出列表
out: list[str] = []
# 用于存储已经出现过的模型名称(小写形式)以去重
seen: set[str] = set()
# 遍历输入值,如果为空则用空列表
for item in v or []:
# 将每个名称转为字符串并去除首尾空格
name = str(item or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经见过该名称,则跳过
if key in seen:
continue
# 添加到已见集合和输出列表
seen.add(key)
out.append(name)
# 返回归一化后的模型名称列表
return out
# 定义LlmModelCreate模型,继承自LlmModelBase,没有扩展字段
class LlmModelCreate(LlmModelBase):
pass
# 定义 LlmModelOut 类,继承自 BaseModel,用于大模型的返回数据结构
class LlmModelOut(BaseModel):
# 唯一ID,类型为整数
id: int
# 提供商名称,字符串类型
provider_name: str
# 提供商图标,可选,字符串或None
provider_icon: str | None
# API 基础地址,字符串类型
api_base_url: str
# API 密钥,字符串类型
api_key: str
# API 密钥申请地址,可选,字符串或None
api_key_url: str | None
# 模型名称列表,字符串列表类型
model_names: list[str]
# Pydantic 配置,启用从属性赋值(用于ORM模式)
model_config = {"from_attributes": True}
# 定义用于大模型服务测试请求体的Pydantic模型
class LlmModelTestRequest(BaseModel):
# API基础地址,必须为非空字符串,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必须为非空字符串,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# 定义 LlmModelTestResult 类,继承自 BaseModel,用于返回大模型服务检测的结果
class LlmModelTestResult(BaseModel):
# 检测是否通过,布尔类型
ok: bool
# 检测的信息提示,字符串类型
message: str
# 检测到的模型名称列表,默认为空列表
models: list[str] = Field(default_factory=list)
# 定义用于更新大模型信息的Pydantic模型
class LlmModelUpdate(BaseModel):
# 提供商名称,可选字段,限定最小长度1,最大长度255
provider_name: str | None = Field(None, min_length=1, max_length=255)
# 提供商图标,可选字段
provider_icon: str | None = None
# API 基础地址,可选字段,限定最小长度1,最大长度1024
api_base_url: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥,可选字段,限定最小长度1,最大长度1024
api_key: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥申请地址,可选字段,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,可选字段
model_names: list[str] | None = None
# 对model_names字段做校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names_optional(cls, v: list[str] | None) -> list[str] | None:
# 如果为None,直接返回None
if v is None:
return None
# 用于存储归一化后的模型名称
out: list[str] = []
# 用于去重
seen: set[str] = set()
# 遍历输入列表
for item in v:
# 转为字符串并去除首尾空白
name = str(item or "").strip()
# 如果名称为空则跳过
if not name:
continue
# 使用小写做去重key
key = name.lower()
# 如果已经见过则跳过
if key in seen:
continue
# 添加进已见集合
seen.add(key)
# 添加进输出列表
out.append(name)
# 返回归一化去重后的列表
return out
# 定义函数,用于规范化 ask_variables 字段,输入为可选的字典列表,返回规范化后的字典列表
def _normalize_ask_variables(v):
# 用于保存处理后的变量字典
out = []
# 用于记录已出现过的变量key,实现去重
seen = set()
# 遍历输入列表(如果为None则转为空列表)
for item in v or []:
# 如果当前元素不是字典类型,则跳过
if not isinstance(item, dict):
continue
# 从字典中获取key字段,去除首尾空白转字符串
key = str(item.get("key") or "").strip()
# 如果key为空,则跳过本轮
if not key:
continue
# 如果key已出现,则跳过实现去重
if key in seen:
continue
# 将当前key加入去重集合
seen.add(key)
# 获取question字段、去空白
question = str(item.get("question") or "").strip()
# 获取label字段、去空白
label = str(item.get("label") or "").strip()
# 获取default字段、去空白
default_value = str(item.get("default") or "").strip()
# 获取required字段,默认为True
required = bool(item.get("required", True))
# 将变量信息归一化后放入输出列表
out.append(
{
"key": key,
"label": label,
"question": question or f"请提供 {label or key}",
"required": required,
"default": default_value,
}
)
# 返回归一化和去重后的变量列表
return out
# 定义 AgentBase 基础模型,继承自 Pydantic 的 BaseModel
class AgentBase(BaseModel):
# 头像字段,可为空
avatar: str | None = None
# 智能体名称,必填,长度1~255
name: str = Field(..., min_length=1, max_length=255)
# 描述信息,可为空
description: str | None = None
# 开场白内容,可为空
opening_message: str | None = None
# 智能体系统提示,必填,最小长度1
system_prompt: str = Field(..., min_length=1)
# LLM 提供商名称,必填,长度1~255
llm_provider_name: str = Field(..., min_length=1, max_length=255)
# LLM 模型名称,必填,长度1~255
llm_model_name: str = Field(..., min_length=1, max_length=255)
# 关联的 MCP 服务ID列表,默认为空列表
mcp_service_ids: list[int] = Field(default_factory=list)
# 询问提示词模板,可为空
ask_prompt_template: str | None = None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# 对 mcp_service_ids 字段做归一化校验
@field_validator("mcp_service_ids")
@classmethod
def normalize_mcp_service_ids(cls, v):
# 存储去重后的有效 id
out = []
# 已经出现过的 id 集合
seen = set()
# 遍历 id 列表(防止为 None)
for item in v or []:
# 转成整数类型
num = int(item)
# 跳过小于等于0的无效 id
if num <= 0:
continue
# 跳过重复 id
if num in seen:
continue
# 添加到去重集合
seen.add(num)
# 添加到输出集合
out.append(num)
# 返回整理后的 id 列表
return out
# 对 ask_variables 字段做归一化校验
@field_validator("ask_variables")
@classmethod
def normalize_ask_variables(cls, v):
# 使用 _normalize_ask_variables 函数处理
return _normalize_ask_variables(v)
# 定义 AgentCreate 创建模型,继承自 AgentBase,无额外字段
class AgentCreate(AgentBase):
pass
# 定义AgentOut响应模型,继承自BaseModel
class AgentOut(BaseModel):
# 主键ID
id: int
# 头像,允许为None
avatar: str | None
# 智能体名称
name: str
# 智能体描述,允许为None
description: str | None
# 开场白,允许为None
opening_message: str | None
# 智能体系统提示,不可为None
system_prompt: str
# LLM提供商名称
llm_provider_name: str
# LLM模型名称
llm_model_name: str
# 关联的MCP服务ID列表
mcp_service_ids: list[int]
# 询问提示词模板,允许为None
ask_prompt_template: str | None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# Pydantic配置:允许根据对象属性创建模型(ORM模式)
model_config = {"from_attributes": True}
# 定义AgentUpdate模型,用于部分更新Agent,继承自Pydantic的BaseModel
+class AgentUpdate(BaseModel):
# 头像字段,允许为None
+ avatar: str | None = None
# 名称字段,允许为None,若不为None则要求长度1~255
+ name: str | None = Field(None, min_length=1, max_length=255)
# 描述字段,允许为None
+ description: str | None = None
# 开场白字段,允许为None
+ opening_message: str | None = None
# 系统提示词字段,允许为None,若不为None则要求最小长度1
+ system_prompt: str | None = Field(None, min_length=1)
# LLM提供商名称,允许为None,若不为None长度1~255
+ llm_provider_name: str | None = Field(None, min_length=1, max_length=255)
# LLM模型名称,允许为None,若不为None长度1~255
+ llm_model_name: str | None = Field(None, min_length=1, max_length=255)
# 关联的MCP服务ID列表,允许为None
+ mcp_service_ids: list[int] | None = None
# 询问提示词模板,允许为None
+ ask_prompt_template: str | None = None
# 询问变量列表,允许为None
+ ask_variables: list[dict[str, Any]] | None = None
# 对mcp_service_ids字段进行校验和去重,允许为None
+ @field_validator("mcp_service_ids")
+ @classmethod
+ def normalize_mcp_service_ids_optional(cls, v: list[int] | None) -> list[int] | None:
# 如果为None,直接返回None
+ if v is None:
+ return None
# 定义用于存放合法、去重后的ID的列表
+ out: list[int] = []
# 定义用于去重的集合
+ seen: set[int] = set()
# 遍历传入的每个ID
+ for item in v:
# 转为整数
+ num = int(item)
# 跳过小于等于0的无效ID
+ if num <= 0:
+ continue
# 跳过重复ID
+ if num in seen:
+ continue
# 加入去重集合
+ seen.add(num)
# 加入输出列表
+ out.append(num)
# 返回整理后的ID列表
+ return out
# 对ask_variables字段进行归一化与合法性检查,允许为None
+ @field_validator("ask_variables")
+ @classmethod
+ def normalize_ask_variables_optional(cls, v: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
# 如果为None直接返回None
+ if v is None:
+ return None
# 使用_normalize_ask_variables工具函数归一化处理
+ return _normalize_ask_variables(v) 19.4 测试 #
curl --location --request PUT "http://127.0.0.1:8000/api/agents/1" ^
--header "Content-Type: application/json" ^
--data-raw "{ \"avatar\": \"/uploads/b20afbe9728742579b53d003ab7a7008.png\", \"name\": \"旅游规划智能体2\", \"description\": \"面向中文出行场景的一站式旅游规划助手:可先问清出发地、目的地、行程与偏好,再结合地点检索、路线规划与天气预报等能力,输出可执行的行程摘要、逐日安排、交通与住宿建议及注意事项。\", \"opening_message\": \"你好,我是你的旅游规划智能体 👋 \\n我可以帮你一站式完成:**去哪玩、怎么去、天气如何、是否适合出行**。\\n\\n为了给你生成一份可执行的旅行方案,我会先快速确认几个关键信息(如出发地、目的地、天数、预算、人数和偏好),然后再结合工具结果给你:\\n1. 结论摘要 \\n2. 逐日行程 \\n3. 路线与交通建议 \\n4. 天气与出行提醒 \\n5. 预算与注意事项\\n\\n我们现在开始吧,我先问你第一个问题。\", \"system_prompt\": \"你是“旅游规划智能体”,目标是帮助用户完成从“去哪玩”到“怎么去、天气如何、是否适合出行”的一站式决策。\\n\\n【核心职责】\\n1. 根据用户需求推荐地点(景点/酒店/餐饮等)。\\n2. 提供路线规划(自驾或公共交通),并给出关键出行建议。\\n3. 提供目的地未来天气预报,结合天气给出出行提醒。\\n4. 回答必须真实、可执行、结构清晰,不夸大、不编造。\\n\\n【工具使用规则】\\n- 与地点检索相关:调用 search_place。\\n- 与路线规划相关:调用 plan_route。\\n- 与天气相关:调用 get_travel_forecast。\\n- 只要问题涉及“地点推荐、路线、天气”中的任意一项,应优先调用工具,不要凭空臆测。\\n- 若用户信息不足(如缺少出发地、目的地、天数、偏好),先用 1-2 个关键问题补齐后再调用工具。\\n- 当工具调用失败或结果为空时,说明原因并给出下一步可操作建议(例如改关键词、改地区、改出行方式)。\\n\\n【对话策略】\\n- 用户目标优先:先判断用户是“选目的地”“查路线”“看天气”还是“完整行程规划”。\\n- 组合调用:\\n - 完整规划时,默认按“地点 -> 路线 -> 天气”顺序调用。\\n - 若用户已指定地点,可直接“路线 + 天气”。\\n- 对用户提到“高铁/动车/地铁/公交/飞机”等,路线优先使用公共交通思路。\\n- 对用户提到“不走高速”等偏好,路线建议中明确体现。\\n\\n【输出风格】\\n- 默认中文,简洁专业,先结论后细节。\\n- 输出结构建议:\\n 1) 结论摘要(1-3 行)\\n 2) 推荐方案(地点/路线/天气)\\n 3) 注意事项(预算、天气、时段、备选)\\n- 涉及时间、温度、里程、时长时,尽量保留工具结果中的关键数据。\\n- 不输出工具内部实现细节,不暴露密钥、请求头等敏感信息。\\n\\n【安全与边界】\\n- 不编造未获取到的数据;不确定时明确说“不确定”并建议重新检索。\\n- 医疗、法律、政策等高风险问题仅给一般性建议并提示咨询专业机构。\\n- 用户要求与旅游无关的内容时,礼貌回应并引导回到旅游场景。\\n\\n现在开始:优先理解用户旅行意图,并按需调用工具给出可执行建议。\", \"ask_prompt_template\": \"你是一名资深中文旅行规划师。请基于以下用户信息,输出一份可执行、具体、务实的旅行方案。\\n\\n【用户信息】\\n- 出发地:{{出发地}}\\n- 目的地:{{目的地}}\\n- 游玩天数:{{游玩天数}} 天\\n- 出发日期:{{出发日期}}\\n- 预算档位:{{预算档位}}\\n- 出行人数:{{出行人数}}\\n- 出行偏好:{{出行偏好}}\\n- 交通偏好:{{交通偏好}}\\n- 住宿偏好:{{住宿偏好}}\\n- 必去点:{{必去点}}\\n- 避开项:{{避开项}}\\n\\n【输出要求】\\n1) 先给“总览结论”(3-5条)\\n2) 再给逐日行程(Day1 ~ DayN),每一天包含:\\n - 上午/下午/晚上安排\\n - 核心景点与停留时长建议\\n - 餐饮建议(当地特色)\\n3) 给交通方案对比(至少2种:时间、成本、优缺点)\\n4) 给住宿区域建议(2-3个片区,适合人群、预算)\\n5) 给预算拆分(交通/住宿/餐饮/门票)\\n6) 给注意事项(天气、穿衣、预约、避坑)\\n7) 如果信息不足,先明确列出“仍需补充的信息”,再给“可先执行的临时方案”。\\n\\n请使用清晰的 Markdown 格式输出,标题层级明确,内容尽量具体到可直接执行。\", \"ask_variables\": [ { \"key\": \"出发地\", \"label\": \"出发地\", \"question\": \"你从哪个城市出发?\", \"required\": true, \"default\": \"\" }, { \"key\": \"目的地\", \"label\": \"目的地\", \"question\": \"你要去哪个城市?\", \"required\": true, \"default\": \"\" }, { \"key\": \"游玩天数\", \"label\": \"游玩天数\", \"question\": \"你计划游玩几天?\", \"required\": true, \"default\": \"\" }, { \"key\": \"出发日期\", \"label\": \"出发日期\", \"question\": \"你的出发日期是?\", \"required\": true, \"default\": \"\" }, { \"key\": \"预算档位\", \"label\": \"预算档位\", \"question\": \"你的预算档位是?(低/中/高)\", \"required\": true, \"default\": \"\" }, { \"key\": \"出行人数\", \"label\": \"出行人数\", \"question\": \"有几人同行?\", \"required\": true, \"default\": \"\" }, { \"key\": \"出行偏好\", \"label\": \"出行偏好\", \"question\": \"你偏好哪种旅行风格?(可选自然风光/人文历史/亲子/美食/轻松等,可多选)\", \"required\": true, \"default\": \"\" }, { \"key\": \"交通偏好\", \"label\": \"交通偏好\", \"question\": \"你偏好哪种交通方式?(高铁/飞机/自驾/无偏好)\", \"required\": true, \"default\": \"\" }, { \"key\": \"住宿偏好\", \"label\": \"住宿偏好\", \"question\": \"你的住宿偏好是?(经济/舒适/高档)\", \"required\": true, \"default\": \"\" }, { \"key\": \"必去点\", \"label\": \"必去点\", \"question\": \"有哪些必去的景点或地点?\", \"required\": true, \"default\": \"\" }, { \"key\": \"避开项\", \"label\": \"避开项\", \"question\": \"有哪些不想去的地方或忌口/限制?\", \"required\": true, \"default\": \"\" } ], \"llm_provider_name\": \"深度探索\", \"llm_model_name\": \"deepseek-chat\", \"mcp_service_ids\": [ 5, 4, 3 ]}"20. 删除智能体 #
在本节中,我们将实现“删除智能体(Agent)”的接口。该接口允许通过指定 agent_id,从数据库中删除对应的智能体,并返回操作结果。
agent_repository.p
在 app/repositories/agent_repository.py 中,已经增加了删除智能体的核心方法:
def delete_agent(db: Session, row: models.Agent) -> None:
# 从数据库会话中删除指定的Agent对象
db.delete(row)
# 提交事务,保存删除操作
db.commit()该方法接收数据库会话和一个已获取的 Agent 实例对象,并将其从数据库中删除。
路由实现
在相应的 FastAPI 路由层(通常在 app/api/routers/agent.py 或类似文件中),提供了删除接口定义:
@router.delete("/{agent_id}")
def delete_agent(agent_id: int, db: Session = Depends(get_db)):
# 根据 agent_id 查询 agent
row = agent_repository.get_agent(db, agent_id)
# 如果查询不到,返回 404
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库方法执行删除
agent_repository.delete_agent(db, row)
# 返回操作成功的响应
return {"ok": True}此接口会先验证 agent 记录是否存在,如果不存在,直接返回 404 错误;如果存在,调用仓库方法删除,并返回一个简单的 JSON 成功响应。
测试方式
可以通过 curl 命令行或 Postman 等工具调用该接口进行测试,例如:
curl --location --request DELETE "http://127.0.0.1:8000/api/agents/2" \
--header "Content-Type: application/json"如果删除成功,将返回:
{"ok": true}如果指定的 agent_id 不存在,返回如下错误:
{
"detail": "记录不存在"
}通过以上步骤,即可实现智能体的删除功能,保持了和其他 Agent 管理接口一致的设计风格和异常处理方式。
20.1. agent_repository.py #
app/repositories/agent_repository.py
# 导入Session类用于数据库会话管理
from sqlalchemy.orm import Session
from sqlalchemy import select
# 从app模块导入models和schemas
from app import models
from app import schemas
# 定义创建Agent的函数,参数为数据库会话和Agent创建数据,返回创建后的Agent模型
def create_agent(db: Session, data: schemas.AgentCreate) -> models.Agent:
# 创建Agent模型实例,将前端传入的数据赋值到各个字段
row = models.Agent(
avatar=data.avatar, # 头像
name=data.name.strip(), # 名称,去除首尾空格
description=data.description, # 描述
opening_message=(data.opening_message or "").strip() or None, # 开场白,去除空格,允许为None
system_prompt=data.system_prompt.strip(), # 系统提示,去除空格
llm_provider_name=data.llm_provider_name.strip(), # LLM提供商名称,去除空格
llm_model_name=data.llm_model_name.strip(), # LLM模型名称,去除空格
mcp_service_ids=data.mcp_service_ids, # 关联MCP服务ID列表
ask_prompt_template=(data.ask_prompt_template or "").strip() or None, # 提问提示词模板,去除空格,允许为None
ask_variables=data.ask_variables or [], # 提问变量,默认为空列表
)
# 将Agent实例添加到数据库会话
db.add(row)
# 提交事务,保存到数据库
db.commit()
# 刷新实例,获取数据库生成的最新字段(如自增ID)
db.refresh(row)
# 返回创建后的Agent对象
return row
# 定义list_agents函数,接收一个数据库会话db作为参数,返回Agent对象列表
def list_agents(db: Session) -> list[models.Agent]:
# 使用select语句查询Agent表,并按id倒序排序,获取所有Agent对象
return list(db.scalars(select(models.Agent).order_by(models.Agent.id.desc())).all())
# 定义一个用于更新智能体(Agent)的函数
def update_agent(db: Session, row: models.Agent, data: schemas.AgentUpdate) -> models.Agent:
# 如果数据中的avatar不为None,则更新头像字段
if data.avatar is not None:
row.avatar = data.avatar
# 如果数据中的name不为None,则去除空格后更新名称字段
if data.name is not None:
row.name = data.name.strip()
# 如果数据中的description不为None,则更新描述信息
if data.description is not None:
row.description = data.description
# 如果数据中的opening_message不为None,则去除空格后更新开场白,允许为None
if data.opening_message is not None:
row.opening_message = (data.opening_message or "").strip() or None
# 如果数据中的system_prompt不为None,则去除空格后更新系统提示词
if data.system_prompt is not None:
row.system_prompt = data.system_prompt.strip()
# 如果数据中的llm_provider_name不为None,则去除空格后更新LLM提供商名称
if data.llm_provider_name is not None:
row.llm_provider_name = data.llm_provider_name.strip()
# 如果数据中的llm_model_name不为None,则去除空格后更新LLM模型名称
if data.llm_model_name is not None:
row.llm_model_name = data.llm_model_name.strip()
# 如果数据中的mcp_service_ids不为None,则更新关联的MCP服务ID列表
if data.mcp_service_ids is not None:
row.mcp_service_ids = data.mcp_service_ids
# 如果数据中的ask_prompt_template不为None,则去除空格后更新提问提示词模板,允许为None
if data.ask_prompt_template is not None:
row.ask_prompt_template = (data.ask_prompt_template or "").strip() or None
# 如果数据中的ask_variables不为None,则更新提问变量配置
if data.ask_variables is not None:
row.ask_variables = data.ask_variables
# 提交事务保存更改
db.commit()
# 刷新实例获取数据库中的最新数据
db.refresh(row)
# 返回更新后的Agent对象
return row
# 定义get_agent函数,根据给定的agent_id,从数据库中获取对应的Agent对象
# 参数db为数据库会话,agent_id为智能体主键ID
# 若找到对应的Agent,返回Agent对象,否则返回None
def get_agent(db: Session, agent_id: int) -> models.Agent | None:
return db.get(models.Agent, agent_id)
# 定义删除智能体(Agent)的方法,接收数据库会话和待删除的Agent对象
+def delete_agent(db: Session, row: models.Agent) -> None:
# 从数据库会话中删除指定的Agent对象
+ db.delete(row)
# 提交事务,保存删除操作
+ db.commit() 20.2. agents.py #
app/routers/agents.py
# 从 fastapi 导入 APIRouter, Depends 以及 HTTPException 异常
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session 会话对象
from sqlalchemy.orm import Session
# 从 app 包分别导入 schemas 模块
from app import schemas
# 导入数据库依赖获取函数
from app.database import get_db
# 导入 agent_repository、llm_repository 和 mcp_repository,分别处理不同的数据操作
from app.repositories import agent_repository, llm_repository, mcp_repository
# 创建一个 APIRouter 实例,设置路由的前缀和标签
router = APIRouter(prefix="/api/agents", tags=["agents"])
# 校验相关引用有效性(大模型提供商、模型、MCP 服务)
def _validate_refs(db: Session, provider_name: str, model_name: str, mcp_service_ids: list[int]) -> None:
# 根据提供商名称查询 LLM 提供商数据
llm_row = llm_repository.get_llm_by_provider_name(db, provider_name.strip())
# 如果没有找到对应的 LLM 提供商则抛出 HTTP 400 异常
if not llm_row:
raise HTTPException(status_code=400, detail=f"大语言模型提供商不存在: {provider_name}")
# 获取该提供商下的所有模型名,并做字符串清洗
llm_models = [str(m or "").strip() for m in (llm_row.model_names or [])]
# 如果传入的模型名称不属于当前提供商的模型,则抛出 HTTP 400 异常
if model_name.strip() not in llm_models:
raise HTTPException(status_code=400, detail=f"模型名称不属于提供商 {provider_name}: {model_name}")
# 遍历所有 mcp_service_ids
for sid in mcp_service_ids:
# 校验每一个 mcp_service 是否存在,如果不存在则抛出 400 异常
if not mcp_repository.get_mcp_service(db, int(sid)):
raise HTTPException(status_code=400, detail=f"MCP 服务不存在: {sid}")
# 定义创建 agent 的接口,POST 请求,响应体为 AgentOut 模型
@router.post("", response_model=schemas.AgentOut)
def create_agent(payload: schemas.AgentCreate, db: Session = Depends(get_db)):
# 调用 _validate_refs 校验 Agent 创建时关联的 LLM 提供商、模型、MCP 服务是否有效
_validate_refs(db, payload.llm_provider_name, payload.llm_model_name, payload.mcp_service_ids)
# 校验通过后,调用 agent_repository 创建新的 Agent,并返回创建结果
return agent_repository.create_agent(db, payload)
# 定义 GET 接口用于获取所有 Agent 列表,响应为 AgentOut 对象列表
@router.get("", response_model=list[schemas.AgentOut])
# 定义 list_agents 视图函数,依赖注入数据库会话 db
def list_agents(db: Session = Depends(get_db)):
# 调用 agent_repository 的 list_agents 方法,获取所有 Agent 数据
return agent_repository.list_agents(db)
# 定义更新指定 agent_id 智能体信息的接口,返回更新后的 AgentOut 响应模型
@router.put("/{agent_id}", response_model=schemas.AgentOut)
# update_agent 视图函数,接收 agent_id、更新数据 payload,和数据库会话 db
def update_agent(agent_id: int, payload: schemas.AgentUpdate, db: Session = Depends(get_db)):
# 根据 agent_id 从数据库查询原有的 agent 记录
row = agent_repository.get_agent(db, agent_id)
# 如果未查到该记录,则抛出 404 错误,提示“记录不存在”
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 优先使用提交的数据,否则使用原有值,确定最终的 provider 名称
provider_name = payload.llm_provider_name if payload.llm_provider_name is not None else row.llm_provider_name
# 优先使用提交的数据,否则使用原有值,确定最终的 model 名称
model_name = payload.llm_model_name if payload.llm_model_name is not None else row.llm_model_name
# 优先使用提交的数据,否则使用原有值,确定 mcp_service_ids
mcp_service_ids = payload.mcp_service_ids if payload.mcp_service_ids is not None else row.mcp_service_ids
# 校验 provider/model/mcp_service 引用是否合法
_validate_refs(db, provider_name, model_name, mcp_service_ids)
# 调用仓库方法更新 agent 记录并返回更新结果
return agent_repository.update_agent(db, row, payload)
# 定义一个用于删除指定 agent_id 智能体信息的接口,HTTP 方法为 DELETE
+@router.delete("/{agent_id}")
# delete_agent 视图函数,接收 agent_id 与数据库会话 db
+def delete_agent(agent_id: int, db: Session = Depends(get_db)):
# 根据 agent_id 从数据库查询对应的 agent 记录
+ row = agent_repository.get_agent(db, agent_id)
# 如果没有查到对应记录,则抛出 404 异常,并提示“记录不存在”
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库方法删除该 agent 记录
+ agent_repository.delete_agent(db, row)
# 返回操作成功的响应
+ return {"ok": True}20.3 测试 #
curl --location --request DELETE "http://127.0.0.1:8000/api/agents/2" ^
--header "Content-Type: application/json"21. 对话 #
本节介绍与智能体对话相关的数据结构、API 接口以及主要的业务逻辑实现。
数据模型说明
对话相关的核心数据模型定义在 app/schemas.py 文件中,包括会话和消息的输入/输出模型:
AgentChatSessionOut:用于输出单个会话信息,包括id,agent_id,title,created_at,updated_at等字段。AgentChatSessionCreate:用于新建会话时的输入数据,只包含可选的title。AgentChatMessageOut:用于输出单个对话消息,包含id,session_id,role,content,meta,created_at等信息。AgentChatMessageCreate:用于新建对话消息的输入,包括role,content及可选meta。
仓库方法(Repository Methods)
所有与会话和会话消息相关的数据增删查改方法定义在 app/repositories/agent_chat_repository.py。主要方法包括:
list_sessions(db, agent_id):获取指定 agent 下的所有对话会话(按更新时间倒序排列)。create_session(db, agent_id, title):新建一个对话会话。get_session(db, session_id):按 ID 获取会话对象(无则返回 None)。list_messages(db, session_id):获取会话下的所有消息(升序排列)。create_message(db, session_id, role, content, meta):在指定会话下新增消息。delete_session(db, row):删除会话及该会话的所有消息。
API 路由接口
相关 API 接口定义在 app/routers/agent_chat.py 路由文件,具体包括:
获取某 Agent 下的所有会话
GET /api/agents/{agent_id}/chat-sessions 返回: list[AgentChatSessionOut]创建新的会话
POST /api/agents/{agent_id}/chat-sessions Body: AgentChatSessionCreate 返回: AgentChatSessionOut获取会话的所有消息
GET /api/agent-chat-sessions/{session_id}/messages 返回: list[AgentChatMessageOut]向会话发送消息
POST /api/agent-chat-sessions/{session_id}/messages Body: AgentChatMessageCreate 返回: AgentChatMessageOut删除会话及其所有消息
DELETE /api/agent-chat-sessions/{session_id} 返回: {"ok": True}
接口在实现时均会 check agent/session 是否存在,不存在时返回 404,并对 role 字段做值校验(仅支持 user、assistant、tool)。
常见对话 API 使用示例
1. 新建会话
curl --location --request POST "http://127.0.0.1:8000/api/agents/{agent_id}/chat-sessions" ^
--header "Content-Type: application/json" ^
--data-raw "{\"title\": \"新的聊天\"}"2. 增加对话消息
curl --location --request POST "http://127.0.0.1:8000/api/agent-chat-sessions/{session_id}/messages" ^
--header "Content-Type: application/json" ^
--data-raw "{\"role\": \"user\", \"content\": \"你好\"}"3. 获取某会话下所有消息
curl --location --request GET "http://127.0.0.1:8000/api/agent-chat-sessions/{session_id}/messages"4. 删除对话会话(及其历史)
curl --location --request DELETE "http://127.0.0.1:8000/api/agent-chat-sessions/{session_id}"设计要点说明
- 对话会话(Session)和消息(Message)为一对多关系
- 所有 API 调用均通过依赖注入获得数据库会话,保证事务安全
- 会话标题可自定义,若为空自动填充“新对话”
- 消息支持可扩展 meta 字段,结构灵活
- 删除会话时会级联删除该会话下的所有消息
21.1. agent_chat_repository.py #
app/repositories/agent_chat_repository.py
# 导入 SQLAlchemy 的 func 和 select 方法用于数据库查询
from sqlalchemy import func, select
# 导入 SQLAlchemy 的 Session 用于数据库会话管理
from sqlalchemy.orm import Session
# 从 app 包导入 models 模块,用于数据库模型操作
from app import models
# 定义一个函数,根据 agent_id 查询对应的所有 AgentChatSession 记录
def list_sessions(db: Session, agent_id: int) -> list[models.AgentChatSession]:
# 构造一个查询语句,筛选 agent_id 等于传入参数的聊天会话
stmt = select(models.AgentChatSession).where(models.AgentChatSession.agent_id == agent_id)
# 按照 updated_at 字段倒序、然后 id 字段倒序排序,确保最新的会话排在前面
stmt = stmt.order_by(models.AgentChatSession.updated_at.desc(), models.AgentChatSession.id.desc())
# 执行查询,将所有结果转换为列表并返回
return list(db.scalars(stmt).all())
# 定义创建聊天会话的方法,传入数据库会话 db、智能体ID agent_id、会话标题 title(可选,默认为 None)
def create_session(db: Session, agent_id: int, title: str | None = None) -> models.AgentChatSession:
# 创建 AgentChatSession 实例,指定 agent_id 和标题(默认为‘新对话’,去除空格后为空也用‘新对话’)
row = models.AgentChatSession(
agent_id=agent_id,
title=(title or "新对话").strip() or "新对话",
)
# 将新建的会话对象添加到数据库 session 中,准备写入数据库
db.add(row)
# 提交事务,将新添加的会话保存到数据库
db.commit()
# 刷新 row 对象,确保获取数据库自动生成的字段(如主键、时间等)的最新值
db.refresh(row)
# 返回新创建的会话对象
return row
# 定义一个函数,通过 session_id 获取指定的 AgentChatSession 记录
def get_session(db: Session, session_id: int) -> models.AgentChatSession | None:
# 调用 db.get 方法,根据主键 session_id 查询 AgentChatSession,如果不存在则返回 None
return db.get(models.AgentChatSession, session_id)
# 定义一个函数,根据会话ID查询对应的所有聊天消息,并按消息ID升序排序
def list_messages(db: Session, session_id: int) -> list[models.AgentChatMessage]:
# 构造查询语句,只筛选session_id为指定值的消息
stmt = select(models.AgentChatMessage).where(models.AgentChatMessage.session_id == session_id)
# 按id字段升序排列消息,确保按先后顺序返回
stmt = stmt.order_by(models.AgentChatMessage.id.asc())
# 执行查询并返回所有结果转换为列表
return list(db.scalars(stmt).all())
# 定义一个函数用于创建 AgentChatMessage 消息记录
def create_message(
db: Session, # 数据库会话对象
session_id: int, # 聊天会话ID
role: str, # 消息角色(如 user、assistant、tool)
content: str, # 消息正文内容
meta: dict | None = None, # 消息附加元数据,默认为None
) -> models.AgentChatMessage:
# 创建 AgentChatMessage 实例,去除角色两端空白,meta为None时使用空字典
row = models.AgentChatMessage(
session_id=session_id,
role=role.strip(),
content=content,
meta=meta or {},
)
# 将新消息加入数据库会话
db.add(row)
# 提交事务,将消息写入数据库
db.commit()
# 刷新对象,获取数据库自动生成的字段(如id、创建时间等)的最新值
db.refresh(row)
# 返回新建的消息对象
return row
# 定义删除指定会话及其所有消息的函数
def delete_session(db: Session, row: models.AgentChatSession) -> None:
# 调用 list_messages 获取该会话下的所有消息
msgs = list_messages(db, row.id)
# 遍历所有消息,逐条从数据库中删除
for item in msgs:
db.delete(item)
# 删除会话记录本身
db.delete(row)
# 提交事务,保存删除操作到数据库
db.commit()
21.2. agent_chat.py #
app/routers/agent_chat.py
# 从 fastapi 导入 APIRouter、Depends 和 HTTPException,用于路由定义和依赖注入及异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session,用于数据库会话管理
from sqlalchemy.orm import Session
# 从 app.repositories 导入 agent_chat_repository 和 agent_repository,用于数据持久层访问
from app.repositories import agent_chat_repository, agent_repository
# 从 app 导入 schemas,用于数据模型
from app import schemas
# 从 app.database 导入 get_db,用于获取数据库会话
from app.database import get_db
# 创建 APIRouter 实例,并设置 tags 标签为 "agent-chat"
router = APIRouter(tags=["agent-chat"])
# 声明 GET 接口,路径为 /api/agents/{agent_id}/chat-sessions,返回值为 AgentChatSessionOut 数据模型列表
@router.get("/api/agents/{agent_id}/chat-sessions", response_model=list[schemas.AgentChatSessionOut])
# 定义 list_sessions 视图函数,接收 agent_id 作为路径参数,db 为依赖注入的 Session 对象
def list_sessions(agent_id: int, db: Session = Depends(get_db)):
# 先检查数据库里是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,查询所有该 agent 下的对话会话,并将结果返回
return agent_chat_repository.list_sessions(db, agent_id)
# 声明 POST 路由,用于创建新的智能体聊天会话,响应模型为 AgentChatSessionOut
@router.post("/api/agents/{agent_id}/chat-sessions", response_model=schemas.AgentChatSessionOut)
# 定义 create_session 函数,接收 agent_id、会话创建载荷 payload、依赖注入的数据库会话 db
def create_session(agent_id: int, payload: schemas.AgentChatSessionCreate, db: Session = Depends(get_db)):
# 先判断数据库中是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,则调用 agent_chat_repository 创建会话,传入数据库会话、智能体ID和会话标题,返回结果
return agent_chat_repository.create_session(db, agent_id, payload.title)
# 声明一个 GET 路由,根据 session_id 获取会话消息,返回值为 AgentChatMessageOut 列表
@router.get("/api/agent-chat-sessions/{session_id}/messages", response_model=list[schemas.AgentChatMessageOut])
# 定义 list_messages 视图函数,接收 session_id 和数据库会话 db(依赖注入方式获取)
def list_messages(session_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取指定会话 session 的数据库对象
row = agent_chat_repository.get_session(db, session_id)
# 如果没有找到对应会话,则抛出 404 异常,并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 如果会话存在,则调用仓库方法获取所有消息并返回
return agent_chat_repository.list_messages(db, session_id)
# 定义一个 POST 路由,路径为 /api/agent-chat-sessions/{session_id}/messages,响应体为 AgentChatMessageOut 模型
@router.post("/api/agent-chat-sessions/{session_id}/messages", response_model=schemas.AgentChatMessageOut)
# 定义 create_message 视图函数,参数为 session_id、payload(通过 Pydantic 校验的消息数据),db 为依赖注入的数据库会话
def create_message(session_id: int, payload: schemas.AgentChatMessageCreate, db: Session = Depends(get_db)):
# 调用 agent_chat_repository.get_session 检查指定 session_id 的会话是否存在
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 获取请求的消息角色参数,去除首尾空格并转为小写字符串
role = str(payload.role or "").strip().lower()
# 判断 role 是否在支持的角色集合中(user/assistant/tool),否则抛出 400 异常
if role not in {"user", "assistant", "tool"}:
raise HTTPException(status_code=400, detail="role 仅支持 user / assistant / tool")
# 调用 agent_chat_repository.create_message 创建一条消息并返回
return agent_chat_repository.create_message(db, session_id, role, payload.content, payload.meta)
# 声明 delete 路由,路径为 /api/agent-chat-sessions/{session_id}
@router.delete("/api/agent-chat-sessions/{session_id}")
# 定义 delete_session 视图函数,接收 session_id 路径参数和数据库会话 db(依赖注入)
def delete_session(session_id: int, db: Session = Depends(get_db)):
# 根据 session_id 获取会话对象
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 会话存在,调用仓库方法删除该会话
agent_chat_repository.delete_session(db, row)
# 返回删除成功的响应
return {"ok": True} 21.3. main.py #
app/main.py
# 导入FastAPI库
from fastapi import FastAPI
# 导入logging库以便后续日志记录
import logging
# 导入asynccontextmanager用于异步上下文管理器
from contextlib import asynccontextmanager
# 导入FastAPI的静态文件中间件,用于静态文件服务
from fastapi.staticfiles import StaticFiles
# 导入FastAPI的CORS中间件,用于跨域资源共享
from fastapi.middleware.cors import CORSMiddleware
# 导入项目配置settings对象
from app.config import settings
# 从app.database模块导入Base和engine,用于数据库相关操作
from app.database import Base, engine
# 导入app.models模块下的所有内容(类、函数等)
from app.models import *
# 导入MCP服务路由 和 大模型路由
+from app.routers import mcp_services,llm_models,uploads,agents,agent_chat
# 配置日志的基本设置,日志级别为INFO
logging.basicConfig(level=logging.INFO)
# 定义一个异步上下文管理器,用于FastAPI生命周期
@asynccontextmanager
async def lifespan(app: FastAPI):
# 创建所有数据库表结构(如果未存在则自动创建)
Base.metadata.create_all(bind=engine)
# 通过yield挂起,等待应用关闭时进行清理
yield
# 创建FastAPI应用实例,设置API标题和版本号,并指定生命周期管理器
app = FastAPI(title="智能体服务", version="0.1.0", lifespan=lifespan)
# 解析配置中的CORS来源列表,去除空白项和空字符串
origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
# 向FastAPI应用添加CORS中间件
app.add_middleware(
CORSMiddleware,
# 允许访问的来源列表,如果为空则允许所有来源("*")
allow_origins=origins or ["*"],
# 允许携带cookie等凭证
allow_credentials=True,
# 允许所有HTTP方法
allow_methods=["*"],
# 允许所有HTTP头
allow_headers=["*"],
)
# 包含MCP服务路由
app.include_router(mcp_services.router)
app.include_router(llm_models.router)
app.include_router(uploads.router)
app.include_router(agents.router)
+app.include_router(agent_chat.router)
upload_root = settings.upload_path()
upload_root.mkdir(parents=True, exist_ok=True)
app.mount("/uploads", StaticFiles(directory=str(upload_root)), name="uploads")
# 定义一个GET类型的/health路由用于健康检查
@app.get("/health")
def health():
# 返回服务状态为ok的JSON响应
return {"status": "ok"}21.4. agents.py #
app/routers/agents.py
# 从 fastapi 导入 APIRouter, Depends 以及 HTTPException 异常
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session 会话对象
from sqlalchemy.orm import Session
# 从 app 包分别导入 schemas 模块
from app import schemas
# 导入数据库依赖获取函数
from app.database import get_db
# 导入 agent_repository、llm_repository 和 mcp_repository,分别处理不同的数据操作
+from app.repositories import agent_repository, llm_repository, mcp_repository,agent_chat_repository
# 创建一个 APIRouter 实例,设置路由的前缀和标签
router = APIRouter(prefix="/api/agents", tags=["agents"])
# 校验相关引用有效性(大模型提供商、模型、MCP 服务)
def _validate_refs(db: Session, provider_name: str, model_name: str, mcp_service_ids: list[int]) -> None:
# 根据提供商名称查询 LLM 提供商数据
llm_row = llm_repository.get_llm_by_provider_name(db, provider_name.strip())
# 如果没有找到对应的 LLM 提供商则抛出 HTTP 400 异常
if not llm_row:
raise HTTPException(status_code=400, detail=f"大语言模型提供商不存在: {provider_name}")
# 获取该提供商下的所有模型名,并做字符串清洗
llm_models = [str(m or "").strip() for m in (llm_row.model_names or [])]
# 如果传入的模型名称不属于当前提供商的模型,则抛出 HTTP 400 异常
if model_name.strip() not in llm_models:
raise HTTPException(status_code=400, detail=f"模型名称不属于提供商 {provider_name}: {model_name}")
# 遍历所有 mcp_service_ids
for sid in mcp_service_ids:
# 校验每一个 mcp_service 是否存在,如果不存在则抛出 400 异常
if not mcp_repository.get_mcp_service(db, int(sid)):
raise HTTPException(status_code=400, detail=f"MCP 服务不存在: {sid}")
# 定义创建 agent 的接口,POST 请求,响应体为 AgentOut 模型
@router.post("", response_model=schemas.AgentOut)
def create_agent(payload: schemas.AgentCreate, db: Session = Depends(get_db)):
# 调用 _validate_refs 校验 Agent 创建时关联的 LLM 提供商、模型、MCP 服务是否有效
_validate_refs(db, payload.llm_provider_name, payload.llm_model_name, payload.mcp_service_ids)
# 校验通过后,调用 agent_repository 创建新的 Agent,并返回创建结果
return agent_repository.create_agent(db, payload)
# 定义 GET 接口用于获取所有 Agent 列表,响应为 AgentOut 对象列表
@router.get("", response_model=list[schemas.AgentOut])
# 定义 list_agents 视图函数,依赖注入数据库会话 db
def list_agents(db: Session = Depends(get_db)):
# 调用 agent_repository 的 list_agents 方法,获取所有 Agent 数据
return agent_repository.list_agents(db)
# 定义更新指定 agent_id 智能体信息的接口,返回更新后的 AgentOut 响应模型
@router.put("/{agent_id}", response_model=schemas.AgentOut)
# update_agent 视图函数,接收 agent_id、更新数据 payload,和数据库会话 db
def update_agent(agent_id: int, payload: schemas.AgentUpdate, db: Session = Depends(get_db)):
# 根据 agent_id 从数据库查询原有的 agent 记录
row = agent_repository.get_agent(db, agent_id)
# 如果未查到该记录,则抛出 404 错误,提示“记录不存在”
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 优先使用提交的数据,否则使用原有值,确定最终的 provider 名称
provider_name = payload.llm_provider_name if payload.llm_provider_name is not None else row.llm_provider_name
# 优先使用提交的数据,否则使用原有值,确定最终的 model 名称
model_name = payload.llm_model_name if payload.llm_model_name is not None else row.llm_model_name
# 优先使用提交的数据,否则使用原有值,确定 mcp_service_ids
mcp_service_ids = payload.mcp_service_ids if payload.mcp_service_ids is not None else row.mcp_service_ids
# 校验 provider/model/mcp_service 引用是否合法
_validate_refs(db, provider_name, model_name, mcp_service_ids)
# 调用仓库方法更新 agent 记录并返回更新结果
return agent_repository.update_agent(db, row, payload)
# 定义一个用于删除指定 agent_id 智能体信息的接口,HTTP 方法为 DELETE
@router.delete("/{agent_id}")
# delete_agent 视图函数,接收 agent_id 与数据库会话 db
def delete_agent(agent_id: int, db: Session = Depends(get_db)):
# 根据 agent_id 从数据库查询对应的 agent 记录
row = agent_repository.get_agent(db, agent_id)
# 如果没有查到对应记录,则抛出 404 异常,并提示“记录不存在”
if not row:
raise HTTPException(status_code=404, detail="记录不存在")
# 调用仓库方法删除该 agent 记录
agent_repository.delete_agent(db, row)
# 返回操作成功的响应
return {"ok": True}
# 定义一个 GET 接口,根据 agent_id 获取指定智能体信息,返回 AgentOut 响应模型
+@router.get("/{agent_id}", response_model=schemas.AgentOut)
# get_agent 视图函数,接收 agent_id 和数据库会话 db 作为参数
+def get_agent(agent_id: int, db: Session = Depends(get_db)):
# 调用 agent_repository 的 get_agent 方法,根据 agent_id 查询数据库中的智能体记录
+ row = agent_repository.get_agent(db, agent_id)
# 如果未查到对应记录,则抛出 404 异常,并返回“记录不存在”信息
+ if not row:
+ raise HTTPException(status_code=404, detail="记录不存在")
# 查询成功则返回该智能体的数据库记录
+ return row
21.5. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
@field_validator("config", mode="before")
@classmethod
def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
# 返回config原值
return v
# 定义McpTestResult模型,继承自BaseModel
class McpTestResult(BaseModel):
# 测试是否成功的标志
ok: bool
# 返回的信息或说明
message: str
# 工具列表,默认为空列表
tools: list[dict[str, Any]] = Field(default_factory=list)
# 定义LlmModelBase基类,用于存储大模型相关信息
class LlmModelBase(BaseModel):
# 提供商名称,必填,最小长度1,最大长度255
provider_name: str = Field(..., min_length=1, max_length=255)
# 提供商图标,可选
provider_icon: str | None = None
# API基础地址,必填,最小长度1,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必填,最小长度1,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# API密钥申请地址,可选,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,默认为空列表
model_names: list[str] = Field(default_factory=list)
# 对model_names字段进行校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names(cls, v: list[str]) -> list[str]:
# 定义输出列表
out: list[str] = []
# 用于存储已经出现过的模型名称(小写形式)以去重
seen: set[str] = set()
# 遍历输入值,如果为空则用空列表
for item in v or []:
# 将每个名称转为字符串并去除首尾空格
name = str(item or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经见过该名称,则跳过
if key in seen:
continue
# 添加到已见集合和输出列表
seen.add(key)
out.append(name)
# 返回归一化后的模型名称列表
return out
# 定义LlmModelCreate模型,继承自LlmModelBase,没有扩展字段
class LlmModelCreate(LlmModelBase):
pass
# 定义 LlmModelOut 类,继承自 BaseModel,用于大模型的返回数据结构
class LlmModelOut(BaseModel):
# 唯一ID,类型为整数
id: int
# 提供商名称,字符串类型
provider_name: str
# 提供商图标,可选,字符串或None
provider_icon: str | None
# API 基础地址,字符串类型
api_base_url: str
# API 密钥,字符串类型
api_key: str
# API 密钥申请地址,可选,字符串或None
api_key_url: str | None
# 模型名称列表,字符串列表类型
model_names: list[str]
# Pydantic 配置,启用从属性赋值(用于ORM模式)
model_config = {"from_attributes": True}
# 定义用于大模型服务测试请求体的Pydantic模型
class LlmModelTestRequest(BaseModel):
# API基础地址,必须为非空字符串,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必须为非空字符串,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# 定义 LlmModelTestResult 类,继承自 BaseModel,用于返回大模型服务检测的结果
class LlmModelTestResult(BaseModel):
# 检测是否通过,布尔类型
ok: bool
# 检测的信息提示,字符串类型
message: str
# 检测到的模型名称列表,默认为空列表
models: list[str] = Field(default_factory=list)
# 定义用于更新大模型信息的Pydantic模型
class LlmModelUpdate(BaseModel):
# 提供商名称,可选字段,限定最小长度1,最大长度255
provider_name: str | None = Field(None, min_length=1, max_length=255)
# 提供商图标,可选字段
provider_icon: str | None = None
# API 基础地址,可选字段,限定最小长度1,最大长度1024
api_base_url: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥,可选字段,限定最小长度1,最大长度1024
api_key: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥申请地址,可选字段,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,可选字段
model_names: list[str] | None = None
# 对model_names字段做校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names_optional(cls, v: list[str] | None) -> list[str] | None:
# 如果为None,直接返回None
if v is None:
return None
# 用于存储归一化后的模型名称
out: list[str] = []
# 用于去重
seen: set[str] = set()
# 遍历输入列表
for item in v:
# 转为字符串并去除首尾空白
name = str(item or "").strip()
# 如果名称为空则跳过
if not name:
continue
# 使用小写做去重key
key = name.lower()
# 如果已经见过则跳过
if key in seen:
continue
# 添加进已见集合
seen.add(key)
# 添加进输出列表
out.append(name)
# 返回归一化去重后的列表
return out
# 定义函数,用于规范化 ask_variables 字段,输入为可选的字典列表,返回规范化后的字典列表
def _normalize_ask_variables(v):
# 用于保存处理后的变量字典
out = []
# 用于记录已出现过的变量key,实现去重
seen = set()
# 遍历输入列表(如果为None则转为空列表)
for item in v or []:
# 如果当前元素不是字典类型,则跳过
if not isinstance(item, dict):
continue
# 从字典中获取key字段,去除首尾空白转字符串
key = str(item.get("key") or "").strip()
# 如果key为空,则跳过本轮
if not key:
continue
# 如果key已出现,则跳过实现去重
if key in seen:
continue
# 将当前key加入去重集合
seen.add(key)
# 获取question字段、去空白
question = str(item.get("question") or "").strip()
# 获取label字段、去空白
label = str(item.get("label") or "").strip()
# 获取default字段、去空白
default_value = str(item.get("default") or "").strip()
# 获取required字段,默认为True
required = bool(item.get("required", True))
# 将变量信息归一化后放入输出列表
out.append(
{
"key": key,
"label": label,
"question": question or f"请提供 {label or key}",
"required": required,
"default": default_value,
}
)
# 返回归一化和去重后的变量列表
return out
# 定义 AgentBase 基础模型,继承自 Pydantic 的 BaseModel
class AgentBase(BaseModel):
# 头像字段,可为空
avatar: str | None = None
# 智能体名称,必填,长度1~255
name: str = Field(..., min_length=1, max_length=255)
# 描述信息,可为空
description: str | None = None
# 开场白内容,可为空
opening_message: str | None = None
# 智能体系统提示,必填,最小长度1
system_prompt: str = Field(..., min_length=1)
# LLM 提供商名称,必填,长度1~255
llm_provider_name: str = Field(..., min_length=1, max_length=255)
# LLM 模型名称,必填,长度1~255
llm_model_name: str = Field(..., min_length=1, max_length=255)
# 关联的 MCP 服务ID列表,默认为空列表
mcp_service_ids: list[int] = Field(default_factory=list)
# 询问提示词模板,可为空
ask_prompt_template: str | None = None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# 对 mcp_service_ids 字段做归一化校验
@field_validator("mcp_service_ids")
@classmethod
def normalize_mcp_service_ids(cls, v):
# 存储去重后的有效 id
out = []
# 已经出现过的 id 集合
seen = set()
# 遍历 id 列表(防止为 None)
for item in v or []:
# 转成整数类型
num = int(item)
# 跳过小于等于0的无效 id
if num <= 0:
continue
# 跳过重复 id
if num in seen:
continue
# 添加到去重集合
seen.add(num)
# 添加到输出集合
out.append(num)
# 返回整理后的 id 列表
return out
# 对 ask_variables 字段做归一化校验
@field_validator("ask_variables")
@classmethod
def normalize_ask_variables(cls, v):
# 使用 _normalize_ask_variables 函数处理
return _normalize_ask_variables(v)
# 定义 AgentCreate 创建模型,继承自 AgentBase,无额外字段
class AgentCreate(AgentBase):
pass
# 定义AgentOut响应模型,继承自BaseModel
class AgentOut(BaseModel):
# 主键ID
id: int
# 头像,允许为None
avatar: str | None
# 智能体名称
name: str
# 智能体描述,允许为None
description: str | None
# 开场白,允许为None
opening_message: str | None
# 智能体系统提示,不可为None
system_prompt: str
# LLM提供商名称
llm_provider_name: str
# LLM模型名称
llm_model_name: str
# 关联的MCP服务ID列表
mcp_service_ids: list[int]
# 询问提示词模板,允许为None
ask_prompt_template: str | None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# Pydantic配置:允许根据对象属性创建模型(ORM模式)
model_config = {"from_attributes": True}
# 定义AgentUpdate模型,用于部分更新Agent,继承自Pydantic的BaseModel
class AgentUpdate(BaseModel):
# 头像字段,允许为None
avatar: str | None = None
# 名称字段,允许为None,若不为None则要求长度1~255
name: str | None = Field(None, min_length=1, max_length=255)
# 描述字段,允许为None
description: str | None = None
# 开场白字段,允许为None
opening_message: str | None = None
# 系统提示词字段,允许为None,若不为None则要求最小长度1
system_prompt: str | None = Field(None, min_length=1)
# LLM提供商名称,允许为None,若不为None长度1~255
llm_provider_name: str | None = Field(None, min_length=1, max_length=255)
# LLM模型名称,允许为None,若不为None长度1~255
llm_model_name: str | None = Field(None, min_length=1, max_length=255)
# 关联的MCP服务ID列表,允许为None
mcp_service_ids: list[int] | None = None
# 询问提示词模板,允许为None
ask_prompt_template: str | None = None
# 询问变量列表,允许为None
ask_variables: list[dict[str, Any]] | None = None
# 对mcp_service_ids字段进行校验和去重,允许为None
@field_validator("mcp_service_ids")
@classmethod
def normalize_mcp_service_ids_optional(cls, v: list[int] | None) -> list[int] | None:
# 如果为None,直接返回None
if v is None:
return None
# 定义用于存放合法、去重后的ID的列表
out: list[int] = []
# 定义用于去重的集合
seen: set[int] = set()
# 遍历传入的每个ID
for item in v:
# 转为整数
num = int(item)
# 跳过小于等于0的无效ID
if num <= 0:
continue
# 跳过重复ID
if num in seen:
continue
# 加入去重集合
seen.add(num)
# 加入输出列表
out.append(num)
# 返回整理后的ID列表
return out
# 对ask_variables字段进行归一化与合法性检查,允许为None
@field_validator("ask_variables")
@classmethod
def normalize_ask_variables_optional(cls, v: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
# 如果为None直接返回None
if v is None:
return None
# 使用_normalize_ask_variables工具函数归一化处理
return _normalize_ask_variables(v)
# 定义用于输出AgentChatSession(智能体聊天会话)的Pydantic模型
+class AgentChatSessionOut(BaseModel):
# 会话记录的主键ID
+ id: int
# 关联的Agent智能体ID
+ agent_id: int
# 会话标题
+ title: str
# 会话创建时间
+ created_at: Any
# 会话最近更新时间
+ updated_at: Any
# 配置项:允许Pydantic模型支持数据库ORM对象的属性映射
+ model_config = {"from_attributes": True}
# 定义AgentChatSessionCreate用于创建聊天会话的Pydantic模型
+class AgentChatSessionCreate(BaseModel):
# 可选的会话标题字段,最大长度255个字符,默认为None
+ title: str | None = Field(None, max_length=255)
# 定义用于输出Agent聊天消息的Pydantic模型
+class AgentChatMessageOut(BaseModel):
# 消息主键ID
+ id: int
# 关联的聊天会话session_id
+ session_id: int
# 消息所属角色(如user/assistant/system等)
+ role: str
# 消息正文内容
+ content: str
# 扩展元数据,默认为空字典
+ meta: dict[str, Any] = Field(default_factory=dict)
# 消息创建时间
+ created_at: Any
# 配置项:允许支持通过ORM对象转为模型
+ model_config = {"from_attributes": True}
# 定义用于创建Agent聊天消息的Pydantic模型
+class AgentChatMessageCreate(BaseModel):
# 消息角色字段(如 user/assistant/tool),必填,最小长度1,最大32
+ role: str = Field(..., min_length=1, max_length=32)
# 消息内容字段,必填,最小长度1
+ content: str = Field(..., min_length=1)
# 扩展元数据,默认为空字典
+ meta: dict[str, Any] = Field(default_factory=dict) 22. 开启对话 #
本节介绍系统中的对话(agent_chat)模块的接口与数据模型设计。
对话模块的核心作用是支持用户与智能体的多轮交互,并结合会话记录、消息存储等能力,实现高效、灵活的对话管理。主要涉及以下内容:
数据模型
- AgentChatSessionOut:用于输出会话信息(会话ID、所属Agent、标题、创建/更新时间等)。
- AgentChatSessionCreate:用于创建会话时的输入(如会话标题,支持可选)。
- AgentChatMessageOut:用于输出一条消息的详细内容,包括消息ID、会话ID、角色(user/assistant/tool/system)、正文、元数据及时间等。
- AgentChatMessageCreate:用于新建一条消息时的输入信息,角色、内容必填,支持扩展元数据。
- AgentChatSendRequest:用于流式对话消息推送场景(如SSE/MCP接入),除消息内容外支持是否启用“工具调用”能力。
路由接口
- 查询指定 Agent 下的所有会话(GET
/api/agents/{agent_id}/chat-sessions)。 - 新建 Agent 的会话(POST
/api/agents/{agent_id}/chat-sessions)。 - 获取会话下全部消息记录(GET
/api/agent-chat-sessions/{session_id}/messages)。 - 向某个会话发送一条新消息(POST
/api/agent-chat-sessions/{session_id}/messages)。 - 删除指定会话(DELETE
/api/agent-chat-sessions/{session_id})。 - 支持流式响应的消息(如 SSE,POST
/api/agent-chat-sessions/{session_id}/messages/stream)。
每个接口都具备详尽的输入/输出参数说明、异常统一处理机制,以及与数据仓库的隔离实现,可满足“分布式对话存储、多类型角色、多会话管理”等现代智能体框架的需求。
场景举例
- 用户可为某个Agent(如旅游规划智能体)新建一轮对话(会话),并在该会话内反复提问/获取建议,每次交互都会被系统结构化存储,方便后续追溯、继续多轮问答或做上下文增强。
- 消息角色区分支持丰富的AI应用场景,如普通用户消息、智能回复、调用外部工具/检索服务后的插入响应等。
扩展性设计
- 所有出入参采用Pydantic数据模型,类型安全且易于前后端协作。
- 会话、消息结构均支持灵活扩展,如增加元数据、新角色(如 tool)、消息摘要等。
- 支持流式对话(SSE/MCP/工具链),便于集成多种大模型和外部知识工具。
通过上述设计,对话模块能够作为智能体框架的核心交互层,高效承载多轮、异步、结构化的人机对话任务。
22.1. agent_chat.py #
app/routers/agent_chat.py
# 从 fastapi 导入 APIRouter、Depends 和 HTTPException,用于路由定义和依赖注入及异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session,用于数据库会话管理
from sqlalchemy.orm import Session
# 导入 json,用于 JSON 序列化
+import json
# 从 fastapi.responses 导入 StreamingResponse,用于流式响应
+from fastapi.responses import StreamingResponse
# 从 app.repositories 导入 agent_chat_repository 和 agent_repository,用于数据持久层访问
+from app.repositories import agent_chat_repository, agent_repository,llm_repository
# 从 app 导入 schemas,用于数据模型
from app import schemas
# 从 app.database 导入 get_db,用于获取数据库会话
from app.database import get_db
# 创建 APIRouter 实例,并设置 tags 标签为 "agent-chat"
router = APIRouter(tags=["agent-chat"])
# 声明 GET 接口,路径为 /api/agents/{agent_id}/chat-sessions,返回值为 AgentChatSessionOut 数据模型列表
@router.get("/api/agents/{agent_id}/chat-sessions", response_model=list[schemas.AgentChatSessionOut])
# 定义 list_sessions 视图函数,接收 agent_id 作为路径参数,db 为依赖注入的 Session 对象
def list_sessions(agent_id: int, db: Session = Depends(get_db)):
# 先检查数据库里是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,查询所有该 agent 下的对话会话,并将结果返回
return agent_chat_repository.list_sessions(db, agent_id)
# 声明 POST 路由,用于创建新的智能体聊天会话,响应模型为 AgentChatSessionOut
@router.post("/api/agents/{agent_id}/chat-sessions", response_model=schemas.AgentChatSessionOut)
# 定义 create_session 函数,接收 agent_id、会话创建载荷 payload、依赖注入的数据库会话 db
def create_session(agent_id: int, payload: schemas.AgentChatSessionCreate, db: Session = Depends(get_db)):
# 先判断数据库中是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,则调用 agent_chat_repository 创建会话,传入数据库会话、智能体ID和会话标题,返回结果
return agent_chat_repository.create_session(db, agent_id, payload.title)
# 声明一个 GET 路由,根据 session_id 获取会话消息,返回值为 AgentChatMessageOut 列表
@router.get("/api/agent-chat-sessions/{session_id}/messages", response_model=list[schemas.AgentChatMessageOut])
# 定义 list_messages 视图函数,接收 session_id 和数据库会话 db(依赖注入方式获取)
def list_messages(session_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取指定会话 session 的数据库对象
row = agent_chat_repository.get_session(db, session_id)
# 如果没有找到对应会话,则抛出 404 异常,并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 如果会话存在,则调用仓库方法获取所有消息并返回
return agent_chat_repository.list_messages(db, session_id)
# 定义一个 POST 路由,路径为 /api/agent-chat-sessions/{session_id}/messages,响应体为 AgentChatMessageOut 模型
@router.post("/api/agent-chat-sessions/{session_id}/messages", response_model=schemas.AgentChatMessageOut)
# 定义 create_message 视图函数,参数为 session_id、payload(通过 Pydantic 校验的消息数据),db 为依赖注入的数据库会话
def create_message(session_id: int, payload: schemas.AgentChatMessageCreate, db: Session = Depends(get_db)):
# 调用 agent_chat_repository.get_session 检查指定 session_id 的会话是否存在
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 获取请求的消息角色参数,去除首尾空格并转为小写字符串
role = str(payload.role or "").strip().lower()
# 判断 role 是否在支持的角色集合中(user/assistant/tool),否则抛出 400 异常
if role not in {"user", "assistant", "tool"}:
raise HTTPException(status_code=400, detail="role 仅支持 user / assistant / tool")
# 调用 agent_chat_repository.create_message 创建一条消息并返回
return agent_chat_repository.create_message(db, session_id, role, payload.content, payload.meta)
# 声明 delete 路由,路径为 /api/agent-chat-sessions/{session_id}
@router.delete("/api/agent-chat-sessions/{session_id}")
# 定义 delete_session 视图函数,接收 session_id 路径参数和数据库会话 db(依赖注入)
def delete_session(session_id: int, db: Session = Depends(get_db)):
# 根据 session_id 获取会话对象
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 会话存在,调用仓库方法删除该会话
agent_chat_repository.delete_session(db, row)
# 返回删除成功的响应
return {"ok": True}
# 异步生成事件流(SSE)的辅助生成器函数
+async def _event_stream():
# 首先推送一条“开始”类型的事件数据
+ yield f"data: {json.dumps({'type': 'start'})}\n\n"
# 最后推送一条“完成”类型的事件,标志流式响应结束
+ yield f"data: {json.dumps({'type': 'done'})}\n\n"
# 定义 POST 路由,支持对话消息的流式传输,路径为 /api/agent-chat-sessions/{session_id}/messages/stream
+@router.post("/api/agent-chat-sessions/{session_id}/messages/stream")
# 异步视图函数 stream_message,接收 session_id、payload 以及依赖注入的数据库会话 db
+async def stream_message(session_id, payload: schemas.AgentChatSendRequest, db = Depends(get_db)):
# 根据会话ID在数据库中查询对应的会话记录
+ session_row = agent_chat_repository.get_session(db, session_id)
# 如果未找到会话,抛出 404 异常并提示“会话不存在”
+ if not session_row:
+ raise HTTPException(status_code=404, detail="会话不存在")
# 根据会话记录的 agent_id 查询对应的智能体对象
+ agent = agent_repository.get_agent(db, session_row.agent_id)
# 如果智能体不存在,抛出 404 异常并提示“智能体不存在”
+ if not agent:
+ raise HTTPException(status_code=404, detail="智能体不存在")
# 按照智能体的llm_provider_name查询所需的大语言模型(LLM)
+ llm = llm_repository.get_llm_by_provider_name(db, agent.llm_provider_name)
# 如果找不到 LLM 提供商,抛出 400 异常,并输出相关的提示信息
+ if not llm:
+ raise HTTPException(status_code=400, detail=f"找不到大语言模型提供商: {agent.llm_provider_name}")
# 返回基于 _event_stream 生成器函数的 StreamingResponse 响应,SSE文本类型
+ return StreamingResponse(_event_stream(), media_type="text/event-stream") 22.2. schemas.py #
app/schemas.py
# 导入枚举类型
from enum import Enum
# 导入Any类型用于类型注解
from typing import Any
# 从pydantic导入BaseModel、Field和field_validator用于数据验证
from pydantic import BaseModel, Field, field_validator
# 定义MCP协议类型的枚举类
class McpProtocol(str, Enum):
# MCP 协议类型注释
"""MCP 协议类型"""
# stdio协议
stdio = "stdio"
# streamable-http协议
streamable_http = "streamable-http"
# sse协议
sse = "sse"
# 定义MCP服务基础模型
class McpServiceBase(BaseModel):
# MCP 服务基础模型注释
"""MCP 服务基础模型"""
# 服务名称,字符串类型,长度1-255
name: str = Field(..., min_length=1, max_length=255)
# 服务描述,可选字段,字符串或None
description: str | None = None
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,确保其为字典类型
@field_validator("config", mode="before")
@classmethod
def config_not_empty(cls, v: Any) -> Any:
# 验证 config 是否为 JSON 对象,如果不是字典则抛出异常
"""验证 config 是否为 JSON 对象"""
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
return v
# 定义MCP服务创建模型,继承自McpServiceBase
class McpServiceCreate(McpServiceBase):
# 不增加额外内容,直接继承
pass
# 定义MCP服务输出模型
class McpServiceOut(BaseModel):
# ID字段,整型
id: int
# 服务名称
name: str
# 服务描述,可选字段
description: str | None
# 协议类型,字符串
protocol: str
# 配置信息,字典类型
config: dict[str, Any]
# 设置模型配置,允许从ORM对象属性读取数据
model_config = {"from_attributes": True}
# 定义McpServiceUpdate模型,用于更新MCP服务,支持部分字段可选更新
class McpServiceUpdate(BaseModel):
# 服务名称,允许为None,最小长度1,最大长度255
name: str | None = Field(None, min_length=1, max_length=255)
# 服务描述字段,允许为None
description: str | None = None
# 协议类型,使用McpProtocol枚举,允许为None
protocol: McpProtocol | None = None
# 配置信息,允许为None,类型为字典
config: dict[str, Any] | None = None
# 定义McpTestRequest模型,继承自BaseModel
class McpTestRequest(BaseModel):
# 协议类型,使用McpProtocol枚举
protocol: McpProtocol
# 配置信息,要求为字典类型
config: dict[str, Any]
# 对config字段添加验证器,在赋值前进行校验
@field_validator("config", mode="before")
@classmethod
def config_is_object(cls, v: Any) -> Any:
# 如果config不是字典类型,则抛出异常
if not isinstance(v, dict):
raise ValueError("config 必须为 JSON 对象")
# 返回config原值
return v
# 定义McpTestResult模型,继承自BaseModel
class McpTestResult(BaseModel):
# 测试是否成功的标志
ok: bool
# 返回的信息或说明
message: str
# 工具列表,默认为空列表
tools: list[dict[str, Any]] = Field(default_factory=list)
# 定义LlmModelBase基类,用于存储大模型相关信息
class LlmModelBase(BaseModel):
# 提供商名称,必填,最小长度1,最大长度255
provider_name: str = Field(..., min_length=1, max_length=255)
# 提供商图标,可选
provider_icon: str | None = None
# API基础地址,必填,最小长度1,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必填,最小长度1,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# API密钥申请地址,可选,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,默认为空列表
model_names: list[str] = Field(default_factory=list)
# 对model_names字段进行校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names(cls, v: list[str]) -> list[str]:
# 定义输出列表
out: list[str] = []
# 用于存储已经出现过的模型名称(小写形式)以去重
seen: set[str] = set()
# 遍历输入值,如果为空则用空列表
for item in v or []:
# 将每个名称转为字符串并去除首尾空格
name = str(item or "").strip()
# 如果名称为空,则跳过
if not name:
continue
# 转为小写做去重key
key = name.lower()
# 如果已经见过该名称,则跳过
if key in seen:
continue
# 添加到已见集合和输出列表
seen.add(key)
out.append(name)
# 返回归一化后的模型名称列表
return out
# 定义LlmModelCreate模型,继承自LlmModelBase,没有扩展字段
class LlmModelCreate(LlmModelBase):
pass
# 定义 LlmModelOut 类,继承自 BaseModel,用于大模型的返回数据结构
class LlmModelOut(BaseModel):
# 唯一ID,类型为整数
id: int
# 提供商名称,字符串类型
provider_name: str
# 提供商图标,可选,字符串或None
provider_icon: str | None
# API 基础地址,字符串类型
api_base_url: str
# API 密钥,字符串类型
api_key: str
# API 密钥申请地址,可选,字符串或None
api_key_url: str | None
# 模型名称列表,字符串列表类型
model_names: list[str]
# Pydantic 配置,启用从属性赋值(用于ORM模式)
model_config = {"from_attributes": True}
# 定义用于大模型服务测试请求体的Pydantic模型
class LlmModelTestRequest(BaseModel):
# API基础地址,必须为非空字符串,最大长度1024
api_base_url: str = Field(..., min_length=1, max_length=1024)
# API密钥,必须为非空字符串,最大长度1024
api_key: str = Field(..., min_length=1, max_length=1024)
# 定义 LlmModelTestResult 类,继承自 BaseModel,用于返回大模型服务检测的结果
class LlmModelTestResult(BaseModel):
# 检测是否通过,布尔类型
ok: bool
# 检测的信息提示,字符串类型
message: str
# 检测到的模型名称列表,默认为空列表
models: list[str] = Field(default_factory=list)
# 定义用于更新大模型信息的Pydantic模型
class LlmModelUpdate(BaseModel):
# 提供商名称,可选字段,限定最小长度1,最大长度255
provider_name: str | None = Field(None, min_length=1, max_length=255)
# 提供商图标,可选字段
provider_icon: str | None = None
# API 基础地址,可选字段,限定最小长度1,最大长度1024
api_base_url: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥,可选字段,限定最小长度1,最大长度1024
api_key: str | None = Field(None, min_length=1, max_length=1024)
# API 密钥申请地址,可选字段,最大长度1024
api_key_url: str | None = Field(None, max_length=1024)
# 模型名称列表,可选字段
model_names: list[str] | None = None
# 对model_names字段做校验和归一化处理
@field_validator("model_names")
@classmethod
def normalize_model_names_optional(cls, v: list[str] | None) -> list[str] | None:
# 如果为None,直接返回None
if v is None:
return None
# 用于存储归一化后的模型名称
out: list[str] = []
# 用于去重
seen: set[str] = set()
# 遍历输入列表
for item in v:
# 转为字符串并去除首尾空白
name = str(item or "").strip()
# 如果名称为空则跳过
if not name:
continue
# 使用小写做去重key
key = name.lower()
# 如果已经见过则跳过
if key in seen:
continue
# 添加进已见集合
seen.add(key)
# 添加进输出列表
out.append(name)
# 返回归一化去重后的列表
return out
# 定义函数,用于规范化 ask_variables 字段,输入为可选的字典列表,返回规范化后的字典列表
def _normalize_ask_variables(v):
# 用于保存处理后的变量字典
out = []
# 用于记录已出现过的变量key,实现去重
seen = set()
# 遍历输入列表(如果为None则转为空列表)
for item in v or []:
# 如果当前元素不是字典类型,则跳过
if not isinstance(item, dict):
continue
# 从字典中获取key字段,去除首尾空白转字符串
key = str(item.get("key") or "").strip()
# 如果key为空,则跳过本轮
if not key:
continue
# 如果key已出现,则跳过实现去重
if key in seen:
continue
# 将当前key加入去重集合
seen.add(key)
# 获取question字段、去空白
question = str(item.get("question") or "").strip()
# 获取label字段、去空白
label = str(item.get("label") or "").strip()
# 获取default字段、去空白
default_value = str(item.get("default") or "").strip()
# 获取required字段,默认为True
required = bool(item.get("required", True))
# 将变量信息归一化后放入输出列表
out.append(
{
"key": key,
"label": label,
"question": question or f"请提供 {label or key}",
"required": required,
"default": default_value,
}
)
# 返回归一化和去重后的变量列表
return out
# 定义 AgentBase 基础模型,继承自 Pydantic 的 BaseModel
class AgentBase(BaseModel):
# 头像字段,可为空
avatar: str | None = None
# 智能体名称,必填,长度1~255
name: str = Field(..., min_length=1, max_length=255)
# 描述信息,可为空
description: str | None = None
# 开场白内容,可为空
opening_message: str | None = None
# 智能体系统提示,必填,最小长度1
system_prompt: str = Field(..., min_length=1)
# LLM 提供商名称,必填,长度1~255
llm_provider_name: str = Field(..., min_length=1, max_length=255)
# LLM 模型名称,必填,长度1~255
llm_model_name: str = Field(..., min_length=1, max_length=255)
# 关联的 MCP 服务ID列表,默认为空列表
mcp_service_ids: list[int] = Field(default_factory=list)
# 询问提示词模板,可为空
ask_prompt_template: str | None = None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# 对 mcp_service_ids 字段做归一化校验
@field_validator("mcp_service_ids")
@classmethod
def normalize_mcp_service_ids(cls, v):
# 存储去重后的有效 id
out = []
# 已经出现过的 id 集合
seen = set()
# 遍历 id 列表(防止为 None)
for item in v or []:
# 转成整数类型
num = int(item)
# 跳过小于等于0的无效 id
if num <= 0:
continue
# 跳过重复 id
if num in seen:
continue
# 添加到去重集合
seen.add(num)
# 添加到输出集合
out.append(num)
# 返回整理后的 id 列表
return out
# 对 ask_variables 字段做归一化校验
@field_validator("ask_variables")
@classmethod
def normalize_ask_variables(cls, v):
# 使用 _normalize_ask_variables 函数处理
return _normalize_ask_variables(v)
# 定义 AgentCreate 创建模型,继承自 AgentBase,无额外字段
class AgentCreate(AgentBase):
pass
# 定义AgentOut响应模型,继承自BaseModel
class AgentOut(BaseModel):
# 主键ID
id: int
# 头像,允许为None
avatar: str | None
# 智能体名称
name: str
# 智能体描述,允许为None
description: str | None
# 开场白,允许为None
opening_message: str | None
# 智能体系统提示,不可为None
system_prompt: str
# LLM提供商名称
llm_provider_name: str
# LLM模型名称
llm_model_name: str
# 关联的MCP服务ID列表
mcp_service_ids: list[int]
# 询问提示词模板,允许为None
ask_prompt_template: str | None
# 询问变量列表,默认为空列表
ask_variables: list[dict[str, Any]] = Field(default_factory=list)
# Pydantic配置:允许根据对象属性创建模型(ORM模式)
model_config = {"from_attributes": True}
# 定义AgentUpdate模型,用于部分更新Agent,继承自Pydantic的BaseModel
class AgentUpdate(BaseModel):
# 头像字段,允许为None
avatar: str | None = None
# 名称字段,允许为None,若不为None则要求长度1~255
name: str | None = Field(None, min_length=1, max_length=255)
# 描述字段,允许为None
description: str | None = None
# 开场白字段,允许为None
opening_message: str | None = None
# 系统提示词字段,允许为None,若不为None则要求最小长度1
system_prompt: str | None = Field(None, min_length=1)
# LLM提供商名称,允许为None,若不为None长度1~255
llm_provider_name: str | None = Field(None, min_length=1, max_length=255)
# LLM模型名称,允许为None,若不为None长度1~255
llm_model_name: str | None = Field(None, min_length=1, max_length=255)
# 关联的MCP服务ID列表,允许为None
mcp_service_ids: list[int] | None = None
# 询问提示词模板,允许为None
ask_prompt_template: str | None = None
# 询问变量列表,允许为None
ask_variables: list[dict[str, Any]] | None = None
# 对mcp_service_ids字段进行校验和去重,允许为None
@field_validator("mcp_service_ids")
@classmethod
def normalize_mcp_service_ids_optional(cls, v: list[int] | None) -> list[int] | None:
# 如果为None,直接返回None
if v is None:
return None
# 定义用于存放合法、去重后的ID的列表
out: list[int] = []
# 定义用于去重的集合
seen: set[int] = set()
# 遍历传入的每个ID
for item in v:
# 转为整数
num = int(item)
# 跳过小于等于0的无效ID
if num <= 0:
continue
# 跳过重复ID
if num in seen:
continue
# 加入去重集合
seen.add(num)
# 加入输出列表
out.append(num)
# 返回整理后的ID列表
return out
# 对ask_variables字段进行归一化与合法性检查,允许为None
@field_validator("ask_variables")
@classmethod
def normalize_ask_variables_optional(cls, v: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
# 如果为None直接返回None
if v is None:
return None
# 使用_normalize_ask_variables工具函数归一化处理
return _normalize_ask_variables(v)
# 定义用于输出AgentChatSession(智能体聊天会话)的Pydantic模型
class AgentChatSessionOut(BaseModel):
# 会话记录的主键ID
id: int
# 关联的Agent智能体ID
agent_id: int
# 会话标题
title: str
# 会话创建时间
created_at: Any
# 会话最近更新时间
updated_at: Any
# 配置项:允许Pydantic模型支持数据库ORM对象的属性映射
model_config = {"from_attributes": True}
# 定义AgentChatSessionCreate用于创建聊天会话的Pydantic模型
class AgentChatSessionCreate(BaseModel):
# 可选的会话标题字段,最大长度255个字符,默认为None
title: str | None = Field(None, max_length=255)
# 定义用于输出Agent聊天消息的Pydantic模型
class AgentChatMessageOut(BaseModel):
# 消息主键ID
id: int
# 关联的聊天会话session_id
session_id: int
# 消息所属角色(如user/assistant/system等)
role: str
# 消息正文内容
content: str
# 扩展元数据,默认为空字典
meta: dict[str, Any] = Field(default_factory=dict)
# 消息创建时间
created_at: Any
# 配置项:允许支持通过ORM对象转为模型
model_config = {"from_attributes": True}
# 定义用于创建Agent聊天消息的Pydantic模型
class AgentChatMessageCreate(BaseModel):
# 消息角色字段(如 user/assistant/tool),必填,最小长度1,最大32
role: str = Field(..., min_length=1, max_length=32)
# 消息内容字段,必填,最小长度1
content: str = Field(..., min_length=1)
# 扩展元数据,默认为空字典
meta: dict[str, Any] = Field(default_factory=dict)
# 定义 AgentChatSendRequest 类,继承自 BaseModel,用于发送消息请求的输入模型
+class AgentChatSendRequest(BaseModel):
# 消息内容字段,必填,最小长度为1
+ content: str = Field(..., min_length=1)
# 是否启用 MCP 工具(默认为 True)
+ enable_mcp_tools: bool = True23. 添加对话消息 #
本节将详细介绍在智能体对话系统中“添加消息”的核心处理逻辑和设计要点。
功能概述
“添加消息”接口允许用户或系统向某一对话会话(session)中追加新的消息(message)。消息可带有角色属性(如 user/assistant/tool),并支持元数据扩展。接口支持普通消息流式和非流式两种方式。
流程说明
参数校验:
- 检查目标会话(session)是否存在,若不存在返回404。
- 校验消息角色(
role)参数,要求在user/assistant/tool之中,否则返回400。 - 消息内容不得为空,否则返回400。
消息创建与保存:
- 使用
agent_chat_repository.create_message方法,将新消息持久化到数据库中,并返回消息模型对象。
- 使用
流式消息(SSE)场景:
- 针对支持流式(如 GPT 等大模型)的场景,通过
/api/agent-chat-sessions/{session_id}/messages/stream路由实现。 - 新用户消息会被自动保存,如为该会话第一条消息且标题为空/为“新对话”,则基于消息内容自动生成并更新会话标题。
- 最终采用
StreamingResponse按SSE标准实时推送交互反馈。
- 针对支持流式(如 GPT 等大模型)的场景,通过
23.1. agent_chat.py #
app/routers/agent_chat.py
# 从 fastapi 导入 APIRouter、Depends 和 HTTPException,用于路由定义和依赖注入及异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session,用于数据库会话管理
from sqlalchemy.orm import Session
# 导入 json,用于 JSON 序列化
import json
# 从 fastapi.responses 导入 StreamingResponse,用于流式响应
from fastapi.responses import StreamingResponse
# 从 app.repositories 导入 agent_chat_repository 和 agent_repository,用于数据持久层访问
+from app.repositories import agent_chat_repository, agent_repository,llm_repository,mcp_repository
# 从 app 导入 schemas,用于数据模型
from app import schemas
# 从 app.database 导入 get_db,用于获取数据库会话
from app.database import get_db
# 创建 APIRouter 实例,并设置 tags 标签为 "agent-chat"
router = APIRouter(tags=["agent-chat"])
# 根据用户提问内容自动生成会话标题的函数
+def _build_session_title_from_question(content: str) -> str:
# 对传入的内容进行去除首尾空白和多余空格,保证标题精简
+ text = " ".join(str(content or "").strip().split())
# 如果内容为空,则返回默认标题“新对话”
+ if not text:
+ return "新对话"
# 最大标题长度设置为24个字符
+ max_len = 24
# 如果内容长度小于等于最大长度,直接返回内容作为标题
+ if len(text) <= max_len:
+ return text
# 否则,截取前24个字符并加省略号作为标题
+ return f"{text[:max_len]}..."
# 声明 GET 接口,路径为 /api/agents/{agent_id}/chat-sessions,返回值为 AgentChatSessionOut 数据模型列表
@router.get("/api/agents/{agent_id}/chat-sessions", response_model=list[schemas.AgentChatSessionOut])
# 定义 list_sessions 视图函数,接收 agent_id 作为路径参数,db 为依赖注入的 Session 对象
def list_sessions(agent_id: int, db: Session = Depends(get_db)):
# 先检查数据库里是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,查询所有该 agent 下的对话会话,并将结果返回
return agent_chat_repository.list_sessions(db, agent_id)
# 声明 POST 路由,用于创建新的智能体聊天会话,响应模型为 AgentChatSessionOut
@router.post("/api/agents/{agent_id}/chat-sessions", response_model=schemas.AgentChatSessionOut)
# 定义 create_session 函数,接收 agent_id、会话创建载荷 payload、依赖注入的数据库会话 db
def create_session(agent_id: int, payload: schemas.AgentChatSessionCreate, db: Session = Depends(get_db)):
# 先判断数据库中是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,则调用 agent_chat_repository 创建会话,传入数据库会话、智能体ID和会话标题,返回结果
return agent_chat_repository.create_session(db, agent_id, payload.title)
# 声明一个 GET 路由,根据 session_id 获取会话消息,返回值为 AgentChatMessageOut 列表
@router.get("/api/agent-chat-sessions/{session_id}/messages", response_model=list[schemas.AgentChatMessageOut])
# 定义 list_messages 视图函数,接收 session_id 和数据库会话 db(依赖注入方式获取)
def list_messages(session_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取指定会话 session 的数据库对象
row = agent_chat_repository.get_session(db, session_id)
# 如果没有找到对应会话,则抛出 404 异常,并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 如果会话存在,则调用仓库方法获取所有消息并返回
return agent_chat_repository.list_messages(db, session_id)
# 定义一个 POST 路由,路径为 /api/agent-chat-sessions/{session_id}/messages,响应体为 AgentChatMessageOut 模型
@router.post("/api/agent-chat-sessions/{session_id}/messages", response_model=schemas.AgentChatMessageOut)
# 定义 create_message 视图函数,参数为 session_id、payload(通过 Pydantic 校验的消息数据),db 为依赖注入的数据库会话
def create_message(session_id: int, payload: schemas.AgentChatMessageCreate, db: Session = Depends(get_db)):
# 调用 agent_chat_repository.get_session 检查指定 session_id 的会话是否存在
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 获取请求的消息角色参数,去除首尾空格并转为小写字符串
role = str(payload.role or "").strip().lower()
# 判断 role 是否在支持的角色集合中(user/assistant/tool),否则抛出 400 异常
if role not in {"user", "assistant", "tool"}:
raise HTTPException(status_code=400, detail="role 仅支持 user / assistant / tool")
# 调用 agent_chat_repository.create_message 创建一条消息并返回
return agent_chat_repository.create_message(db, session_id, role, payload.content, payload.meta)
# 声明 delete 路由,路径为 /api/agent-chat-sessions/{session_id}
@router.delete("/api/agent-chat-sessions/{session_id}")
# 定义 delete_session 视图函数,接收 session_id 路径参数和数据库会话 db(依赖注入)
def delete_session(session_id: int, db: Session = Depends(get_db)):
# 根据 session_id 获取会话对象
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 会话存在,调用仓库方法删除该会话
agent_chat_repository.delete_session(db, row)
# 返回删除成功的响应
return {"ok": True}
# 异步生成事件流(SSE)的辅助生成器函数
async def _event_stream():
# 首先推送一条“开始”类型的事件数据
yield f"data: {json.dumps({'type': 'start'})}\n\n"
# 最后推送一条“完成”类型的事件,标志流式响应结束
yield f"data: {json.dumps({'type': 'done'})}\n\n"
# 定义 POST 路由,支持对话消息的流式传输,路径为 /api/agent-chat-sessions/{session_id}/messages/stream
@router.post("/api/agent-chat-sessions/{session_id}/messages/stream")
# 异步视图函数 stream_message,接收 session_id、payload 以及依赖注入的数据库会话 db
async def stream_message(session_id, payload: schemas.AgentChatSendRequest, db = Depends(get_db)):
# 根据会话ID在数据库中查询对应的会话记录
session_row = agent_chat_repository.get_session(db, session_id)
# 如果未找到会话,抛出 404 异常并提示“会话不存在”
if not session_row:
raise HTTPException(status_code=404, detail="会话不存在")
# 根据会话记录的 agent_id 查询对应的智能体对象
agent = agent_repository.get_agent(db, session_row.agent_id)
# 如果智能体不存在,抛出 404 异常并提示“智能体不存在”
if not agent:
raise HTTPException(status_code=404, detail="智能体不存在")
# 按照智能体的llm_provider_name查询所需的大语言模型(LLM)
llm = llm_repository.get_llm_by_provider_name(db, agent.llm_provider_name)
# 如果找不到 LLM 提供商,抛出 400 异常,并输出相关的提示信息
if not llm:
raise HTTPException(status_code=400, detail=f"找不到大语言模型提供商: {agent.llm_provider_name}")
# 获取大语言模型的 API 基础地址,去除首尾空白后保存为字符串
+ llm_api_base_url = str(llm.api_base_url or "").strip()
# 获取大语言模型的 API 密钥,去除首尾空白后保存为字符串
+ llm_api_key = str(llm.api_key or "").strip()
# 获取当前智能体配置的大语言模型名称,去除首尾空白后保存为字符串
+ llm_model_name = str(agent.llm_model_name or "").strip()
# 获取当前智能体的系统提示词,去除首尾空白后保存为字符串
+ system_prompt = str(agent.system_prompt or "").strip()
# 初始化用于存储 MCP 服务信息的空列表
+ mcp_services = []
# 遍历配置的 MCP 服务 ID 列表(如果有)
+ for sid in agent.mcp_service_ids or []:
# 根据 MCP 服务 ID 查询对应服务的数据库记录
+ row = mcp_repository.get_mcp_service(db, int(sid))
# 如果查到服务,则将其信息追加到 mcp_services 列表
+ if row:
+ mcp_services.append(
+ {
+ "id": int(row.id),
+ "name": str(row.name or ""),
+ "protocol": str(row.protocol or ""),
+ "config": row.config or {},
+ }
+ )
# 获取用户消息内容,并去除其首尾空白
+ user_content = payload.content.strip()
# 如果消息内容为空,抛出 400 错误,提示“消息内容不能为空”
+ if not user_content:
+ raise HTTPException(status_code=400, detail="消息内容不能为空")
# 查询当前会话已存在的历史消息记录
+ existing_rows = agent_chat_repository.list_messages(db, session_id)
# 如果没有历史消息,且会话标题为空或为“新对话”,则自动生成并更新会话标题
+ if not existing_rows and str(session_row.title or "").strip() in {"", "新对话"}:
+ agent_chat_repository.update_session_title(db, session_row, _build_session_title_from_question(user_content))
# 创建一条新的用户消息到当前会话
+ agent_chat_repository.create_message(db, session_id, "user", user_content)
# 返回基于 _event_stream 生成器函数的 StreamingResponse 响应,SSE文本类型
return StreamingResponse(_event_stream(), media_type="text/event-stream") 24. 流式输出回答 #
本节详细讲解如何实现面向 AI 智能体的「流式输出回答」功能,涵盖核心业务逻辑、异步事件流生成、消息持久化等关键步骤。
功能简介
流式输出(Stream Output)是指后端服务将 AI 助手生成的回复内容,以“边生成边发送”的方式实时推送到前端客户端,实现对话内容的即时展示体验。本方案主要基于 Server-Sent Events(SSE) 实现。
主要技术路径与流程
- 查询并保存用户问题为消息记录。
- 获取历史消息(用于上下文)。
- 构造消息序列(包括系统提示词、历史对话等)。
- 异步生成助手的回复内容(可与大模型 API 对接,也支持模拟/演示文本)。
- 采用异步生成器函数将内容逐步“分片”推送给客户端,实现流式显示。
- 助手回复内容保存为消息记录,自动刷新会话活跃时间。
- 封装为
StreamingResponse并声明text/event-streamMIME 类型。
典型应用场景
- 用户每次发起对话请求,快速获得“正在生成”和“逐步显示”体验。
- 可对接 OpenAI、Azure、通义千问、GLM、讯飞、文心一言等主流 LLM 提供商的流式 API,按需调整
_event_stream逻辑实现真实流式对接。
前端配合建议
- 使用
EventSource监听后端 SSE,实时渲染新分片(delta)。 - 根据
type字段区分“开始”、“分片内容”、“完成”或“异常”。
参考与最佳实践
- 推荐每次推送的最小单位可按“token”、“句子”或“段落”灵活定制。
- 对话上下文应完整传递至 LLM,方便多轮记忆。
- 如需支持富文本,可按需传输 Markdown 或富内容结构。
通过上述实现,可极大优化智能体对话的用户体验,实现高效、流畅、实时的 AI 流式问答交互流程。
24.1. agent_chat_repository.py #
app/repositories/agent_chat_repository.py
# 导入 SQLAlchemy 的 func 和 select 方法用于数据库查询
from sqlalchemy import func, select
# 导入 SQLAlchemy 的 Session 用于数据库会话管理
from sqlalchemy.orm import Session
# 从 app 包导入 models 模块,用于数据库模型操作
from app import models
# 定义一个函数,根据 agent_id 查询对应的所有 AgentChatSession 记录
def list_sessions(db: Session, agent_id: int) -> list[models.AgentChatSession]:
# 构造一个查询语句,筛选 agent_id 等于传入参数的聊天会话
stmt = select(models.AgentChatSession).where(models.AgentChatSession.agent_id == agent_id)
# 按照 updated_at 字段倒序、然后 id 字段倒序排序,确保最新的会话排在前面
stmt = stmt.order_by(models.AgentChatSession.updated_at.desc(), models.AgentChatSession.id.desc())
# 执行查询,将所有结果转换为列表并返回
return list(db.scalars(stmt).all())
# 定义创建聊天会话的方法,传入数据库会话 db、智能体ID agent_id、会话标题 title(可选,默认为 None)
def create_session(db: Session, agent_id: int, title: str | None = None) -> models.AgentChatSession:
# 创建 AgentChatSession 实例,指定 agent_id 和标题(默认为‘新对话’,去除空格后为空也用‘新对话’)
row = models.AgentChatSession(
agent_id=agent_id,
title=(title or "新对话").strip() or "新对话",
)
# 将新建的会话对象添加到数据库 session 中,准备写入数据库
db.add(row)
# 提交事务,将新添加的会话保存到数据库
db.commit()
# 刷新 row 对象,确保获取数据库自动生成的字段(如主键、时间等)的最新值
db.refresh(row)
# 返回新创建的会话对象
return row
# 定义一个函数,通过 session_id 获取指定的 AgentChatSession 记录
def get_session(db: Session, session_id: int) -> models.AgentChatSession | None:
# 调用 db.get 方法,根据主键 session_id 查询 AgentChatSession,如果不存在则返回 None
return db.get(models.AgentChatSession, session_id)
# 定义一个函数,根据会话ID查询对应的所有聊天消息,并按消息ID升序排序
def list_messages(db: Session, session_id: int) -> list[models.AgentChatMessage]:
# 构造查询语句,只筛选session_id为指定值的消息
stmt = select(models.AgentChatMessage).where(models.AgentChatMessage.session_id == session_id)
# 按id字段升序排列消息,确保按先后顺序返回
stmt = stmt.order_by(models.AgentChatMessage.id.asc())
# 执行查询并返回所有结果转换为列表
return list(db.scalars(stmt).all())
# 定义一个函数用于创建 AgentChatMessage 消息记录
def create_message(
db: Session, # 数据库会话对象
session_id: int, # 聊天会话ID
role: str, # 消息角色(如 user、assistant、tool)
content: str, # 消息正文内容
meta: dict | None = None, # 消息附加元数据,默认为None
) -> models.AgentChatMessage:
# 创建 AgentChatMessage 实例,去除角色两端空白,meta为None时使用空字典
row = models.AgentChatMessage(
session_id=session_id,
role=role.strip(),
content=content,
meta=meta or {},
)
# 将新消息加入数据库会话
db.add(row)
# 提交事务,将消息写入数据库
db.commit()
# 刷新对象,获取数据库自动生成的字段(如id、创建时间等)的最新值
db.refresh(row)
# 返回新建的消息对象
return row
# 定义删除指定会话及其所有消息的函数
def delete_session(db: Session, row: models.AgentChatSession) -> None:
# 调用 list_messages 获取该会话下的所有消息
msgs = list_messages(db, row.id)
# 遍历所有消息,逐条从数据库中删除
for item in msgs:
db.delete(item)
# 删除会话记录本身
db.delete(row)
# 提交事务,保存删除操作到数据库
db.commit()
# 定义 touch_session 函数,用于更新会话的更新时间戳
+def touch_session(db: Session, row: models.AgentChatSession) -> models.AgentChatSession:
# 将会话对象的 updated_at 字段设置为当前时间
+ row.updated_at = func.now()
# 将更新后的会话对象添加到数据库会话
+ db.add(row)
# 提交事务,使更改生效
+ db.commit()
# 刷新 row 对象,确保获取到数据库生成的最新字段值
+ db.refresh(row)
# 返回更新后的会话对象
+ return row24.2. agent_chat.py #
app/routers/agent_chat.py
# 从 fastapi 导入 APIRouter、Depends 和 HTTPException,用于路由定义和依赖注入及异常处理
from fastapi import APIRouter, Depends, HTTPException
# 从 sqlalchemy.orm 导入 Session,用于数据库会话管理
from sqlalchemy.orm import Session
# 导入 json,用于 JSON 序列化
import json
# 导入 asyncio,用于异步队列管理
+import asyncio
# 从 fastapi.responses 导入 StreamingResponse,用于流式响应
from fastapi.responses import StreamingResponse
# 从 app.repositories 导入 agent_chat_repository 和 agent_repository,用于数据持久层访问
from app.repositories import agent_chat_repository, agent_repository,llm_repository,mcp_repository
# 从 app 导入 schemas,用于数据模型
from app import schemas
# 从 app.database 导入 get_db,用于获取数据库会话
from app.database import get_db
# 创建 APIRouter 实例,并设置 tags 标签为 "agent-chat"
router = APIRouter(tags=["agent-chat"])
# 根据用户提问内容自动生成会话标题的函数
def _build_session_title_from_question(content: str) -> str:
# 对传入的内容进行去除首尾空白和多余空格,保证标题精简
text = " ".join(str(content or "").strip().split())
# 如果内容为空,则返回默认标题“新对话”
if not text:
return "新对话"
# 最大标题长度设置为24个字符
max_len = 24
# 如果内容长度小于等于最大长度,直接返回内容作为标题
if len(text) <= max_len:
return text
# 否则,截取前24个字符并加省略号作为标题
return f"{text[:max_len]}..."
# 声明 GET 接口,路径为 /api/agents/{agent_id}/chat-sessions,返回值为 AgentChatSessionOut 数据模型列表
@router.get("/api/agents/{agent_id}/chat-sessions", response_model=list[schemas.AgentChatSessionOut])
# 定义 list_sessions 视图函数,接收 agent_id 作为路径参数,db 为依赖注入的 Session 对象
def list_sessions(agent_id: int, db: Session = Depends(get_db)):
# 先检查数据库里是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,查询所有该 agent 下的对话会话,并将结果返回
return agent_chat_repository.list_sessions(db, agent_id)
# 声明 POST 路由,用于创建新的智能体聊天会话,响应模型为 AgentChatSessionOut
@router.post("/api/agents/{agent_id}/chat-sessions", response_model=schemas.AgentChatSessionOut)
# 定义 create_session 函数,接收 agent_id、会话创建载荷 payload、依赖注入的数据库会话 db
def create_session(agent_id: int, payload: schemas.AgentChatSessionCreate, db: Session = Depends(get_db)):
# 先判断数据库中是否存在指定 agent,如果不存在则抛出 404 错误
if not agent_repository.get_agent(db, agent_id):
raise HTTPException(status_code=404, detail="智能体不存在")
# 如果智能体存在,则调用 agent_chat_repository 创建会话,传入数据库会话、智能体ID和会话标题,返回结果
return agent_chat_repository.create_session(db, agent_id, payload.title)
# 声明一个 GET 路由,根据 session_id 获取会话消息,返回值为 AgentChatMessageOut 列表
@router.get("/api/agent-chat-sessions/{session_id}/messages", response_model=list[schemas.AgentChatMessageOut])
# 定义 list_messages 视图函数,接收 session_id 和数据库会话 db(依赖注入方式获取)
def list_messages(session_id: int, db: Session = Depends(get_db)):
# 调用仓库方法获取指定会话 session 的数据库对象
row = agent_chat_repository.get_session(db, session_id)
# 如果没有找到对应会话,则抛出 404 异常,并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 如果会话存在,则调用仓库方法获取所有消息并返回
return agent_chat_repository.list_messages(db, session_id)
# 定义一个 POST 路由,路径为 /api/agent-chat-sessions/{session_id}/messages,响应体为 AgentChatMessageOut 模型
@router.post("/api/agent-chat-sessions/{session_id}/messages", response_model=schemas.AgentChatMessageOut)
# 定义 create_message 视图函数,参数为 session_id、payload(通过 Pydantic 校验的消息数据),db 为依赖注入的数据库会话
def create_message(session_id: int, payload: schemas.AgentChatMessageCreate, db: Session = Depends(get_db)):
# 调用 agent_chat_repository.get_session 检查指定 session_id 的会话是否存在
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 获取请求的消息角色参数,去除首尾空格并转为小写字符串
role = str(payload.role or "").strip().lower()
# 判断 role 是否在支持的角色集合中(user/assistant/tool),否则抛出 400 异常
if role not in {"user", "assistant", "tool"}:
raise HTTPException(status_code=400, detail="role 仅支持 user / assistant / tool")
# 调用 agent_chat_repository.create_message 创建一条消息并返回
return agent_chat_repository.create_message(db, session_id, role, payload.content, payload.meta)
# 声明 delete 路由,路径为 /api/agent-chat-sessions/{session_id}
@router.delete("/api/agent-chat-sessions/{session_id}")
# 定义 delete_session 视图函数,接收 session_id 路径参数和数据库会话 db(依赖注入)
def delete_session(session_id: int, db: Session = Depends(get_db)):
# 根据 session_id 获取会话对象
row = agent_chat_repository.get_session(db, session_id)
# 如果会话不存在,抛出 404 异常并提示“会话不存在”
if not row:
raise HTTPException(status_code=404, detail="会话不存在")
# 会话存在,调用仓库方法删除该会话
agent_chat_repository.delete_session(db, row)
# 返回删除成功的响应
return {"ok": True}
# 定义 POST 路由,支持对话消息的流式传输,路径为 /api/agent-chat-sessions/{session_id}/messages/stream
@router.post("/api/agent-chat-sessions/{session_id}/messages/stream")
# 异步视图函数 stream_message,接收 session_id、payload 以及依赖注入的数据库会话 db
async def stream_message(session_id, payload: schemas.AgentChatSendRequest, db = Depends(get_db)):
# 根据会话ID在数据库中查询对应的会话记录
session_row = agent_chat_repository.get_session(db, session_id)
# 如果未找到会话,抛出 404 异常并提示“会话不存在”
if not session_row:
raise HTTPException(status_code=404, detail="会话不存在")
# 根据会话记录的 agent_id 查询对应的智能体对象
agent = agent_repository.get_agent(db, session_row.agent_id)
# 如果智能体不存在,抛出 404 异常并提示“智能体不存在”
if not agent:
raise HTTPException(status_code=404, detail="智能体不存在")
# 按照智能体的llm_provider_name查询所需的大语言模型(LLM)
llm = llm_repository.get_llm_by_provider_name(db, agent.llm_provider_name)
# 如果找不到 LLM 提供商,抛出 400 异常,并输出相关的提示信息
if not llm:
raise HTTPException(status_code=400, detail=f"找不到大语言模型提供商: {agent.llm_provider_name}")
# 获取大语言模型的 API 基础地址,去除首尾空白后保存为字符串
llm_api_base_url = str(llm.api_base_url or "").strip()
# 获取大语言模型的 API 密钥,去除首尾空白后保存为字符串
llm_api_key = str(llm.api_key or "").strip()
# 获取当前智能体配置的大语言模型名称,去除首尾空白后保存为字符串
llm_model_name = str(agent.llm_model_name or "").strip()
# 获取当前智能体的系统提示词,去除首尾空白后保存为字符串
system_prompt = str(agent.system_prompt or "").strip()
# 初始化用于存储 MCP 服务信息的空列表
mcp_services = []
# 遍历配置的 MCP 服务 ID 列表(如果有)
for sid in agent.mcp_service_ids or []:
# 根据 MCP 服务 ID 查询对应服务的数据库记录
row = mcp_repository.get_mcp_service(db, int(sid))
# 如果查到服务,则将其信息追加到 mcp_services 列表
if row:
mcp_services.append(
{
"id": int(row.id),
"name": str(row.name or ""),
"protocol": str(row.protocol or ""),
"config": row.config or {},
}
)
# 获取用户消息内容,并去除其首尾空白
user_content = payload.content.strip()
# 如果消息内容为空,抛出 400 错误,提示“消息内容不能为空”
if not user_content:
raise HTTPException(status_code=400, detail="消息内容不能为空")
# 查询当前会话已存在的历史消息记录
existing_rows = agent_chat_repository.list_messages(db, session_id)
# 如果没有历史消息,且会话标题为空或为“新对话”,则自动生成并更新会话标题
if not existing_rows and str(session_row.title or "").strip() in {"", "新对话"}:
agent_chat_repository.update_session_title(db, session_row, _build_session_title_from_question(user_content))
# 创建一条新的用户消息到当前会话
+ agent_chat_repository.create_message(db, session_id, "user", user_content)
# 查询当前会话历史消息列表
+ history_rows = agent_chat_repository.list_messages(db, session_id)
# 初始化消息列表,将系统提示词作为第一条消息
+ messages = [{"role": "system", "content": system_prompt}]
# 遍历历史消息,将有效的消息(role为user/assistant/tool)追加到消息列表中
+ for m in history_rows:
# 获取该条消息的角色,去除首尾空格并统一小写
+ role = str(m.role or "").strip().lower()
# 如果角色不在支持的角色集合,则跳过
+ if role not in {"user", "assistant", "tool"}:
+ continue
# 构建消息字典并追加到消息列表
+ item = {"role": role, "content": m.content}
+ messages.append(item)
# 定义异步事件流生成器,用于 SSE 流式响应
+ async def _event_stream():
# 存储完整文本的分片
+ full_text_parts = []
# 推送“开始”事件到客户端
+ yield f"data: {json.dumps({'type': 'start'})}\n\n"
+ try:
# 示例:本次流式推送的结果文本,仅作演示用途
+ chunk = """
+ 春季3-4月最佳,核心玩中轴线+市郊S2线花海。高铁往返约5小时/576元。住西单或亚运村,4天3晚预算约1550元(含住宿餐饮门票市内交通)。务必提前7天约故宫,清明人流极大,注意温差穿衣。
+ """
# 追加本次分片到完整文本
+ full_text_parts.append(chunk)
+ for char in chunk:
# 发送本次分片数据给客户端,类型为“delta”
+ yield f"data: {json.dumps({'type': 'delta', 'text': char}, ensure_ascii=False)}\n\n"
+ await asyncio.sleep(0.05)
# 合并所有分片为最终完整输出
+ full_text = "".join(full_text_parts).strip()
# 将助手的回复消息保存到数据库
+ agent_chat_repository.create_message(db, session_id, "assistant", full_text or "")
# 更新会话的最近活动时间
+ agent_chat_repository.touch_session(db, session_row)
# 推送“完成”事件到客户端
+ yield f"data: {json.dumps({'type': 'done'})}\n\n"
+ except Exception as e: # noqa: BLE001
# 如果发生异常,则发送错误事件
+ yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
# 返回基于 _event_stream 生成器函数的 StreamingResponse 响应,SSE文本类型
return StreamingResponse(_event_stream(), media_type="text/event-stream")