C0726N01-FastAPI 路由系统
概述
FastAPI的路由系统是其核心功能之一,负责将HTTP请求映射到相应的处理函数。本章将深入分析FastAPI路由系统的设计原理、实现机制以及高级特性。
1. 路由系统架构
1.1 核心组件
FastAPI路由系统由以下核心组件构成:
# 核心路由类
class APIRoute(Route):
"""API路由类,继承自Starlette的Route"""
class APIRouter:
"""API路由器,用于组织和管理路由"""
class APIWebSocketRoute(WebSocketRoute):
"""WebSocket路由类"""
设计层次:
- Route层: 基础路由功能(来自Starlette)
- APIRoute层: API特定功能(参数验证、文档生成)
- APIRouter层: 路由组织和管理
- FastAPI层: 应用程序级路由集成
1.2 路由匹配机制
路由匹配遵循以下优先级:
- 精确匹配: 完全匹配的静态路径
- 参数匹配: 包含路径参数的动态路径
- 通配符匹配: 使用通配符的路径模式
- 默认处理: 404错误处理
# 路由匹配示例
"/users/123" # 精确匹配
"/users/{user_id}" # 参数匹配
"/files/{file_path:path}" # 通配符匹配
2. APIRoute 类详解
2.1 初始化参数
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Union[Callable[[APIRoute], str], DefaultPlaceholder] = Default(generate_unique_id),
) -> None:
参数分类解析:
基础路由参数
path: URL路径模式endpoint: 处理函数methods: HTTP方法列表name: 路由名称(用于URL反向解析)
响应控制参数
response_model: 响应数据模型status_code: 默认HTTP状态码response_class: 响应类型(JSON、HTML、文件等)response_model_*: 响应序列化控制
文档生成参数
summary: API摘要description: 详细描述tags: 分组标签deprecated: 废弃标记responses: 多状态码响应定义
高级功能参数
dependencies: 路由级依赖callbacks: 回调定义openapi_extra: 自定义OpenAPI扩展
2.2 路由处理流程
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI应用程序接口实现"""
assert scope["type"] == "http"
request = Request(scope, receive)
# 1. 依赖解析
dependant = get_dependant(path=self.path_regex, call=self.endpoint)
# 2. 参数提取和验证
values, errors, background_tasks, sub_response = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
)
# 3. 错误处理
if errors:
raise RequestValidationError(errors, body=body)
# 4. 调用端点函数
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
)
# 5. 响应处理
if isinstance(raw_response, Response):
response = raw_response
else:
response = await serialize_response(
field=self.response_field,
response_content=raw_response,
include=self.response_model_include,
exclude=self.response_model_exclude,
by_alias=self.response_model_by_alias,
exclude_unset=self.response_model_exclude_unset,
)
# 6. 发送响应
await response(scope, receive, send)
处理流程详解:
- 请求解析: 将ASGI scope转换为Request对象
- 依赖解析: 分析函数签名,构建依赖图
- 参数验证: 提取并验证路径、查询、请求体参数
- 函数调用: 执行用户定义的端点函数
- 响应序列化: 将返回值转换为HTTP响应
- 响应发送: 通过ASGI接口发送响应
2.3 参数提取机制
路径参数提取
def get_path_params(path: str, path_regex: Pattern) -> Dict[str, Any]:
"""从URL路径中提取参数"""
match = path_regex.match(path)
if match:
return match.groupdict()
return {}
查询参数提取
def get_query_params(query_string: bytes) -> Dict[str, Any]:
"""从查询字符串中提取参数"""
return dict(parse_qsl(query_string.decode()))
请求体参数提取
async def get_body_params(request: Request, body_field: ModelField) -> Any:
"""从请求体中提取参数"""
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
body = await request.json()
elif "application/x-www-form-urlencoded" in content_type:
body = await request.form()
elif "multipart/form-data" in content_type:
body = await request.form()
else:
body = await request.body()
return body
3. APIRouter 类详解
3.1 路由器设计原理
class APIRouter:
def __init__(
self,
*,
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
routes: Optional[List[BaseRoute]] = None,
redirect_slashes: bool = True,
default: Optional[ASGIApp] = None,
dependency_overrides_provider: Optional[Any] = None,
route_class: Type[APIRoute] = APIRoute,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
lifespan: Optional[Lifespan[Any]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
generate_unique_id_function: Callable[[APIRoute], str] = Default(generate_unique_id),
) -> None:
路由器特性:
- 路径前缀: 为所有子路由添加统一前缀
- 标签继承: 子路由继承路由器标签
- 依赖继承: 子路由继承路由器依赖
- 响应配置: 统一的响应类和错误处理
- 生命周期: 独立的启动和关闭事件
3.2 路由注册机制
def add_api_route(
self,
path: str,
endpoint: Callable[..., Any],
**kwargs
) -> None:
"""添加API路由"""
# 1. 路径前缀处理
if self.prefix:
path = self.prefix + path
# 2. 标签合并
tags = kwargs.get("tags", []) or []
if self.tags:
tags.extend(self.tags)
kwargs["tags"] = tags
# 3. 依赖合并
dependencies = kwargs.get("dependencies", []) or []
if self.dependencies:
dependencies.extend(self.dependencies)
kwargs["dependencies"] = dependencies
# 4. 创建路由对象
route = self.route_class(path, endpoint, **kwargs)
# 5. 添加到路由列表
self.routes.append(route)
3.3 路由器嵌套
def include_router(
self,
router: "APIRouter",
*,
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
generate_unique_id_function: Callable[[APIRoute], str] = Default(generate_unique_id),
) -> None:
"""包含子路由器"""
# 路由器嵌套逻辑
for route in router.routes:
# 复制路由并应用新的配置
new_route = self._copy_route_with_new_config(route, prefix, tags, dependencies)
self.routes.append(new_route)
嵌套特性:
- 支持无限层级嵌套
- 配置参数层层继承
- 路径前缀自动拼接
- 依赖和标签累积合并
4. 路径参数系统
4.1 路径参数类型
# 基础类型参数
@app.get("/users/{user_id}")
async def get_user(user_id: int):
pass
# 字符串参数
@app.get("/items/{item_name}")
async def get_item(item_name: str):
pass
# 路径参数
@app.get("/files/{file_path:path}")
async def get_file(file_path: str):
pass
# 枚举参数
class ModelName(str, Enum):
alexnet = "alexnet"
resnet = "resnet"
lenet = "lenet"
@app.get("/models/{model_name}")
async def get_model(model_name: ModelName):
pass
4.2 参数验证机制
def validate_path_param(param_name: str, param_value: str, param_type: Type) -> Any:
"""路径参数验证"""
try:
if param_type == int:
return int(param_value)
elif param_type == float:
return float(param_value)
elif param_type == bool:
return param_value.lower() in ("true", "1", "yes")
elif issubclass(param_type, Enum):
return param_type(param_value)
else:
return param_value
except (ValueError, TypeError) as e:
raise ValidationError(f"Invalid value for {param_name}: {param_value}")
4.3 路径模式编译
def compile_path_pattern(path: str) -> Tuple[Pattern, List[str]]:
"""编译路径模式为正则表达式"""
param_names = []
pattern_parts = []
for part in path.split("/"):
if part.startswith("{") and part.endswith("}"):
# 提取参数名和类型
param_spec = part[1:-1]
if ":" in param_spec:
param_name, param_type = param_spec.split(":", 1)
if param_type == "path":
pattern_parts.append(r"(?P<{}>.+)".format(param_name))
else:
pattern_parts.append(r"(?P<{}>[^/]+)".format(param_name))
else:
param_name = param_spec
pattern_parts.append(r"(?P<{}>[^/]+)".format(param_name))
param_names.append(param_name)
else:
pattern_parts.append(re.escape(part))
pattern = re.compile("/".join(pattern_parts))
return pattern, param_names
5. 查询参数系统
5.1 查询参数定义
from fastapi import Query
@app.get("/items/")
async def read_items(
q: Optional[str] = Query(None, description="Search query"),
skip: int = Query(0, ge=0, description="Skip items"),
limit: int = Query(10, le=100, description="Limit items"),
tags: List[str] = Query([], description="Filter by tags")
):
pass
5.2 查询参数验证
class QueryParam:
def __init__(
self,
default: Any = Ellipsis,
*,
alias: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
regex: Optional[str] = None,
example: Any = Undefined,
examples: Optional[Dict[str, Any]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
):
self.default = default
self.alias = alias
self.title = title
self.description = description
# 数值验证
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le
# 字符串验证
self.min_length = min_length
self.max_length = max_length
self.regex = regex
# 文档相关
self.example = example
self.examples = examples
self.deprecated = deprecated
self.include_in_schema = include_in_schema
5.3 列表参数处理
def parse_query_list(query_string: str, param_name: str) -> List[str]:
"""解析查询字符串中的列表参数"""
# 支持多种格式:
# ?tags=python&tags=fastapi (重复参数)
# ?tags=python,fastapi (逗号分隔)
# ?tags[]=python&tags[]=fastapi (数组格式)
parsed = parse_qsl(query_string, keep_blank_values=True)
values = []
for key, value in parsed:
if key == param_name or key == f"{param_name}[]":
if "," in value:
values.extend(value.split(","))
else:
values.append(value)
return values
6. 请求体处理
6.1 JSON请求体
from pydantic import BaseModel
class Item(BaseModel):
name: str
description: Optional[str] = None
price: float
tax: Optional[float] = None
@app.post("/items/")
async def create_item(item: Item):
return item
6.2 表单数据处理
from fastapi import Form
@app.post("/login/")
async def login(
username: str = Form(...),
password: str = Form(...)
):
return {"username": username}
6.3 文件上传处理
from fastapi import File, UploadFile
@app.post("/uploadfile/")
async def create_upload_file(file: UploadFile = File(...)):
contents = await file.read()
return {
"filename": file.filename,
"content_type": file.content_type,
"size": len(contents)
}
6.4 混合参数处理
@app.post("/items/{item_id}")
async def update_item(
item_id: int, # 路径参数
q: Optional[str] = None, # 查询参数
item: Item = Body(...), # 请求体
timestamp: datetime = Header(None) # 头部参数
):
pass
7. 响应处理系统
7.1 响应模型
class ItemResponse(BaseModel):
id: int
name: str
created_at: datetime
@app.get("/items/{item_id}", response_model=ItemResponse)
async def get_item(item_id: int):
# 返回的数据会自动按照ItemResponse模型序列化
return {
"id": item_id,
"name": "Sample Item",
"created_at": datetime.now(),
"internal_field": "This will be excluded" # 不在模型中的字段会被排除
}
7.2 多状态码响应
@app.get(
"/items/{item_id}",
responses={
200: {"model": ItemResponse, "description": "Item found"},
404: {"model": ErrorResponse, "description": "Item not found"},
422: {"model": ValidationErrorResponse, "description": "Validation error"}
}
)
async def get_item(item_id: int):
pass
7.3 自定义响应类
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse
@app.get("/html", response_class=HTMLResponse)
async def get_html():
return "<html><body><h1>Hello World</h1></body></html>"
@app.get("/download")
async def download_file():
return FileResponse("path/to/file.pdf", filename="document.pdf")
@app.get("/stream")
async def stream_data():
def generate():
for i in range(1000):
yield f"data chunk {i}\n"
return StreamingResponse(generate(), media_type="text/plain")
8. WebSocket路由
8.1 WebSocket路由定义
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}")
except WebSocketDisconnect:
print("Client disconnected")
8.2 WebSocket参数处理
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(
websocket: WebSocket,
client_id: int,
token: str = Query(...)
):
# 验证token
if not validate_token(token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await websocket.accept()
# WebSocket逻辑
8.3 WebSocket依赖注入
async def get_current_user(websocket: WebSocket, token: str = Query(...)):
user = authenticate_user(token)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return user
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
current_user: User = Depends(get_current_user)
):
await websocket.accept()
# 已认证的WebSocket连接
9. 路由性能优化
9.1 路由缓存
class RouteCache:
def __init__(self):
self._cache = {}
self._compiled_routes = {}
def get_route(self, path: str, method: str) -> Optional[APIRoute]:
cache_key = f"{method}:{path}"
if cache_key in self._cache:
return self._cache[cache_key]
# 路由匹配逻辑
for route in self.routes:
if route.matches(path, method):
self._cache[cache_key] = route
return route
return None
9.2 路由预编译
def precompile_routes(routes: List[APIRoute]) -> Dict[str, Pattern]:
"""预编译所有路由模式"""
compiled = {}
for route in routes:
pattern, param_names = compile_path_pattern(route.path)
compiled[route.path] = {
"pattern": pattern,
"param_names": param_names,
"route": route
}
return compiled
9.3 路由树优化
class RouteTree:
"""路由树结构,用于快速路由匹配"""
def __init__(self):
self.static_routes = {} # 静态路由
self.dynamic_routes = [] # 动态路由
self.wildcard_routes = [] # 通配符路由
def add_route(self, route: APIRoute):
if "{" not in route.path:
# 静态路由
self.static_routes[route.path] = route
elif ":path}" in route.path:
# 通配符路由
self.wildcard_routes.append(route)
else:
# 动态路由
self.dynamic_routes.append(route)
def match(self, path: str) -> Optional[APIRoute]:
# 1. 首先检查静态路由
if path in self.static_routes:
return self.static_routes[path]
# 2. 检查动态路由
for route in self.dynamic_routes:
if route.path_regex.match(path):
return route
# 3. 检查通配符路由
for route in self.wildcard_routes:
if route.path_regex.match(path):
return route
return None
10. 错误处理和调试
10.1 路由调试信息
def debug_routes(app: FastAPI):
"""打印所有路由信息"""
print("Registered Routes:")
for route in app.routes:
if isinstance(route, APIRoute):
print(f" {route.methods} {route.path} -> {route.endpoint.__name__}")
if route.dependencies:
print(f" Dependencies: {[dep.dependency.__name__ for dep in route.dependencies]}")
if route.tags:
print(f" Tags: {route.tags}")
10.2 路由冲突检测
def detect_route_conflicts(routes: List[APIRoute]) -> List[str]:
"""检测路由冲突"""
conflicts = []
route_patterns = {}
for route in routes:
for method in route.methods:
key = f"{method}:{route.path_regex.pattern}"
if key in route_patterns:
conflicts.append(
f"Conflict: {method} {route.path} conflicts with {route_patterns[key].path}"
)
else:
route_patterns[key] = route
return conflicts